diff options
Diffstat (limited to 'src/personalization/serving/personalized_llm.py')
| -rw-r--r-- | src/personalization/serving/personalized_llm.py | 512 |
1 files changed, 469 insertions, 43 deletions
diff --git a/src/personalization/serving/personalized_llm.py b/src/personalization/serving/personalized_llm.py index 45d002b..785995b 100644 --- a/src/personalization/serving/personalized_llm.py +++ b/src/personalization/serving/personalized_llm.py @@ -285,6 +285,17 @@ class PersonalizedLLM: reward_mode: str = "keyword", # "keyword", "llm" (GPT-4o-mini), or "llm_local" (local vLLM) llm_reward_config: Optional["LLMRewardConfig"] = None, # Config for LLM judge reward_vllm_url: Optional[str] = None, # vLLM URL for local reward model (when reward_mode="llm_local") + enable_query_transform: bool = False, # Transform queries for better retrieval matching + enable_global_preferences: bool = False, # Separate global prefs that bypass retrieval + dynamic_topk: bool = False, # Use dynamic topk based on rerank scores + dynamic_min_k: int = 3, # Min preferences for dynamic topk + dynamic_max_k: int = 8, # Max preferences for dynamic topk + dynamic_score_ratio: float = 0.5, # Threshold = top_score * ratio + eta_long: float = None, # Override RL learning rate for z_long + eta_short: float = None, # Override RL learning rate for z_short + enable_preference_consolidation: bool = False, # Consolidate preferences at session end + consolidation_threshold: int = 5, # Min preferences before consolidation + enable_preference_rewrite: bool = False, # Use LLM to rewrite/merge retrieved preferences ): """ Initialize the PersonalizedLLM. @@ -319,6 +330,11 @@ class PersonalizedLLM: self.reranker_type = reranker_type # "qwen3" or "bge" self.best_of_n = best_of_n # Generate N responses and pick best self.reward_mode = reward_mode # "keyword", "llm", or "llm_local" + self.enable_query_transform = enable_query_transform + self.enable_global_preferences = enable_global_preferences + self.enable_preference_consolidation = enable_preference_consolidation + self.consolidation_threshold = consolidation_threshold + self.enable_preference_rewrite = enable_preference_rewrite # Initialize LLM reward client if using LLM judge self._llm_reward_client = None # Can be LLMRewardClient or LocalLLMRewardClient @@ -354,13 +370,18 @@ class PersonalizedLLM: "beta_long": 2.0, # Increased from 0.1 for stronger personalization "beta_short": 5.0, # Increased from 0.3 "tau": 1.0, - "eta_long": 0.01, # Increased from 1e-3 for faster learning - "eta_short": 0.05, # Increased from 5e-3 + "eta_long": eta_long if eta_long is not None else 0.01, + "eta_short": eta_short if eta_short is not None else 0.05, "ema_alpha": 0.05, "short_decay": 0.1, "dense_topk": 64, - "rerank_topk": 3, + "rerank_topk": 5, "max_new_tokens": 512, + # Dynamic topk settings + "dynamic_topk": dynamic_topk, + "dynamic_min_k": dynamic_min_k, + "dynamic_max_k": dynamic_max_k, + "dynamic_score_ratio": dynamic_score_ratio, } # Store llm_name before loading config (needed in _load_config) @@ -528,7 +549,13 @@ class PersonalizedLLM: self._memory_cards: List[MemoryCard] = [] self._memory_embeddings = np.zeros((0, 4096), dtype=np.float32) self._item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32) - self._projection = None + # Create default projection (truncation to first k dims) so preferences can be added + k = self._rl_cfg["item_dim"] + d = 4096 + P = np.zeros((k, d), dtype=np.float32) + P[:, :k] = np.eye(k, dtype=np.float32) + self._projection = ItemProjection(P=P, mean=np.zeros(d, dtype=np.float32)) + print(f"[PersonalizedLLM] Created default projection (truncation, k={k})") return # Load cards @@ -551,8 +578,14 @@ class PersonalizedLLM: self._projection = ItemProjection(P=proj_data["P"], mean=proj_data["mean"]) self._item_vectors = proj_data["V"] else: - self._projection = None + # Create default projection so preferences can still be added + k = self._rl_cfg["item_dim"] + d = 4096 + P = np.zeros((k, d), dtype=np.float32) + P[:, :k] = np.eye(k, dtype=np.float32) + self._projection = ItemProjection(P=P, mean=np.zeros(d, dtype=np.float32)) self._item_vectors = np.zeros((len(self._memory_cards), self._rl_cfg["item_dim"]), dtype=np.float32) + print(f"[PersonalizedLLM] Created default projection (truncation, k={k})") print(f"[PersonalizedLLM] Loaded {len(self._memory_cards)} memory cards.") @@ -588,6 +621,290 @@ class PersonalizedLLM: except Exception: return len(text) // 4 + # Task type keywords for query transformation + _TASK_KEYWORDS = { + "math": ["solve", "calculate", "integral", "equation", "proof", "derivative", + "math", "algebra", "geometry", "trigonometry", "calculus", "arithmetic", + "formula", "compute", "evaluate", "simplify", "factor", "graph"], + "coding": ["code", "program", "function", "implement", "debug", "python", "java", + "javascript", "algorithm", "class", "method", "bug", "error", "compile", + "script", "html", "css", "sql", "api", "library", "framework"], + "writing": ["write", "essay", "paragraph", "summarize", "draft", "compose", + "article", "story", "letter", "email", "report", "review", "edit", + "rewrite", "paraphrase", "outline"], + "explanation": ["explain", "what is", "how does", "why", "describe", "define", + "meaning", "concept", "difference between", "compare", "contrast"], + } + + def _transform_query_for_retrieval(self, query: str) -> List[str]: + """ + Transform raw user query into multiple retrieval queries to bridge + the semantic gap between task queries and preference descriptions. + + Returns [original_query, transformed_query] or [original_query] if + no task type detected. + """ + import re + query_lower = query.lower() + detected_types = [] + for task_type, keywords in self._TASK_KEYWORDS.items(): + for kw in keywords: + # Use word boundary matching to avoid false positives + # e.g., "api" should not match "capital" + if re.search(r'\b' + re.escape(kw) + r'\b', query_lower): + detected_types.append(task_type) + break + + if not detected_types: + return [query] + + # Use first detected type (most specific match) + task_type = detected_types[0] + transformed = f"user preferences for {task_type} tasks: {query}" + return [query, transformed] + + # Patterns indicating a global/universal preference condition + _GLOBAL_PATTERNS = ["general", "any", "always", "all ", "every", "regardless", + "any task", "any topic", "any question", "all tasks", "all topics"] + + # Domain-specific terms that indicate a conditional preference + _DOMAIN_TERMS = ["math", "code", "coding", "program", "writing", "essay", "science", + "history", "language", "physics", "chemistry", "biology", "literature", + "creative", "technical", "formal", "informal", "academic", "casual"] + + def _classify_preference_scope(self, condition: str) -> bool: + """ + Classify whether a preference condition is global (always applicable) + or conditional (task-specific). + + Returns True if global, False if conditional. + """ + cond_lower = condition.lower().strip() + + # Check for explicit global patterns + for pattern in self._GLOBAL_PATTERNS: + if pattern in cond_lower: + return True + + # Very short/vague conditions with no domain terms are likely global + words = cond_lower.split() + if len(words) <= 2: + has_domain = any(term in cond_lower for term in self._DOMAIN_TERMS) + if not has_domain: + return True + + return False + + # Rewrite prompt for merging retrieved preferences + _REWRITE_PROMPT = """You are helping to prepare user preferences for an AI assistant. + +The user is asking: {query} + +Retrieved preferences about this user: +{preferences} + +Task: Create a concise preference summary that the assistant MUST follow. + +Rules: +1. PRESERVE all specific formatting requirements exactly (e.g., "type hints", "snake_case", "code fence with language") +2. PRESERVE all structural requirements (e.g., "numbered steps", "bullet points", "answer first then explanation") +3. Only MERGE preferences that are truly redundant (saying the same thing differently) +4. Output as a short bulleted list if there are multiple distinct requirements +5. Keep each point actionable and specific - NO vague generalizations like "follow best practices" + +Example input: +- Include type hints in Python code +- Use snake_case for variable names +- When explaining, use numbered steps + +Example output: +- Include type hints +- Use snake_case for variables +- Use numbered steps for explanations + +If no preferences are relevant to this query type, output: "No specific preferences apply." + +Preference summary:""" + + def _rewrite_preferences(self, memory_notes: List[str], query: str) -> List[str]: + """ + Use LLM to rewrite/merge multiple retrieved preferences into concise instructions. + + This is similar to Reflection's proper_scaffolding but focuses on merging + rather than just filtering. + + Args: + memory_notes: List of retrieved preference notes + query: Current user query + + Returns: + List with single rewritten instruction (or original if rewrite fails/disabled) + """ + if not memory_notes or len(memory_notes) <= 1: + return memory_notes + + try: + import requests + + # Format preferences for prompt + prefs_text = "\n".join(f"- {note}" for note in memory_notes) + prompt = self._REWRITE_PROMPT.format(query=query[:200], preferences=prefs_text) + + # Direct vLLM API call (simpler than going through chat model) + messages = [{"role": "user", "content": prompt}] + payload = { + "model": self._chat_model.model_name, + "messages": messages, + "max_tokens": 150, + "temperature": 0.3, # Lower temperature for more consistent output + } + + response = requests.post( + f"{self._chat_model.vllm_url}/chat/completions", + json=payload, + timeout=30 + ) + + if response.status_code != 200: + print(f"[REWRITE] API error {response.status_code}, keeping original notes") + return memory_notes + + result = response.json() + rewritten = result["choices"][0]["message"]["content"].strip().strip('"') + + # Validate response + if rewritten and len(rewritten) > 10 and "No specific preferences" not in rewritten: + print(f"[REWRITE] {len(memory_notes)} notes → 1 merged instruction") + return [rewritten] + else: + print(f"[REWRITE] Kept original {len(memory_notes)} notes (no valid merge)") + return memory_notes + + except Exception as e: + print(f"[REWRITE] Failed: {e}, keeping original notes") + return memory_notes + + # Consolidation prompt for session-end preference merging + _CONSOLIDATION_PROMPT = """You are analyzing user preferences extracted from conversations. + +Current preferences for this user: +{preferences} + +Task: Consolidate these preferences into a cleaner, more organized set by: +1. MERGE similar preferences (e.g., "use bullet points" + "format with bullets" → single preference) +2. REMOVE redundant or contradictory preferences (keep the more specific one) +3. PRESERVE all unique, meaningful preferences +4. Keep the same "When [condition], [action]." format + +Output ONLY the consolidated preferences, one per line, in this exact format: +When [condition], [action]. + +Do not add explanations or commentary. Just output the preference lines.""" + + def consolidate_user_preferences(self, user_id: str) -> int: + """ + Consolidate user preferences at session end using LLM. + + Merges similar preferences, removes redundancy, and creates cleaner + preference descriptions. Only runs if user has enough preferences. + + Args: + user_id: The user whose preferences to consolidate. + + Returns: + Number of preferences after consolidation (0 if skipped). + """ + if not self.enable_preference_consolidation: + return 0 + + # Get user's memory cards + user_cards = [c for c in self._memory_cards if c.user_id == user_id] + + if len(user_cards) < self.consolidation_threshold: + return len(user_cards) + + # Build preference list for prompt + pref_lines = [card.note_text for card in user_cards] + preferences_text = "\n".join(f"- {p}" for p in pref_lines) + + # Call LLM for consolidation + prompt = self._CONSOLIDATION_PROMPT.format(preferences=preferences_text) + messages = [{"role": "user", "content": prompt}] + + try: + result = self._chat_model.answer(messages, max_new_tokens=512) + consolidated_text = result.get("content", "").strip() + + if not consolidated_text: + return len(user_cards) + + # Parse consolidated preferences + new_prefs = [] + for line in consolidated_text.split("\n"): + line = line.strip() + if not line or not line.startswith("When "): + continue + # Parse "When [condition], [action]." + if ", " in line: + parts = line.split(", ", 1) + condition = parts[0].replace("When ", "").strip() + action = parts[1].rstrip(".").strip() + if condition and action: + new_prefs.append({ + "condition": condition, + "action": action, + "is_global": self._classify_preference_scope(condition) if self.enable_global_preferences else False, + }) + + if not new_prefs: + return len(user_cards) + + # Remove old cards for this user + keep_indices = [i for i, c in enumerate(self._memory_cards) if c.user_id != user_id] + self._memory_cards = [self._memory_cards[i] for i in keep_indices] + if len(keep_indices) > 0 and len(self._memory_embeddings) > 0: + self._memory_embeddings = self._memory_embeddings[keep_indices] + self._item_vectors = self._item_vectors[keep_indices] + else: + embed_dim = self._memory_embeddings.shape[1] if len(self._memory_embeddings) > 0 else 4096 + self._memory_embeddings = np.zeros((0, embed_dim), dtype=np.float32) + self._item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32) + + # Add consolidated preferences + for pref in new_prefs: + note_text = f"When {pref['condition']}, {pref['action']}." + + # Compute embedding + e_note = self._embed_model.encode([note_text], normalize=True, return_tensor=False)[0] + v_note = self._projection.transform_vector(np.array(e_note)) + + # Create card + card = MemoryCard( + card_id=str(uuid.uuid4()), + user_id=user_id, + source_session_id=f"consolidated_{user_id}", + source_turn_ids=[], + raw_queries=[], + preference_list=PreferenceList(preferences=[ + Preference(condition=pref["condition"], action=pref["action"], confidence=1.0) + ]), + note_text=note_text, + embedding_e=list(e_note), + kind="pref", + is_global=pref["is_global"], + ) + + self._memory_cards.append(card) + self._memory_embeddings = np.vstack([self._memory_embeddings, np.array([e_note])]) + self._item_vectors = np.vstack([self._item_vectors, np.array([v_note])]) + + print(f"[PersonalizedLLM] Consolidated {len(user_cards)} → {len(new_prefs)} preferences for user {user_id}") + return len(new_prefs) + + except Exception as e: + print(f"[PersonalizedLLM] Consolidation failed for user {user_id}: {e}") + return len(user_cards) + def _add_preferences_as_memory( self, prefs: PreferenceList, @@ -628,6 +945,9 @@ class PersonalizedLLM: e_note = self._embed_model.encode([note_text], normalize=True, return_tensor=False)[0] v_note = self._projection.transform_vector(np.array(e_note)) + # Classify as global or conditional + is_global = self._classify_preference_scope(pref.condition) if self.enable_global_preferences else False + # Create new memory card card = MemoryCard( card_id=str(uuid.uuid4()), @@ -639,6 +959,7 @@ class PersonalizedLLM: note_text=note_text, embedding_e=list(e_note), kind="pref", + is_global=is_global, ) # Add to memory store @@ -788,35 +1109,61 @@ class PersonalizedLLM: if extracted_prefs: print(f"[DEBUG] Added {len(extracted_prefs)} to memory. Total cards: {len(self._memory_cards)}") + # Separate global preferences (bypass retrieval) from conditional ones + global_notes = [] + retrieval_cards = self._memory_cards + retrieval_embeddings = self._memory_embeddings + retrieval_item_vectors = self._item_vectors + if self.enable_global_preferences: + global_cards = [c for c in self._memory_cards if c.is_global and c.user_id == user_id] + global_notes = [c.note_text for c in global_cards[:10]] # Cap at 10 + # Filter out global cards for retrieval + cond_indices = [i for i, c in enumerate(self._memory_cards) if not c.is_global] + if cond_indices: + retrieval_cards = [self._memory_cards[i] for i in cond_indices] + retrieval_embeddings = self._memory_embeddings[cond_indices] + if len(self._item_vectors) > 0: + retrieval_item_vectors = self._item_vectors[cond_indices] + else: + retrieval_cards = [] + retrieval_embeddings = np.zeros((0, self._memory_embeddings.shape[1]), dtype=np.float32) if len(self._memory_embeddings) > 0 else self._memory_embeddings + retrieval_item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32) + + # Query transformation for better retrieval matching + retrieval_queries = None + if self.enable_query_transform: + retrieval_queries = self._transform_query_for_retrieval(query) + # Retrieve memories - # In "nopersonal" mode: deterministic retrieval (dense + rerank + topk), no policy/user vector - # In "full" mode: policy-based retrieval with user vector influence if self.mode == "nopersonal": candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_no_policy( user_id=user_id, query=query, embed_model=self._embed_model, reranker=self._reranker, - memory_cards=self._memory_cards, - memory_embeddings=self._memory_embeddings, + memory_cards=retrieval_cards, + memory_embeddings=retrieval_embeddings, topk_dense=self._rl_cfg["dense_topk"], topk_rerank=self._rl_cfg["rerank_topk"], only_own_memories=self.only_own_memories, + queries=retrieval_queries, + dynamic_topk=self._rl_cfg["dynamic_topk"], + dynamic_min_k=self._rl_cfg["dynamic_min_k"], + dynamic_max_k=self._rl_cfg["dynamic_max_k"], + dynamic_score_ratio=self._rl_cfg["dynamic_score_ratio"], ) else: beta_long = self._rl_cfg["beta_long"] beta_short = self._rl_cfg["beta_short"] - # eval_mode=True -> sample=False (greedy/deterministic) - # eval_mode=False -> sample=True (stochastic/exploration) candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_with_policy( user_id=user_id, query=query, embed_model=self._embed_model, reranker=self._reranker, - memory_cards=self._memory_cards, - memory_embeddings=self._memory_embeddings, + memory_cards=retrieval_cards, + memory_embeddings=retrieval_embeddings, user_store=self._user_store, - item_vectors=self._item_vectors, + item_vectors=retrieval_item_vectors, topk_dense=self._rl_cfg["dense_topk"], topk_rerank=self._rl_cfg["rerank_topk"], beta_long=beta_long, @@ -824,27 +1171,39 @@ class PersonalizedLLM: tau=self._rl_cfg["tau"], only_own_memories=self.only_own_memories, sample=not self.eval_mode, + queries=retrieval_queries, ) - + # Get selected memories memories_t = [candidates[int(i)] for i in chosen_indices] if chosen_indices else [] memory_notes = [m.note_text for m in memories_t] + # Apply preference rewrite if enabled + if self.enable_preference_rewrite and memory_notes: + memory_notes = self._rewrite_preferences(memory_notes, query) + # Debug: show retrieval info - if memories_t: + if memories_t or global_notes: print(f"[DEBUG-RETRIEVAL] User={user_id}, Query={query[:50]}...") - print(f"[DEBUG-RETRIEVAL] Candidates={len(candidates)}, Selected={len(memories_t)}") + print(f"[DEBUG-RETRIEVAL] Global={len(global_notes)}, Candidates={len(candidates)}, Retrieved={len(memories_t)}") for i, m in enumerate(memories_t[:3]): # Show top 3 score = probs[chosen_indices[i]] if i < len(chosen_indices) and chosen_indices[i] < len(probs) else 0 print(f"[DEBUG-RETRIEVAL] [{i+1}] score={score:.3f}: {m.note_text[:80]}...") - + + # Combine all notes for prompt (global + retrieved) + # For chat(), we combine all notes; chat_prepare() handles them separately + if self.mode != "vanilla": + all_memory_notes = (global_notes if global_notes else []) + memory_notes + else: + all_memory_notes = memory_notes + # Build prompt and count tokens prompt_tokens = self._count_tokens(query) for turn in session.history: prompt_tokens += self._count_tokens(turn.text) - for note in memory_notes: + for note in all_memory_notes: prompt_tokens += self._count_tokens(note) - + # Generate answer (with best-of-N if enabled) if self.best_of_n > 1: # Generate N responses and pick the best one @@ -852,7 +1211,7 @@ class PersonalizedLLM: for i in range(self.best_of_n): resp = self._chat_model.answer( history=session.history, - memory_notes=memory_notes, + memory_notes=all_memory_notes, max_new_tokens=self._rl_cfg["max_new_tokens"], temperature=0.8, # Slightly higher temp for diversity ) @@ -869,7 +1228,7 @@ class PersonalizedLLM: else: answer_t = self._chat_model.answer( history=session.history, - memory_notes=memory_notes, + memory_notes=all_memory_notes, max_new_tokens=self._rl_cfg["max_new_tokens"], ) @@ -920,7 +1279,7 @@ class PersonalizedLLM: debug=debug, ) - def chat_prepare(self, user_id: str, query: str) -> dict: + def chat_prepare(self, user_id: str, query: str, skip_extraction: bool = False, skip_auto_reward: bool = False) -> dict: """ Prepare for chat without calling the LLM. @@ -984,7 +1343,8 @@ class PersonalizedLLM: } # Auto-compute reward via LLM judge if enabled - if self._llm_reward_client is not None: + # skip_auto_reward=True when batch framework handles rewards externally + if self._llm_reward_client is not None and not skip_auto_reward: import asyncio try: reward, gating = asyncio.run(eval_step_llm( @@ -1006,7 +1366,7 @@ class PersonalizedLLM: # Extract preferences from conversation extracted_prefs = [] - if self.enable_preference_extraction: + if self.enable_preference_extraction and not skip_extraction: prefs = self._extractor.extract_turn(session.history) if prefs.preferences: print(f"[DEBUG] Extracted {len(prefs.preferences)} prefs from history (len={len(session.history)})") @@ -1016,6 +1376,30 @@ class PersonalizedLLM: if extracted_prefs: print(f"[DEBUG] Added {len(extracted_prefs)} to memory. Total cards: {len(self._memory_cards)}") + # Separate global preferences (bypass retrieval) from conditional ones + global_notes = [] + retrieval_cards = self._memory_cards + retrieval_embeddings = self._memory_embeddings + retrieval_item_vectors = self._item_vectors + if self.enable_global_preferences: + global_cards = [c for c in self._memory_cards if c.is_global and c.user_id == user_id] + global_notes = [c.note_text for c in global_cards[:10]] # Cap at 10 + cond_indices = [i for i, c in enumerate(self._memory_cards) if not c.is_global] + if cond_indices: + retrieval_cards = [self._memory_cards[i] for i in cond_indices] + retrieval_embeddings = self._memory_embeddings[cond_indices] + if len(self._item_vectors) > 0: + retrieval_item_vectors = self._item_vectors[cond_indices] + else: + retrieval_cards = [] + retrieval_embeddings = np.zeros((0, self._memory_embeddings.shape[1]), dtype=np.float32) if len(self._memory_embeddings) > 0 else self._memory_embeddings + retrieval_item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32) + + # Query transformation for better retrieval matching + retrieval_queries = None + if self.enable_query_transform: + retrieval_queries = self._transform_query_for_retrieval(query) + # Retrieve memories if self.mode == "nopersonal": candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_no_policy( @@ -1023,11 +1407,16 @@ class PersonalizedLLM: query=query, embed_model=self._embed_model, reranker=self._reranker, - memory_cards=self._memory_cards, - memory_embeddings=self._memory_embeddings, + memory_cards=retrieval_cards, + memory_embeddings=retrieval_embeddings, topk_dense=self._rl_cfg["dense_topk"], topk_rerank=self._rl_cfg["rerank_topk"], only_own_memories=self.only_own_memories, + queries=retrieval_queries, + dynamic_topk=self._rl_cfg["dynamic_topk"], + dynamic_min_k=self._rl_cfg["dynamic_min_k"], + dynamic_max_k=self._rl_cfg["dynamic_max_k"], + dynamic_score_ratio=self._rl_cfg["dynamic_score_ratio"], ) else: beta_long = self._rl_cfg["beta_long"] @@ -1037,10 +1426,10 @@ class PersonalizedLLM: query=query, embed_model=self._embed_model, reranker=self._reranker, - memory_cards=self._memory_cards, - memory_embeddings=self._memory_embeddings, + memory_cards=retrieval_cards, + memory_embeddings=retrieval_embeddings, user_store=self._user_store, - item_vectors=self._item_vectors, + item_vectors=retrieval_item_vectors, topk_dense=self._rl_cfg["dense_topk"], topk_rerank=self._rl_cfg["rerank_topk"], beta_long=beta_long, @@ -1048,14 +1437,19 @@ class PersonalizedLLM: tau=self._rl_cfg["tau"], only_own_memories=self.only_own_memories, sample=not self.eval_mode, + queries=retrieval_queries, ) memories_t = [candidates[int(i)] for i in chosen_indices] if chosen_indices else [] memory_notes = [m.note_text for m in memories_t] - if memories_t: + # Apply preference rewrite if enabled + if self.enable_preference_rewrite and memory_notes: + memory_notes = self._rewrite_preferences(memory_notes, query) + + if memories_t or global_notes: print(f"[DEBUG-RETRIEVAL] User={user_id}, Query={query[:50]}...") - print(f"[DEBUG-RETRIEVAL] Candidates={len(candidates)}, Selected={len(memories_t)}") + print(f"[DEBUG-RETRIEVAL] Global={len(global_notes)}, Candidates={len(candidates)}, Retrieved={len(memories_t)}") for i, m in enumerate(memories_t[:3]): score = probs[chosen_indices[i]] if i < len(chosen_indices) and chosen_indices[i] < len(probs) else 0 print(f"[DEBUG-RETRIEVAL] [{i+1}] score={score:.3f}: {m.note_text[:80]}...") @@ -1064,14 +1458,17 @@ class PersonalizedLLM: prompt_tokens = self._count_tokens(query) for turn in session.history: prompt_tokens += self._count_tokens(turn.text) - for note in memory_notes: + all_notes = memory_notes + (global_notes if self.mode != "vanilla" else []) + for note in all_notes: prompt_tokens += self._count_tokens(note) - # Build messages for LLM + # Build messages for LLM (pass global_notes separately for distinct prompt sections) + effective_global = global_notes if (self.enable_global_preferences and self.mode != "vanilla") else None messages = self._chat_model.build_messages( history=session.history, memory_notes=memory_notes, max_new_tokens=self._rl_cfg["max_new_tokens"], + global_notes=effective_global, ) # Return messages and context for chat_complete @@ -1176,22 +1573,47 @@ class PersonalizedLLM: debug=debug, ) + def apply_extracted_preferences(self, user_id: str, pref_dict: dict) -> list: + """Apply pre-computed extraction results (from batch extraction) to memory.""" + prefs = PreferenceList.model_validate(pref_dict) + if not prefs.preferences: + return [] + ctx = self._get_or_create_session(user_id) + query = ctx.session_state.history[-1].text if ctx.session_state.history else "" + extracted = self._add_preferences_as_memory(prefs, query, user_id, ctx.turn_counter) + if extracted: + print(f"[DEBUG] Batch-added {len(extracted)} to memory. Total cards: {len(self._memory_cards)}") + return extracted + + def get_last_user_query(self, user_id: str) -> str: + """Get the last user message text for this user's session.""" + ctx = self._sessions.get(user_id) + if ctx and ctx.session_state.history: + for t in reversed(ctx.session_state.history): + if t.role == "user": + return t.text + return "" + def reset_session(self, user_id: str) -> None: """ Reset session for a user (new chat window). - + This clears: - Session conversation history - Short-term user vector (z_short) - Pending RL update info - + This preserves: - Long-term user vector (z_long) - - User's memory cards - + - User's memory cards (may be consolidated if enabled) + Args: user_id: The user whose session to reset. """ + # Consolidate preferences at session end (before clearing session) + if self.enable_preference_consolidation: + self.consolidate_user_preferences(user_id) + # Clear session context if user_id in self._sessions: del self._sessions[user_id] @@ -1270,14 +1692,14 @@ class PersonalizedLLM: """ if not self.enable_rl_updates: return - + # In "nopersonal" or "vanilla" mode, skip RL updates entirely (baseline) if self.mode in ("nopersonal", "vanilla"): return - + user_id = feedback.user_id ctx = self._sessions.get(user_id) - + if ctx is None or ctx.pending_rl_update is None: return @@ -1289,12 +1711,15 @@ class PersonalizedLLM: pending.get("last_policy_probs") is not None and pending.get("last_chosen_indices") is not None and len(pending["last_chosen_indices"]) > 0): - + # Extract chosen vectors chosen_indices = pending["last_chosen_indices"] candidate_vectors = pending["last_candidate_item_vectors"] - + if len(candidate_vectors) > 0: + print(f"[DEBUG-REINFORCE] User={user_id} reward={feedback.reward:.2f} " + f"n_candidates={len(candidate_vectors)} chosen={chosen_indices} " + f"probs_shape={pending['last_policy_probs'].shape if hasattr(pending['last_policy_probs'], 'shape') else 'N/A'}") # REINFORCE expects: # - item_vectors: ALL candidate vectors [K, k] # - chosen_indices: indices into those candidates @@ -1313,6 +1738,7 @@ class PersonalizedLLM: short_decay=self._rl_cfg["short_decay"], ) + print(f"[DEBUG-REINFORCE] updated={updated} z_long_norm={np.linalg.norm(user_state.z_long):.15e}") if updated: self._user_store.save_state(user_state) |
