diff options
Diffstat (limited to 'src/personalization/serving/personalized_llm.py')
| -rw-r--r-- | src/personalization/serving/personalized_llm.py | 721 |
1 files changed, 637 insertions, 84 deletions
diff --git a/src/personalization/serving/personalized_llm.py b/src/personalization/serving/personalized_llm.py index 2c4d5a8..733ff87 100644 --- a/src/personalization/serving/personalized_llm.py +++ b/src/personalization/serving/personalized_llm.py @@ -33,6 +33,7 @@ from personalization.config.settings import load_local_models_config from personalization.config.registry import get_preference_extractor, get_chat_model from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B from personalization.models.reranker.qwen3_reranker import Qwen3Reranker +from personalization.models.reranker.bge_reranker import BGEReranker from personalization.user_model.tensor_store import UserTensorStore, UserState from personalization.user_model.session_state import OnlineSessionState from personalization.user_model.features import ItemProjection @@ -40,7 +41,8 @@ from personalization.retrieval.preference_store.schemas import ( MemoryCard, ChatTurn, PreferenceList, Preference ) from personalization.retrieval.pipeline import retrieve_with_policy, retrieve_no_policy -from personalization.feedback.handlers import eval_step +from personalization.feedback.handlers import eval_step, eval_step_llm +from personalization.feedback.llm_reward import LLMRewardClient, LLMRewardConfig from personalization.user_model.policy.reinforce import reinforce_update_user_state @@ -113,6 +115,119 @@ class _SessionContext: # ============================================================================= +# Shared Model Singletons for Multi-threaded Efficiency +# ============================================================================= + +_shared_embed_model = None +_shared_reranker = None +_shared_extractor = None +_shared_models_lock = None # Will be initialized on first use + + +def _get_shared_models_lock(): + """Get or create the threading lock for shared models.""" + global _shared_models_lock + if _shared_models_lock is None: + import threading + _shared_models_lock = threading.Lock() + return _shared_models_lock + + +def get_shared_embedding_model(model_path: str, device_map: str = "auto"): + """Get or create shared embedding model (thread-safe singleton).""" + global _shared_embed_model + import torch + + lock = _get_shared_models_lock() + with lock: + if _shared_embed_model is None: + print(f"[SharedModels] Loading shared embedding model on {device_map}...") + _shared_embed_model = Qwen3Embedding8B( + model_path=model_path, + dtype=torch.bfloat16, + device_map=device_map, + ) + print("[SharedModels] Shared embedding model loaded.") + return _shared_embed_model + + +def get_shared_reranker(model_path: str, device_map: str = "auto", reranker_type: str = "qwen3"): + """Get or create shared reranker model (thread-safe singleton).""" + global _shared_reranker + import torch + + lock = _get_shared_models_lock() + with lock: + if _shared_reranker is None: + print(f"[SharedModels] Loading shared reranker ({reranker_type}) on {device_map}...") + if reranker_type == "bge": + _shared_reranker = BGEReranker( + model_path=model_path, + device_map=device_map, + dtype=torch.float16, + ) + else: + _shared_reranker = Qwen3Reranker( + model_path=model_path, + device_map=device_map, + dtype=torch.bfloat16, + ) + print("[SharedModels] Shared reranker model loaded.") + return _shared_reranker + + +def get_shared_extractor(model_path: str, device_map: str = "auto"): + """Get or create shared preference extractor model (thread-safe singleton).""" + global _shared_extractor + import torch + from personalization.models.preference_extractor.rule_extractor import QwenRuleExtractor + + lock = _get_shared_models_lock() + with lock: + if _shared_extractor is None: + print(f"[SharedModels] Loading shared preference extractor on {device_map}...") + _shared_extractor = QwenRuleExtractor( + model_path=model_path, + dtype=torch.bfloat16, + device_map=device_map, + ) + print("[SharedModels] Shared preference extractor loaded.") + return _shared_extractor + + +def clear_shared_models(): + """Free all shared singleton models to reclaim GPU memory between methods.""" + global _shared_embed_model, _shared_reranker, _shared_extractor + import gc + + lock = _get_shared_models_lock() + with lock: + freed = [] + if _shared_embed_model is not None: + freed.append("embedding") + del _shared_embed_model + _shared_embed_model = None + if _shared_reranker is not None: + freed.append("reranker") + del _shared_reranker + _shared_reranker = None + if _shared_extractor is not None: + freed.append("extractor") + del _shared_extractor + _shared_extractor = None + + if freed: + gc.collect() + try: + import torch + if torch.cuda.is_available(): + torch.cuda.empty_cache() + except ImportError: + pass + print(f"[SharedModels] Cleared: {', '.join(freed)}") + + +# ============================================================================= # PersonalizedLLM Class # ============================================================================= @@ -163,6 +278,12 @@ class PersonalizedLLM: mode: str = "full", # "full", "nopersonal", or "vanilla" eval_mode: bool = True, # True = greedy selection, False = stochastic sampling device_assignment: Optional[Dict[str, str]] = None, # Multi-GPU support + llm_name: Optional[str] = None, # Override LLM name (e.g., "llama_8b_vllm" for vLLM) + use_shared_models: bool = False, # Use shared singleton models for multi-threaded efficiency + reranker_type: str = "qwen3", # "qwen3" (8B) or "bge" (278M) + best_of_n: int = 1, # Generate N responses and pick best (for RAG methods) + reward_mode: str = "keyword", # "keyword" (legacy heuristic) or "llm" (GPT-5-nano judge) + llm_reward_config: Optional["LLMRewardConfig"] = None, # Config for LLM judge ): """ Initialize the PersonalizedLLM. @@ -183,12 +304,25 @@ class PersonalizedLLM: device_assignment: Optional dict to assign models to specific GPUs. Example: {"embed": "cuda:0", "reranker": "cuda:1", "chat": "cuda:2", "extractor": "cuda:3"} If None, uses "auto" for all models. + use_shared_models: If True, use shared singleton models for embedding and reranker. + This is essential for multi-threaded/parallel profile processing to avoid + loading duplicate models. When enabled, the first thread loads the models, + and subsequent threads reuse the shared instances. """ self.only_own_memories = only_own_memories + self.use_shared_models = use_shared_models self.enable_preference_extraction = enable_preference_extraction self.enable_rl_updates = enable_rl_updates self.mode = mode # "full" or "nopersonal" self.eval_mode = eval_mode # True = greedy, False = sample + self.reranker_type = reranker_type # "qwen3" or "bge" + self.best_of_n = best_of_n # Generate N responses and pick best + self.reward_mode = reward_mode # "keyword" or "llm" + + # Initialize LLM reward client if using LLM judge + self._llm_reward_client: Optional[LLMRewardClient] = None + if reward_mode == "llm": + self._llm_reward_client = LLMRewardClient(llm_reward_config or LLMRewardConfig()) # Multi-GPU device assignment self._device_assignment = device_assignment or { @@ -219,6 +353,9 @@ class PersonalizedLLM: "max_new_tokens": 512, } + # Store llm_name before loading config (needed in _load_config) + self._llm_name_override = llm_name + # Load config and override RL params if available self._load_config(config_path) @@ -249,8 +386,8 @@ class PersonalizedLLM: if config_path is None: config_path = "configs/user_model.yaml" - self._llm_name = "qwen_1_5b" # Default - + self._llm_name = self._llm_name_override or "qwen_1_5b" # Default, can be overridden + try: if os.path.exists(config_path): with open(config_path, "r") as f: @@ -260,8 +397,8 @@ class PersonalizedLLM: for key in self._rl_cfg: if key in user_cfg: self._rl_cfg[key] = user_cfg[key] - # LLM name - if "llm_name" in user_cfg: + # LLM name (only from config if not already set via parameter) + if self._llm_name_override is None and "llm_name" in user_cfg: self._llm_name = user_cfg["llm_name"] except Exception as e: print(f"[PersonalizedLLM] Warning: Failed to load config: {e}") @@ -269,53 +406,110 @@ class PersonalizedLLM: def _load_models(self): """Load all ML models with optional multi-GPU assignment.""" import torch - - # Report GPU availability - num_gpus = torch.cuda.device_count() - print(f"[PersonalizedLLM] Available GPUs: {num_gpus}") - for i in range(num_gpus): - mem = torch.cuda.get_device_properties(i).total_memory / 1e9 - print(f" GPU {i}: {torch.cuda.get_device_name(i)} ({mem:.1f}GB)") - + + # Report GPU availability (only once, not for shared model instances) + if not self.use_shared_models: + num_gpus = torch.cuda.device_count() + print(f"[PersonalizedLLM] Available GPUs: {num_gpus}") + for i in range(num_gpus): + mem = torch.cuda.get_device_properties(i).total_memory / 1e9 + print(f" GPU {i}: {torch.cuda.get_device_name(i)} ({mem:.1f}GB)") + embed_device = self._device_assignment.get("embed", "auto") reranker_device = self._device_assignment.get("reranker", "auto") chat_device = self._device_assignment.get("chat", "auto") extractor_device = self._device_assignment.get("extractor", "auto") - - # Embedding model - print(f"[PersonalizedLLM] Loading Embedding model on {embed_device}...") - self._embed_model = Qwen3Embedding8B( - model_path=self._cfg.embedding.qwen3.local_path, - dtype=torch.bfloat16, - device_map=embed_device, - ) - - # Reranker - print(f"[PersonalizedLLM] Loading Reranker on {reranker_device}...") - self._reranker = Qwen3Reranker( - model_path=self._cfg.reranker.qwen3_8b.local_path, - device_map=reranker_device, - dtype=torch.bfloat16, - ) - + + # Embedding model - only load for modes that use RAG retrieval + # Vanilla and contextual modes don't need embedding/reranker + needs_retrieval = self.mode not in ("vanilla", "contextual") + + if needs_retrieval: + if self.use_shared_models: + print(f"[PersonalizedLLM] Using shared embedding model...") + self._embed_model = get_shared_embedding_model( + model_path=self._cfg.embedding.qwen3.local_path, + device_map=embed_device, + ) + else: + print(f"[PersonalizedLLM] Loading Embedding model on {embed_device}...") + self._embed_model = Qwen3Embedding8B( + model_path=self._cfg.embedding.qwen3.local_path, + dtype=torch.bfloat16, + device_map=embed_device, + ) + else: + print(f"[PersonalizedLLM] Skipping embedding model (not needed for {self.mode} mode)") + self._embed_model = None + + # Reranker - only load for modes that use RAG retrieval + # Support both qwen3 (8B) and bge (278M) rerankers + if needs_retrieval: + if self.reranker_type == "bge": + reranker_path = getattr(self._cfg.reranker, "bge_base", None) + reranker_path = reranker_path.local_path if reranker_path else "BAAI/bge-reranker-base" + else: + reranker_path = self._cfg.reranker.qwen3_8b.local_path + + if self.use_shared_models: + print(f"[PersonalizedLLM] Using shared reranker model ({self.reranker_type})...") + self._reranker = get_shared_reranker( + model_path=reranker_path, + device_map=reranker_device, + reranker_type=self.reranker_type, + ) + else: + print(f"[PersonalizedLLM] Loading Reranker ({self.reranker_type}) on {reranker_device}...") + if self.reranker_type == "bge": + self._reranker = BGEReranker( + model_path=reranker_path, + device_map=reranker_device, + dtype=torch.float16, + ) + else: + self._reranker = Qwen3Reranker( + model_path=reranker_path, + device_map=reranker_device, + dtype=torch.bfloat16, + ) + else: + print(f"[PersonalizedLLM] Skipping reranker (not needed for {self.mode} mode)") + self._reranker = None + # Chat model (via registry for backend switching) print(f"[PersonalizedLLM] Loading ChatModel: {self._llm_name} on {chat_device}...") # Pass device override if specified (not "auto") device_for_chat = chat_device if chat_device != "auto" else None self._chat_model = get_chat_model(self._llm_name, device_override=device_for_chat) - - # Preference extractor + + # Preference extractor - use shared singleton if enabled if self.enable_preference_extraction: extractor_name = "qwen3_0_6b_sft" - print(f"[PersonalizedLLM] Loading extractor: {extractor_name} on {extractor_device}...") - try: - self._extractor = get_preference_extractor(extractor_name) - except Exception as e: - print(f"[PersonalizedLLM] Warning: Failed to load {extractor_name}: {e}. Using rule-based.") - self._extractor = get_preference_extractor("rule") + if self.use_shared_models: + print(f"[PersonalizedLLM] Using shared preference extractor...") + try: + extractor_path = self._cfg.preference_extractor.get("qwen3_0_6b_sft", {}).get("path", None) + if extractor_path: + self._extractor = get_shared_extractor( + model_path=extractor_path, + device_map=extractor_device, + ) + else: + print(f"[PersonalizedLLM] Extractor path not found, using rule-based.") + self._extractor = get_preference_extractor("rule") + except Exception as e: + print(f"[PersonalizedLLM] Warning: Failed to load shared extractor: {e}. Using rule-based.") + self._extractor = get_preference_extractor("rule") + else: + print(f"[PersonalizedLLM] Loading extractor: {extractor_name} on {extractor_device}...") + try: + self._extractor = get_preference_extractor(extractor_name) + except Exception as e: + print(f"[PersonalizedLLM] Warning: Failed to load {extractor_name}: {e}. Using rule-based.") + self._extractor = get_preference_extractor("rule") else: - print("[PersonalizedLLM] Preference extraction disabled, using rule-based extractor.") - self._extractor = get_preference_extractor("rule") + print("[PersonalizedLLM] Preference extraction disabled, skipping extractor.") + self._extractor = None def _load_memory_store(self): """Load memory cards and embeddings.""" @@ -396,33 +590,34 @@ class PersonalizedLLM: Returns list of preference dicts for debug info. """ extracted = [] - + if not prefs.preferences or self._projection is None: return extracted - - # Compute embedding for the query - e_q = self._embed_model.encode([query], return_tensor=False)[0] - v_q = self._projection.transform_vector(np.array(e_q)) - + for pref in prefs.preferences: note_text = f"When {pref.condition}, {pref.action}." - + # Record for debug extracted.append({ "condition": pref.condition, "action": pref.action, "confidence": pref.confidence, }) - + # Deduplication check is_duplicate = any( card.user_id == user_id and card.note_text == note_text for card in self._memory_cards ) - + if is_duplicate: continue - + + # Compute embedding from note_text (NOT query) for proper semantic retrieval + # This ensures retrieval query "solve math problem" matches stored "When math problems..." + e_note = self._embed_model.encode([note_text], normalize=True, return_tensor=False)[0] + v_note = self._projection.transform_vector(np.array(e_note)) + # Create new memory card card = MemoryCard( card_id=str(uuid.uuid4()), @@ -432,21 +627,61 @@ class PersonalizedLLM: raw_queries=[query], preference_list=PreferenceList(preferences=[pref]), note_text=note_text, - embedding_e=list(e_q), + embedding_e=list(e_note), kind="pref", ) - + # Add to memory store self._memory_cards.append(card) - self._memory_embeddings = np.vstack([self._memory_embeddings, np.array([e_q])]) - self._item_vectors = np.vstack([self._item_vectors, np.array([v_q])]) + self._memory_embeddings = np.vstack([self._memory_embeddings, np.array([e_note])]) + self._item_vectors = np.vstack([self._item_vectors, np.array([v_note])]) return extracted - + + def _score_response(self, response: str) -> float: + """ + Score a response for best-of-N selection. + + Higher score = better response. Scoring heuristics: + 1. Length: Longer responses typically have more substance + 2. Solution indicators: Contains formulas, steps, answers + 3. Proactivity: Doesn't end with just a question + + Returns: + Float score (higher is better) + """ + score = 0.0 + response_lower = response.lower() + + # Length score (normalized, cap at 1000 chars) + score += min(len(response), 1000) / 1000 * 3.0 + + # Solution indicators (+1 each, max 5) + solution_indicators = ['=', 'step', 'answer', 'formula', 'result', 'therefore', 'solution'] + indicator_count = sum(1 for ind in solution_indicators if ind in response_lower) + score += min(indicator_count, 5) * 0.5 + + # Structured content (+1 for numbered/bulleted lists) + if any(marker in response for marker in ['1.', '2.', '- ', '* ', '##']): + score += 1.0 + + # Penalty for ending with question (passive behavior) + # Check last 100 chars for question marks + if '?' in response[-100:]: + score -= 1.5 + + # Bonus for providing concrete values/numbers + import re + numbers = re.findall(r'\d+\.?\d*', response) + if len(numbers) >= 3: + score += 1.0 + + return score + # ========================================================================= # Public Interface # ========================================================================= - + def chat(self, user_id: str, query: str) -> AssistantResponse: """ Main online chat interface. @@ -465,34 +700,19 @@ class PersonalizedLLM: ctx = self._get_or_create_session(user_id) session = ctx.session_state user_state = self._user_store.get_state(user_id) - + # Record user vector before for debug z_long_before = user_state.z_long.copy().tolist() z_short_before = user_state.z_short.copy().tolist() - - # Compute query embedding - e_q_t = np.array(self._embed_model.encode([query], return_tensor=False)[0]) - - # Store pending RL update info from last turn (for apply_feedback) - if session.last_query is not None and self.enable_rl_updates: - ctx.pending_rl_update = { - "last_query": session.last_query, - "last_answer": session.last_answer, - "last_memories": session.last_memories, - "last_query_embedding": session.last_query_embedding, - "current_query_embedding": e_q_t, - "last_candidate_item_vectors": session.last_candidate_item_vectors, - "last_policy_probs": session.last_policy_probs, - "last_chosen_indices": session.last_chosen_indices, - } - + # Add user turn to history user_turn = self._build_chat_turn(user_id, query, "user", ctx.turn_counter) session.history.append(user_turn) - + # Vanilla mode: pure LLM without any memory or preference extraction if self.mode == "vanilla": - # Skip preference extraction and memory retrieval entirely + # Skip embedding, preference extraction, and memory retrieval entirely + e_q_t = np.zeros(4096, dtype=np.float32) # Placeholder for vanilla mode extracted_prefs = [] candidates = [] cand_item_vecs = np.array([]) @@ -502,13 +722,61 @@ class PersonalizedLLM: memories_t = [] memory_notes = [] else: + # Compute query embedding (only needed for non-vanilla modes) + # Explicitly normalize for consistent cosine similarity with stored embeddings + embed_result = self._embed_model.encode([query], normalize=True, return_tensor=False) + if embed_result is None or len(embed_result) == 0: + raise RuntimeError(f"Embedding model returned empty result for query: {query[:100]}") + e_q_t = np.array(embed_result[0]) + + # Store pending RL update info from last turn (for apply_feedback) + if session.last_query is not None and self.enable_rl_updates: + ctx.pending_rl_update = { + "last_query": session.last_query, + "last_answer": session.last_answer, + "last_memories": session.last_memories, + "last_query_embedding": session.last_query_embedding, + "current_query_embedding": e_q_t, + "last_candidate_item_vectors": session.last_candidate_item_vectors, + "last_policy_probs": session.last_policy_probs, + "last_chosen_indices": session.last_chosen_indices, + } + + # Auto-compute reward via LLM judge if enabled + if self.reward_mode == "llm" and self._llm_reward_client is not None: + import asyncio + try: + reward, gating = asyncio.run(eval_step_llm( + q_t=session.last_query, + answer_t=session.last_answer, + q_t1=query, + memories_t=session.last_memories or [], + client=self._llm_reward_client, + )) + if gating > 0.0: + self.apply_feedback(Feedback( + user_id=user_id, + turn_id=ctx.turn_counter - 1, + reward=reward, + gating=gating, + )) + except Exception as e: + # Graceful fallback: skip RL update if judge fails + print(f"[LLM-Reward] Judge call failed, skipping update: {e}") + # Extract preferences from conversation (if enabled) + # extract_turn processes only the last user turn - efficient since called each turn + # Preferences accumulate in _memory_cards across turns (dedup prevents duplicates) extracted_prefs = [] if self.enable_preference_extraction: prefs = self._extractor.extract_turn(session.history) + if prefs.preferences: + print(f"[DEBUG] Extracted {len(prefs.preferences)} prefs from history (len={len(session.history)})") extracted_prefs = self._add_preferences_as_memory( prefs, query, user_id, ctx.turn_counter ) + if extracted_prefs: + print(f"[DEBUG] Added {len(extracted_prefs)} to memory. Total cards: {len(self._memory_cards)}") # Retrieve memories # In "nopersonal" mode: deterministic retrieval (dense + rerank + topk), no policy/user vector @@ -551,6 +819,14 @@ class PersonalizedLLM: # Get selected memories memories_t = [candidates[int(i)] for i in chosen_indices] if chosen_indices else [] memory_notes = [m.note_text for m in memories_t] + + # Debug: show retrieval info + if memories_t: + print(f"[DEBUG-RETRIEVAL] User={user_id}, Query={query[:50]}...") + print(f"[DEBUG-RETRIEVAL] Candidates={len(candidates)}, Selected={len(memories_t)}") + for i, m in enumerate(memories_t[:3]): # Show top 3 + score = probs[chosen_indices[i]] if i < len(chosen_indices) and chosen_indices[i] < len(probs) else 0 + print(f"[DEBUG-RETRIEVAL] [{i+1}] score={score:.3f}: {m.note_text[:80]}...") # Build prompt and count tokens prompt_tokens = self._count_tokens(query) @@ -559,13 +835,34 @@ class PersonalizedLLM: for note in memory_notes: prompt_tokens += self._count_tokens(note) - # Generate answer - answer_t = self._chat_model.answer( - history=session.history, - memory_notes=memory_notes, - max_new_tokens=self._rl_cfg["max_new_tokens"], - ) - + # Generate answer (with best-of-N if enabled) + if self.best_of_n > 1: + # Generate N responses and pick the best one + candidates_responses = [] + for i in range(self.best_of_n): + resp = self._chat_model.answer( + history=session.history, + memory_notes=memory_notes, + max_new_tokens=self._rl_cfg["max_new_tokens"], + temperature=0.8, # Slightly higher temp for diversity + ) + score = self._score_response(resp) + candidates_responses.append((resp, score)) + + # Sort by score (descending) and pick best + candidates_responses.sort(key=lambda x: x[1], reverse=True) + answer_t = candidates_responses[0][0] + best_score = candidates_responses[0][1] + + if len(candidates_responses) > 1: + print(f"[BEST-OF-{self.best_of_n}] Scores: {[f'{s:.2f}' for _, s in candidates_responses]}, picked score={best_score:.2f}") + else: + answer_t = self._chat_model.answer( + history=session.history, + memory_notes=memory_notes, + max_new_tokens=self._rl_cfg["max_new_tokens"], + ) + completion_tokens = self._count_tokens(answer_t) # Add assistant turn to history @@ -612,7 +909,263 @@ class PersonalizedLLM: usage=usage, debug=debug, ) - + + def chat_prepare(self, user_id: str, query: str) -> dict: + """ + Prepare for chat without calling the LLM. + + This does all the preparation work (embedding, memory retrieval, etc.) + and returns the messages to send to the LLM along with context needed + for post-processing. + + Used for batch processing where messages are collected first, then + sent in batch to vLLM for concurrent processing. + + Args: + user_id: Unique identifier for the user. + query: Current user query/message. + + Returns: + Dict containing: + - messages: List of messages to send to LLM + - context: Dict with all state needed for chat_complete() + """ + ctx = self._get_or_create_session(user_id) + session = ctx.session_state + user_state = self._user_store.get_state(user_id) + + # Record user vector before for debug + z_long_before = user_state.z_long.copy().tolist() + z_short_before = user_state.z_short.copy().tolist() + + # Add user turn to history + user_turn = self._build_chat_turn(user_id, query, "user", ctx.turn_counter) + session.history.append(user_turn) + + # Vanilla mode: pure LLM without any memory or preference extraction + if self.mode == "vanilla": + e_q_t = np.zeros(4096, dtype=np.float32) + extracted_prefs = [] + candidates = [] + cand_item_vecs = np.array([]) + base_scores = np.array([]) + chosen_indices = [] + probs = np.array([]) + memories_t = [] + memory_notes = [] + else: + # Compute query embedding + embed_result = self._embed_model.encode([query], normalize=True, return_tensor=False) + if embed_result is None or len(embed_result) == 0: + raise RuntimeError(f"Embedding model returned empty result for query: {query[:100]}") + e_q_t = np.array(embed_result[0]) + + # Store pending RL update info from last turn + if session.last_query is not None and self.enable_rl_updates: + ctx.pending_rl_update = { + "last_query": session.last_query, + "last_answer": session.last_answer, + "last_memories": session.last_memories, + "last_query_embedding": session.last_query_embedding, + "current_query_embedding": e_q_t, + "last_candidate_item_vectors": session.last_candidate_item_vectors, + "last_policy_probs": session.last_policy_probs, + "last_chosen_indices": session.last_chosen_indices, + } + + # Auto-compute reward via LLM judge if enabled + if self.reward_mode == "llm" and self._llm_reward_client is not None: + import asyncio + try: + reward, gating = asyncio.run(eval_step_llm( + q_t=session.last_query, + answer_t=session.last_answer, + q_t1=query, + memories_t=session.last_memories or [], + client=self._llm_reward_client, + )) + if gating > 0.0: + self.apply_feedback(Feedback( + user_id=user_id, + turn_id=ctx.turn_counter - 1, + reward=reward, + gating=gating, + )) + except Exception as e: + print(f"[LLM-Reward] Judge call failed, skipping update: {e}") + + # Extract preferences from conversation + extracted_prefs = [] + if self.enable_preference_extraction: + prefs = self._extractor.extract_turn(session.history) + if prefs.preferences: + print(f"[DEBUG] Extracted {len(prefs.preferences)} prefs from history (len={len(session.history)})") + extracted_prefs = self._add_preferences_as_memory( + prefs, query, user_id, ctx.turn_counter + ) + if extracted_prefs: + print(f"[DEBUG] Added {len(extracted_prefs)} to memory. Total cards: {len(self._memory_cards)}") + + # Retrieve memories + if self.mode == "nopersonal": + candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_no_policy( + user_id=user_id, + query=query, + embed_model=self._embed_model, + reranker=self._reranker, + memory_cards=self._memory_cards, + memory_embeddings=self._memory_embeddings, + topk_dense=self._rl_cfg["dense_topk"], + topk_rerank=self._rl_cfg["rerank_topk"], + only_own_memories=self.only_own_memories, + ) + else: + beta_long = self._rl_cfg["beta_long"] + beta_short = self._rl_cfg["beta_short"] + candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_with_policy( + user_id=user_id, + query=query, + embed_model=self._embed_model, + reranker=self._reranker, + memory_cards=self._memory_cards, + memory_embeddings=self._memory_embeddings, + user_store=self._user_store, + item_vectors=self._item_vectors, + topk_dense=self._rl_cfg["dense_topk"], + topk_rerank=self._rl_cfg["rerank_topk"], + beta_long=beta_long, + beta_short=beta_short, + tau=self._rl_cfg["tau"], + only_own_memories=self.only_own_memories, + sample=not self.eval_mode, + ) + + memories_t = [candidates[int(i)] for i in chosen_indices] if chosen_indices else [] + memory_notes = [m.note_text for m in memories_t] + + if memories_t: + print(f"[DEBUG-RETRIEVAL] User={user_id}, Query={query[:50]}...") + print(f"[DEBUG-RETRIEVAL] Candidates={len(candidates)}, Selected={len(memories_t)}") + for i, m in enumerate(memories_t[:3]): + score = probs[chosen_indices[i]] if i < len(chosen_indices) and chosen_indices[i] < len(probs) else 0 + print(f"[DEBUG-RETRIEVAL] [{i+1}] score={score:.3f}: {m.note_text[:80]}...") + + # Build prompt token count + prompt_tokens = self._count_tokens(query) + for turn in session.history: + prompt_tokens += self._count_tokens(turn.text) + for note in memory_notes: + prompt_tokens += self._count_tokens(note) + + # Build messages for LLM + messages = self._chat_model.build_messages( + history=session.history, + memory_notes=memory_notes, + max_new_tokens=self._rl_cfg["max_new_tokens"], + ) + + # Return messages and context for chat_complete + return { + "messages": messages, + "context": { + "user_id": user_id, + "query": query, + "ctx": ctx, + "session": session, + "user_state": user_state, + "z_long_before": z_long_before, + "z_short_before": z_short_before, + "e_q_t": e_q_t, + "extracted_prefs": extracted_prefs, + "candidates": candidates, + "cand_item_vecs": cand_item_vecs, + "chosen_indices": chosen_indices, + "probs": probs, + "memories_t": memories_t, + "memory_notes": memory_notes, + "prompt_tokens": prompt_tokens, + } + } + + def chat_complete(self, answer_t: str, context: dict) -> AssistantResponse: + """ + Complete chat with LLM response. + + This takes the LLM response and context from chat_prepare(), and + does all post-processing (add to history, debug info, etc.). + + Args: + answer_t: The LLM response text. + context: Context dict from chat_prepare(). + + Returns: + AssistantResponse containing the answer, usage stats, and debug info. + """ + # Unpack context + user_id = context["user_id"] + query = context["query"] + ctx = context["ctx"] + session = context["session"] + user_state = context["user_state"] + z_long_before = context["z_long_before"] + z_short_before = context["z_short_before"] + e_q_t = context["e_q_t"] + extracted_prefs = context["extracted_prefs"] + candidates = context["candidates"] + cand_item_vecs = context["cand_item_vecs"] + chosen_indices = context["chosen_indices"] + probs = context["probs"] + memories_t = context["memories_t"] + memory_notes = context["memory_notes"] + prompt_tokens = context["prompt_tokens"] + + completion_tokens = self._count_tokens(answer_t) + + # Add assistant turn to history + assist_turn = self._build_chat_turn(user_id, answer_t, "assistant", ctx.turn_counter) + session.history.append(assist_turn) + + # Update session state for next turn + session.last_query = query + session.last_answer = answer_t + session.last_memories = memories_t + session.last_query_embedding = e_q_t + session.last_candidate_item_vectors = cand_item_vecs + session.last_policy_probs = probs + session.last_chosen_indices = list(chosen_indices) if len(chosen_indices) > 0 else [] + + ctx.turn_counter += 1 + + # Build debug info + debug = DebugInfo( + selected_memory_ids=[m.card_id for m in memories_t], + selected_memory_notes=[m.note_text for m in memories_t], + selected_memory_scores=[float(probs[i]) if i < len(probs) else 0.0 for i in chosen_indices] if len(chosen_indices) > 0 else [], + user_vector_before=z_long_before + z_short_before, + user_vector_after=user_state.z_long.tolist() + user_state.z_short.tolist(), + extracted_preferences=extracted_prefs, + extra={ + "num_candidates": len(candidates), + "num_total_memories": len(self._memory_cards), + "z_long_norm": float(np.linalg.norm(user_state.z_long)), + "z_short_norm": float(np.linalg.norm(user_state.z_short)), + } + ) + + # Build usage stats + usage = UsageStats( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + model=self._llm_name, + ) + + return AssistantResponse( + answer=answer_t, + usage=usage, + debug=debug, + ) + def reset_session(self, user_id: str) -> None: """ Reset session for a user (new chat window). |
