From b6c3e4e51eeab703b40284459c6e9fff2151216c Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Wed, 18 Mar 2026 18:25:09 -0500 Subject: Initial release: VARS - personalized LLM with RAG and user vector learning --- src/personalization/serving/__init__.py | 22 + src/personalization/serving/personalized_llm.py | 1835 +++++++++++++++++++++++ 2 files changed, 1857 insertions(+) create mode 100644 src/personalization/serving/__init__.py create mode 100644 src/personalization/serving/personalized_llm.py (limited to 'src/personalization/serving') diff --git a/src/personalization/serving/__init__.py b/src/personalization/serving/__init__.py new file mode 100644 index 0000000..11adcf8 --- /dev/null +++ b/src/personalization/serving/__init__.py @@ -0,0 +1,22 @@ +# Personalization Serving Module +# +# This module provides the interface layer for the personalization system. + +from personalization.serving.personalized_llm import ( + PersonalizedLLM, + AssistantResponse, + UsageStats, + DebugInfo, + Feedback, + create_personalized_llm, +) + +__all__ = [ + "PersonalizedLLM", + "AssistantResponse", + "UsageStats", + "DebugInfo", + "Feedback", + "create_personalized_llm", +] + diff --git a/src/personalization/serving/personalized_llm.py b/src/personalization/serving/personalized_llm.py new file mode 100644 index 0000000..8032e6b --- /dev/null +++ b/src/personalization/serving/personalized_llm.py @@ -0,0 +1,1835 @@ +#!/usr/bin/env python3 +""" +Personalized LLM Interface for Evaluation. + +This module provides the `PersonalizedLLM` class that wraps the entire +personalization system into a clean interface for evaluation frameworks +and user simulators. + +Interface contract: +- chat(user_id, query) -> AssistantResponse: Main online interface +- reset_session(user_id): Clear session history and short-term state +- reset_user(user_id): Completely reset user (long-term, short-term, memories) +- apply_feedback(feedback): Apply external feedback for RL updates +""" + +from __future__ import annotations + +import os +import sys +import uuid +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +import numpy as np +import yaml + +# Ensure src is in path for standalone usage +_src_path = os.path.join(os.path.dirname(__file__), "../../..") +if _src_path not in sys.path: + sys.path.insert(0, _src_path) + +from personalization.config.settings import load_local_models_config +from personalization.config.registry import get_preference_extractor, get_chat_model +from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B +from personalization.models.reranker.qwen3_reranker import Qwen3Reranker +from personalization.models.reranker.bge_reranker import BGEReranker +from personalization.user_model.tensor_store import UserTensorStore, UserState +from personalization.user_model.session_state import OnlineSessionState +from personalization.user_model.features import ItemProjection +from personalization.retrieval.preference_store.schemas import ( + MemoryCard, ChatTurn, PreferenceList, Preference +) +from personalization.retrieval.pipeline import retrieve_with_policy, retrieve_no_policy +from personalization.feedback.handlers import eval_step, eval_step_llm +from personalization.feedback.llm_reward import LLMRewardClient, LLMRewardConfig +from personalization.user_model.policy.reinforce import reinforce_update_user_state + + +# ============================================================================= +# Data Classes for Interface +# ============================================================================= + +@dataclass +class UsageStats: + """Token usage statistics from a chat completion.""" + prompt_tokens: int + completion_tokens: int + total_tokens: int + model: str + + +@dataclass +class DebugInfo: + """ + Debug information for analysis and ablation studies. + All fields are optional - fill what you have, leave empty what you don't. + """ + selected_memory_ids: List[str] = field(default_factory=list) + selected_memory_notes: List[str] = field(default_factory=list) + selected_memory_scores: List[float] = field(default_factory=list) + user_vector_before: Optional[List[float]] = None + user_vector_after: Optional[List[float]] = None + extracted_preferences: List[Dict[str, Any]] = field(default_factory=list) + extra: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class AssistantResponse: + """Response from the personalized LLM chat interface.""" + answer: str + usage: UsageStats + debug: Optional[DebugInfo] = None + + +@dataclass +class Feedback: + """ + Feedback data structure for RL updates from user simulator or judge. + + Attributes: + user_id: The user this feedback is for. + turn_id: The turn this feedback refers to (from the previous turn). + reward: Reward scalar computed by user simulator / judge. + gating: Gating flag (1=valid learning signal, 0=skip update). + meta: Additional metadata for training/analysis. + """ + user_id: str + turn_id: int + reward: float + gating: float # Can be 0.0 or 1.0, or continuous + meta: Dict[str, Any] = field(default_factory=dict) + + +# ============================================================================= +# Internal Session State Extended +# ============================================================================= + +@dataclass +class _SessionContext: + """Extended session context for evaluation tracking.""" + session_state: OnlineSessionState + turn_counter: int = 0 + # Store info needed for apply_feedback + pending_rl_update: Optional[Dict[str, Any]] = None + + +# ============================================================================= +# Shared Model Singletons for Multi-threaded Efficiency +# ============================================================================= + +_shared_embed_model = None +_shared_reranker = None +_shared_extractor = None +_shared_models_lock = None # Will be initialized on first use + + +def _get_shared_models_lock(): + """Get or create the threading lock for shared models.""" + global _shared_models_lock + if _shared_models_lock is None: + import threading + _shared_models_lock = threading.Lock() + return _shared_models_lock + + +def get_shared_embedding_model(model_path: str, device_map: str = "auto"): + """Get or create shared embedding model (thread-safe singleton).""" + global _shared_embed_model + import torch + + lock = _get_shared_models_lock() + with lock: + if _shared_embed_model is None: + print(f"[SharedModels] Loading shared embedding model on {device_map}...") + _shared_embed_model = Qwen3Embedding8B( + model_path=model_path, + dtype=torch.bfloat16, + device_map=device_map, + ) + print("[SharedModels] Shared embedding model loaded.") + return _shared_embed_model + + +def get_shared_reranker(model_path: str, device_map: str = "auto", reranker_type: str = "qwen3"): + """Get or create shared reranker model (thread-safe singleton).""" + global _shared_reranker + import torch + + lock = _get_shared_models_lock() + with lock: + if _shared_reranker is None: + print(f"[SharedModels] Loading shared reranker ({reranker_type}) on {device_map}...") + if reranker_type == "bge": + _shared_reranker = BGEReranker( + model_path=model_path, + device_map=device_map, + dtype=torch.float16, + ) + else: + _shared_reranker = Qwen3Reranker( + model_path=model_path, + device_map=device_map, + dtype=torch.bfloat16, + ) + print("[SharedModels] Shared reranker model loaded.") + return _shared_reranker + + +def get_shared_extractor(model_path: str, device_map: str = "auto"): + """Get or create shared preference extractor model (thread-safe singleton).""" + global _shared_extractor + import torch + from personalization.models.preference_extractor.rule_extractor import QwenRuleExtractor + + lock = _get_shared_models_lock() + with lock: + if _shared_extractor is None: + print(f"[SharedModels] Loading shared preference extractor on {device_map}...") + _shared_extractor = QwenRuleExtractor( + model_path=model_path, + dtype=torch.bfloat16, + device_map=device_map, + ) + print("[SharedModels] Shared preference extractor loaded.") + return _shared_extractor + + +def clear_shared_models(): + """Free all shared singleton models to reclaim GPU memory between methods.""" + global _shared_embed_model, _shared_reranker, _shared_extractor + import gc + + lock = _get_shared_models_lock() + with lock: + freed = [] + if _shared_embed_model is not None: + freed.append("embedding") + del _shared_embed_model + _shared_embed_model = None + if _shared_reranker is not None: + freed.append("reranker") + del _shared_reranker + _shared_reranker = None + if _shared_extractor is not None: + freed.append("extractor") + del _shared_extractor + _shared_extractor = None + + if freed: + gc.collect() + try: + import torch + if torch.cuda.is_available(): + torch.cuda.empty_cache() + except ImportError: + pass + print(f"[SharedModels] Cleared: {', '.join(freed)}") + + +# ============================================================================= +# PersonalizedLLM Class +# ============================================================================= + +class PersonalizedLLM: + """ + Personalized LLM wrapper for evaluation frameworks. + + This class provides a clean interface that accepts only (user_id, query) + for the main chat function, while internally managing: + - User state vectors (z_long, z_short) + - Session history + - Memory retrieval and policy + - Preference extraction and storage + - RL updates + + Example usage: + llm = PersonalizedLLM() + + # Reset user for fresh experiment + llm.reset_user("user_123") + + # Start a session + llm.reset_session("user_123") + + # Chat + response = llm.chat("user_123", "What's a good recipe for dinner?") + print(response.answer) + + # Apply feedback from previous turn (from turn 2 onwards) + llm.apply_feedback(Feedback( + user_id="user_123", + turn_id=0, + reward=0.8, + gating=1.0 + )) + """ + + def __init__( + self, + config_path: Optional[str] = None, + user_store_path: str = "data/users/user_store_eval.npz", + memory_cards_path: str = "data/corpora/memory_cards.jsonl", + memory_embeddings_path: str = "data/corpora/memory_embeddings.npy", + item_projection_path: str = "data/corpora/item_projection.npz", + only_own_memories: bool = True, + enable_preference_extraction: bool = True, + enable_rl_updates: bool = True, + mode: str = "full", # "full", "nopersonal", or "vanilla" + eval_mode: bool = True, # True = greedy selection, False = stochastic sampling + device_assignment: Optional[Dict[str, str]] = None, # Multi-GPU support + llm_name: Optional[str] = None, # Override LLM name (e.g., "llama_8b_vllm" for vLLM) + use_shared_models: bool = False, # Use shared singleton models for multi-threaded efficiency + reranker_type: str = "qwen3", # "qwen3" (8B) or "bge" (278M) + best_of_n: int = 1, # Generate N responses and pick best (for RAG methods) + reward_mode: str = "keyword", # "keyword", "llm" (GPT-4o-mini), or "llm_local" (local vLLM) + llm_reward_config: Optional["LLMRewardConfig"] = None, # Config for LLM judge + reward_vllm_url: Optional[str] = None, # vLLM URL for local reward model (when reward_mode="llm_local") + enable_query_transform: bool = False, # Transform queries for better retrieval matching + enable_global_preferences: bool = False, # Separate global prefs that bypass retrieval + dynamic_topk: bool = False, # Use dynamic topk based on rerank scores + dynamic_min_k: int = 3, # Min preferences for dynamic topk + dynamic_max_k: int = 8, # Max preferences for dynamic topk + dynamic_score_ratio: float = 0.5, # Threshold = top_score * ratio + eta_long: float = None, # Override RL learning rate for z_long + eta_short: float = None, # Override RL learning rate for z_short + enable_preference_consolidation: bool = False, # Consolidate preferences at session end + consolidation_threshold: int = 5, # Min preferences before consolidation + enable_preference_rewrite: bool = False, # Use LLM to rewrite/merge retrieved preferences + ): + """ + Initialize the PersonalizedLLM. + + Args: + config_path: Path to config file. If None, uses default locations. + user_store_path: Path to persist user state vectors. + memory_cards_path: Path to memory cards JSONL file. + memory_embeddings_path: Path to memory embeddings numpy file. + item_projection_path: Path to item projection (PCA) file. + only_own_memories: If True, only retrieve user's own memories (strict privacy). + enable_preference_extraction: If True, extract preferences from user turns. + enable_rl_updates: If True, apply RL updates via apply_feedback. + mode: "full" for full personalization, "nopersonal" for baseline (no user vector influence), + "vanilla" for pure LLM without any memory retrieval or preference extraction. + eval_mode: If True, use greedy/deterministic selection (for evaluation). + If False, use stochastic sampling (for training/exploration). + device_assignment: Optional dict to assign models to specific GPUs. + Example: {"embed": "cuda:0", "reranker": "cuda:1", "chat": "cuda:2", "extractor": "cuda:3"} + If None, uses "auto" for all models. + use_shared_models: If True, use shared singleton models for embedding and reranker. + This is essential for multi-threaded/parallel profile processing to avoid + loading duplicate models. When enabled, the first thread loads the models, + and subsequent threads reuse the shared instances. + """ + self.only_own_memories = only_own_memories + self.use_shared_models = use_shared_models + self.enable_preference_extraction = enable_preference_extraction + self.enable_rl_updates = enable_rl_updates + self.mode = mode # "full" or "nopersonal" + self.eval_mode = eval_mode # True = greedy, False = sample + self.reranker_type = reranker_type # "qwen3" or "bge" + self.best_of_n = best_of_n # Generate N responses and pick best + self.reward_mode = reward_mode # "keyword", "llm", or "llm_local" + self.enable_query_transform = enable_query_transform + self.enable_global_preferences = enable_global_preferences + self.enable_preference_consolidation = enable_preference_consolidation + self.consolidation_threshold = consolidation_threshold + self.enable_preference_rewrite = enable_preference_rewrite + + # Initialize LLM reward client if using LLM judge + self._llm_reward_client = None # Can be LLMRewardClient or LocalLLMRewardClient + if reward_mode == "llm": + self._llm_reward_client = LLMRewardClient(llm_reward_config or LLMRewardConfig()) + elif reward_mode == "llm_local": + from personalization.feedback.local_llm_reward import ( + LocalLLMRewardClient, + LocalLLMRewardConfig, + ) + local_config = LocalLLMRewardConfig( + vllm_url=reward_vllm_url or "http://localhost:8005/v1", + ) + self._llm_reward_client = LocalLLMRewardClient(local_config) + + # Multi-GPU device assignment + self._device_assignment = device_assignment or { + "embed": "auto", + "reranker": "auto", + "chat": "auto", + "extractor": "auto", + } + + # Paths + self._memory_cards_path = memory_cards_path + self._memory_embeddings_path = memory_embeddings_path + self._item_projection_path = item_projection_path + + # RL Configuration + # Note: beta/eta increased for more significant z_u updates + self._rl_cfg = { + "item_dim": 256, + "beta_long": 2.0, # Increased from 0.1 for stronger personalization + "beta_short": 5.0, # Increased from 0.3 + "tau": 1.0, + "eta_long": eta_long if eta_long is not None else 0.01, + "eta_short": eta_short if eta_short is not None else 0.05, + "ema_alpha": 0.05, + "short_decay": 0.1, + "dense_topk": 64, + "rerank_topk": 5, + "max_new_tokens": 512, + # Dynamic topk settings + "dynamic_topk": dynamic_topk, + "dynamic_min_k": dynamic_min_k, + "dynamic_max_k": dynamic_max_k, + "dynamic_score_ratio": dynamic_score_ratio, + } + + # Store llm_name before loading config (needed in _load_config) + self._llm_name_override = llm_name + + # Load config and override RL params if available + self._load_config(config_path) + + # Load models + print("[PersonalizedLLM] Loading models...") + self._load_models() + + # Load memory store + print("[PersonalizedLLM] Loading memory store...") + self._load_memory_store() + + # Initialize user store + self._user_store = UserTensorStore( + k=self._rl_cfg["item_dim"], + path=user_store_path, + ) + + # Session contexts per user (in-memory) + self._sessions: Dict[str, _SessionContext] = {} + + print("[PersonalizedLLM] Initialization complete.") + + def _load_config(self, config_path: Optional[str]): + """Load configuration from yaml files.""" + self._cfg = load_local_models_config() + + # Try to load user_model.yaml for RL params + if config_path is None: + config_path = "configs/user_model.yaml" + + self._llm_name = self._llm_name_override or "qwen_1_5b" # Default, can be overridden + + try: + if os.path.exists(config_path): + with open(config_path, "r") as f: + user_cfg = yaml.safe_load(f) + if user_cfg: + # Override RL params if present + for key in self._rl_cfg: + if key in user_cfg: + self._rl_cfg[key] = user_cfg[key] + # LLM name (only from config if not already set via parameter) + if self._llm_name_override is None and "llm_name" in user_cfg: + self._llm_name = user_cfg["llm_name"] + except Exception as e: + print(f"[PersonalizedLLM] Warning: Failed to load config: {e}") + + def _load_models(self): + """Load all ML models with optional multi-GPU assignment.""" + import torch + + # Report GPU availability (only once, not for shared model instances) + if not self.use_shared_models: + num_gpus = torch.cuda.device_count() + print(f"[PersonalizedLLM] Available GPUs: {num_gpus}") + for i in range(num_gpus): + mem = torch.cuda.get_device_properties(i).total_memory / 1e9 + print(f" GPU {i}: {torch.cuda.get_device_name(i)} ({mem:.1f}GB)") + + embed_device = self._device_assignment.get("embed", "auto") + reranker_device = self._device_assignment.get("reranker", "auto") + chat_device = self._device_assignment.get("chat", "auto") + extractor_device = self._device_assignment.get("extractor", "auto") + + # Embedding model - only load for modes that use RAG retrieval + # Vanilla and contextual modes don't need embedding/reranker + needs_retrieval = self.mode not in ("vanilla", "contextual") + + if needs_retrieval: + if self.use_shared_models: + print(f"[PersonalizedLLM] Using shared embedding model...") + self._embed_model = get_shared_embedding_model( + model_path=self._cfg.embedding.qwen3.local_path, + device_map=embed_device, + ) + else: + print(f"[PersonalizedLLM] Loading Embedding model on {embed_device}...") + self._embed_model = Qwen3Embedding8B( + model_path=self._cfg.embedding.qwen3.local_path, + dtype=torch.bfloat16, + device_map=embed_device, + ) + else: + print(f"[PersonalizedLLM] Skipping embedding model (not needed for {self.mode} mode)") + self._embed_model = None + + # Reranker - only load for modes that use RAG retrieval + # Support both qwen3 (8B) and bge (278M) rerankers + if needs_retrieval: + if self.reranker_type == "bge": + reranker_path = getattr(self._cfg.reranker, "bge_base", None) + reranker_path = reranker_path.local_path if reranker_path else "BAAI/bge-reranker-base" + else: + reranker_path = self._cfg.reranker.qwen3_8b.local_path + + if self.use_shared_models: + print(f"[PersonalizedLLM] Using shared reranker model ({self.reranker_type})...") + self._reranker = get_shared_reranker( + model_path=reranker_path, + device_map=reranker_device, + reranker_type=self.reranker_type, + ) + else: + print(f"[PersonalizedLLM] Loading Reranker ({self.reranker_type}) on {reranker_device}...") + if self.reranker_type == "bge": + self._reranker = BGEReranker( + model_path=reranker_path, + device_map=reranker_device, + dtype=torch.float16, + ) + else: + self._reranker = Qwen3Reranker( + model_path=reranker_path, + device_map=reranker_device, + dtype=torch.bfloat16, + ) + else: + print(f"[PersonalizedLLM] Skipping reranker (not needed for {self.mode} mode)") + self._reranker = None + + # Chat model (via registry for backend switching) + print(f"[PersonalizedLLM] Loading ChatModel: {self._llm_name} on {chat_device}...") + # Pass device override if specified (not "auto") + device_for_chat = chat_device if chat_device != "auto" else None + self._chat_model = get_chat_model(self._llm_name, device_override=device_for_chat) + + # Preference extractor - use shared singleton if enabled + if self.enable_preference_extraction: + extractor_name = "qwen3_0_6b_sft" + if self.use_shared_models: + print(f"[PersonalizedLLM] Using shared preference extractor...") + try: + extractor_path = self._cfg.preference_extractor.get("qwen3_0_6b_sft", {}).get("path", None) + if extractor_path: + self._extractor = get_shared_extractor( + model_path=extractor_path, + device_map=extractor_device, + ) + else: + print(f"[PersonalizedLLM] Extractor path not found, using rule-based.") + self._extractor = get_preference_extractor("rule") + except Exception as e: + print(f"[PersonalizedLLM] Warning: Failed to load shared extractor: {e}. Trying fallbacks...") + try: + self._extractor = get_preference_extractor("rule") + except Exception as e2: + print(f"[PersonalizedLLM] Rule extractor also failed: {e2}. Using GPT-5-mini extractor.") + self._extractor = get_preference_extractor("gpt5_mini") + else: + print(f"[PersonalizedLLM] Loading extractor: {extractor_name} on {extractor_device}...") + try: + self._extractor = get_preference_extractor(extractor_name) + except Exception as e: + print(f"[PersonalizedLLM] Warning: Failed to load {extractor_name}: {e}. Trying fallbacks...") + try: + self._extractor = get_preference_extractor("rule") + except Exception as e2: + print(f"[PersonalizedLLM] Rule extractor also failed: {e2}. Using GPT-5-mini extractor.") + self._extractor = get_preference_extractor("gpt5_mini") + else: + print("[PersonalizedLLM] Preference extraction disabled, skipping extractor.") + self._extractor = None + + def _load_memory_store(self): + """Load memory cards and embeddings.""" + if not os.path.exists(self._memory_cards_path): + print(f"[PersonalizedLLM] Warning: Memory cards not found at {self._memory_cards_path}") + self._memory_cards: List[MemoryCard] = [] + self._memory_embeddings = np.zeros((0, 4096), dtype=np.float32) + self._item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32) + # Create default projection (truncation to first k dims) so preferences can be added + k = self._rl_cfg["item_dim"] + d = 4096 + P = np.zeros((k, d), dtype=np.float32) + P[:, :k] = np.eye(k, dtype=np.float32) + self._projection = ItemProjection(P=P, mean=np.zeros(d, dtype=np.float32)) + print(f"[PersonalizedLLM] Created default projection (truncation, k={k})") + return + + # Load cards + self._memory_cards = [] + with open(self._memory_cards_path, "r") as f: + for line in f: + line = line.strip() + if line: + self._memory_cards.append(MemoryCard.model_validate_json(line)) + + # Load embeddings + if os.path.exists(self._memory_embeddings_path): + self._memory_embeddings = np.load(self._memory_embeddings_path) + else: + self._memory_embeddings = np.zeros((len(self._memory_cards), 4096), dtype=np.float32) + + # Load projection + if os.path.exists(self._item_projection_path): + proj_data = np.load(self._item_projection_path) + self._projection = ItemProjection(P=proj_data["P"], mean=proj_data["mean"]) + self._item_vectors = proj_data["V"] + else: + # Create default projection so preferences can still be added + k = self._rl_cfg["item_dim"] + d = 4096 + P = np.zeros((k, d), dtype=np.float32) + P[:, :k] = np.eye(k, dtype=np.float32) + self._projection = ItemProjection(P=P, mean=np.zeros(d, dtype=np.float32)) + self._item_vectors = np.zeros((len(self._memory_cards), self._rl_cfg["item_dim"]), dtype=np.float32) + print(f"[PersonalizedLLM] Created default projection (truncation, k={k})") + + print(f"[PersonalizedLLM] Loaded {len(self._memory_cards)} memory cards.") + + def _get_or_create_session(self, user_id: str) -> _SessionContext: + """Get or create session context for a user.""" + if user_id not in self._sessions: + self._sessions[user_id] = _SessionContext( + session_state=OnlineSessionState(user_id=user_id), + turn_counter=0, + ) + return self._sessions[user_id] + + def _build_chat_turn(self, user_id: str, text: str, role: str, turn_id: int) -> ChatTurn: + """Build a ChatTurn object.""" + return ChatTurn( + user_id=user_id, + session_id=f"eval_session_{user_id}", + turn_id=turn_id, + role=role, + text=text, + meta={"source": "eval"} + ) + + def _count_tokens(self, text: str) -> int: + """Estimate token count using the tokenizer.""" + try: + # Use the chat model's tokenizer if available + if hasattr(self._chat_model, 'tokenizer'): + return len(self._chat_model.tokenizer.encode(text)) + else: + # Rough estimate: ~4 chars per token + return len(text) // 4 + except Exception: + return len(text) // 4 + + # Task type keywords for query transformation + _TASK_KEYWORDS = { + "math": ["solve", "calculate", "integral", "equation", "proof", "derivative", + "math", "algebra", "geometry", "trigonometry", "calculus", "arithmetic", + "formula", "compute", "evaluate", "simplify", "factor", "graph"], + "coding": ["code", "program", "function", "implement", "debug", "python", "java", + "javascript", "algorithm", "class", "method", "bug", "error", "compile", + "script", "html", "css", "sql", "api", "library", "framework"], + "writing": ["write", "essay", "paragraph", "summarize", "draft", "compose", + "article", "story", "letter", "email", "report", "review", "edit", + "rewrite", "paraphrase", "outline"], + "explanation": ["explain", "what is", "how does", "why", "describe", "define", + "meaning", "concept", "difference between", "compare", "contrast"], + } + + def _transform_query_for_retrieval(self, query: str) -> List[str]: + """ + Transform raw user query into multiple retrieval queries to bridge + the semantic gap between task queries and preference descriptions. + + Returns [original_query, transformed_query] or [original_query] if + no task type detected. + """ + import re + query_lower = query.lower() + detected_types = [] + for task_type, keywords in self._TASK_KEYWORDS.items(): + for kw in keywords: + # Use word boundary matching to avoid false positives + # e.g., "api" should not match "capital" + if re.search(r'\b' + re.escape(kw) + r'\b', query_lower): + detected_types.append(task_type) + break + + if not detected_types: + return [query] + + # Use first detected type (most specific match) + task_type = detected_types[0] + transformed = f"user preferences for {task_type} tasks: {query}" + return [query, transformed] + + # Patterns indicating a global/universal preference condition + _GLOBAL_PATTERNS = ["general", "any", "always", "all ", "every", "regardless", + "any task", "any topic", "any question", "all tasks", "all topics"] + + # Domain-specific terms that indicate a conditional preference + _DOMAIN_TERMS = ["math", "code", "coding", "program", "writing", "essay", "science", + "history", "language", "physics", "chemistry", "biology", "literature", + "creative", "technical", "formal", "informal", "academic", "casual"] + + def _classify_preference_scope(self, condition: str) -> bool: + """ + Classify whether a preference condition is global (always applicable) + or conditional (task-specific). + + Returns True if global, False if conditional. + """ + cond_lower = condition.lower().strip() + + # Check for explicit global patterns + for pattern in self._GLOBAL_PATTERNS: + if pattern in cond_lower: + return True + + # Very short/vague conditions with no domain terms are likely global + words = cond_lower.split() + if len(words) <= 2: + has_domain = any(term in cond_lower for term in self._DOMAIN_TERMS) + if not has_domain: + return True + + return False + + # Rewrite prompt for merging retrieved preferences + _REWRITE_PROMPT = """You are helping to prepare user preferences for an AI assistant. + +The user is asking: {query} + +Retrieved preferences about this user: +{preferences} + +Task: Create a concise preference summary that the assistant MUST follow. + +Rules: +1. PRESERVE all specific formatting requirements exactly (e.g., "type hints", "snake_case", "code fence with language") +2. PRESERVE all structural requirements (e.g., "numbered steps", "bullet points", "answer first then explanation") +3. Only MERGE preferences that are truly redundant (saying the same thing differently) +4. Output as a short bulleted list if there are multiple distinct requirements +5. Keep each point actionable and specific - NO vague generalizations like "follow best practices" + +Example input: +- Include type hints in Python code +- Use snake_case for variable names +- When explaining, use numbered steps + +Example output: +- Include type hints +- Use snake_case for variables +- Use numbered steps for explanations + +If no preferences are relevant to this query type, output: "No specific preferences apply." + +Preference summary:""" + + def _rewrite_preferences(self, memory_notes: List[str], query: str) -> List[str]: + """ + Use LLM to rewrite/merge multiple retrieved preferences into concise instructions. + + This is similar to Reflection's proper_scaffolding but focuses on merging + rather than just filtering. + + Args: + memory_notes: List of retrieved preference notes + query: Current user query + + Returns: + List with single rewritten instruction (or original if rewrite fails/disabled) + """ + if not memory_notes or len(memory_notes) <= 1: + return memory_notes + + try: + import requests + + # Format preferences for prompt + prefs_text = "\n".join(f"- {note}" for note in memory_notes) + prompt = self._REWRITE_PROMPT.format(query=query[:200], preferences=prefs_text) + + # Direct vLLM API call (simpler than going through chat model) + messages = [{"role": "user", "content": prompt}] + payload = { + "model": self._chat_model.model_name, + "messages": messages, + "max_tokens": 150, + "temperature": 0.3, # Lower temperature for more consistent output + } + + response = requests.post( + f"{self._chat_model.vllm_url}/chat/completions", + json=payload, + timeout=30 + ) + + if response.status_code != 200: + print(f"[REWRITE] API error {response.status_code}, keeping original notes") + return memory_notes + + result = response.json() + rewritten = result["choices"][0]["message"]["content"].strip().strip('"') + + # Validate response + if rewritten and len(rewritten) > 10 and "No specific preferences" not in rewritten: + print(f"[REWRITE] {len(memory_notes)} notes → 1 merged instruction") + return [rewritten] + else: + print(f"[REWRITE] Kept original {len(memory_notes)} notes (no valid merge)") + return memory_notes + + except Exception as e: + print(f"[REWRITE] Failed: {e}, keeping original notes") + return memory_notes + + # Consolidation prompt for session-end preference merging + _CONSOLIDATION_PROMPT = """You are analyzing user preferences extracted from conversations. + +Current preferences for this user: +{preferences} + +Task: Consolidate these preferences into a cleaner, more organized set by: +1. MERGE similar preferences (e.g., "use bullet points" + "format with bullets" → single preference) +2. REMOVE redundant or contradictory preferences (keep the more specific one) +3. PRESERVE all unique, meaningful preferences +4. Keep the same "When [condition], [action]." format + +Output ONLY the consolidated preferences, one per line, in this exact format: +When [condition], [action]. + +Do not add explanations or commentary. Just output the preference lines.""" + + def consolidate_user_preferences(self, user_id: str) -> int: + """ + Consolidate user preferences at session end using LLM. + + Merges similar preferences, removes redundancy, and creates cleaner + preference descriptions. Only runs if user has enough preferences. + + Args: + user_id: The user whose preferences to consolidate. + + Returns: + Number of preferences after consolidation (0 if skipped). + """ + if not self.enable_preference_consolidation: + return 0 + + # Get user's memory cards + user_cards = [c for c in self._memory_cards if c.user_id == user_id] + + if len(user_cards) < self.consolidation_threshold: + return len(user_cards) + + # Build preference list for prompt + pref_lines = [card.note_text for card in user_cards] + preferences_text = "\n".join(f"- {p}" for p in pref_lines) + + # Call LLM for consolidation + prompt = self._CONSOLIDATION_PROMPT.format(preferences=preferences_text) + messages = [{"role": "user", "content": prompt}] + + try: + result = self._chat_model.answer(messages, max_new_tokens=512) + consolidated_text = result.get("content", "").strip() + + if not consolidated_text: + return len(user_cards) + + # Parse consolidated preferences + new_prefs = [] + for line in consolidated_text.split("\n"): + line = line.strip() + if not line or not line.startswith("When "): + continue + # Parse "When [condition], [action]." + if ", " in line: + parts = line.split(", ", 1) + condition = parts[0].replace("When ", "").strip() + action = parts[1].rstrip(".").strip() + if condition and action: + new_prefs.append({ + "condition": condition, + "action": action, + "is_global": self._classify_preference_scope(condition) if self.enable_global_preferences else False, + }) + + if not new_prefs: + return len(user_cards) + + # Remove old cards for this user + keep_indices = [i for i, c in enumerate(self._memory_cards) if c.user_id != user_id] + self._memory_cards = [self._memory_cards[i] for i in keep_indices] + if len(keep_indices) > 0 and len(self._memory_embeddings) > 0: + self._memory_embeddings = self._memory_embeddings[keep_indices] + self._item_vectors = self._item_vectors[keep_indices] + else: + embed_dim = self._memory_embeddings.shape[1] if len(self._memory_embeddings) > 0 else 4096 + self._memory_embeddings = np.zeros((0, embed_dim), dtype=np.float32) + self._item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32) + + # Add consolidated preferences + for pref in new_prefs: + note_text = f"When {pref['condition']}, {pref['action']}." + + # Compute embedding + e_note = self._embed_model.encode([note_text], normalize=True, return_tensor=False)[0] + v_note = self._projection.transform_vector(np.array(e_note)) + + # Create card + card = MemoryCard( + card_id=str(uuid.uuid4()), + user_id=user_id, + source_session_id=f"consolidated_{user_id}", + source_turn_ids=[], + raw_queries=[], + preference_list=PreferenceList(preferences=[ + Preference(condition=pref["condition"], action=pref["action"], confidence=1.0) + ]), + note_text=note_text, + embedding_e=list(e_note), + kind="pref", + is_global=pref["is_global"], + ) + + self._memory_cards.append(card) + self._memory_embeddings = np.vstack([self._memory_embeddings, np.array([e_note])]) + self._item_vectors = np.vstack([self._item_vectors, np.array([v_note])]) + + print(f"[PersonalizedLLM] Consolidated {len(user_cards)} → {len(new_prefs)} preferences for user {user_id}") + return len(new_prefs) + + except Exception as e: + print(f"[PersonalizedLLM] Consolidation failed for user {user_id}: {e}") + return len(user_cards) + + def _add_preferences_as_memory( + self, + prefs: PreferenceList, + query: str, + user_id: str, + turn_id: int, + ) -> List[Dict[str, Any]]: + """ + Add extracted preferences as new memory cards. + Returns list of preference dicts for debug info. + """ + extracted = [] + + if not prefs.preferences or self._projection is None: + return extracted + + for pref in prefs.preferences: + note_text = f"When {pref.condition}, {pref.action}." + + # Record for debug + extracted.append({ + "condition": pref.condition, + "action": pref.action, + "confidence": pref.confidence, + }) + + # Deduplication check + is_duplicate = any( + card.user_id == user_id and card.note_text == note_text + for card in self._memory_cards + ) + + if is_duplicate: + continue + + # Compute embedding from note_text (NOT query) for proper semantic retrieval + # This ensures retrieval query "solve math problem" matches stored "When math problems..." + e_note = self._embed_model.encode([note_text], normalize=True, return_tensor=False)[0] + v_note = self._projection.transform_vector(np.array(e_note)) + + # Classify as global or conditional + is_global = self._classify_preference_scope(pref.condition) if self.enable_global_preferences else False + + # Create new memory card + card = MemoryCard( + card_id=str(uuid.uuid4()), + user_id=user_id, + source_session_id=f"eval_session_{user_id}", + source_turn_ids=[turn_id], + raw_queries=[query], + preference_list=PreferenceList(preferences=[pref]), + note_text=note_text, + embedding_e=list(e_note), + kind="pref", + is_global=is_global, + ) + + # Add to memory store + self._memory_cards.append(card) + self._memory_embeddings = np.vstack([self._memory_embeddings, np.array([e_note])]) + self._item_vectors = np.vstack([self._item_vectors, np.array([v_note])]) + + return extracted + + def _score_response(self, response: str) -> float: + """ + Score a response for best-of-N selection. + + Higher score = better response. Scoring heuristics: + 1. Length: Longer responses typically have more substance + 2. Solution indicators: Contains formulas, steps, answers + 3. Proactivity: Doesn't end with just a question + + Returns: + Float score (higher is better) + """ + score = 0.0 + response_lower = response.lower() + + # Length score (normalized, cap at 1000 chars) + score += min(len(response), 1000) / 1000 * 3.0 + + # Solution indicators (+1 each, max 5) + solution_indicators = ['=', 'step', 'answer', 'formula', 'result', 'therefore', 'solution'] + indicator_count = sum(1 for ind in solution_indicators if ind in response_lower) + score += min(indicator_count, 5) * 0.5 + + # Structured content (+1 for numbered/bulleted lists) + if any(marker in response for marker in ['1.', '2.', '- ', '* ', '##']): + score += 1.0 + + # Penalty for ending with question (passive behavior) + # Check last 100 chars for question marks + if '?' in response[-100:]: + score -= 1.5 + + # Bonus for providing concrete values/numbers + import re + numbers = re.findall(r'\d+\.?\d*', response) + if len(numbers) >= 3: + score += 1.0 + + return score + + # ========================================================================= + # Public Interface + # ========================================================================= + + def chat(self, user_id: str, query: str) -> AssistantResponse: + """ + Main online chat interface. + + Args: + user_id: Unique identifier for the user. + query: Current user query/message. + + Returns: + AssistantResponse containing the answer, usage stats, and debug info. + + Notes: + - Internally manages user state, session history, memory retrieval + - After this call, you can call apply_feedback() with the turn's feedback + """ + ctx = self._get_or_create_session(user_id) + session = ctx.session_state + user_state = self._user_store.get_state(user_id) + + # Record user vector before for debug + z_long_before = user_state.z_long.copy().tolist() + z_short_before = user_state.z_short.copy().tolist() + + # Add user turn to history + user_turn = self._build_chat_turn(user_id, query, "user", ctx.turn_counter) + session.history.append(user_turn) + + # Vanilla mode: pure LLM without any memory or preference extraction + if self.mode == "vanilla": + # Skip embedding, preference extraction, and memory retrieval entirely + e_q_t = np.zeros(4096, dtype=np.float32) # Placeholder for vanilla mode + extracted_prefs = [] + candidates = [] + cand_item_vecs = np.array([]) + base_scores = np.array([]) + chosen_indices = [] + probs = np.array([]) + memories_t = [] + memory_notes = [] + else: + # Compute query embedding (only needed for non-vanilla modes) + # Explicitly normalize for consistent cosine similarity with stored embeddings + embed_result = self._embed_model.encode([query], normalize=True, return_tensor=False) + if embed_result is None or len(embed_result) == 0: + raise RuntimeError(f"Embedding model returned empty result for query: {query[:100]}") + e_q_t = np.array(embed_result[0]) + + # Store pending RL update info from last turn (for apply_feedback) + if session.last_query is not None and self.enable_rl_updates: + ctx.pending_rl_update = { + "last_query": session.last_query, + "last_answer": session.last_answer, + "last_memories": session.last_memories, + "last_query_embedding": session.last_query_embedding, + "current_query_embedding": e_q_t, + "last_candidate_item_vectors": session.last_candidate_item_vectors, + "last_policy_probs": session.last_policy_probs, + "last_chosen_indices": session.last_chosen_indices, + } + + # Auto-compute reward via LLM judge if enabled + if self._llm_reward_client is not None: + import asyncio + try: + reward, gating = asyncio.run(eval_step_llm( + q_t=session.last_query, + answer_t=session.last_answer, + q_t1=query, + memories_t=session.last_memories or [], + client=self._llm_reward_client, + )) + if gating > 0.0: + self.apply_feedback(Feedback( + user_id=user_id, + turn_id=ctx.turn_counter - 1, + reward=reward, + gating=gating, + )) + except Exception as e: + # Graceful fallback: skip RL update if judge fails + print(f"[LLM-Reward] Judge call failed, skipping update: {e}") + + # Extract preferences from conversation (if enabled) + # extract_turn processes only the last user turn - efficient since called each turn + # Preferences accumulate in _memory_cards across turns (dedup prevents duplicates) + extracted_prefs = [] + if self.enable_preference_extraction: + prefs = self._extractor.extract_turn(session.history) + if prefs.preferences: + print(f"[DEBUG] Extracted {len(prefs.preferences)} prefs from history (len={len(session.history)})") + extracted_prefs = self._add_preferences_as_memory( + prefs, query, user_id, ctx.turn_counter + ) + if extracted_prefs: + print(f"[DEBUG] Added {len(extracted_prefs)} to memory. Total cards: {len(self._memory_cards)}") + + # Separate global preferences (bypass retrieval) from conditional ones + global_notes = [] + retrieval_cards = self._memory_cards + retrieval_embeddings = self._memory_embeddings + retrieval_item_vectors = self._item_vectors + if self.enable_global_preferences: + global_cards = [c for c in self._memory_cards if c.is_global and c.user_id == user_id] + global_notes = [c.note_text for c in global_cards[:10]] # Cap at 10 + # Filter out global cards for retrieval + cond_indices = [i for i, c in enumerate(self._memory_cards) if not c.is_global] + if cond_indices: + retrieval_cards = [self._memory_cards[i] for i in cond_indices] + retrieval_embeddings = self._memory_embeddings[cond_indices] + if len(self._item_vectors) > 0: + retrieval_item_vectors = self._item_vectors[cond_indices] + else: + retrieval_cards = [] + retrieval_embeddings = np.zeros((0, self._memory_embeddings.shape[1]), dtype=np.float32) if len(self._memory_embeddings) > 0 else self._memory_embeddings + retrieval_item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32) + + # Query transformation for better retrieval matching + retrieval_queries = None + if self.enable_query_transform: + retrieval_queries = self._transform_query_for_retrieval(query) + + # Retrieve memories + if self.mode == "nopersonal": + candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_no_policy( + user_id=user_id, + query=query, + embed_model=self._embed_model, + reranker=self._reranker, + memory_cards=retrieval_cards, + memory_embeddings=retrieval_embeddings, + topk_dense=self._rl_cfg["dense_topk"], + topk_rerank=self._rl_cfg["rerank_topk"], + only_own_memories=self.only_own_memories, + queries=retrieval_queries, + dynamic_topk=self._rl_cfg["dynamic_topk"], + dynamic_min_k=self._rl_cfg["dynamic_min_k"], + dynamic_max_k=self._rl_cfg["dynamic_max_k"], + dynamic_score_ratio=self._rl_cfg["dynamic_score_ratio"], + ) + else: + beta_long = self._rl_cfg["beta_long"] + beta_short = self._rl_cfg["beta_short"] + candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_with_policy( + user_id=user_id, + query=query, + embed_model=self._embed_model, + reranker=self._reranker, + memory_cards=retrieval_cards, + memory_embeddings=retrieval_embeddings, + user_store=self._user_store, + item_vectors=retrieval_item_vectors, + topk_dense=self._rl_cfg["dense_topk"], + topk_rerank=self._rl_cfg["rerank_topk"], + beta_long=beta_long, + beta_short=beta_short, + tau=self._rl_cfg["tau"], + only_own_memories=self.only_own_memories, + sample=not self.eval_mode, + queries=retrieval_queries, + ) + + # Get selected memories + memories_t = [candidates[int(i)] for i in chosen_indices] if chosen_indices else [] + memory_notes = [m.note_text for m in memories_t] + + # Apply preference rewrite if enabled + if self.enable_preference_rewrite and memory_notes: + memory_notes = self._rewrite_preferences(memory_notes, query) + + # Debug: show retrieval info + if memories_t or global_notes: + print(f"[DEBUG-RETRIEVAL] User={user_id}, Query={query[:50]}...") + print(f"[DEBUG-RETRIEVAL] Global={len(global_notes)}, Candidates={len(candidates)}, Retrieved={len(memories_t)}") + for i, m in enumerate(memories_t[:3]): # Show top 3 + score = probs[chosen_indices[i]] if i < len(chosen_indices) and chosen_indices[i] < len(probs) else 0 + print(f"[DEBUG-RETRIEVAL] [{i+1}] score={score:.3f}: {m.note_text[:80]}...") + + # Combine all notes for prompt (global + retrieved) + # For chat(), we combine all notes; chat_prepare() handles them separately + if self.mode != "vanilla": + all_memory_notes = (global_notes if global_notes else []) + memory_notes + else: + all_memory_notes = memory_notes + + # Build prompt and count tokens + prompt_tokens = self._count_tokens(query) + for turn in session.history: + prompt_tokens += self._count_tokens(turn.text) + for note in all_memory_notes: + prompt_tokens += self._count_tokens(note) + + # Generate answer (with best-of-N if enabled) + if self.best_of_n > 1: + # Generate N responses and pick the best one + candidates_responses = [] + for i in range(self.best_of_n): + resp = self._chat_model.answer( + history=session.history, + memory_notes=all_memory_notes, + max_new_tokens=self._rl_cfg["max_new_tokens"], + temperature=0.8, # Slightly higher temp for diversity + ) + score = self._score_response(resp) + candidates_responses.append((resp, score)) + + # Sort by score (descending) and pick best + candidates_responses.sort(key=lambda x: x[1], reverse=True) + answer_t = candidates_responses[0][0] + best_score = candidates_responses[0][1] + + if len(candidates_responses) > 1: + print(f"[BEST-OF-{self.best_of_n}] Scores: {[f'{s:.2f}' for _, s in candidates_responses]}, picked score={best_score:.2f}") + else: + answer_t = self._chat_model.answer( + history=session.history, + memory_notes=all_memory_notes, + max_new_tokens=self._rl_cfg["max_new_tokens"], + ) + + completion_tokens = self._count_tokens(answer_t) + + # Add assistant turn to history + assist_turn = self._build_chat_turn(user_id, answer_t, "assistant", ctx.turn_counter) + session.history.append(assist_turn) + + # Update session state for next turn + session.last_query = query + session.last_answer = answer_t + session.last_memories = memories_t + session.last_query_embedding = e_q_t + session.last_candidate_item_vectors = cand_item_vecs + session.last_policy_probs = probs + session.last_chosen_indices = list(chosen_indices) if len(chosen_indices) > 0 else [] + + ctx.turn_counter += 1 + + # Build debug info + debug = DebugInfo( + selected_memory_ids=[m.card_id for m in memories_t], + selected_memory_notes=[m.note_text for m in memories_t], + selected_memory_scores=[float(probs[i]) if i < len(probs) else 0.0 for i in chosen_indices] if len(chosen_indices) > 0 else [], + user_vector_before=z_long_before + z_short_before, # Concatenated for simplicity + user_vector_after=user_state.z_long.tolist() + user_state.z_short.tolist(), + extracted_preferences=extracted_prefs, + extra={ + "num_candidates": len(candidates), + "num_total_memories": len(self._memory_cards), + "z_long_norm": float(np.linalg.norm(user_state.z_long)), + "z_short_norm": float(np.linalg.norm(user_state.z_short)), + } + ) + + # Build usage stats + usage = UsageStats( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + model=self._llm_name, + ) + + return AssistantResponse( + answer=answer_t, + usage=usage, + debug=debug, + ) + + def chat_prepare(self, user_id: str, query: str, skip_extraction: bool = False, skip_auto_reward: bool = False) -> dict: + """ + Prepare for chat without calling the LLM. + + This does all the preparation work (embedding, memory retrieval, etc.) + and returns the messages to send to the LLM along with context needed + for post-processing. + + Used for batch processing where messages are collected first, then + sent in batch to vLLM for concurrent processing. + + Args: + user_id: Unique identifier for the user. + query: Current user query/message. + + Returns: + Dict containing: + - messages: List of messages to send to LLM + - context: Dict with all state needed for chat_complete() + """ + ctx = self._get_or_create_session(user_id) + session = ctx.session_state + user_state = self._user_store.get_state(user_id) + + # Record user vector before for debug + z_long_before = user_state.z_long.copy().tolist() + z_short_before = user_state.z_short.copy().tolist() + + # Add user turn to history + user_turn = self._build_chat_turn(user_id, query, "user", ctx.turn_counter) + session.history.append(user_turn) + + # Vanilla mode: pure LLM without any memory or preference extraction + if self.mode == "vanilla": + e_q_t = np.zeros(4096, dtype=np.float32) + extracted_prefs = [] + candidates = [] + cand_item_vecs = np.array([]) + base_scores = np.array([]) + chosen_indices = [] + probs = np.array([]) + memories_t = [] + memory_notes = [] + else: + # Compute query embedding + embed_result = self._embed_model.encode([query], normalize=True, return_tensor=False) + if embed_result is None or len(embed_result) == 0: + raise RuntimeError(f"Embedding model returned empty result for query: {query[:100]}") + e_q_t = np.array(embed_result[0]) + + # Store pending RL update info from last turn + if session.last_query is not None and self.enable_rl_updates: + ctx.pending_rl_update = { + "last_query": session.last_query, + "last_answer": session.last_answer, + "last_memories": session.last_memories, + "last_query_embedding": session.last_query_embedding, + "current_query_embedding": e_q_t, + "last_candidate_item_vectors": session.last_candidate_item_vectors, + "last_policy_probs": session.last_policy_probs, + "last_chosen_indices": session.last_chosen_indices, + } + + # Auto-compute reward via LLM judge if enabled + # skip_auto_reward=True when batch framework handles rewards externally + if self._llm_reward_client is not None and not skip_auto_reward: + import asyncio + try: + reward, gating = asyncio.run(eval_step_llm( + q_t=session.last_query, + answer_t=session.last_answer, + q_t1=query, + memories_t=session.last_memories or [], + client=self._llm_reward_client, + )) + if gating > 0.0: + self.apply_feedback(Feedback( + user_id=user_id, + turn_id=ctx.turn_counter - 1, + reward=reward, + gating=gating, + )) + except Exception as e: + print(f"[LLM-Reward] Judge call failed, skipping update: {e}") + + # Extract preferences from conversation + extracted_prefs = [] + if self.enable_preference_extraction and not skip_extraction: + prefs = self._extractor.extract_turn(session.history) + if prefs.preferences: + print(f"[DEBUG] Extracted {len(prefs.preferences)} prefs from history (len={len(session.history)})") + extracted_prefs = self._add_preferences_as_memory( + prefs, query, user_id, ctx.turn_counter + ) + if extracted_prefs: + print(f"[DEBUG] Added {len(extracted_prefs)} to memory. Total cards: {len(self._memory_cards)}") + + # Separate global preferences (bypass retrieval) from conditional ones + global_notes = [] + retrieval_cards = self._memory_cards + retrieval_embeddings = self._memory_embeddings + retrieval_item_vectors = self._item_vectors + if self.enable_global_preferences: + global_cards = [c for c in self._memory_cards if c.is_global and c.user_id == user_id] + global_notes = [c.note_text for c in global_cards[:10]] # Cap at 10 + cond_indices = [i for i, c in enumerate(self._memory_cards) if not c.is_global] + if cond_indices: + retrieval_cards = [self._memory_cards[i] for i in cond_indices] + retrieval_embeddings = self._memory_embeddings[cond_indices] + if len(self._item_vectors) > 0: + retrieval_item_vectors = self._item_vectors[cond_indices] + else: + retrieval_cards = [] + retrieval_embeddings = np.zeros((0, self._memory_embeddings.shape[1]), dtype=np.float32) if len(self._memory_embeddings) > 0 else self._memory_embeddings + retrieval_item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32) + + # Query transformation for better retrieval matching + retrieval_queries = None + if self.enable_query_transform: + retrieval_queries = self._transform_query_for_retrieval(query) + + # Retrieve memories + if self.mode == "nopersonal": + candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_no_policy( + user_id=user_id, + query=query, + embed_model=self._embed_model, + reranker=self._reranker, + memory_cards=retrieval_cards, + memory_embeddings=retrieval_embeddings, + topk_dense=self._rl_cfg["dense_topk"], + topk_rerank=self._rl_cfg["rerank_topk"], + only_own_memories=self.only_own_memories, + queries=retrieval_queries, + dynamic_topk=self._rl_cfg["dynamic_topk"], + dynamic_min_k=self._rl_cfg["dynamic_min_k"], + dynamic_max_k=self._rl_cfg["dynamic_max_k"], + dynamic_score_ratio=self._rl_cfg["dynamic_score_ratio"], + ) + else: + beta_long = self._rl_cfg["beta_long"] + beta_short = self._rl_cfg["beta_short"] + candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_with_policy( + user_id=user_id, + query=query, + embed_model=self._embed_model, + reranker=self._reranker, + memory_cards=retrieval_cards, + memory_embeddings=retrieval_embeddings, + user_store=self._user_store, + item_vectors=retrieval_item_vectors, + topk_dense=self._rl_cfg["dense_topk"], + topk_rerank=self._rl_cfg["rerank_topk"], + beta_long=beta_long, + beta_short=beta_short, + tau=self._rl_cfg["tau"], + only_own_memories=self.only_own_memories, + sample=not self.eval_mode, + queries=retrieval_queries, + ) + + memories_t = [candidates[int(i)] for i in chosen_indices] if chosen_indices else [] + memory_notes = [m.note_text for m in memories_t] + + # Apply preference rewrite if enabled + if self.enable_preference_rewrite and memory_notes: + memory_notes = self._rewrite_preferences(memory_notes, query) + + if memories_t or global_notes: + print(f"[DEBUG-RETRIEVAL] User={user_id}, Query={query[:50]}...") + print(f"[DEBUG-RETRIEVAL] Global={len(global_notes)}, Candidates={len(candidates)}, Retrieved={len(memories_t)}") + for i, m in enumerate(memories_t[:3]): + score = probs[chosen_indices[i]] if i < len(chosen_indices) and chosen_indices[i] < len(probs) else 0 + print(f"[DEBUG-RETRIEVAL] [{i+1}] score={score:.3f}: {m.note_text[:80]}...") + + # Build prompt token count + prompt_tokens = self._count_tokens(query) + for turn in session.history: + prompt_tokens += self._count_tokens(turn.text) + all_notes = memory_notes + (global_notes if self.mode != "vanilla" else []) + for note in all_notes: + prompt_tokens += self._count_tokens(note) + + # Build messages for LLM (pass global_notes separately for distinct prompt sections) + effective_global = global_notes if (self.enable_global_preferences and self.mode != "vanilla") else None + messages = self._chat_model.build_messages( + history=session.history, + memory_notes=memory_notes, + max_new_tokens=self._rl_cfg["max_new_tokens"], + global_notes=effective_global, + ) + + # Return messages and context for chat_complete + return { + "messages": messages, + "context": { + "user_id": user_id, + "query": query, + "ctx": ctx, + "session": session, + "user_state": user_state, + "z_long_before": z_long_before, + "z_short_before": z_short_before, + "e_q_t": e_q_t, + "extracted_prefs": extracted_prefs, + "candidates": candidates, + "cand_item_vecs": cand_item_vecs, + "base_scores": base_scores, + "chosen_indices": chosen_indices, + "probs": probs, + "memories_t": memories_t, + "memory_notes": memory_notes, + "prompt_tokens": prompt_tokens, + } + } + + def chat_complete(self, answer_t: str, context: dict) -> AssistantResponse: + """ + Complete chat with LLM response. + + This takes the LLM response and context from chat_prepare(), and + does all post-processing (add to history, debug info, etc.). + + Args: + answer_t: The LLM response text. + context: Context dict from chat_prepare(). + + Returns: + AssistantResponse containing the answer, usage stats, and debug info. + """ + # Unpack context + user_id = context["user_id"] + query = context["query"] + ctx = context["ctx"] + session = context["session"] + user_state = context["user_state"] + z_long_before = context["z_long_before"] + z_short_before = context["z_short_before"] + e_q_t = context["e_q_t"] + extracted_prefs = context["extracted_prefs"] + candidates = context["candidates"] + cand_item_vecs = context["cand_item_vecs"] + chosen_indices = context["chosen_indices"] + probs = context["probs"] + memories_t = context["memories_t"] + memory_notes = context["memory_notes"] + prompt_tokens = context["prompt_tokens"] + + completion_tokens = self._count_tokens(answer_t) + + # Add assistant turn to history + assist_turn = self._build_chat_turn(user_id, answer_t, "assistant", ctx.turn_counter) + session.history.append(assist_turn) + + # Update session state for next turn + session.last_query = query + session.last_answer = answer_t + session.last_memories = memories_t + session.last_query_embedding = e_q_t + session.last_candidate_item_vectors = cand_item_vecs + session.last_policy_probs = probs + session.last_chosen_indices = list(chosen_indices) if len(chosen_indices) > 0 else [] + + ctx.turn_counter += 1 + + # Build debug info + debug = DebugInfo( + selected_memory_ids=[m.card_id for m in memories_t], + selected_memory_notes=[m.note_text for m in memories_t], + selected_memory_scores=[float(probs[i]) if i < len(probs) else 0.0 for i in chosen_indices] if len(chosen_indices) > 0 else [], + user_vector_before=z_long_before + z_short_before, + user_vector_after=user_state.z_long.tolist() + user_state.z_short.tolist(), + extracted_preferences=extracted_prefs, + extra={ + "num_candidates": len(candidates), + "num_total_memories": len(self._memory_cards), + "z_long_norm": float(np.linalg.norm(user_state.z_long)), + "z_short_norm": float(np.linalg.norm(user_state.z_short)), + } + ) + + # Build usage stats + usage = UsageStats( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + model=self._llm_name, + ) + + return AssistantResponse( + answer=answer_t, + usage=usage, + debug=debug, + ) + + def apply_extracted_preferences(self, user_id: str, pref_dict: dict) -> list: + """Apply pre-computed extraction results (from batch extraction) to memory.""" + prefs = PreferenceList.model_validate(pref_dict) + if not prefs.preferences: + return [] + ctx = self._get_or_create_session(user_id) + query = ctx.session_state.history[-1].text if ctx.session_state.history else "" + extracted = self._add_preferences_as_memory(prefs, query, user_id, ctx.turn_counter) + if extracted: + print(f"[DEBUG] Batch-added {len(extracted)} to memory. Total cards: {len(self._memory_cards)}") + return extracted + + def get_last_user_query(self, user_id: str) -> str: + """Get the last user message text for this user's session.""" + ctx = self._sessions.get(user_id) + if ctx and ctx.session_state.history: + for t in reversed(ctx.session_state.history): + if t.role == "user": + return t.text + return "" + + def reset_session(self, user_id: str) -> None: + """ + Reset session for a user (new chat window). + + This clears: + - Session conversation history + - Short-term user vector (z_short) + - Pending RL update info + + This preserves: + - Long-term user vector (z_long) + - User's memory cards (may be consolidated if enabled) + + Args: + user_id: The user whose session to reset. + """ + # Consolidate preferences at session end (before clearing session) + if self.enable_preference_consolidation: + self.consolidate_user_preferences(user_id) + + # Clear session context + if user_id in self._sessions: + del self._sessions[user_id] + + # Create fresh session + self._sessions[user_id] = _SessionContext( + session_state=OnlineSessionState(user_id=user_id), + turn_counter=0, + ) + + # Reset short-term vector but keep long-term + user_state = self._user_store.get_state(user_id) + user_state.z_short = np.zeros(self._rl_cfg["item_dim"], dtype=np.float32) + self._user_store.save_state(user_state) + + def reset_user(self, user_id: str) -> None: + """ + Completely reset a user (new "life"). + + This clears: + - Long-term user vector (z_long) + - Short-term user vector (z_short) + - User's memory cards + - Session history + - All cached state + + Args: + user_id: The user to reset. + """ + # Clear session + if user_id in self._sessions: + del self._sessions[user_id] + + # Reset user state vectors + user_state = self._user_store.get_state(user_id) + user_state.z_long = self._user_store.global_init_z.copy() + user_state.z_short = np.zeros(self._rl_cfg["item_dim"], dtype=np.float32) + user_state.reward_ma = 0.0 + self._user_store.save_state(user_state) + + # Find indices to KEEP (cards NOT belonging to this user) + # Must do this BEFORE modifying _memory_cards + keep_indices = [ + i for i, card in enumerate(self._memory_cards) + if card.user_id != user_id + ] + + # Filter memory cards + self._memory_cards = [self._memory_cards[i] for i in keep_indices] + + # Filter embeddings and item vectors to match + if len(keep_indices) > 0 and len(self._memory_embeddings) > 0: + self._memory_embeddings = self._memory_embeddings[keep_indices] + self._item_vectors = self._item_vectors[keep_indices] + else: + # No cards left or no embeddings + embed_dim = self._memory_embeddings.shape[1] if len(self._memory_embeddings) > 0 else 4096 + self._memory_embeddings = np.zeros((0, embed_dim), dtype=np.float32) + self._item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32) + + def apply_feedback(self, feedback: Feedback) -> None: + """ + Apply feedback from user simulator or judge. + + This performs the REINFORCE update to user vectors based on + the reward signal from the previous turn. + + Args: + feedback: Feedback object containing reward, gating, and metadata. + + Notes: + - Should be called AFTER chat() but BEFORE the next chat() call + - Uses the stored context from the previous turn + - If enable_rl_updates is False, this is a no-op (logging only) + - If mode is "nopersonal", this is a no-op (baseline comparison) + """ + if not self.enable_rl_updates: + return + + # In "nopersonal" or "vanilla" mode, skip RL updates entirely (baseline) + if self.mode in ("nopersonal", "vanilla"): + return + + user_id = feedback.user_id + ctx = self._sessions.get(user_id) + + if ctx is None or ctx.pending_rl_update is None: + return + + pending = ctx.pending_rl_update + user_state = self._user_store.get_state(user_id) + + # Check if we have the necessary data for RL update + if (pending.get("last_candidate_item_vectors") is not None and + pending.get("last_policy_probs") is not None and + pending.get("last_chosen_indices") is not None and + len(pending["last_chosen_indices"]) > 0): + + # Extract chosen vectors + chosen_indices = pending["last_chosen_indices"] + candidate_vectors = pending["last_candidate_item_vectors"] + + if len(candidate_vectors) > 0: + print(f"[DEBUG-REINFORCE] User={user_id} reward={feedback.reward:.2f} " + f"n_candidates={len(candidate_vectors)} chosen={chosen_indices} " + f"probs_shape={pending['last_policy_probs'].shape if hasattr(pending['last_policy_probs'], 'shape') else 'N/A'}") + # REINFORCE expects: + # - item_vectors: ALL candidate vectors [K, k] + # - chosen_indices: indices into those candidates + # - policy_probs: probabilities over all K candidates [K] + updated = reinforce_update_user_state( + user_state=user_state, + item_vectors=candidate_vectors, # All candidates, not just chosen + chosen_indices=chosen_indices, # Original indices into candidates + policy_probs=pending["last_policy_probs"], + reward_hat=feedback.reward, + gating=feedback.gating, + tau=self._rl_cfg["tau"], + eta_long=self._rl_cfg["eta_long"], + eta_short=self._rl_cfg["eta_short"], + ema_alpha=self._rl_cfg["ema_alpha"], + short_decay=self._rl_cfg["short_decay"], + ) + + print(f"[DEBUG-REINFORCE] updated={updated} z_long_norm={np.linalg.norm(user_state.z_long):.15e}") + if updated: + self._user_store.save_state(user_state) + + # Clear pending update + ctx.pending_rl_update = None + + def get_user_state_summary(self, user_id: str) -> Dict[str, Any]: + """ + Get a summary of the user's current state (for debugging/analysis). + + Args: + user_id: The user to query. + + Returns: + Dictionary with user state information. + """ + user_state = self._user_store.get_state(user_id) + ctx = self._sessions.get(user_id) + + user_memory_count = sum( + 1 for card in self._memory_cards if card.user_id == user_id + ) + + return { + "user_id": user_id, + "z_long_norm": float(np.linalg.norm(user_state.z_long)), + "z_short_norm": float(np.linalg.norm(user_state.z_short)), + "reward_ma": user_state.reward_ma, + "session_history_length": len(ctx.session_state.history) if ctx else 0, + "turn_counter": ctx.turn_counter if ctx else 0, + "user_memory_count": user_memory_count, + "total_memory_count": len(self._memory_cards), + } + + def persist(self) -> None: + """ + Persist all state to disk. + + Call this at the end of an evaluation run to save: + - User state vectors + - Memory cards + """ + # Save user store + self._user_store.persist() + + # Save memory cards + with open(self._memory_cards_path, "w", encoding="utf-8") as f: + for card in self._memory_cards: + f.write(card.model_dump_json() + "\n") + + # Save embeddings + np.save(self._memory_embeddings_path, self._memory_embeddings) + + # Save item projection with updated vectors + if self._projection is not None: + np.savez( + self._item_projection_path, + P=self._projection.P, + mean=self._projection.mean, + V=self._item_vectors, + ) + + print("[PersonalizedLLM] State persisted to disk.") + + +# ============================================================================= +# Convenience Factory +# ============================================================================= + +def create_personalized_llm( + config_path: Optional[str] = None, + **kwargs +) -> PersonalizedLLM: + """ + Factory function to create a PersonalizedLLM instance. + + Args: + config_path: Optional path to configuration file. + **kwargs: Additional arguments passed to PersonalizedLLM constructor. + + Returns: + Configured PersonalizedLLM instance. + """ + return PersonalizedLLM(config_path=config_path, **kwargs) + -- cgit v1.2.3