summaryrefslogtreecommitdiff
path: root/src/personalization/serving
diff options
context:
space:
mode:
Diffstat (limited to 'src/personalization/serving')
-rw-r--r--src/personalization/serving/__init__.py22
-rw-r--r--src/personalization/serving/personalized_llm.py1835
2 files changed, 1857 insertions, 0 deletions
diff --git a/src/personalization/serving/__init__.py b/src/personalization/serving/__init__.py
new file mode 100644
index 0000000..11adcf8
--- /dev/null
+++ b/src/personalization/serving/__init__.py
@@ -0,0 +1,22 @@
+# Personalization Serving Module
+#
+# This module provides the interface layer for the personalization system.
+
+from personalization.serving.personalized_llm import (
+ PersonalizedLLM,
+ AssistantResponse,
+ UsageStats,
+ DebugInfo,
+ Feedback,
+ create_personalized_llm,
+)
+
+__all__ = [
+ "PersonalizedLLM",
+ "AssistantResponse",
+ "UsageStats",
+ "DebugInfo",
+ "Feedback",
+ "create_personalized_llm",
+]
+
diff --git a/src/personalization/serving/personalized_llm.py b/src/personalization/serving/personalized_llm.py
new file mode 100644
index 0000000..8032e6b
--- /dev/null
+++ b/src/personalization/serving/personalized_llm.py
@@ -0,0 +1,1835 @@
+#!/usr/bin/env python3
+"""
+Personalized LLM Interface for Evaluation.
+
+This module provides the `PersonalizedLLM` class that wraps the entire
+personalization system into a clean interface for evaluation frameworks
+and user simulators.
+
+Interface contract:
+- chat(user_id, query) -> AssistantResponse: Main online interface
+- reset_session(user_id): Clear session history and short-term state
+- reset_user(user_id): Completely reset user (long-term, short-term, memories)
+- apply_feedback(feedback): Apply external feedback for RL updates
+"""
+
+from __future__ import annotations
+
+import os
+import sys
+import uuid
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional
+
+import numpy as np
+import yaml
+
+# Ensure src is in path for standalone usage
+_src_path = os.path.join(os.path.dirname(__file__), "../../..")
+if _src_path not in sys.path:
+ sys.path.insert(0, _src_path)
+
+from personalization.config.settings import load_local_models_config
+from personalization.config.registry import get_preference_extractor, get_chat_model
+from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B
+from personalization.models.reranker.qwen3_reranker import Qwen3Reranker
+from personalization.models.reranker.bge_reranker import BGEReranker
+from personalization.user_model.tensor_store import UserTensorStore, UserState
+from personalization.user_model.session_state import OnlineSessionState
+from personalization.user_model.features import ItemProjection
+from personalization.retrieval.preference_store.schemas import (
+ MemoryCard, ChatTurn, PreferenceList, Preference
+)
+from personalization.retrieval.pipeline import retrieve_with_policy, retrieve_no_policy
+from personalization.feedback.handlers import eval_step, eval_step_llm
+from personalization.feedback.llm_reward import LLMRewardClient, LLMRewardConfig
+from personalization.user_model.policy.reinforce import reinforce_update_user_state
+
+
+# =============================================================================
+# Data Classes for Interface
+# =============================================================================
+
+@dataclass
+class UsageStats:
+ """Token usage statistics from a chat completion."""
+ prompt_tokens: int
+ completion_tokens: int
+ total_tokens: int
+ model: str
+
+
+@dataclass
+class DebugInfo:
+ """
+ Debug information for analysis and ablation studies.
+ All fields are optional - fill what you have, leave empty what you don't.
+ """
+ selected_memory_ids: List[str] = field(default_factory=list)
+ selected_memory_notes: List[str] = field(default_factory=list)
+ selected_memory_scores: List[float] = field(default_factory=list)
+ user_vector_before: Optional[List[float]] = None
+ user_vector_after: Optional[List[float]] = None
+ extracted_preferences: List[Dict[str, Any]] = field(default_factory=list)
+ extra: Dict[str, Any] = field(default_factory=dict)
+
+
+@dataclass
+class AssistantResponse:
+ """Response from the personalized LLM chat interface."""
+ answer: str
+ usage: UsageStats
+ debug: Optional[DebugInfo] = None
+
+
+@dataclass
+class Feedback:
+ """
+ Feedback data structure for RL updates from user simulator or judge.
+
+ Attributes:
+ user_id: The user this feedback is for.
+ turn_id: The turn this feedback refers to (from the previous turn).
+ reward: Reward scalar computed by user simulator / judge.
+ gating: Gating flag (1=valid learning signal, 0=skip update).
+ meta: Additional metadata for training/analysis.
+ """
+ user_id: str
+ turn_id: int
+ reward: float
+ gating: float # Can be 0.0 or 1.0, or continuous
+ meta: Dict[str, Any] = field(default_factory=dict)
+
+
+# =============================================================================
+# Internal Session State Extended
+# =============================================================================
+
+@dataclass
+class _SessionContext:
+ """Extended session context for evaluation tracking."""
+ session_state: OnlineSessionState
+ turn_counter: int = 0
+ # Store info needed for apply_feedback
+ pending_rl_update: Optional[Dict[str, Any]] = None
+
+
+# =============================================================================
+# Shared Model Singletons for Multi-threaded Efficiency
+# =============================================================================
+
+_shared_embed_model = None
+_shared_reranker = None
+_shared_extractor = None
+_shared_models_lock = None # Will be initialized on first use
+
+
+def _get_shared_models_lock():
+ """Get or create the threading lock for shared models."""
+ global _shared_models_lock
+ if _shared_models_lock is None:
+ import threading
+ _shared_models_lock = threading.Lock()
+ return _shared_models_lock
+
+
+def get_shared_embedding_model(model_path: str, device_map: str = "auto"):
+ """Get or create shared embedding model (thread-safe singleton)."""
+ global _shared_embed_model
+ import torch
+
+ lock = _get_shared_models_lock()
+ with lock:
+ if _shared_embed_model is None:
+ print(f"[SharedModels] Loading shared embedding model on {device_map}...")
+ _shared_embed_model = Qwen3Embedding8B(
+ model_path=model_path,
+ dtype=torch.bfloat16,
+ device_map=device_map,
+ )
+ print("[SharedModels] Shared embedding model loaded.")
+ return _shared_embed_model
+
+
+def get_shared_reranker(model_path: str, device_map: str = "auto", reranker_type: str = "qwen3"):
+ """Get or create shared reranker model (thread-safe singleton)."""
+ global _shared_reranker
+ import torch
+
+ lock = _get_shared_models_lock()
+ with lock:
+ if _shared_reranker is None:
+ print(f"[SharedModels] Loading shared reranker ({reranker_type}) on {device_map}...")
+ if reranker_type == "bge":
+ _shared_reranker = BGEReranker(
+ model_path=model_path,
+ device_map=device_map,
+ dtype=torch.float16,
+ )
+ else:
+ _shared_reranker = Qwen3Reranker(
+ model_path=model_path,
+ device_map=device_map,
+ dtype=torch.bfloat16,
+ )
+ print("[SharedModels] Shared reranker model loaded.")
+ return _shared_reranker
+
+
+def get_shared_extractor(model_path: str, device_map: str = "auto"):
+ """Get or create shared preference extractor model (thread-safe singleton)."""
+ global _shared_extractor
+ import torch
+ from personalization.models.preference_extractor.rule_extractor import QwenRuleExtractor
+
+ lock = _get_shared_models_lock()
+ with lock:
+ if _shared_extractor is None:
+ print(f"[SharedModels] Loading shared preference extractor on {device_map}...")
+ _shared_extractor = QwenRuleExtractor(
+ model_path=model_path,
+ dtype=torch.bfloat16,
+ device_map=device_map,
+ )
+ print("[SharedModels] Shared preference extractor loaded.")
+ return _shared_extractor
+
+
+def clear_shared_models():
+ """Free all shared singleton models to reclaim GPU memory between methods."""
+ global _shared_embed_model, _shared_reranker, _shared_extractor
+ import gc
+
+ lock = _get_shared_models_lock()
+ with lock:
+ freed = []
+ if _shared_embed_model is not None:
+ freed.append("embedding")
+ del _shared_embed_model
+ _shared_embed_model = None
+ if _shared_reranker is not None:
+ freed.append("reranker")
+ del _shared_reranker
+ _shared_reranker = None
+ if _shared_extractor is not None:
+ freed.append("extractor")
+ del _shared_extractor
+ _shared_extractor = None
+
+ if freed:
+ gc.collect()
+ try:
+ import torch
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ except ImportError:
+ pass
+ print(f"[SharedModels] Cleared: {', '.join(freed)}")
+
+
+# =============================================================================
+# PersonalizedLLM Class
+# =============================================================================
+
+class PersonalizedLLM:
+ """
+ Personalized LLM wrapper for evaluation frameworks.
+
+ This class provides a clean interface that accepts only (user_id, query)
+ for the main chat function, while internally managing:
+ - User state vectors (z_long, z_short)
+ - Session history
+ - Memory retrieval and policy
+ - Preference extraction and storage
+ - RL updates
+
+ Example usage:
+ llm = PersonalizedLLM()
+
+ # Reset user for fresh experiment
+ llm.reset_user("user_123")
+
+ # Start a session
+ llm.reset_session("user_123")
+
+ # Chat
+ response = llm.chat("user_123", "What's a good recipe for dinner?")
+ print(response.answer)
+
+ # Apply feedback from previous turn (from turn 2 onwards)
+ llm.apply_feedback(Feedback(
+ user_id="user_123",
+ turn_id=0,
+ reward=0.8,
+ gating=1.0
+ ))
+ """
+
+ def __init__(
+ self,
+ config_path: Optional[str] = None,
+ user_store_path: str = "data/users/user_store_eval.npz",
+ memory_cards_path: str = "data/corpora/memory_cards.jsonl",
+ memory_embeddings_path: str = "data/corpora/memory_embeddings.npy",
+ item_projection_path: str = "data/corpora/item_projection.npz",
+ only_own_memories: bool = True,
+ enable_preference_extraction: bool = True,
+ enable_rl_updates: bool = True,
+ mode: str = "full", # "full", "nopersonal", or "vanilla"
+ eval_mode: bool = True, # True = greedy selection, False = stochastic sampling
+ device_assignment: Optional[Dict[str, str]] = None, # Multi-GPU support
+ llm_name: Optional[str] = None, # Override LLM name (e.g., "llama_8b_vllm" for vLLM)
+ use_shared_models: bool = False, # Use shared singleton models for multi-threaded efficiency
+ reranker_type: str = "qwen3", # "qwen3" (8B) or "bge" (278M)
+ best_of_n: int = 1, # Generate N responses and pick best (for RAG methods)
+ reward_mode: str = "keyword", # "keyword", "llm" (GPT-4o-mini), or "llm_local" (local vLLM)
+ llm_reward_config: Optional["LLMRewardConfig"] = None, # Config for LLM judge
+ reward_vllm_url: Optional[str] = None, # vLLM URL for local reward model (when reward_mode="llm_local")
+ enable_query_transform: bool = False, # Transform queries for better retrieval matching
+ enable_global_preferences: bool = False, # Separate global prefs that bypass retrieval
+ dynamic_topk: bool = False, # Use dynamic topk based on rerank scores
+ dynamic_min_k: int = 3, # Min preferences for dynamic topk
+ dynamic_max_k: int = 8, # Max preferences for dynamic topk
+ dynamic_score_ratio: float = 0.5, # Threshold = top_score * ratio
+ eta_long: float = None, # Override RL learning rate for z_long
+ eta_short: float = None, # Override RL learning rate for z_short
+ enable_preference_consolidation: bool = False, # Consolidate preferences at session end
+ consolidation_threshold: int = 5, # Min preferences before consolidation
+ enable_preference_rewrite: bool = False, # Use LLM to rewrite/merge retrieved preferences
+ ):
+ """
+ Initialize the PersonalizedLLM.
+
+ Args:
+ config_path: Path to config file. If None, uses default locations.
+ user_store_path: Path to persist user state vectors.
+ memory_cards_path: Path to memory cards JSONL file.
+ memory_embeddings_path: Path to memory embeddings numpy file.
+ item_projection_path: Path to item projection (PCA) file.
+ only_own_memories: If True, only retrieve user's own memories (strict privacy).
+ enable_preference_extraction: If True, extract preferences from user turns.
+ enable_rl_updates: If True, apply RL updates via apply_feedback.
+ mode: "full" for full personalization, "nopersonal" for baseline (no user vector influence),
+ "vanilla" for pure LLM without any memory retrieval or preference extraction.
+ eval_mode: If True, use greedy/deterministic selection (for evaluation).
+ If False, use stochastic sampling (for training/exploration).
+ device_assignment: Optional dict to assign models to specific GPUs.
+ Example: {"embed": "cuda:0", "reranker": "cuda:1", "chat": "cuda:2", "extractor": "cuda:3"}
+ If None, uses "auto" for all models.
+ use_shared_models: If True, use shared singleton models for embedding and reranker.
+ This is essential for multi-threaded/parallel profile processing to avoid
+ loading duplicate models. When enabled, the first thread loads the models,
+ and subsequent threads reuse the shared instances.
+ """
+ self.only_own_memories = only_own_memories
+ self.use_shared_models = use_shared_models
+ self.enable_preference_extraction = enable_preference_extraction
+ self.enable_rl_updates = enable_rl_updates
+ self.mode = mode # "full" or "nopersonal"
+ self.eval_mode = eval_mode # True = greedy, False = sample
+ self.reranker_type = reranker_type # "qwen3" or "bge"
+ self.best_of_n = best_of_n # Generate N responses and pick best
+ self.reward_mode = reward_mode # "keyword", "llm", or "llm_local"
+ self.enable_query_transform = enable_query_transform
+ self.enable_global_preferences = enable_global_preferences
+ self.enable_preference_consolidation = enable_preference_consolidation
+ self.consolidation_threshold = consolidation_threshold
+ self.enable_preference_rewrite = enable_preference_rewrite
+
+ # Initialize LLM reward client if using LLM judge
+ self._llm_reward_client = None # Can be LLMRewardClient or LocalLLMRewardClient
+ if reward_mode == "llm":
+ self._llm_reward_client = LLMRewardClient(llm_reward_config or LLMRewardConfig())
+ elif reward_mode == "llm_local":
+ from personalization.feedback.local_llm_reward import (
+ LocalLLMRewardClient,
+ LocalLLMRewardConfig,
+ )
+ local_config = LocalLLMRewardConfig(
+ vllm_url=reward_vllm_url or "http://localhost:8005/v1",
+ )
+ self._llm_reward_client = LocalLLMRewardClient(local_config)
+
+ # Multi-GPU device assignment
+ self._device_assignment = device_assignment or {
+ "embed": "auto",
+ "reranker": "auto",
+ "chat": "auto",
+ "extractor": "auto",
+ }
+
+ # Paths
+ self._memory_cards_path = memory_cards_path
+ self._memory_embeddings_path = memory_embeddings_path
+ self._item_projection_path = item_projection_path
+
+ # RL Configuration
+ # Note: beta/eta increased for more significant z_u updates
+ self._rl_cfg = {
+ "item_dim": 256,
+ "beta_long": 2.0, # Increased from 0.1 for stronger personalization
+ "beta_short": 5.0, # Increased from 0.3
+ "tau": 1.0,
+ "eta_long": eta_long if eta_long is not None else 0.01,
+ "eta_short": eta_short if eta_short is not None else 0.05,
+ "ema_alpha": 0.05,
+ "short_decay": 0.1,
+ "dense_topk": 64,
+ "rerank_topk": 5,
+ "max_new_tokens": 512,
+ # Dynamic topk settings
+ "dynamic_topk": dynamic_topk,
+ "dynamic_min_k": dynamic_min_k,
+ "dynamic_max_k": dynamic_max_k,
+ "dynamic_score_ratio": dynamic_score_ratio,
+ }
+
+ # Store llm_name before loading config (needed in _load_config)
+ self._llm_name_override = llm_name
+
+ # Load config and override RL params if available
+ self._load_config(config_path)
+
+ # Load models
+ print("[PersonalizedLLM] Loading models...")
+ self._load_models()
+
+ # Load memory store
+ print("[PersonalizedLLM] Loading memory store...")
+ self._load_memory_store()
+
+ # Initialize user store
+ self._user_store = UserTensorStore(
+ k=self._rl_cfg["item_dim"],
+ path=user_store_path,
+ )
+
+ # Session contexts per user (in-memory)
+ self._sessions: Dict[str, _SessionContext] = {}
+
+ print("[PersonalizedLLM] Initialization complete.")
+
+ def _load_config(self, config_path: Optional[str]):
+ """Load configuration from yaml files."""
+ self._cfg = load_local_models_config()
+
+ # Try to load user_model.yaml for RL params
+ if config_path is None:
+ config_path = "configs/user_model.yaml"
+
+ self._llm_name = self._llm_name_override or "qwen_1_5b" # Default, can be overridden
+
+ try:
+ if os.path.exists(config_path):
+ with open(config_path, "r") as f:
+ user_cfg = yaml.safe_load(f)
+ if user_cfg:
+ # Override RL params if present
+ for key in self._rl_cfg:
+ if key in user_cfg:
+ self._rl_cfg[key] = user_cfg[key]
+ # LLM name (only from config if not already set via parameter)
+ if self._llm_name_override is None and "llm_name" in user_cfg:
+ self._llm_name = user_cfg["llm_name"]
+ except Exception as e:
+ print(f"[PersonalizedLLM] Warning: Failed to load config: {e}")
+
+ def _load_models(self):
+ """Load all ML models with optional multi-GPU assignment."""
+ import torch
+
+ # Report GPU availability (only once, not for shared model instances)
+ if not self.use_shared_models:
+ num_gpus = torch.cuda.device_count()
+ print(f"[PersonalizedLLM] Available GPUs: {num_gpus}")
+ for i in range(num_gpus):
+ mem = torch.cuda.get_device_properties(i).total_memory / 1e9
+ print(f" GPU {i}: {torch.cuda.get_device_name(i)} ({mem:.1f}GB)")
+
+ embed_device = self._device_assignment.get("embed", "auto")
+ reranker_device = self._device_assignment.get("reranker", "auto")
+ chat_device = self._device_assignment.get("chat", "auto")
+ extractor_device = self._device_assignment.get("extractor", "auto")
+
+ # Embedding model - only load for modes that use RAG retrieval
+ # Vanilla and contextual modes don't need embedding/reranker
+ needs_retrieval = self.mode not in ("vanilla", "contextual")
+
+ if needs_retrieval:
+ if self.use_shared_models:
+ print(f"[PersonalizedLLM] Using shared embedding model...")
+ self._embed_model = get_shared_embedding_model(
+ model_path=self._cfg.embedding.qwen3.local_path,
+ device_map=embed_device,
+ )
+ else:
+ print(f"[PersonalizedLLM] Loading Embedding model on {embed_device}...")
+ self._embed_model = Qwen3Embedding8B(
+ model_path=self._cfg.embedding.qwen3.local_path,
+ dtype=torch.bfloat16,
+ device_map=embed_device,
+ )
+ else:
+ print(f"[PersonalizedLLM] Skipping embedding model (not needed for {self.mode} mode)")
+ self._embed_model = None
+
+ # Reranker - only load for modes that use RAG retrieval
+ # Support both qwen3 (8B) and bge (278M) rerankers
+ if needs_retrieval:
+ if self.reranker_type == "bge":
+ reranker_path = getattr(self._cfg.reranker, "bge_base", None)
+ reranker_path = reranker_path.local_path if reranker_path else "BAAI/bge-reranker-base"
+ else:
+ reranker_path = self._cfg.reranker.qwen3_8b.local_path
+
+ if self.use_shared_models:
+ print(f"[PersonalizedLLM] Using shared reranker model ({self.reranker_type})...")
+ self._reranker = get_shared_reranker(
+ model_path=reranker_path,
+ device_map=reranker_device,
+ reranker_type=self.reranker_type,
+ )
+ else:
+ print(f"[PersonalizedLLM] Loading Reranker ({self.reranker_type}) on {reranker_device}...")
+ if self.reranker_type == "bge":
+ self._reranker = BGEReranker(
+ model_path=reranker_path,
+ device_map=reranker_device,
+ dtype=torch.float16,
+ )
+ else:
+ self._reranker = Qwen3Reranker(
+ model_path=reranker_path,
+ device_map=reranker_device,
+ dtype=torch.bfloat16,
+ )
+ else:
+ print(f"[PersonalizedLLM] Skipping reranker (not needed for {self.mode} mode)")
+ self._reranker = None
+
+ # Chat model (via registry for backend switching)
+ print(f"[PersonalizedLLM] Loading ChatModel: {self._llm_name} on {chat_device}...")
+ # Pass device override if specified (not "auto")
+ device_for_chat = chat_device if chat_device != "auto" else None
+ self._chat_model = get_chat_model(self._llm_name, device_override=device_for_chat)
+
+ # Preference extractor - use shared singleton if enabled
+ if self.enable_preference_extraction:
+ extractor_name = "qwen3_0_6b_sft"
+ if self.use_shared_models:
+ print(f"[PersonalizedLLM] Using shared preference extractor...")
+ try:
+ extractor_path = self._cfg.preference_extractor.get("qwen3_0_6b_sft", {}).get("path", None)
+ if extractor_path:
+ self._extractor = get_shared_extractor(
+ model_path=extractor_path,
+ device_map=extractor_device,
+ )
+ else:
+ print(f"[PersonalizedLLM] Extractor path not found, using rule-based.")
+ self._extractor = get_preference_extractor("rule")
+ except Exception as e:
+ print(f"[PersonalizedLLM] Warning: Failed to load shared extractor: {e}. Trying fallbacks...")
+ try:
+ self._extractor = get_preference_extractor("rule")
+ except Exception as e2:
+ print(f"[PersonalizedLLM] Rule extractor also failed: {e2}. Using GPT-5-mini extractor.")
+ self._extractor = get_preference_extractor("gpt5_mini")
+ else:
+ print(f"[PersonalizedLLM] Loading extractor: {extractor_name} on {extractor_device}...")
+ try:
+ self._extractor = get_preference_extractor(extractor_name)
+ except Exception as e:
+ print(f"[PersonalizedLLM] Warning: Failed to load {extractor_name}: {e}. Trying fallbacks...")
+ try:
+ self._extractor = get_preference_extractor("rule")
+ except Exception as e2:
+ print(f"[PersonalizedLLM] Rule extractor also failed: {e2}. Using GPT-5-mini extractor.")
+ self._extractor = get_preference_extractor("gpt5_mini")
+ else:
+ print("[PersonalizedLLM] Preference extraction disabled, skipping extractor.")
+ self._extractor = None
+
+ def _load_memory_store(self):
+ """Load memory cards and embeddings."""
+ if not os.path.exists(self._memory_cards_path):
+ print(f"[PersonalizedLLM] Warning: Memory cards not found at {self._memory_cards_path}")
+ self._memory_cards: List[MemoryCard] = []
+ self._memory_embeddings = np.zeros((0, 4096), dtype=np.float32)
+ self._item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32)
+ # Create default projection (truncation to first k dims) so preferences can be added
+ k = self._rl_cfg["item_dim"]
+ d = 4096
+ P = np.zeros((k, d), dtype=np.float32)
+ P[:, :k] = np.eye(k, dtype=np.float32)
+ self._projection = ItemProjection(P=P, mean=np.zeros(d, dtype=np.float32))
+ print(f"[PersonalizedLLM] Created default projection (truncation, k={k})")
+ return
+
+ # Load cards
+ self._memory_cards = []
+ with open(self._memory_cards_path, "r") as f:
+ for line in f:
+ line = line.strip()
+ if line:
+ self._memory_cards.append(MemoryCard.model_validate_json(line))
+
+ # Load embeddings
+ if os.path.exists(self._memory_embeddings_path):
+ self._memory_embeddings = np.load(self._memory_embeddings_path)
+ else:
+ self._memory_embeddings = np.zeros((len(self._memory_cards), 4096), dtype=np.float32)
+
+ # Load projection
+ if os.path.exists(self._item_projection_path):
+ proj_data = np.load(self._item_projection_path)
+ self._projection = ItemProjection(P=proj_data["P"], mean=proj_data["mean"])
+ self._item_vectors = proj_data["V"]
+ else:
+ # Create default projection so preferences can still be added
+ k = self._rl_cfg["item_dim"]
+ d = 4096
+ P = np.zeros((k, d), dtype=np.float32)
+ P[:, :k] = np.eye(k, dtype=np.float32)
+ self._projection = ItemProjection(P=P, mean=np.zeros(d, dtype=np.float32))
+ self._item_vectors = np.zeros((len(self._memory_cards), self._rl_cfg["item_dim"]), dtype=np.float32)
+ print(f"[PersonalizedLLM] Created default projection (truncation, k={k})")
+
+ print(f"[PersonalizedLLM] Loaded {len(self._memory_cards)} memory cards.")
+
+ def _get_or_create_session(self, user_id: str) -> _SessionContext:
+ """Get or create session context for a user."""
+ if user_id not in self._sessions:
+ self._sessions[user_id] = _SessionContext(
+ session_state=OnlineSessionState(user_id=user_id),
+ turn_counter=0,
+ )
+ return self._sessions[user_id]
+
+ def _build_chat_turn(self, user_id: str, text: str, role: str, turn_id: int) -> ChatTurn:
+ """Build a ChatTurn object."""
+ return ChatTurn(
+ user_id=user_id,
+ session_id=f"eval_session_{user_id}",
+ turn_id=turn_id,
+ role=role,
+ text=text,
+ meta={"source": "eval"}
+ )
+
+ def _count_tokens(self, text: str) -> int:
+ """Estimate token count using the tokenizer."""
+ try:
+ # Use the chat model's tokenizer if available
+ if hasattr(self._chat_model, 'tokenizer'):
+ return len(self._chat_model.tokenizer.encode(text))
+ else:
+ # Rough estimate: ~4 chars per token
+ return len(text) // 4
+ except Exception:
+ return len(text) // 4
+
+ # Task type keywords for query transformation
+ _TASK_KEYWORDS = {
+ "math": ["solve", "calculate", "integral", "equation", "proof", "derivative",
+ "math", "algebra", "geometry", "trigonometry", "calculus", "arithmetic",
+ "formula", "compute", "evaluate", "simplify", "factor", "graph"],
+ "coding": ["code", "program", "function", "implement", "debug", "python", "java",
+ "javascript", "algorithm", "class", "method", "bug", "error", "compile",
+ "script", "html", "css", "sql", "api", "library", "framework"],
+ "writing": ["write", "essay", "paragraph", "summarize", "draft", "compose",
+ "article", "story", "letter", "email", "report", "review", "edit",
+ "rewrite", "paraphrase", "outline"],
+ "explanation": ["explain", "what is", "how does", "why", "describe", "define",
+ "meaning", "concept", "difference between", "compare", "contrast"],
+ }
+
+ def _transform_query_for_retrieval(self, query: str) -> List[str]:
+ """
+ Transform raw user query into multiple retrieval queries to bridge
+ the semantic gap between task queries and preference descriptions.
+
+ Returns [original_query, transformed_query] or [original_query] if
+ no task type detected.
+ """
+ import re
+ query_lower = query.lower()
+ detected_types = []
+ for task_type, keywords in self._TASK_KEYWORDS.items():
+ for kw in keywords:
+ # Use word boundary matching to avoid false positives
+ # e.g., "api" should not match "capital"
+ if re.search(r'\b' + re.escape(kw) + r'\b', query_lower):
+ detected_types.append(task_type)
+ break
+
+ if not detected_types:
+ return [query]
+
+ # Use first detected type (most specific match)
+ task_type = detected_types[0]
+ transformed = f"user preferences for {task_type} tasks: {query}"
+ return [query, transformed]
+
+ # Patterns indicating a global/universal preference condition
+ _GLOBAL_PATTERNS = ["general", "any", "always", "all ", "every", "regardless",
+ "any task", "any topic", "any question", "all tasks", "all topics"]
+
+ # Domain-specific terms that indicate a conditional preference
+ _DOMAIN_TERMS = ["math", "code", "coding", "program", "writing", "essay", "science",
+ "history", "language", "physics", "chemistry", "biology", "literature",
+ "creative", "technical", "formal", "informal", "academic", "casual"]
+
+ def _classify_preference_scope(self, condition: str) -> bool:
+ """
+ Classify whether a preference condition is global (always applicable)
+ or conditional (task-specific).
+
+ Returns True if global, False if conditional.
+ """
+ cond_lower = condition.lower().strip()
+
+ # Check for explicit global patterns
+ for pattern in self._GLOBAL_PATTERNS:
+ if pattern in cond_lower:
+ return True
+
+ # Very short/vague conditions with no domain terms are likely global
+ words = cond_lower.split()
+ if len(words) <= 2:
+ has_domain = any(term in cond_lower for term in self._DOMAIN_TERMS)
+ if not has_domain:
+ return True
+
+ return False
+
+ # Rewrite prompt for merging retrieved preferences
+ _REWRITE_PROMPT = """You are helping to prepare user preferences for an AI assistant.
+
+The user is asking: {query}
+
+Retrieved preferences about this user:
+{preferences}
+
+Task: Create a concise preference summary that the assistant MUST follow.
+
+Rules:
+1. PRESERVE all specific formatting requirements exactly (e.g., "type hints", "snake_case", "code fence with language")
+2. PRESERVE all structural requirements (e.g., "numbered steps", "bullet points", "answer first then explanation")
+3. Only MERGE preferences that are truly redundant (saying the same thing differently)
+4. Output as a short bulleted list if there are multiple distinct requirements
+5. Keep each point actionable and specific - NO vague generalizations like "follow best practices"
+
+Example input:
+- Include type hints in Python code
+- Use snake_case for variable names
+- When explaining, use numbered steps
+
+Example output:
+- Include type hints
+- Use snake_case for variables
+- Use numbered steps for explanations
+
+If no preferences are relevant to this query type, output: "No specific preferences apply."
+
+Preference summary:"""
+
+ def _rewrite_preferences(self, memory_notes: List[str], query: str) -> List[str]:
+ """
+ Use LLM to rewrite/merge multiple retrieved preferences into concise instructions.
+
+ This is similar to Reflection's proper_scaffolding but focuses on merging
+ rather than just filtering.
+
+ Args:
+ memory_notes: List of retrieved preference notes
+ query: Current user query
+
+ Returns:
+ List with single rewritten instruction (or original if rewrite fails/disabled)
+ """
+ if not memory_notes or len(memory_notes) <= 1:
+ return memory_notes
+
+ try:
+ import requests
+
+ # Format preferences for prompt
+ prefs_text = "\n".join(f"- {note}" for note in memory_notes)
+ prompt = self._REWRITE_PROMPT.format(query=query[:200], preferences=prefs_text)
+
+ # Direct vLLM API call (simpler than going through chat model)
+ messages = [{"role": "user", "content": prompt}]
+ payload = {
+ "model": self._chat_model.model_name,
+ "messages": messages,
+ "max_tokens": 150,
+ "temperature": 0.3, # Lower temperature for more consistent output
+ }
+
+ response = requests.post(
+ f"{self._chat_model.vllm_url}/chat/completions",
+ json=payload,
+ timeout=30
+ )
+
+ if response.status_code != 200:
+ print(f"[REWRITE] API error {response.status_code}, keeping original notes")
+ return memory_notes
+
+ result = response.json()
+ rewritten = result["choices"][0]["message"]["content"].strip().strip('"')
+
+ # Validate response
+ if rewritten and len(rewritten) > 10 and "No specific preferences" not in rewritten:
+ print(f"[REWRITE] {len(memory_notes)} notes → 1 merged instruction")
+ return [rewritten]
+ else:
+ print(f"[REWRITE] Kept original {len(memory_notes)} notes (no valid merge)")
+ return memory_notes
+
+ except Exception as e:
+ print(f"[REWRITE] Failed: {e}, keeping original notes")
+ return memory_notes
+
+ # Consolidation prompt for session-end preference merging
+ _CONSOLIDATION_PROMPT = """You are analyzing user preferences extracted from conversations.
+
+Current preferences for this user:
+{preferences}
+
+Task: Consolidate these preferences into a cleaner, more organized set by:
+1. MERGE similar preferences (e.g., "use bullet points" + "format with bullets" → single preference)
+2. REMOVE redundant or contradictory preferences (keep the more specific one)
+3. PRESERVE all unique, meaningful preferences
+4. Keep the same "When [condition], [action]." format
+
+Output ONLY the consolidated preferences, one per line, in this exact format:
+When [condition], [action].
+
+Do not add explanations or commentary. Just output the preference lines."""
+
+ def consolidate_user_preferences(self, user_id: str) -> int:
+ """
+ Consolidate user preferences at session end using LLM.
+
+ Merges similar preferences, removes redundancy, and creates cleaner
+ preference descriptions. Only runs if user has enough preferences.
+
+ Args:
+ user_id: The user whose preferences to consolidate.
+
+ Returns:
+ Number of preferences after consolidation (0 if skipped).
+ """
+ if not self.enable_preference_consolidation:
+ return 0
+
+ # Get user's memory cards
+ user_cards = [c for c in self._memory_cards if c.user_id == user_id]
+
+ if len(user_cards) < self.consolidation_threshold:
+ return len(user_cards)
+
+ # Build preference list for prompt
+ pref_lines = [card.note_text for card in user_cards]
+ preferences_text = "\n".join(f"- {p}" for p in pref_lines)
+
+ # Call LLM for consolidation
+ prompt = self._CONSOLIDATION_PROMPT.format(preferences=preferences_text)
+ messages = [{"role": "user", "content": prompt}]
+
+ try:
+ result = self._chat_model.answer(messages, max_new_tokens=512)
+ consolidated_text = result.get("content", "").strip()
+
+ if not consolidated_text:
+ return len(user_cards)
+
+ # Parse consolidated preferences
+ new_prefs = []
+ for line in consolidated_text.split("\n"):
+ line = line.strip()
+ if not line or not line.startswith("When "):
+ continue
+ # Parse "When [condition], [action]."
+ if ", " in line:
+ parts = line.split(", ", 1)
+ condition = parts[0].replace("When ", "").strip()
+ action = parts[1].rstrip(".").strip()
+ if condition and action:
+ new_prefs.append({
+ "condition": condition,
+ "action": action,
+ "is_global": self._classify_preference_scope(condition) if self.enable_global_preferences else False,
+ })
+
+ if not new_prefs:
+ return len(user_cards)
+
+ # Remove old cards for this user
+ keep_indices = [i for i, c in enumerate(self._memory_cards) if c.user_id != user_id]
+ self._memory_cards = [self._memory_cards[i] for i in keep_indices]
+ if len(keep_indices) > 0 and len(self._memory_embeddings) > 0:
+ self._memory_embeddings = self._memory_embeddings[keep_indices]
+ self._item_vectors = self._item_vectors[keep_indices]
+ else:
+ embed_dim = self._memory_embeddings.shape[1] if len(self._memory_embeddings) > 0 else 4096
+ self._memory_embeddings = np.zeros((0, embed_dim), dtype=np.float32)
+ self._item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32)
+
+ # Add consolidated preferences
+ for pref in new_prefs:
+ note_text = f"When {pref['condition']}, {pref['action']}."
+
+ # Compute embedding
+ e_note = self._embed_model.encode([note_text], normalize=True, return_tensor=False)[0]
+ v_note = self._projection.transform_vector(np.array(e_note))
+
+ # Create card
+ card = MemoryCard(
+ card_id=str(uuid.uuid4()),
+ user_id=user_id,
+ source_session_id=f"consolidated_{user_id}",
+ source_turn_ids=[],
+ raw_queries=[],
+ preference_list=PreferenceList(preferences=[
+ Preference(condition=pref["condition"], action=pref["action"], confidence=1.0)
+ ]),
+ note_text=note_text,
+ embedding_e=list(e_note),
+ kind="pref",
+ is_global=pref["is_global"],
+ )
+
+ self._memory_cards.append(card)
+ self._memory_embeddings = np.vstack([self._memory_embeddings, np.array([e_note])])
+ self._item_vectors = np.vstack([self._item_vectors, np.array([v_note])])
+
+ print(f"[PersonalizedLLM] Consolidated {len(user_cards)} → {len(new_prefs)} preferences for user {user_id}")
+ return len(new_prefs)
+
+ except Exception as e:
+ print(f"[PersonalizedLLM] Consolidation failed for user {user_id}: {e}")
+ return len(user_cards)
+
+ def _add_preferences_as_memory(
+ self,
+ prefs: PreferenceList,
+ query: str,
+ user_id: str,
+ turn_id: int,
+ ) -> List[Dict[str, Any]]:
+ """
+ Add extracted preferences as new memory cards.
+ Returns list of preference dicts for debug info.
+ """
+ extracted = []
+
+ if not prefs.preferences or self._projection is None:
+ return extracted
+
+ for pref in prefs.preferences:
+ note_text = f"When {pref.condition}, {pref.action}."
+
+ # Record for debug
+ extracted.append({
+ "condition": pref.condition,
+ "action": pref.action,
+ "confidence": pref.confidence,
+ })
+
+ # Deduplication check
+ is_duplicate = any(
+ card.user_id == user_id and card.note_text == note_text
+ for card in self._memory_cards
+ )
+
+ if is_duplicate:
+ continue
+
+ # Compute embedding from note_text (NOT query) for proper semantic retrieval
+ # This ensures retrieval query "solve math problem" matches stored "When math problems..."
+ e_note = self._embed_model.encode([note_text], normalize=True, return_tensor=False)[0]
+ v_note = self._projection.transform_vector(np.array(e_note))
+
+ # Classify as global or conditional
+ is_global = self._classify_preference_scope(pref.condition) if self.enable_global_preferences else False
+
+ # Create new memory card
+ card = MemoryCard(
+ card_id=str(uuid.uuid4()),
+ user_id=user_id,
+ source_session_id=f"eval_session_{user_id}",
+ source_turn_ids=[turn_id],
+ raw_queries=[query],
+ preference_list=PreferenceList(preferences=[pref]),
+ note_text=note_text,
+ embedding_e=list(e_note),
+ kind="pref",
+ is_global=is_global,
+ )
+
+ # Add to memory store
+ self._memory_cards.append(card)
+ self._memory_embeddings = np.vstack([self._memory_embeddings, np.array([e_note])])
+ self._item_vectors = np.vstack([self._item_vectors, np.array([v_note])])
+
+ return extracted
+
+ def _score_response(self, response: str) -> float:
+ """
+ Score a response for best-of-N selection.
+
+ Higher score = better response. Scoring heuristics:
+ 1. Length: Longer responses typically have more substance
+ 2. Solution indicators: Contains formulas, steps, answers
+ 3. Proactivity: Doesn't end with just a question
+
+ Returns:
+ Float score (higher is better)
+ """
+ score = 0.0
+ response_lower = response.lower()
+
+ # Length score (normalized, cap at 1000 chars)
+ score += min(len(response), 1000) / 1000 * 3.0
+
+ # Solution indicators (+1 each, max 5)
+ solution_indicators = ['=', 'step', 'answer', 'formula', 'result', 'therefore', 'solution']
+ indicator_count = sum(1 for ind in solution_indicators if ind in response_lower)
+ score += min(indicator_count, 5) * 0.5
+
+ # Structured content (+1 for numbered/bulleted lists)
+ if any(marker in response for marker in ['1.', '2.', '- ', '* ', '##']):
+ score += 1.0
+
+ # Penalty for ending with question (passive behavior)
+ # Check last 100 chars for question marks
+ if '?' in response[-100:]:
+ score -= 1.5
+
+ # Bonus for providing concrete values/numbers
+ import re
+ numbers = re.findall(r'\d+\.?\d*', response)
+ if len(numbers) >= 3:
+ score += 1.0
+
+ return score
+
+ # =========================================================================
+ # Public Interface
+ # =========================================================================
+
+ def chat(self, user_id: str, query: str) -> AssistantResponse:
+ """
+ Main online chat interface.
+
+ Args:
+ user_id: Unique identifier for the user.
+ query: Current user query/message.
+
+ Returns:
+ AssistantResponse containing the answer, usage stats, and debug info.
+
+ Notes:
+ - Internally manages user state, session history, memory retrieval
+ - After this call, you can call apply_feedback() with the turn's feedback
+ """
+ ctx = self._get_or_create_session(user_id)
+ session = ctx.session_state
+ user_state = self._user_store.get_state(user_id)
+
+ # Record user vector before for debug
+ z_long_before = user_state.z_long.copy().tolist()
+ z_short_before = user_state.z_short.copy().tolist()
+
+ # Add user turn to history
+ user_turn = self._build_chat_turn(user_id, query, "user", ctx.turn_counter)
+ session.history.append(user_turn)
+
+ # Vanilla mode: pure LLM without any memory or preference extraction
+ if self.mode == "vanilla":
+ # Skip embedding, preference extraction, and memory retrieval entirely
+ e_q_t = np.zeros(4096, dtype=np.float32) # Placeholder for vanilla mode
+ extracted_prefs = []
+ candidates = []
+ cand_item_vecs = np.array([])
+ base_scores = np.array([])
+ chosen_indices = []
+ probs = np.array([])
+ memories_t = []
+ memory_notes = []
+ else:
+ # Compute query embedding (only needed for non-vanilla modes)
+ # Explicitly normalize for consistent cosine similarity with stored embeddings
+ embed_result = self._embed_model.encode([query], normalize=True, return_tensor=False)
+ if embed_result is None or len(embed_result) == 0:
+ raise RuntimeError(f"Embedding model returned empty result for query: {query[:100]}")
+ e_q_t = np.array(embed_result[0])
+
+ # Store pending RL update info from last turn (for apply_feedback)
+ if session.last_query is not None and self.enable_rl_updates:
+ ctx.pending_rl_update = {
+ "last_query": session.last_query,
+ "last_answer": session.last_answer,
+ "last_memories": session.last_memories,
+ "last_query_embedding": session.last_query_embedding,
+ "current_query_embedding": e_q_t,
+ "last_candidate_item_vectors": session.last_candidate_item_vectors,
+ "last_policy_probs": session.last_policy_probs,
+ "last_chosen_indices": session.last_chosen_indices,
+ }
+
+ # Auto-compute reward via LLM judge if enabled
+ if self._llm_reward_client is not None:
+ import asyncio
+ try:
+ reward, gating = asyncio.run(eval_step_llm(
+ q_t=session.last_query,
+ answer_t=session.last_answer,
+ q_t1=query,
+ memories_t=session.last_memories or [],
+ client=self._llm_reward_client,
+ ))
+ if gating > 0.0:
+ self.apply_feedback(Feedback(
+ user_id=user_id,
+ turn_id=ctx.turn_counter - 1,
+ reward=reward,
+ gating=gating,
+ ))
+ except Exception as e:
+ # Graceful fallback: skip RL update if judge fails
+ print(f"[LLM-Reward] Judge call failed, skipping update: {e}")
+
+ # Extract preferences from conversation (if enabled)
+ # extract_turn processes only the last user turn - efficient since called each turn
+ # Preferences accumulate in _memory_cards across turns (dedup prevents duplicates)
+ extracted_prefs = []
+ if self.enable_preference_extraction:
+ prefs = self._extractor.extract_turn(session.history)
+ if prefs.preferences:
+ print(f"[DEBUG] Extracted {len(prefs.preferences)} prefs from history (len={len(session.history)})")
+ extracted_prefs = self._add_preferences_as_memory(
+ prefs, query, user_id, ctx.turn_counter
+ )
+ if extracted_prefs:
+ print(f"[DEBUG] Added {len(extracted_prefs)} to memory. Total cards: {len(self._memory_cards)}")
+
+ # Separate global preferences (bypass retrieval) from conditional ones
+ global_notes = []
+ retrieval_cards = self._memory_cards
+ retrieval_embeddings = self._memory_embeddings
+ retrieval_item_vectors = self._item_vectors
+ if self.enable_global_preferences:
+ global_cards = [c for c in self._memory_cards if c.is_global and c.user_id == user_id]
+ global_notes = [c.note_text for c in global_cards[:10]] # Cap at 10
+ # Filter out global cards for retrieval
+ cond_indices = [i for i, c in enumerate(self._memory_cards) if not c.is_global]
+ if cond_indices:
+ retrieval_cards = [self._memory_cards[i] for i in cond_indices]
+ retrieval_embeddings = self._memory_embeddings[cond_indices]
+ if len(self._item_vectors) > 0:
+ retrieval_item_vectors = self._item_vectors[cond_indices]
+ else:
+ retrieval_cards = []
+ retrieval_embeddings = np.zeros((0, self._memory_embeddings.shape[1]), dtype=np.float32) if len(self._memory_embeddings) > 0 else self._memory_embeddings
+ retrieval_item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32)
+
+ # Query transformation for better retrieval matching
+ retrieval_queries = None
+ if self.enable_query_transform:
+ retrieval_queries = self._transform_query_for_retrieval(query)
+
+ # Retrieve memories
+ if self.mode == "nopersonal":
+ candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_no_policy(
+ user_id=user_id,
+ query=query,
+ embed_model=self._embed_model,
+ reranker=self._reranker,
+ memory_cards=retrieval_cards,
+ memory_embeddings=retrieval_embeddings,
+ topk_dense=self._rl_cfg["dense_topk"],
+ topk_rerank=self._rl_cfg["rerank_topk"],
+ only_own_memories=self.only_own_memories,
+ queries=retrieval_queries,
+ dynamic_topk=self._rl_cfg["dynamic_topk"],
+ dynamic_min_k=self._rl_cfg["dynamic_min_k"],
+ dynamic_max_k=self._rl_cfg["dynamic_max_k"],
+ dynamic_score_ratio=self._rl_cfg["dynamic_score_ratio"],
+ )
+ else:
+ beta_long = self._rl_cfg["beta_long"]
+ beta_short = self._rl_cfg["beta_short"]
+ candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_with_policy(
+ user_id=user_id,
+ query=query,
+ embed_model=self._embed_model,
+ reranker=self._reranker,
+ memory_cards=retrieval_cards,
+ memory_embeddings=retrieval_embeddings,
+ user_store=self._user_store,
+ item_vectors=retrieval_item_vectors,
+ topk_dense=self._rl_cfg["dense_topk"],
+ topk_rerank=self._rl_cfg["rerank_topk"],
+ beta_long=beta_long,
+ beta_short=beta_short,
+ tau=self._rl_cfg["tau"],
+ only_own_memories=self.only_own_memories,
+ sample=not self.eval_mode,
+ queries=retrieval_queries,
+ )
+
+ # Get selected memories
+ memories_t = [candidates[int(i)] for i in chosen_indices] if chosen_indices else []
+ memory_notes = [m.note_text for m in memories_t]
+
+ # Apply preference rewrite if enabled
+ if self.enable_preference_rewrite and memory_notes:
+ memory_notes = self._rewrite_preferences(memory_notes, query)
+
+ # Debug: show retrieval info
+ if memories_t or global_notes:
+ print(f"[DEBUG-RETRIEVAL] User={user_id}, Query={query[:50]}...")
+ print(f"[DEBUG-RETRIEVAL] Global={len(global_notes)}, Candidates={len(candidates)}, Retrieved={len(memories_t)}")
+ for i, m in enumerate(memories_t[:3]): # Show top 3
+ score = probs[chosen_indices[i]] if i < len(chosen_indices) and chosen_indices[i] < len(probs) else 0
+ print(f"[DEBUG-RETRIEVAL] [{i+1}] score={score:.3f}: {m.note_text[:80]}...")
+
+ # Combine all notes for prompt (global + retrieved)
+ # For chat(), we combine all notes; chat_prepare() handles them separately
+ if self.mode != "vanilla":
+ all_memory_notes = (global_notes if global_notes else []) + memory_notes
+ else:
+ all_memory_notes = memory_notes
+
+ # Build prompt and count tokens
+ prompt_tokens = self._count_tokens(query)
+ for turn in session.history:
+ prompt_tokens += self._count_tokens(turn.text)
+ for note in all_memory_notes:
+ prompt_tokens += self._count_tokens(note)
+
+ # Generate answer (with best-of-N if enabled)
+ if self.best_of_n > 1:
+ # Generate N responses and pick the best one
+ candidates_responses = []
+ for i in range(self.best_of_n):
+ resp = self._chat_model.answer(
+ history=session.history,
+ memory_notes=all_memory_notes,
+ max_new_tokens=self._rl_cfg["max_new_tokens"],
+ temperature=0.8, # Slightly higher temp for diversity
+ )
+ score = self._score_response(resp)
+ candidates_responses.append((resp, score))
+
+ # Sort by score (descending) and pick best
+ candidates_responses.sort(key=lambda x: x[1], reverse=True)
+ answer_t = candidates_responses[0][0]
+ best_score = candidates_responses[0][1]
+
+ if len(candidates_responses) > 1:
+ print(f"[BEST-OF-{self.best_of_n}] Scores: {[f'{s:.2f}' for _, s in candidates_responses]}, picked score={best_score:.2f}")
+ else:
+ answer_t = self._chat_model.answer(
+ history=session.history,
+ memory_notes=all_memory_notes,
+ max_new_tokens=self._rl_cfg["max_new_tokens"],
+ )
+
+ completion_tokens = self._count_tokens(answer_t)
+
+ # Add assistant turn to history
+ assist_turn = self._build_chat_turn(user_id, answer_t, "assistant", ctx.turn_counter)
+ session.history.append(assist_turn)
+
+ # Update session state for next turn
+ session.last_query = query
+ session.last_answer = answer_t
+ session.last_memories = memories_t
+ session.last_query_embedding = e_q_t
+ session.last_candidate_item_vectors = cand_item_vecs
+ session.last_policy_probs = probs
+ session.last_chosen_indices = list(chosen_indices) if len(chosen_indices) > 0 else []
+
+ ctx.turn_counter += 1
+
+ # Build debug info
+ debug = DebugInfo(
+ selected_memory_ids=[m.card_id for m in memories_t],
+ selected_memory_notes=[m.note_text for m in memories_t],
+ selected_memory_scores=[float(probs[i]) if i < len(probs) else 0.0 for i in chosen_indices] if len(chosen_indices) > 0 else [],
+ user_vector_before=z_long_before + z_short_before, # Concatenated for simplicity
+ user_vector_after=user_state.z_long.tolist() + user_state.z_short.tolist(),
+ extracted_preferences=extracted_prefs,
+ extra={
+ "num_candidates": len(candidates),
+ "num_total_memories": len(self._memory_cards),
+ "z_long_norm": float(np.linalg.norm(user_state.z_long)),
+ "z_short_norm": float(np.linalg.norm(user_state.z_short)),
+ }
+ )
+
+ # Build usage stats
+ usage = UsageStats(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=prompt_tokens + completion_tokens,
+ model=self._llm_name,
+ )
+
+ return AssistantResponse(
+ answer=answer_t,
+ usage=usage,
+ debug=debug,
+ )
+
+ def chat_prepare(self, user_id: str, query: str, skip_extraction: bool = False, skip_auto_reward: bool = False) -> dict:
+ """
+ Prepare for chat without calling the LLM.
+
+ This does all the preparation work (embedding, memory retrieval, etc.)
+ and returns the messages to send to the LLM along with context needed
+ for post-processing.
+
+ Used for batch processing where messages are collected first, then
+ sent in batch to vLLM for concurrent processing.
+
+ Args:
+ user_id: Unique identifier for the user.
+ query: Current user query/message.
+
+ Returns:
+ Dict containing:
+ - messages: List of messages to send to LLM
+ - context: Dict with all state needed for chat_complete()
+ """
+ ctx = self._get_or_create_session(user_id)
+ session = ctx.session_state
+ user_state = self._user_store.get_state(user_id)
+
+ # Record user vector before for debug
+ z_long_before = user_state.z_long.copy().tolist()
+ z_short_before = user_state.z_short.copy().tolist()
+
+ # Add user turn to history
+ user_turn = self._build_chat_turn(user_id, query, "user", ctx.turn_counter)
+ session.history.append(user_turn)
+
+ # Vanilla mode: pure LLM without any memory or preference extraction
+ if self.mode == "vanilla":
+ e_q_t = np.zeros(4096, dtype=np.float32)
+ extracted_prefs = []
+ candidates = []
+ cand_item_vecs = np.array([])
+ base_scores = np.array([])
+ chosen_indices = []
+ probs = np.array([])
+ memories_t = []
+ memory_notes = []
+ else:
+ # Compute query embedding
+ embed_result = self._embed_model.encode([query], normalize=True, return_tensor=False)
+ if embed_result is None or len(embed_result) == 0:
+ raise RuntimeError(f"Embedding model returned empty result for query: {query[:100]}")
+ e_q_t = np.array(embed_result[0])
+
+ # Store pending RL update info from last turn
+ if session.last_query is not None and self.enable_rl_updates:
+ ctx.pending_rl_update = {
+ "last_query": session.last_query,
+ "last_answer": session.last_answer,
+ "last_memories": session.last_memories,
+ "last_query_embedding": session.last_query_embedding,
+ "current_query_embedding": e_q_t,
+ "last_candidate_item_vectors": session.last_candidate_item_vectors,
+ "last_policy_probs": session.last_policy_probs,
+ "last_chosen_indices": session.last_chosen_indices,
+ }
+
+ # Auto-compute reward via LLM judge if enabled
+ # skip_auto_reward=True when batch framework handles rewards externally
+ if self._llm_reward_client is not None and not skip_auto_reward:
+ import asyncio
+ try:
+ reward, gating = asyncio.run(eval_step_llm(
+ q_t=session.last_query,
+ answer_t=session.last_answer,
+ q_t1=query,
+ memories_t=session.last_memories or [],
+ client=self._llm_reward_client,
+ ))
+ if gating > 0.0:
+ self.apply_feedback(Feedback(
+ user_id=user_id,
+ turn_id=ctx.turn_counter - 1,
+ reward=reward,
+ gating=gating,
+ ))
+ except Exception as e:
+ print(f"[LLM-Reward] Judge call failed, skipping update: {e}")
+
+ # Extract preferences from conversation
+ extracted_prefs = []
+ if self.enable_preference_extraction and not skip_extraction:
+ prefs = self._extractor.extract_turn(session.history)
+ if prefs.preferences:
+ print(f"[DEBUG] Extracted {len(prefs.preferences)} prefs from history (len={len(session.history)})")
+ extracted_prefs = self._add_preferences_as_memory(
+ prefs, query, user_id, ctx.turn_counter
+ )
+ if extracted_prefs:
+ print(f"[DEBUG] Added {len(extracted_prefs)} to memory. Total cards: {len(self._memory_cards)}")
+
+ # Separate global preferences (bypass retrieval) from conditional ones
+ global_notes = []
+ retrieval_cards = self._memory_cards
+ retrieval_embeddings = self._memory_embeddings
+ retrieval_item_vectors = self._item_vectors
+ if self.enable_global_preferences:
+ global_cards = [c for c in self._memory_cards if c.is_global and c.user_id == user_id]
+ global_notes = [c.note_text for c in global_cards[:10]] # Cap at 10
+ cond_indices = [i for i, c in enumerate(self._memory_cards) if not c.is_global]
+ if cond_indices:
+ retrieval_cards = [self._memory_cards[i] for i in cond_indices]
+ retrieval_embeddings = self._memory_embeddings[cond_indices]
+ if len(self._item_vectors) > 0:
+ retrieval_item_vectors = self._item_vectors[cond_indices]
+ else:
+ retrieval_cards = []
+ retrieval_embeddings = np.zeros((0, self._memory_embeddings.shape[1]), dtype=np.float32) if len(self._memory_embeddings) > 0 else self._memory_embeddings
+ retrieval_item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32)
+
+ # Query transformation for better retrieval matching
+ retrieval_queries = None
+ if self.enable_query_transform:
+ retrieval_queries = self._transform_query_for_retrieval(query)
+
+ # Retrieve memories
+ if self.mode == "nopersonal":
+ candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_no_policy(
+ user_id=user_id,
+ query=query,
+ embed_model=self._embed_model,
+ reranker=self._reranker,
+ memory_cards=retrieval_cards,
+ memory_embeddings=retrieval_embeddings,
+ topk_dense=self._rl_cfg["dense_topk"],
+ topk_rerank=self._rl_cfg["rerank_topk"],
+ only_own_memories=self.only_own_memories,
+ queries=retrieval_queries,
+ dynamic_topk=self._rl_cfg["dynamic_topk"],
+ dynamic_min_k=self._rl_cfg["dynamic_min_k"],
+ dynamic_max_k=self._rl_cfg["dynamic_max_k"],
+ dynamic_score_ratio=self._rl_cfg["dynamic_score_ratio"],
+ )
+ else:
+ beta_long = self._rl_cfg["beta_long"]
+ beta_short = self._rl_cfg["beta_short"]
+ candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_with_policy(
+ user_id=user_id,
+ query=query,
+ embed_model=self._embed_model,
+ reranker=self._reranker,
+ memory_cards=retrieval_cards,
+ memory_embeddings=retrieval_embeddings,
+ user_store=self._user_store,
+ item_vectors=retrieval_item_vectors,
+ topk_dense=self._rl_cfg["dense_topk"],
+ topk_rerank=self._rl_cfg["rerank_topk"],
+ beta_long=beta_long,
+ beta_short=beta_short,
+ tau=self._rl_cfg["tau"],
+ only_own_memories=self.only_own_memories,
+ sample=not self.eval_mode,
+ queries=retrieval_queries,
+ )
+
+ memories_t = [candidates[int(i)] for i in chosen_indices] if chosen_indices else []
+ memory_notes = [m.note_text for m in memories_t]
+
+ # Apply preference rewrite if enabled
+ if self.enable_preference_rewrite and memory_notes:
+ memory_notes = self._rewrite_preferences(memory_notes, query)
+
+ if memories_t or global_notes:
+ print(f"[DEBUG-RETRIEVAL] User={user_id}, Query={query[:50]}...")
+ print(f"[DEBUG-RETRIEVAL] Global={len(global_notes)}, Candidates={len(candidates)}, Retrieved={len(memories_t)}")
+ for i, m in enumerate(memories_t[:3]):
+ score = probs[chosen_indices[i]] if i < len(chosen_indices) and chosen_indices[i] < len(probs) else 0
+ print(f"[DEBUG-RETRIEVAL] [{i+1}] score={score:.3f}: {m.note_text[:80]}...")
+
+ # Build prompt token count
+ prompt_tokens = self._count_tokens(query)
+ for turn in session.history:
+ prompt_tokens += self._count_tokens(turn.text)
+ all_notes = memory_notes + (global_notes if self.mode != "vanilla" else [])
+ for note in all_notes:
+ prompt_tokens += self._count_tokens(note)
+
+ # Build messages for LLM (pass global_notes separately for distinct prompt sections)
+ effective_global = global_notes if (self.enable_global_preferences and self.mode != "vanilla") else None
+ messages = self._chat_model.build_messages(
+ history=session.history,
+ memory_notes=memory_notes,
+ max_new_tokens=self._rl_cfg["max_new_tokens"],
+ global_notes=effective_global,
+ )
+
+ # Return messages and context for chat_complete
+ return {
+ "messages": messages,
+ "context": {
+ "user_id": user_id,
+ "query": query,
+ "ctx": ctx,
+ "session": session,
+ "user_state": user_state,
+ "z_long_before": z_long_before,
+ "z_short_before": z_short_before,
+ "e_q_t": e_q_t,
+ "extracted_prefs": extracted_prefs,
+ "candidates": candidates,
+ "cand_item_vecs": cand_item_vecs,
+ "base_scores": base_scores,
+ "chosen_indices": chosen_indices,
+ "probs": probs,
+ "memories_t": memories_t,
+ "memory_notes": memory_notes,
+ "prompt_tokens": prompt_tokens,
+ }
+ }
+
+ def chat_complete(self, answer_t: str, context: dict) -> AssistantResponse:
+ """
+ Complete chat with LLM response.
+
+ This takes the LLM response and context from chat_prepare(), and
+ does all post-processing (add to history, debug info, etc.).
+
+ Args:
+ answer_t: The LLM response text.
+ context: Context dict from chat_prepare().
+
+ Returns:
+ AssistantResponse containing the answer, usage stats, and debug info.
+ """
+ # Unpack context
+ user_id = context["user_id"]
+ query = context["query"]
+ ctx = context["ctx"]
+ session = context["session"]
+ user_state = context["user_state"]
+ z_long_before = context["z_long_before"]
+ z_short_before = context["z_short_before"]
+ e_q_t = context["e_q_t"]
+ extracted_prefs = context["extracted_prefs"]
+ candidates = context["candidates"]
+ cand_item_vecs = context["cand_item_vecs"]
+ chosen_indices = context["chosen_indices"]
+ probs = context["probs"]
+ memories_t = context["memories_t"]
+ memory_notes = context["memory_notes"]
+ prompt_tokens = context["prompt_tokens"]
+
+ completion_tokens = self._count_tokens(answer_t)
+
+ # Add assistant turn to history
+ assist_turn = self._build_chat_turn(user_id, answer_t, "assistant", ctx.turn_counter)
+ session.history.append(assist_turn)
+
+ # Update session state for next turn
+ session.last_query = query
+ session.last_answer = answer_t
+ session.last_memories = memories_t
+ session.last_query_embedding = e_q_t
+ session.last_candidate_item_vectors = cand_item_vecs
+ session.last_policy_probs = probs
+ session.last_chosen_indices = list(chosen_indices) if len(chosen_indices) > 0 else []
+
+ ctx.turn_counter += 1
+
+ # Build debug info
+ debug = DebugInfo(
+ selected_memory_ids=[m.card_id for m in memories_t],
+ selected_memory_notes=[m.note_text for m in memories_t],
+ selected_memory_scores=[float(probs[i]) if i < len(probs) else 0.0 for i in chosen_indices] if len(chosen_indices) > 0 else [],
+ user_vector_before=z_long_before + z_short_before,
+ user_vector_after=user_state.z_long.tolist() + user_state.z_short.tolist(),
+ extracted_preferences=extracted_prefs,
+ extra={
+ "num_candidates": len(candidates),
+ "num_total_memories": len(self._memory_cards),
+ "z_long_norm": float(np.linalg.norm(user_state.z_long)),
+ "z_short_norm": float(np.linalg.norm(user_state.z_short)),
+ }
+ )
+
+ # Build usage stats
+ usage = UsageStats(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=prompt_tokens + completion_tokens,
+ model=self._llm_name,
+ )
+
+ return AssistantResponse(
+ answer=answer_t,
+ usage=usage,
+ debug=debug,
+ )
+
+ def apply_extracted_preferences(self, user_id: str, pref_dict: dict) -> list:
+ """Apply pre-computed extraction results (from batch extraction) to memory."""
+ prefs = PreferenceList.model_validate(pref_dict)
+ if not prefs.preferences:
+ return []
+ ctx = self._get_or_create_session(user_id)
+ query = ctx.session_state.history[-1].text if ctx.session_state.history else ""
+ extracted = self._add_preferences_as_memory(prefs, query, user_id, ctx.turn_counter)
+ if extracted:
+ print(f"[DEBUG] Batch-added {len(extracted)} to memory. Total cards: {len(self._memory_cards)}")
+ return extracted
+
+ def get_last_user_query(self, user_id: str) -> str:
+ """Get the last user message text for this user's session."""
+ ctx = self._sessions.get(user_id)
+ if ctx and ctx.session_state.history:
+ for t in reversed(ctx.session_state.history):
+ if t.role == "user":
+ return t.text
+ return ""
+
+ def reset_session(self, user_id: str) -> None:
+ """
+ Reset session for a user (new chat window).
+
+ This clears:
+ - Session conversation history
+ - Short-term user vector (z_short)
+ - Pending RL update info
+
+ This preserves:
+ - Long-term user vector (z_long)
+ - User's memory cards (may be consolidated if enabled)
+
+ Args:
+ user_id: The user whose session to reset.
+ """
+ # Consolidate preferences at session end (before clearing session)
+ if self.enable_preference_consolidation:
+ self.consolidate_user_preferences(user_id)
+
+ # Clear session context
+ if user_id in self._sessions:
+ del self._sessions[user_id]
+
+ # Create fresh session
+ self._sessions[user_id] = _SessionContext(
+ session_state=OnlineSessionState(user_id=user_id),
+ turn_counter=0,
+ )
+
+ # Reset short-term vector but keep long-term
+ user_state = self._user_store.get_state(user_id)
+ user_state.z_short = np.zeros(self._rl_cfg["item_dim"], dtype=np.float32)
+ self._user_store.save_state(user_state)
+
+ def reset_user(self, user_id: str) -> None:
+ """
+ Completely reset a user (new "life").
+
+ This clears:
+ - Long-term user vector (z_long)
+ - Short-term user vector (z_short)
+ - User's memory cards
+ - Session history
+ - All cached state
+
+ Args:
+ user_id: The user to reset.
+ """
+ # Clear session
+ if user_id in self._sessions:
+ del self._sessions[user_id]
+
+ # Reset user state vectors
+ user_state = self._user_store.get_state(user_id)
+ user_state.z_long = self._user_store.global_init_z.copy()
+ user_state.z_short = np.zeros(self._rl_cfg["item_dim"], dtype=np.float32)
+ user_state.reward_ma = 0.0
+ self._user_store.save_state(user_state)
+
+ # Find indices to KEEP (cards NOT belonging to this user)
+ # Must do this BEFORE modifying _memory_cards
+ keep_indices = [
+ i for i, card in enumerate(self._memory_cards)
+ if card.user_id != user_id
+ ]
+
+ # Filter memory cards
+ self._memory_cards = [self._memory_cards[i] for i in keep_indices]
+
+ # Filter embeddings and item vectors to match
+ if len(keep_indices) > 0 and len(self._memory_embeddings) > 0:
+ self._memory_embeddings = self._memory_embeddings[keep_indices]
+ self._item_vectors = self._item_vectors[keep_indices]
+ else:
+ # No cards left or no embeddings
+ embed_dim = self._memory_embeddings.shape[1] if len(self._memory_embeddings) > 0 else 4096
+ self._memory_embeddings = np.zeros((0, embed_dim), dtype=np.float32)
+ self._item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32)
+
+ def apply_feedback(self, feedback: Feedback) -> None:
+ """
+ Apply feedback from user simulator or judge.
+
+ This performs the REINFORCE update to user vectors based on
+ the reward signal from the previous turn.
+
+ Args:
+ feedback: Feedback object containing reward, gating, and metadata.
+
+ Notes:
+ - Should be called AFTER chat() but BEFORE the next chat() call
+ - Uses the stored context from the previous turn
+ - If enable_rl_updates is False, this is a no-op (logging only)
+ - If mode is "nopersonal", this is a no-op (baseline comparison)
+ """
+ if not self.enable_rl_updates:
+ return
+
+ # In "nopersonal" or "vanilla" mode, skip RL updates entirely (baseline)
+ if self.mode in ("nopersonal", "vanilla"):
+ return
+
+ user_id = feedback.user_id
+ ctx = self._sessions.get(user_id)
+
+ if ctx is None or ctx.pending_rl_update is None:
+ return
+
+ pending = ctx.pending_rl_update
+ user_state = self._user_store.get_state(user_id)
+
+ # Check if we have the necessary data for RL update
+ if (pending.get("last_candidate_item_vectors") is not None and
+ pending.get("last_policy_probs") is not None and
+ pending.get("last_chosen_indices") is not None and
+ len(pending["last_chosen_indices"]) > 0):
+
+ # Extract chosen vectors
+ chosen_indices = pending["last_chosen_indices"]
+ candidate_vectors = pending["last_candidate_item_vectors"]
+
+ if len(candidate_vectors) > 0:
+ print(f"[DEBUG-REINFORCE] User={user_id} reward={feedback.reward:.2f} "
+ f"n_candidates={len(candidate_vectors)} chosen={chosen_indices} "
+ f"probs_shape={pending['last_policy_probs'].shape if hasattr(pending['last_policy_probs'], 'shape') else 'N/A'}")
+ # REINFORCE expects:
+ # - item_vectors: ALL candidate vectors [K, k]
+ # - chosen_indices: indices into those candidates
+ # - policy_probs: probabilities over all K candidates [K]
+ updated = reinforce_update_user_state(
+ user_state=user_state,
+ item_vectors=candidate_vectors, # All candidates, not just chosen
+ chosen_indices=chosen_indices, # Original indices into candidates
+ policy_probs=pending["last_policy_probs"],
+ reward_hat=feedback.reward,
+ gating=feedback.gating,
+ tau=self._rl_cfg["tau"],
+ eta_long=self._rl_cfg["eta_long"],
+ eta_short=self._rl_cfg["eta_short"],
+ ema_alpha=self._rl_cfg["ema_alpha"],
+ short_decay=self._rl_cfg["short_decay"],
+ )
+
+ print(f"[DEBUG-REINFORCE] updated={updated} z_long_norm={np.linalg.norm(user_state.z_long):.15e}")
+ if updated:
+ self._user_store.save_state(user_state)
+
+ # Clear pending update
+ ctx.pending_rl_update = None
+
+ def get_user_state_summary(self, user_id: str) -> Dict[str, Any]:
+ """
+ Get a summary of the user's current state (for debugging/analysis).
+
+ Args:
+ user_id: The user to query.
+
+ Returns:
+ Dictionary with user state information.
+ """
+ user_state = self._user_store.get_state(user_id)
+ ctx = self._sessions.get(user_id)
+
+ user_memory_count = sum(
+ 1 for card in self._memory_cards if card.user_id == user_id
+ )
+
+ return {
+ "user_id": user_id,
+ "z_long_norm": float(np.linalg.norm(user_state.z_long)),
+ "z_short_norm": float(np.linalg.norm(user_state.z_short)),
+ "reward_ma": user_state.reward_ma,
+ "session_history_length": len(ctx.session_state.history) if ctx else 0,
+ "turn_counter": ctx.turn_counter if ctx else 0,
+ "user_memory_count": user_memory_count,
+ "total_memory_count": len(self._memory_cards),
+ }
+
+ def persist(self) -> None:
+ """
+ Persist all state to disk.
+
+ Call this at the end of an evaluation run to save:
+ - User state vectors
+ - Memory cards
+ """
+ # Save user store
+ self._user_store.persist()
+
+ # Save memory cards
+ with open(self._memory_cards_path, "w", encoding="utf-8") as f:
+ for card in self._memory_cards:
+ f.write(card.model_dump_json() + "\n")
+
+ # Save embeddings
+ np.save(self._memory_embeddings_path, self._memory_embeddings)
+
+ # Save item projection with updated vectors
+ if self._projection is not None:
+ np.savez(
+ self._item_projection_path,
+ P=self._projection.P,
+ mean=self._projection.mean,
+ V=self._item_vectors,
+ )
+
+ print("[PersonalizedLLM] State persisted to disk.")
+
+
+# =============================================================================
+# Convenience Factory
+# =============================================================================
+
+def create_personalized_llm(
+ config_path: Optional[str] = None,
+ **kwargs
+) -> PersonalizedLLM:
+ """
+ Factory function to create a PersonalizedLLM instance.
+
+ Args:
+ config_path: Optional path to configuration file.
+ **kwargs: Additional arguments passed to PersonalizedLLM constructor.
+
+ Returns:
+ Configured PersonalizedLLM instance.
+ """
+ return PersonalizedLLM(config_path=config_path, **kwargs)
+