#!/usr/bin/env python3 """ Personalized LLM Interface for Evaluation. This module provides the `PersonalizedLLM` class that wraps the entire personalization system into a clean interface for evaluation frameworks and user simulators. Interface contract: - chat(user_id, query) -> AssistantResponse: Main online interface - reset_session(user_id): Clear session history and short-term state - reset_user(user_id): Completely reset user (long-term, short-term, memories) - apply_feedback(feedback): Apply external feedback for RL updates """ from __future__ import annotations import os import sys import uuid from dataclasses import dataclass, field from typing import Any, Dict, List, Optional import numpy as np import yaml # Ensure src is in path for standalone usage _src_path = os.path.join(os.path.dirname(__file__), "../../..") if _src_path not in sys.path: sys.path.insert(0, _src_path) from personalization.config.settings import load_local_models_config from personalization.config.registry import get_preference_extractor, get_chat_model from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B from personalization.models.reranker.qwen3_reranker import Qwen3Reranker from personalization.user_model.tensor_store import UserTensorStore, UserState from personalization.user_model.session_state import OnlineSessionState from personalization.user_model.features import ItemProjection from personalization.retrieval.preference_store.schemas import ( MemoryCard, ChatTurn, PreferenceList, Preference ) from personalization.retrieval.pipeline import retrieve_with_policy, retrieve_no_policy from personalization.feedback.handlers import eval_step from personalization.user_model.policy.reinforce import reinforce_update_user_state # ============================================================================= # Data Classes for Interface # ============================================================================= @dataclass class UsageStats: """Token usage statistics from a chat completion.""" prompt_tokens: int completion_tokens: int total_tokens: int model: str @dataclass class DebugInfo: """ Debug information for analysis and ablation studies. All fields are optional - fill what you have, leave empty what you don't. """ selected_memory_ids: List[str] = field(default_factory=list) selected_memory_notes: List[str] = field(default_factory=list) selected_memory_scores: List[float] = field(default_factory=list) user_vector_before: Optional[List[float]] = None user_vector_after: Optional[List[float]] = None extracted_preferences: List[Dict[str, Any]] = field(default_factory=list) extra: Dict[str, Any] = field(default_factory=dict) @dataclass class AssistantResponse: """Response from the personalized LLM chat interface.""" answer: str usage: UsageStats debug: Optional[DebugInfo] = None @dataclass class Feedback: """ Feedback data structure for RL updates from user simulator or judge. Attributes: user_id: The user this feedback is for. turn_id: The turn this feedback refers to (from the previous turn). reward: Reward scalar computed by user simulator / judge. gating: Gating flag (1=valid learning signal, 0=skip update). meta: Additional metadata for training/analysis. """ user_id: str turn_id: int reward: float gating: float # Can be 0.0 or 1.0, or continuous meta: Dict[str, Any] = field(default_factory=dict) # ============================================================================= # Internal Session State Extended # ============================================================================= @dataclass class _SessionContext: """Extended session context for evaluation tracking.""" session_state: OnlineSessionState turn_counter: int = 0 # Store info needed for apply_feedback pending_rl_update: Optional[Dict[str, Any]] = None # ============================================================================= # PersonalizedLLM Class # ============================================================================= class PersonalizedLLM: """ Personalized LLM wrapper for evaluation frameworks. This class provides a clean interface that accepts only (user_id, query) for the main chat function, while internally managing: - User state vectors (z_long, z_short) - Session history - Memory retrieval and policy - Preference extraction and storage - RL updates Example usage: llm = PersonalizedLLM() # Reset user for fresh experiment llm.reset_user("user_123") # Start a session llm.reset_session("user_123") # Chat response = llm.chat("user_123", "What's a good recipe for dinner?") print(response.answer) # Apply feedback from previous turn (from turn 2 onwards) llm.apply_feedback(Feedback( user_id="user_123", turn_id=0, reward=0.8, gating=1.0 )) """ def __init__( self, config_path: Optional[str] = None, user_store_path: str = "data/users/user_store_eval.npz", memory_cards_path: str = "data/corpora/memory_cards.jsonl", memory_embeddings_path: str = "data/corpora/memory_embeddings.npy", item_projection_path: str = "data/corpora/item_projection.npz", only_own_memories: bool = True, enable_preference_extraction: bool = True, enable_rl_updates: bool = True, mode: str = "full", # "full", "nopersonal", or "vanilla" eval_mode: bool = True, # True = greedy selection, False = stochastic sampling device_assignment: Optional[Dict[str, str]] = None, # Multi-GPU support ): """ Initialize the PersonalizedLLM. Args: config_path: Path to config file. If None, uses default locations. user_store_path: Path to persist user state vectors. memory_cards_path: Path to memory cards JSONL file. memory_embeddings_path: Path to memory embeddings numpy file. item_projection_path: Path to item projection (PCA) file. only_own_memories: If True, only retrieve user's own memories (strict privacy). enable_preference_extraction: If True, extract preferences from user turns. enable_rl_updates: If True, apply RL updates via apply_feedback. mode: "full" for full personalization, "nopersonal" for baseline (no user vector influence), "vanilla" for pure LLM without any memory retrieval or preference extraction. eval_mode: If True, use greedy/deterministic selection (for evaluation). If False, use stochastic sampling (for training/exploration). device_assignment: Optional dict to assign models to specific GPUs. Example: {"embed": "cuda:0", "reranker": "cuda:1", "chat": "cuda:2", "extractor": "cuda:3"} If None, uses "auto" for all models. """ self.only_own_memories = only_own_memories self.enable_preference_extraction = enable_preference_extraction self.enable_rl_updates = enable_rl_updates self.mode = mode # "full" or "nopersonal" self.eval_mode = eval_mode # True = greedy, False = sample # Multi-GPU device assignment self._device_assignment = device_assignment or { "embed": "auto", "reranker": "auto", "chat": "auto", "extractor": "auto", } # Paths self._memory_cards_path = memory_cards_path self._memory_embeddings_path = memory_embeddings_path self._item_projection_path = item_projection_path # RL Configuration # Note: beta/eta increased for more significant z_u updates self._rl_cfg = { "item_dim": 256, "beta_long": 2.0, # Increased from 0.1 for stronger personalization "beta_short": 5.0, # Increased from 0.3 "tau": 1.0, "eta_long": 0.01, # Increased from 1e-3 for faster learning "eta_short": 0.05, # Increased from 5e-3 "ema_alpha": 0.05, "short_decay": 0.1, "dense_topk": 64, "rerank_topk": 3, "max_new_tokens": 512, } # Load config and override RL params if available self._load_config(config_path) # Load models print("[PersonalizedLLM] Loading models...") self._load_models() # Load memory store print("[PersonalizedLLM] Loading memory store...") self._load_memory_store() # Initialize user store self._user_store = UserTensorStore( k=self._rl_cfg["item_dim"], path=user_store_path, ) # Session contexts per user (in-memory) self._sessions: Dict[str, _SessionContext] = {} print("[PersonalizedLLM] Initialization complete.") def _load_config(self, config_path: Optional[str]): """Load configuration from yaml files.""" self._cfg = load_local_models_config() # Try to load user_model.yaml for RL params if config_path is None: config_path = "configs/user_model.yaml" self._llm_name = "qwen_1_5b" # Default try: if os.path.exists(config_path): with open(config_path, "r") as f: user_cfg = yaml.safe_load(f) if user_cfg: # Override RL params if present for key in self._rl_cfg: if key in user_cfg: self._rl_cfg[key] = user_cfg[key] # LLM name if "llm_name" in user_cfg: self._llm_name = user_cfg["llm_name"] except Exception as e: print(f"[PersonalizedLLM] Warning: Failed to load config: {e}") def _load_models(self): """Load all ML models with optional multi-GPU assignment.""" import torch # Report GPU availability num_gpus = torch.cuda.device_count() print(f"[PersonalizedLLM] Available GPUs: {num_gpus}") for i in range(num_gpus): mem = torch.cuda.get_device_properties(i).total_memory / 1e9 print(f" GPU {i}: {torch.cuda.get_device_name(i)} ({mem:.1f}GB)") embed_device = self._device_assignment.get("embed", "auto") reranker_device = self._device_assignment.get("reranker", "auto") chat_device = self._device_assignment.get("chat", "auto") extractor_device = self._device_assignment.get("extractor", "auto") # Embedding model print(f"[PersonalizedLLM] Loading Embedding model on {embed_device}...") self._embed_model = Qwen3Embedding8B( model_path=self._cfg.embedding.qwen3.local_path, dtype=torch.bfloat16, device_map=embed_device, ) # Reranker print(f"[PersonalizedLLM] Loading Reranker on {reranker_device}...") self._reranker = Qwen3Reranker( model_path=self._cfg.reranker.qwen3_8b.local_path, device_map=reranker_device, dtype=torch.bfloat16, ) # Chat model (via registry for backend switching) print(f"[PersonalizedLLM] Loading ChatModel: {self._llm_name} on {chat_device}...") # Pass device override if specified (not "auto") device_for_chat = chat_device if chat_device != "auto" else None self._chat_model = get_chat_model(self._llm_name, device_override=device_for_chat) # Preference extractor if self.enable_preference_extraction: extractor_name = "qwen3_0_6b_sft" print(f"[PersonalizedLLM] Loading extractor: {extractor_name} on {extractor_device}...") try: self._extractor = get_preference_extractor(extractor_name) except Exception as e: print(f"[PersonalizedLLM] Warning: Failed to load {extractor_name}: {e}. Using rule-based.") self._extractor = get_preference_extractor("rule") else: print("[PersonalizedLLM] Preference extraction disabled, using rule-based extractor.") self._extractor = get_preference_extractor("rule") def _load_memory_store(self): """Load memory cards and embeddings.""" if not os.path.exists(self._memory_cards_path): print(f"[PersonalizedLLM] Warning: Memory cards not found at {self._memory_cards_path}") self._memory_cards: List[MemoryCard] = [] self._memory_embeddings = np.zeros((0, 4096), dtype=np.float32) self._item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32) self._projection = None return # Load cards self._memory_cards = [] with open(self._memory_cards_path, "r") as f: for line in f: line = line.strip() if line: self._memory_cards.append(MemoryCard.model_validate_json(line)) # Load embeddings if os.path.exists(self._memory_embeddings_path): self._memory_embeddings = np.load(self._memory_embeddings_path) else: self._memory_embeddings = np.zeros((len(self._memory_cards), 4096), dtype=np.float32) # Load projection if os.path.exists(self._item_projection_path): proj_data = np.load(self._item_projection_path) self._projection = ItemProjection(P=proj_data["P"], mean=proj_data["mean"]) self._item_vectors = proj_data["V"] else: self._projection = None self._item_vectors = np.zeros((len(self._memory_cards), self._rl_cfg["item_dim"]), dtype=np.float32) print(f"[PersonalizedLLM] Loaded {len(self._memory_cards)} memory cards.") def _get_or_create_session(self, user_id: str) -> _SessionContext: """Get or create session context for a user.""" if user_id not in self._sessions: self._sessions[user_id] = _SessionContext( session_state=OnlineSessionState(user_id=user_id), turn_counter=0, ) return self._sessions[user_id] def _build_chat_turn(self, user_id: str, text: str, role: str, turn_id: int) -> ChatTurn: """Build a ChatTurn object.""" return ChatTurn( user_id=user_id, session_id=f"eval_session_{user_id}", turn_id=turn_id, role=role, text=text, meta={"source": "eval"} ) def _count_tokens(self, text: str) -> int: """Estimate token count using the tokenizer.""" try: # Use the chat model's tokenizer if available if hasattr(self._chat_model, 'tokenizer'): return len(self._chat_model.tokenizer.encode(text)) else: # Rough estimate: ~4 chars per token return len(text) // 4 except Exception: return len(text) // 4 def _add_preferences_as_memory( self, prefs: PreferenceList, query: str, user_id: str, turn_id: int, ) -> List[Dict[str, Any]]: """ Add extracted preferences as new memory cards. Returns list of preference dicts for debug info. """ extracted = [] if not prefs.preferences or self._projection is None: return extracted # Compute embedding for the query e_q = self._embed_model.encode([query], return_tensor=False)[0] v_q = self._projection.transform_vector(np.array(e_q)) for pref in prefs.preferences: note_text = f"When {pref.condition}, {pref.action}." # Record for debug extracted.append({ "condition": pref.condition, "action": pref.action, "confidence": pref.confidence, }) # Deduplication check is_duplicate = any( card.user_id == user_id and card.note_text == note_text for card in self._memory_cards ) if is_duplicate: continue # Create new memory card card = MemoryCard( card_id=str(uuid.uuid4()), user_id=user_id, source_session_id=f"eval_session_{user_id}", source_turn_ids=[turn_id], raw_queries=[query], preference_list=PreferenceList(preferences=[pref]), note_text=note_text, embedding_e=list(e_q), kind="pref", ) # Add to memory store self._memory_cards.append(card) self._memory_embeddings = np.vstack([self._memory_embeddings, np.array([e_q])]) self._item_vectors = np.vstack([self._item_vectors, np.array([v_q])]) return extracted # ========================================================================= # Public Interface # ========================================================================= def chat(self, user_id: str, query: str) -> AssistantResponse: """ Main online chat interface. Args: user_id: Unique identifier for the user. query: Current user query/message. Returns: AssistantResponse containing the answer, usage stats, and debug info. Notes: - Internally manages user state, session history, memory retrieval - After this call, you can call apply_feedback() with the turn's feedback """ ctx = self._get_or_create_session(user_id) session = ctx.session_state user_state = self._user_store.get_state(user_id) # Record user vector before for debug z_long_before = user_state.z_long.copy().tolist() z_short_before = user_state.z_short.copy().tolist() # Compute query embedding e_q_t = np.array(self._embed_model.encode([query], return_tensor=False)[0]) # Store pending RL update info from last turn (for apply_feedback) if session.last_query is not None and self.enable_rl_updates: ctx.pending_rl_update = { "last_query": session.last_query, "last_answer": session.last_answer, "last_memories": session.last_memories, "last_query_embedding": session.last_query_embedding, "current_query_embedding": e_q_t, "last_candidate_item_vectors": session.last_candidate_item_vectors, "last_policy_probs": session.last_policy_probs, "last_chosen_indices": session.last_chosen_indices, } # Add user turn to history user_turn = self._build_chat_turn(user_id, query, "user", ctx.turn_counter) session.history.append(user_turn) # Vanilla mode: pure LLM without any memory or preference extraction if self.mode == "vanilla": # Skip preference extraction and memory retrieval entirely extracted_prefs = [] candidates = [] cand_item_vecs = np.array([]) base_scores = np.array([]) chosen_indices = [] probs = np.array([]) memories_t = [] memory_notes = [] else: # Extract preferences from conversation (if enabled) extracted_prefs = [] if self.enable_preference_extraction: prefs = self._extractor.extract_turn(session.history) extracted_prefs = self._add_preferences_as_memory( prefs, query, user_id, ctx.turn_counter ) # Retrieve memories # In "nopersonal" mode: deterministic retrieval (dense + rerank + topk), no policy/user vector # In "full" mode: policy-based retrieval with user vector influence if self.mode == "nopersonal": candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_no_policy( user_id=user_id, query=query, embed_model=self._embed_model, reranker=self._reranker, memory_cards=self._memory_cards, memory_embeddings=self._memory_embeddings, topk_dense=self._rl_cfg["dense_topk"], topk_rerank=self._rl_cfg["rerank_topk"], only_own_memories=self.only_own_memories, ) else: beta_long = self._rl_cfg["beta_long"] beta_short = self._rl_cfg["beta_short"] # eval_mode=True -> sample=False (greedy/deterministic) # eval_mode=False -> sample=True (stochastic/exploration) candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_with_policy( user_id=user_id, query=query, embed_model=self._embed_model, reranker=self._reranker, memory_cards=self._memory_cards, memory_embeddings=self._memory_embeddings, user_store=self._user_store, item_vectors=self._item_vectors, topk_dense=self._rl_cfg["dense_topk"], topk_rerank=self._rl_cfg["rerank_topk"], beta_long=beta_long, beta_short=beta_short, tau=self._rl_cfg["tau"], only_own_memories=self.only_own_memories, sample=not self.eval_mode, ) # Get selected memories memories_t = [candidates[int(i)] for i in chosen_indices] if chosen_indices else [] memory_notes = [m.note_text for m in memories_t] # Build prompt and count tokens prompt_tokens = self._count_tokens(query) for turn in session.history: prompt_tokens += self._count_tokens(turn.text) for note in memory_notes: prompt_tokens += self._count_tokens(note) # Generate answer answer_t = self._chat_model.answer( history=session.history, memory_notes=memory_notes, max_new_tokens=self._rl_cfg["max_new_tokens"], ) completion_tokens = self._count_tokens(answer_t) # Add assistant turn to history assist_turn = self._build_chat_turn(user_id, answer_t, "assistant", ctx.turn_counter) session.history.append(assist_turn) # Update session state for next turn session.last_query = query session.last_answer = answer_t session.last_memories = memories_t session.last_query_embedding = e_q_t session.last_candidate_item_vectors = cand_item_vecs session.last_policy_probs = probs session.last_chosen_indices = list(chosen_indices) if len(chosen_indices) > 0 else [] ctx.turn_counter += 1 # Build debug info debug = DebugInfo( selected_memory_ids=[m.card_id for m in memories_t], selected_memory_notes=[m.note_text for m in memories_t], selected_memory_scores=[float(probs[i]) if i < len(probs) else 0.0 for i in chosen_indices] if len(chosen_indices) > 0 else [], user_vector_before=z_long_before + z_short_before, # Concatenated for simplicity user_vector_after=user_state.z_long.tolist() + user_state.z_short.tolist(), extracted_preferences=extracted_prefs, extra={ "num_candidates": len(candidates), "num_total_memories": len(self._memory_cards), "z_long_norm": float(np.linalg.norm(user_state.z_long)), "z_short_norm": float(np.linalg.norm(user_state.z_short)), } ) # Build usage stats usage = UsageStats( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, model=self._llm_name, ) return AssistantResponse( answer=answer_t, usage=usage, debug=debug, ) def reset_session(self, user_id: str) -> None: """ Reset session for a user (new chat window). This clears: - Session conversation history - Short-term user vector (z_short) - Pending RL update info This preserves: - Long-term user vector (z_long) - User's memory cards Args: user_id: The user whose session to reset. """ # Clear session context if user_id in self._sessions: del self._sessions[user_id] # Create fresh session self._sessions[user_id] = _SessionContext( session_state=OnlineSessionState(user_id=user_id), turn_counter=0, ) # Reset short-term vector but keep long-term user_state = self._user_store.get_state(user_id) user_state.z_short = np.zeros(self._rl_cfg["item_dim"], dtype=np.float32) self._user_store.save_state(user_state) def reset_user(self, user_id: str) -> None: """ Completely reset a user (new "life"). This clears: - Long-term user vector (z_long) - Short-term user vector (z_short) - User's memory cards - Session history - All cached state Args: user_id: The user to reset. """ # Clear session if user_id in self._sessions: del self._sessions[user_id] # Reset user state vectors user_state = self._user_store.get_state(user_id) user_state.z_long = self._user_store.global_init_z.copy() user_state.z_short = np.zeros(self._rl_cfg["item_dim"], dtype=np.float32) user_state.reward_ma = 0.0 self._user_store.save_state(user_state) # Find indices to KEEP (cards NOT belonging to this user) # Must do this BEFORE modifying _memory_cards keep_indices = [ i for i, card in enumerate(self._memory_cards) if card.user_id != user_id ] # Filter memory cards self._memory_cards = [self._memory_cards[i] for i in keep_indices] # Filter embeddings and item vectors to match if len(keep_indices) > 0 and len(self._memory_embeddings) > 0: self._memory_embeddings = self._memory_embeddings[keep_indices] self._item_vectors = self._item_vectors[keep_indices] else: # No cards left or no embeddings embed_dim = self._memory_embeddings.shape[1] if len(self._memory_embeddings) > 0 else 4096 self._memory_embeddings = np.zeros((0, embed_dim), dtype=np.float32) self._item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32) def apply_feedback(self, feedback: Feedback) -> None: """ Apply feedback from user simulator or judge. This performs the REINFORCE update to user vectors based on the reward signal from the previous turn. Args: feedback: Feedback object containing reward, gating, and metadata. Notes: - Should be called AFTER chat() but BEFORE the next chat() call - Uses the stored context from the previous turn - If enable_rl_updates is False, this is a no-op (logging only) - If mode is "nopersonal", this is a no-op (baseline comparison) """ if not self.enable_rl_updates: return # In "nopersonal" or "vanilla" mode, skip RL updates entirely (baseline) if self.mode in ("nopersonal", "vanilla"): return user_id = feedback.user_id ctx = self._sessions.get(user_id) if ctx is None or ctx.pending_rl_update is None: return pending = ctx.pending_rl_update user_state = self._user_store.get_state(user_id) # Check if we have the necessary data for RL update if (pending.get("last_candidate_item_vectors") is not None and pending.get("last_policy_probs") is not None and pending.get("last_chosen_indices") is not None and len(pending["last_chosen_indices"]) > 0): # Extract chosen vectors chosen_indices = pending["last_chosen_indices"] candidate_vectors = pending["last_candidate_item_vectors"] if len(candidate_vectors) > 0: # REINFORCE expects: # - item_vectors: ALL candidate vectors [K, k] # - chosen_indices: indices into those candidates # - policy_probs: probabilities over all K candidates [K] updated = reinforce_update_user_state( user_state=user_state, item_vectors=candidate_vectors, # All candidates, not just chosen chosen_indices=chosen_indices, # Original indices into candidates policy_probs=pending["last_policy_probs"], reward_hat=feedback.reward, gating=feedback.gating, tau=self._rl_cfg["tau"], eta_long=self._rl_cfg["eta_long"], eta_short=self._rl_cfg["eta_short"], ema_alpha=self._rl_cfg["ema_alpha"], short_decay=self._rl_cfg["short_decay"], ) if updated: self._user_store.save_state(user_state) # Clear pending update ctx.pending_rl_update = None def get_user_state_summary(self, user_id: str) -> Dict[str, Any]: """ Get a summary of the user's current state (for debugging/analysis). Args: user_id: The user to query. Returns: Dictionary with user state information. """ user_state = self._user_store.get_state(user_id) ctx = self._sessions.get(user_id) user_memory_count = sum( 1 for card in self._memory_cards if card.user_id == user_id ) return { "user_id": user_id, "z_long_norm": float(np.linalg.norm(user_state.z_long)), "z_short_norm": float(np.linalg.norm(user_state.z_short)), "reward_ma": user_state.reward_ma, "session_history_length": len(ctx.session_state.history) if ctx else 0, "turn_counter": ctx.turn_counter if ctx else 0, "user_memory_count": user_memory_count, "total_memory_count": len(self._memory_cards), } def persist(self) -> None: """ Persist all state to disk. Call this at the end of an evaluation run to save: - User state vectors - Memory cards """ # Save user store self._user_store.persist() # Save memory cards with open(self._memory_cards_path, "w", encoding="utf-8") as f: for card in self._memory_cards: f.write(card.model_dump_json() + "\n") # Save embeddings np.save(self._memory_embeddings_path, self._memory_embeddings) # Save item projection with updated vectors if self._projection is not None: np.savez( self._item_projection_path, P=self._projection.P, mean=self._projection.mean, V=self._item_vectors, ) print("[PersonalizedLLM] State persisted to disk.") # ============================================================================= # Convenience Factory # ============================================================================= def create_personalized_llm( config_path: Optional[str] = None, **kwargs ) -> PersonalizedLLM: """ Factory function to create a PersonalizedLLM instance. Args: config_path: Optional path to configuration file. **kwargs: Additional arguments passed to PersonalizedLLM constructor. Returns: Configured PersonalizedLLM instance. """ return PersonalizedLLM(config_path=config_path, **kwargs)