diff options
Diffstat (limited to 'src/personalization')
| -rw-r--r-- | src/personalization/config/registry.py | 5 | ||||
| -rw-r--r-- | src/personalization/config/settings.py | 4 | ||||
| -rw-r--r-- | src/personalization/feedback/local_llm_reward.py | 36 | ||||
| -rw-r--r-- | src/personalization/models/llm/vllm_chat.py | 37 | ||||
| -rw-r--r-- | src/personalization/models/preference_extractor/rule_extractor.py | 53 | ||||
| -rw-r--r-- | src/personalization/retrieval/pipeline.py | 184 | ||||
| -rw-r--r-- | src/personalization/retrieval/preference_store/schemas.py | 1 | ||||
| -rw-r--r-- | src/personalization/serving/personalized_llm.py | 512 |
8 files changed, 750 insertions, 82 deletions
diff --git a/src/personalization/config/registry.py b/src/personalization/config/registry.py index 6048044..c7a6a09 100644 --- a/src/personalization/config/registry.py +++ b/src/personalization/config/registry.py @@ -7,6 +7,9 @@ import yaml from personalization.config import settings +# Project root for resolving config paths +_PROJECT_ROOT = Path(__file__).parent.parent.parent.parent + # Avoid circular imports by NOT importing extractors here at top level # from personalization.models.preference_extractor.base import PreferenceExtractorBase # from personalization.models.preference_extractor.rule_extractor import QwenRuleExtractor @@ -54,7 +57,7 @@ def get_chat_model(name: str, device_override: Optional[str] = None): cfg = settings.load_local_models_config() # Try to load raw config to support multi-backend map - with open("configs/local_models.yaml", "r") as f: + with open(_PROJECT_ROOT / "configs/local_models.yaml", "r") as f: raw_cfg = yaml.safe_load(f) models = raw_cfg.get("models", {}).get("llm", {}) diff --git a/src/personalization/config/settings.py b/src/personalization/config/settings.py index 1bb1bbe..8f0cc8a 100644 --- a/src/personalization/config/settings.py +++ b/src/personalization/config/settings.py @@ -37,7 +37,9 @@ def _resolve_config_path(env_key: str, default_rel: str) -> Path: value = os.getenv(env_key) if value: return Path(value).expanduser().resolve() - return (Path.cwd() / default_rel).resolve() + # Use project root (parent of src/personalization/config) instead of cwd + project_root = Path(__file__).parent.parent.parent.parent + return (project_root / default_rel).resolve() def load_local_models_config(path: Optional[str] = None) -> LocalModelsConfig: diff --git a/src/personalization/feedback/local_llm_reward.py b/src/personalization/feedback/local_llm_reward.py index 9837ff0..70bbeb8 100644 --- a/src/personalization/feedback/local_llm_reward.py +++ b/src/personalization/feedback/local_llm_reward.py @@ -307,11 +307,39 @@ class LocalLLMRewardClient: This is the main entry point for batch reward estimation. """ - return asyncio.run(self.judge_batch_async(samples)) - - def judge(self, sample: TurnSample) -> RewardResult: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop is not None: + # Already in an event loop - create a new thread to run the coroutine + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, self.judge_batch_async(samples)) + return future.result() + else: + return asyncio.run(self.judge_batch_async(samples)) + + async def judge(self, sample: TurnSample) -> RewardResult: + """Judge a single turn (async interface for compatibility with LLMRewardClient).""" + return await self.judge_async(sample) + + def judge_sync(self, sample: TurnSample) -> RewardResult: """Judge a single turn (sync wrapper).""" - return asyncio.run(self.judge_async(sample)) + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop is not None: + # Already in an event loop - create a new thread to run the coroutine + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, self.judge_async(sample)) + return future.result() + else: + return asyncio.run(self.judge_async(sample)) # --- Convenience Functions --- diff --git a/src/personalization/models/llm/vllm_chat.py b/src/personalization/models/llm/vllm_chat.py index b5c3a05..d577a30 100644 --- a/src/personalization/models/llm/vllm_chat.py +++ b/src/personalization/models/llm/vllm_chat.py @@ -78,27 +78,53 @@ class VLLMChatModel(ChatModel): history: List[ChatTurn], memory_notes: List[str], max_new_tokens: int = 512, + global_notes: List[str] = None, ) -> List[dict]: """Build messages list for chat completion API with auto-truncation. If the context exceeds max_context_length, older conversation turns are removed to keep only the most recent context that fits. + + Args: + global_notes: If provided, these are always-applicable preferences + displayed in a separate section from task-specific retrieved notes. """ # Use CollaborativeAgents-style system prompt - if memory_notes: - bullet = "\n".join(f"- {n}" for n in memory_notes) + has_any_notes = memory_notes or global_notes + if has_any_notes: + # Build preference sections + pref_sections = "" + if global_notes: + global_bullet = "\n".join(f"- {n}" for n in global_notes) + pref_sections += f"## General Preferences (always apply)\n{global_bullet}\n\n" + if memory_notes: + task_bullet = "\n".join(f"- {n}" for n in memory_notes) + if global_notes: + pref_sections += f"## Task-Specific Preferences\n{task_bullet}\n" + else: + pref_sections += f"{task_bullet}\n" + system_content = ( "You are a collaborative AI agent helping users solve writing, question answering, math, and coding problems.\n\n" "# User Preferences\n" "The user has a set of preferences for how you should behave. If you do not follow these preferences, " "the user will be unable to learn from your response and you will need to adjust your response to adhere " - "to these preferences (so it is best to follow them initially).\n" + "to these preferences (so it is best to follow them initially).\n\n" + "**IMPORTANT**: If the user explicitly requests something in THIS conversation (e.g., asks you to change " + "your format, style, or approach), that request takes PRIORITY over the remembered preferences below. " + "Always adapt to the user's direct feedback first.\n\n" "Based on your past interactions with the user, you have maintained a set of notes about the user's preferences:\n" - f"{bullet}\n\n" + f"{pref_sections}\n" + "# Before Responding\n" + "Before writing your response, briefly consider:\n" + "1. Which preferences above are relevant to this specific request?\n" + "2. How will you satisfy each relevant preference in your response?\n\n" "# Conversation Guidelines:\n" + "- If the user asks you to adjust your response (e.g., 'be more concise', 'focus on intuition'), you MUST change your approach accordingly. Do NOT repeat the same response.\n" "- If the user's message is unclear, lacks details, or is ambiguous (e.g. length of an essay, format requirements, " "specific constraints), do not make assumptions. Ask for clarification and ensure you have enough information before providing an answer.\n" "- Your goal is to help the user solve their problem. Adhere to their preferences and do your best to help them solve their problem.\n" + "- **Verify**: Before finalizing, check that your response satisfies the relevant preferences listed above.\n" ) else: # Vanilla mode - no preferences @@ -152,13 +178,14 @@ class VLLMChatModel(ChatModel): history: List[ChatTurn], memory_notes: List[str], max_new_tokens: int = 512, + global_notes: List[str] = None, ) -> List[dict]: """Public method to build messages without calling the API. Used for batch processing where messages are collected first, then sent in batch to vLLM for concurrent processing. """ - return self._build_messages(history, memory_notes, max_new_tokens) + return self._build_messages(history, memory_notes, max_new_tokens, global_notes=global_notes) def answer( self, diff --git a/src/personalization/models/preference_extractor/rule_extractor.py b/src/personalization/models/preference_extractor/rule_extractor.py index 0f743d9..42f43ed 100644 --- a/src/personalization/models/preference_extractor/rule_extractor.py +++ b/src/personalization/models/preference_extractor/rule_extractor.py @@ -119,6 +119,59 @@ class QwenRuleExtractor(PreferenceExtractor): return text[start : end + 1] return None + @torch.inference_mode() + def batch_extract_preferences(self, queries: List[str], batch_size: int = 64) -> List[Dict[str, Any]]: + """ + Batch extract preferences from multiple queries using left-padded batching. + """ + if not queries: + return [] + + # Save and set padding side for decoder-only batched generation + orig_padding_side = self.tokenizer.padding_side + self.tokenizer.padding_side = "left" + + all_results = [] + prompts = [self.build_preference_prompt(q) for q in queries] + + for start in range(0, len(prompts), batch_size): + batch_prompts = prompts[start:start + batch_size] + inputs = self.tokenizer( + batch_prompts, return_tensors="pt", padding=True, truncation=True + ).to(self.model.device) + + outputs = self.model.generate( + **inputs, + do_sample=False, + max_new_tokens=512, + pad_token_id=self.tokenizer.pad_token_id, + eos_token_id=self.tokenizer.eos_token_id, + ) + + for i in range(len(batch_prompts)): + input_len = (inputs["attention_mask"][i] == 1).sum().item() + gen_ids = outputs[i][input_len:] + text = self.tokenizer.decode(gen_ids, skip_special_tokens=True) + + try: + data = json.loads(text) + validated = PreferenceList.model_validate(data) + all_results.append(validated.model_dump()) + except Exception: + extracted_json = self._extract_json_substring(text) + if extracted_json: + try: + data = json.loads(extracted_json) + validated = PreferenceList.model_validate(data) + all_results.append(validated.model_dump()) + continue + except Exception: + pass + all_results.append({"preferences": []}) + + self.tokenizer.padding_side = orig_padding_side + return all_results + def extract_turn(self, turns: List[ChatTurn]) -> PreferenceList: """ Extract preferences from the LAST user turn in the history. diff --git a/src/personalization/retrieval/pipeline.py b/src/personalization/retrieval/pipeline.py index e83940d..6cc7f3e 100644 --- a/src/personalization/retrieval/pipeline.py +++ b/src/personalization/retrieval/pipeline.py @@ -12,6 +12,51 @@ def cosine_similarity_matrix(E: np.ndarray, e_q: np.ndarray) -> np.ndarray: # E: [M, d], e_q: [d] return np.dot(E, e_q) + +def dynamic_topk_selection( + scores: np.ndarray, + min_k: int = 3, + max_k: int = 8, + score_ratio: float = 0.5, +) -> List[int]: + """ + Dynamically select top-k indices based on score distribution. + + Strategy: + 1. Sort by score descending + 2. Compute threshold = top_score * score_ratio + 3. Select all indices with score > threshold + 4. Clamp to [min_k, max_k] range + + Args: + scores: Array of scores (higher = better) + min_k: Minimum number of items to select + max_k: Maximum number of items to select + score_ratio: Threshold ratio relative to top score + + Returns: + List of selected indices (in descending score order) + """ + if len(scores) == 0: + return [] + + # Sort indices by score descending + sorted_indices = np.argsort(scores)[::-1] + sorted_scores = scores[sorted_indices] + + # Compute threshold + top_score = sorted_scores[0] + threshold = top_score * score_ratio + + # Find how many pass threshold + n_above_threshold = np.sum(sorted_scores > threshold) + + # Clamp to [min_k, max_k] + n_select = max(min_k, min(max_k, n_above_threshold)) + n_select = min(n_select, len(scores)) # Don't exceed available + + return sorted_indices[:n_select].tolist() + def dense_topk_indices( query: str, embed_model: EmbeddingModel, @@ -58,6 +103,49 @@ def dense_topk_indices( idx = np.argsort(sims)[-k:][::-1] return idx.tolist() +def dense_topk_indices_multi_query( + queries: List[str], + embed_model: EmbeddingModel, + memory_embeddings: np.ndarray, + valid_indices: List[int] = None, + topk: int = 64 +) -> List[int]: + """ + Multi-query dense retrieval: embed all queries, take max similarity per memory, + return top-k by max similarity (union effect). + """ + if len(memory_embeddings) == 0: + return [] + + # Embed all queries at once + e_qs = embed_model.encode(queries, normalize=True, return_tensor=False) + e_qs = np.array(e_qs, dtype=np.float32) # [Q, d] + + if valid_indices is not None: + if len(valid_indices) == 0: + return [] + E_sub = memory_embeddings[valid_indices] + # sims: [Q, M_sub] + sims = np.dot(e_qs, E_sub.T) + # max across queries per memory + max_sims = sims.max(axis=0) # [M_sub] + k = min(topk, len(max_sims)) + if k == 0: + return [] + idx_sub = np.argsort(max_sims)[-k:][::-1] + return [valid_indices[i] for i in idx_sub] + + # Global search + # sims: [Q, M] + sims = np.dot(e_qs, memory_embeddings.T) + max_sims = sims.max(axis=0) # [M] + k = min(topk, len(max_sims)) + if k == 0: + return [] + idx = np.argsort(max_sims)[-k:][::-1] + return idx.tolist() + + def retrieve_with_policy( user_id: str, query: str, @@ -74,6 +162,7 @@ def retrieve_with_policy( tau: float = 1.0, only_own_memories: bool = False, sample: bool = False, + queries: List[str] = None, ) -> Tuple[List[MemoryCard], np.ndarray, np.ndarray, List[int], np.ndarray]: """ Returns extended info for policy update: @@ -90,28 +179,37 @@ def retrieve_with_policy( if not valid_indices: return [], np.array([]), np.array([]), [], np.array([]) - # 1. Dense retrieval - dense_idx = dense_topk_indices( - query, - embed_model, - memory_embeddings, - valid_indices=valid_indices, - topk=topk_dense - ) + # 1. Dense retrieval (multi-query if available) + if queries and len(queries) > 1: + dense_idx = dense_topk_indices_multi_query( + queries, + embed_model, + memory_embeddings, + valid_indices=valid_indices, + topk=topk_dense + ) + else: + dense_idx = dense_topk_indices( + query, + embed_model, + memory_embeddings, + valid_indices=valid_indices, + topk=topk_dense + ) # DEBUG: Check for duplicates or out of bounds if len(dense_idx) > 0: import os if os.getenv("RETRIEVAL_DEBUG") == "1": print(f" [Pipeline] Dense Indices (Top {len(dense_idx)}): {dense_idx[:10]}...") print(f" [Pipeline] Max Index: {max(dense_idx)} | Memory Size: {len(memory_cards)}") - + if not dense_idx: return [], np.array([]), np.array([]), [], np.array([]) candidates = [memory_cards[i] for i in dense_idx] candidate_docs = [c.note_text for c in candidates] - # 2. Rerank base score (P(yes|q,m)) + # 2. Rerank base score (P(yes|q,m)) - always use original query for reranking # Skip reranking if we have fewer candidates than topk_rerank (saves GPU memory) if len(candidates) <= topk_rerank: base_scores = np.ones(len(candidates)) # Uniform scores @@ -165,14 +263,25 @@ def retrieve_no_policy( topk_dense: int = 64, topk_rerank: int = 8, only_own_memories: bool = False, + queries: List[str] = None, + dynamic_topk: bool = False, + dynamic_min_k: int = 3, + dynamic_max_k: int = 8, + dynamic_score_ratio: float = 0.5, ) -> Tuple[List[MemoryCard], np.ndarray, np.ndarray, List[int], np.ndarray]: """ Deterministic retrieval baseline (NoPersonal mode): - Dense retrieval -> Rerank -> Top-K (no policy sampling, no user vector influence) - + + Args: + dynamic_topk: If True, use dynamic selection based on score distribution + dynamic_min_k: Minimum items to select (when dynamic_topk=True) + dynamic_max_k: Maximum items to select (when dynamic_topk=True) + dynamic_score_ratio: Threshold = top_score * ratio (when dynamic_topk=True) + Returns same structure as retrieve_with_policy for compatibility: (candidates, candidate_item_vectors, base_scores, chosen_indices, rerank_scores_for_chosen) - + Note: candidate_item_vectors is empty array (not used in NoPersonal mode) The last return value is rerank scores instead of policy probs """ @@ -183,14 +292,23 @@ def retrieve_no_policy( if not valid_indices: return [], np.array([]), np.array([]), [], np.array([]) - # 1. Dense retrieval - dense_idx = dense_topk_indices( - query, - embed_model, - memory_embeddings, - valid_indices=valid_indices, - topk=topk_dense - ) + # 1. Dense retrieval (multi-query if available) + if queries and len(queries) > 1: + dense_idx = dense_topk_indices_multi_query( + queries, + embed_model, + memory_embeddings, + valid_indices=valid_indices, + topk=topk_dense + ) + else: + dense_idx = dense_topk_indices( + query, + embed_model, + memory_embeddings, + valid_indices=valid_indices, + topk=topk_dense + ) if not dense_idx: return [], np.array([]), np.array([]), [], np.array([]) @@ -198,23 +316,33 @@ def retrieve_no_policy( candidates = [memory_cards[i] for i in dense_idx] candidate_docs = [c.note_text for c in candidates] - # 2. Rerank base score (P(yes|q,m)) - # Skip reranking if we have fewer candidates than topk_rerank (saves GPU memory) - if len(candidates) <= topk_rerank: + # 2. Rerank base score (P(yes|q,m)) - always use original query for reranking + max_k = dynamic_max_k if dynamic_topk else topk_rerank + + # Skip reranking if we have fewer candidates than needed + if len(candidates) <= max_k: # Just return all candidates without reranking base_scores = np.ones(len(candidates)) # Uniform scores chosen_indices = list(range(len(candidates))) else: base_scores = np.array(reranker.score(query, candidate_docs)) - # 3. Deterministic Top-K selection based on rerank scores ONLY (no policy) - k = min(topk_rerank, len(base_scores)) - top_indices_local = base_scores.argsort()[-k:][::-1] - chosen_indices = top_indices_local.tolist() + # 3. Selection: dynamic or fixed top-K + if dynamic_topk: + chosen_indices = dynamic_topk_selection( + base_scores, + min_k=dynamic_min_k, + max_k=dynamic_max_k, + score_ratio=dynamic_score_ratio, + ) + else: + k = min(topk_rerank, len(base_scores)) + top_indices_local = base_scores.argsort()[-k:][::-1] + chosen_indices = top_indices_local.tolist() # Get scores for chosen items (for logging compatibility) chosen_scores = base_scores[chosen_indices] - + # Return empty item vectors (not used in NoPersonal mode) # Return rerank scores as the "probs" field for logging compatibility return candidates, np.array([]), base_scores, chosen_indices, chosen_scores diff --git a/src/personalization/retrieval/preference_store/schemas.py b/src/personalization/retrieval/preference_store/schemas.py index eb82558..5245025 100644 --- a/src/personalization/retrieval/preference_store/schemas.py +++ b/src/personalization/retrieval/preference_store/schemas.py @@ -45,3 +45,4 @@ class MemoryCard(BaseModel): note_text: str # Summarized "condition: action" text embedding_e: List[float] # The embedding vector kind: Literal["pref", "fact"] = "pref" + is_global: bool = False # True = always include in prompt, bypass retrieval diff --git a/src/personalization/serving/personalized_llm.py b/src/personalization/serving/personalized_llm.py index 45d002b..785995b 100644 --- a/src/personalization/serving/personalized_llm.py +++ b/src/personalization/serving/personalized_llm.py @@ -285,6 +285,17 @@ class PersonalizedLLM: reward_mode: str = "keyword", # "keyword", "llm" (GPT-4o-mini), or "llm_local" (local vLLM) llm_reward_config: Optional["LLMRewardConfig"] = None, # Config for LLM judge reward_vllm_url: Optional[str] = None, # vLLM URL for local reward model (when reward_mode="llm_local") + enable_query_transform: bool = False, # Transform queries for better retrieval matching + enable_global_preferences: bool = False, # Separate global prefs that bypass retrieval + dynamic_topk: bool = False, # Use dynamic topk based on rerank scores + dynamic_min_k: int = 3, # Min preferences for dynamic topk + dynamic_max_k: int = 8, # Max preferences for dynamic topk + dynamic_score_ratio: float = 0.5, # Threshold = top_score * ratio + eta_long: float = None, # Override RL learning rate for z_long + eta_short: float = None, # Override RL learning rate for z_short + enable_preference_consolidation: bool = False, # Consolidate preferences at session end + consolidation_threshold: int = 5, # Min preferences before consolidation + enable_preference_rewrite: bool = False, # Use LLM to rewrite/merge retrieved preferences ): """ Initialize the PersonalizedLLM. @@ -319,6 +330,11 @@ class PersonalizedLLM: self.reranker_type = reranker_type # "qwen3" or "bge" self.best_of_n = best_of_n # Generate N responses and pick best self.reward_mode = reward_mode # "keyword", "llm", or "llm_local" + self.enable_query_transform = enable_query_transform + self.enable_global_preferences = enable_global_preferences + self.enable_preference_consolidation = enable_preference_consolidation + self.consolidation_threshold = consolidation_threshold + self.enable_preference_rewrite = enable_preference_rewrite # Initialize LLM reward client if using LLM judge self._llm_reward_client = None # Can be LLMRewardClient or LocalLLMRewardClient @@ -354,13 +370,18 @@ class PersonalizedLLM: "beta_long": 2.0, # Increased from 0.1 for stronger personalization "beta_short": 5.0, # Increased from 0.3 "tau": 1.0, - "eta_long": 0.01, # Increased from 1e-3 for faster learning - "eta_short": 0.05, # Increased from 5e-3 + "eta_long": eta_long if eta_long is not None else 0.01, + "eta_short": eta_short if eta_short is not None else 0.05, "ema_alpha": 0.05, "short_decay": 0.1, "dense_topk": 64, - "rerank_topk": 3, + "rerank_topk": 5, "max_new_tokens": 512, + # Dynamic topk settings + "dynamic_topk": dynamic_topk, + "dynamic_min_k": dynamic_min_k, + "dynamic_max_k": dynamic_max_k, + "dynamic_score_ratio": dynamic_score_ratio, } # Store llm_name before loading config (needed in _load_config) @@ -528,7 +549,13 @@ class PersonalizedLLM: self._memory_cards: List[MemoryCard] = [] self._memory_embeddings = np.zeros((0, 4096), dtype=np.float32) self._item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32) - self._projection = None + # Create default projection (truncation to first k dims) so preferences can be added + k = self._rl_cfg["item_dim"] + d = 4096 + P = np.zeros((k, d), dtype=np.float32) + P[:, :k] = np.eye(k, dtype=np.float32) + self._projection = ItemProjection(P=P, mean=np.zeros(d, dtype=np.float32)) + print(f"[PersonalizedLLM] Created default projection (truncation, k={k})") return # Load cards @@ -551,8 +578,14 @@ class PersonalizedLLM: self._projection = ItemProjection(P=proj_data["P"], mean=proj_data["mean"]) self._item_vectors = proj_data["V"] else: - self._projection = None + # Create default projection so preferences can still be added + k = self._rl_cfg["item_dim"] + d = 4096 + P = np.zeros((k, d), dtype=np.float32) + P[:, :k] = np.eye(k, dtype=np.float32) + self._projection = ItemProjection(P=P, mean=np.zeros(d, dtype=np.float32)) self._item_vectors = np.zeros((len(self._memory_cards), self._rl_cfg["item_dim"]), dtype=np.float32) + print(f"[PersonalizedLLM] Created default projection (truncation, k={k})") print(f"[PersonalizedLLM] Loaded {len(self._memory_cards)} memory cards.") @@ -588,6 +621,290 @@ class PersonalizedLLM: except Exception: return len(text) // 4 + # Task type keywords for query transformation + _TASK_KEYWORDS = { + "math": ["solve", "calculate", "integral", "equation", "proof", "derivative", + "math", "algebra", "geometry", "trigonometry", "calculus", "arithmetic", + "formula", "compute", "evaluate", "simplify", "factor", "graph"], + "coding": ["code", "program", "function", "implement", "debug", "python", "java", + "javascript", "algorithm", "class", "method", "bug", "error", "compile", + "script", "html", "css", "sql", "api", "library", "framework"], + "writing": ["write", "essay", "paragraph", "summarize", "draft", "compose", + "article", "story", "letter", "email", "report", "review", "edit", + "rewrite", "paraphrase", "outline"], + "explanation": ["explain", "what is", "how does", "why", "describe", "define", + "meaning", "concept", "difference between", "compare", "contrast"], + } + + def _transform_query_for_retrieval(self, query: str) -> List[str]: + """ + Transform raw user query into multiple retrieval queries to bridge + the semantic gap between task queries and preference descriptions. + + Returns [original_query, transformed_query] or [original_query] if + no task type detected. + """ + import re + query_lower = query.lower() + detected_types = [] + for task_type, keywords in self._TASK_KEYWORDS.items(): + for kw in keywords: + # Use word boundary matching to avoid false positives + # e.g., "api" should not match "capital" + if re.search(r'\b' + re.escape(kw) + r'\b', query_lower): + detected_types.append(task_type) + break + + if not detected_types: + return [query] + + # Use first detected type (most specific match) + task_type = detected_types[0] + transformed = f"user preferences for {task_type} tasks: {query}" + return [query, transformed] + + # Patterns indicating a global/universal preference condition + _GLOBAL_PATTERNS = ["general", "any", "always", "all ", "every", "regardless", + "any task", "any topic", "any question", "all tasks", "all topics"] + + # Domain-specific terms that indicate a conditional preference + _DOMAIN_TERMS = ["math", "code", "coding", "program", "writing", "essay", "science", + "history", "language", "physics", "chemistry", "biology", "literature", + "creative", "technical", "formal", "informal", "academic", "casual"] + + def _classify_preference_scope(self, condition: str) -> bool: + """ + Classify whether a preference condition is global (always applicable) + or conditional (task-specific). + + Returns True if global, False if conditional. + """ + cond_lower = condition.lower().strip() + + # Check for explicit global patterns + for pattern in self._GLOBAL_PATTERNS: + if pattern in cond_lower: + return True + + # Very short/vague conditions with no domain terms are likely global + words = cond_lower.split() + if len(words) <= 2: + has_domain = any(term in cond_lower for term in self._DOMAIN_TERMS) + if not has_domain: + return True + + return False + + # Rewrite prompt for merging retrieved preferences + _REWRITE_PROMPT = """You are helping to prepare user preferences for an AI assistant. + +The user is asking: {query} + +Retrieved preferences about this user: +{preferences} + +Task: Create a concise preference summary that the assistant MUST follow. + +Rules: +1. PRESERVE all specific formatting requirements exactly (e.g., "type hints", "snake_case", "code fence with language") +2. PRESERVE all structural requirements (e.g., "numbered steps", "bullet points", "answer first then explanation") +3. Only MERGE preferences that are truly redundant (saying the same thing differently) +4. Output as a short bulleted list if there are multiple distinct requirements +5. Keep each point actionable and specific - NO vague generalizations like "follow best practices" + +Example input: +- Include type hints in Python code +- Use snake_case for variable names +- When explaining, use numbered steps + +Example output: +- Include type hints +- Use snake_case for variables +- Use numbered steps for explanations + +If no preferences are relevant to this query type, output: "No specific preferences apply." + +Preference summary:""" + + def _rewrite_preferences(self, memory_notes: List[str], query: str) -> List[str]: + """ + Use LLM to rewrite/merge multiple retrieved preferences into concise instructions. + + This is similar to Reflection's proper_scaffolding but focuses on merging + rather than just filtering. + + Args: + memory_notes: List of retrieved preference notes + query: Current user query + + Returns: + List with single rewritten instruction (or original if rewrite fails/disabled) + """ + if not memory_notes or len(memory_notes) <= 1: + return memory_notes + + try: + import requests + + # Format preferences for prompt + prefs_text = "\n".join(f"- {note}" for note in memory_notes) + prompt = self._REWRITE_PROMPT.format(query=query[:200], preferences=prefs_text) + + # Direct vLLM API call (simpler than going through chat model) + messages = [{"role": "user", "content": prompt}] + payload = { + "model": self._chat_model.model_name, + "messages": messages, + "max_tokens": 150, + "temperature": 0.3, # Lower temperature for more consistent output + } + + response = requests.post( + f"{self._chat_model.vllm_url}/chat/completions", + json=payload, + timeout=30 + ) + + if response.status_code != 200: + print(f"[REWRITE] API error {response.status_code}, keeping original notes") + return memory_notes + + result = response.json() + rewritten = result["choices"][0]["message"]["content"].strip().strip('"') + + # Validate response + if rewritten and len(rewritten) > 10 and "No specific preferences" not in rewritten: + print(f"[REWRITE] {len(memory_notes)} notes → 1 merged instruction") + return [rewritten] + else: + print(f"[REWRITE] Kept original {len(memory_notes)} notes (no valid merge)") + return memory_notes + + except Exception as e: + print(f"[REWRITE] Failed: {e}, keeping original notes") + return memory_notes + + # Consolidation prompt for session-end preference merging + _CONSOLIDATION_PROMPT = """You are analyzing user preferences extracted from conversations. + +Current preferences for this user: +{preferences} + +Task: Consolidate these preferences into a cleaner, more organized set by: +1. MERGE similar preferences (e.g., "use bullet points" + "format with bullets" → single preference) +2. REMOVE redundant or contradictory preferences (keep the more specific one) +3. PRESERVE all unique, meaningful preferences +4. Keep the same "When [condition], [action]." format + +Output ONLY the consolidated preferences, one per line, in this exact format: +When [condition], [action]. + +Do not add explanations or commentary. Just output the preference lines.""" + + def consolidate_user_preferences(self, user_id: str) -> int: + """ + Consolidate user preferences at session end using LLM. + + Merges similar preferences, removes redundancy, and creates cleaner + preference descriptions. Only runs if user has enough preferences. + + Args: + user_id: The user whose preferences to consolidate. + + Returns: + Number of preferences after consolidation (0 if skipped). + """ + if not self.enable_preference_consolidation: + return 0 + + # Get user's memory cards + user_cards = [c for c in self._memory_cards if c.user_id == user_id] + + if len(user_cards) < self.consolidation_threshold: + return len(user_cards) + + # Build preference list for prompt + pref_lines = [card.note_text for card in user_cards] + preferences_text = "\n".join(f"- {p}" for p in pref_lines) + + # Call LLM for consolidation + prompt = self._CONSOLIDATION_PROMPT.format(preferences=preferences_text) + messages = [{"role": "user", "content": prompt}] + + try: + result = self._chat_model.answer(messages, max_new_tokens=512) + consolidated_text = result.get("content", "").strip() + + if not consolidated_text: + return len(user_cards) + + # Parse consolidated preferences + new_prefs = [] + for line in consolidated_text.split("\n"): + line = line.strip() + if not line or not line.startswith("When "): + continue + # Parse "When [condition], [action]." + if ", " in line: + parts = line.split(", ", 1) + condition = parts[0].replace("When ", "").strip() + action = parts[1].rstrip(".").strip() + if condition and action: + new_prefs.append({ + "condition": condition, + "action": action, + "is_global": self._classify_preference_scope(condition) if self.enable_global_preferences else False, + }) + + if not new_prefs: + return len(user_cards) + + # Remove old cards for this user + keep_indices = [i for i, c in enumerate(self._memory_cards) if c.user_id != user_id] + self._memory_cards = [self._memory_cards[i] for i in keep_indices] + if len(keep_indices) > 0 and len(self._memory_embeddings) > 0: + self._memory_embeddings = self._memory_embeddings[keep_indices] + self._item_vectors = self._item_vectors[keep_indices] + else: + embed_dim = self._memory_embeddings.shape[1] if len(self._memory_embeddings) > 0 else 4096 + self._memory_embeddings = np.zeros((0, embed_dim), dtype=np.float32) + self._item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32) + + # Add consolidated preferences + for pref in new_prefs: + note_text = f"When {pref['condition']}, {pref['action']}." + + # Compute embedding + e_note = self._embed_model.encode([note_text], normalize=True, return_tensor=False)[0] + v_note = self._projection.transform_vector(np.array(e_note)) + + # Create card + card = MemoryCard( + card_id=str(uuid.uuid4()), + user_id=user_id, + source_session_id=f"consolidated_{user_id}", + source_turn_ids=[], + raw_queries=[], + preference_list=PreferenceList(preferences=[ + Preference(condition=pref["condition"], action=pref["action"], confidence=1.0) + ]), + note_text=note_text, + embedding_e=list(e_note), + kind="pref", + is_global=pref["is_global"], + ) + + self._memory_cards.append(card) + self._memory_embeddings = np.vstack([self._memory_embeddings, np.array([e_note])]) + self._item_vectors = np.vstack([self._item_vectors, np.array([v_note])]) + + print(f"[PersonalizedLLM] Consolidated {len(user_cards)} → {len(new_prefs)} preferences for user {user_id}") + return len(new_prefs) + + except Exception as e: + print(f"[PersonalizedLLM] Consolidation failed for user {user_id}: {e}") + return len(user_cards) + def _add_preferences_as_memory( self, prefs: PreferenceList, @@ -628,6 +945,9 @@ class PersonalizedLLM: e_note = self._embed_model.encode([note_text], normalize=True, return_tensor=False)[0] v_note = self._projection.transform_vector(np.array(e_note)) + # Classify as global or conditional + is_global = self._classify_preference_scope(pref.condition) if self.enable_global_preferences else False + # Create new memory card card = MemoryCard( card_id=str(uuid.uuid4()), @@ -639,6 +959,7 @@ class PersonalizedLLM: note_text=note_text, embedding_e=list(e_note), kind="pref", + is_global=is_global, ) # Add to memory store @@ -788,35 +1109,61 @@ class PersonalizedLLM: if extracted_prefs: print(f"[DEBUG] Added {len(extracted_prefs)} to memory. Total cards: {len(self._memory_cards)}") + # Separate global preferences (bypass retrieval) from conditional ones + global_notes = [] + retrieval_cards = self._memory_cards + retrieval_embeddings = self._memory_embeddings + retrieval_item_vectors = self._item_vectors + if self.enable_global_preferences: + global_cards = [c for c in self._memory_cards if c.is_global and c.user_id == user_id] + global_notes = [c.note_text for c in global_cards[:10]] # Cap at 10 + # Filter out global cards for retrieval + cond_indices = [i for i, c in enumerate(self._memory_cards) if not c.is_global] + if cond_indices: + retrieval_cards = [self._memory_cards[i] for i in cond_indices] + retrieval_embeddings = self._memory_embeddings[cond_indices] + if len(self._item_vectors) > 0: + retrieval_item_vectors = self._item_vectors[cond_indices] + else: + retrieval_cards = [] + retrieval_embeddings = np.zeros((0, self._memory_embeddings.shape[1]), dtype=np.float32) if len(self._memory_embeddings) > 0 else self._memory_embeddings + retrieval_item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32) + + # Query transformation for better retrieval matching + retrieval_queries = None + if self.enable_query_transform: + retrieval_queries = self._transform_query_for_retrieval(query) + # Retrieve memories - # In "nopersonal" mode: deterministic retrieval (dense + rerank + topk), no policy/user vector - # In "full" mode: policy-based retrieval with user vector influence if self.mode == "nopersonal": candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_no_policy( user_id=user_id, query=query, embed_model=self._embed_model, reranker=self._reranker, - memory_cards=self._memory_cards, - memory_embeddings=self._memory_embeddings, + memory_cards=retrieval_cards, + memory_embeddings=retrieval_embeddings, topk_dense=self._rl_cfg["dense_topk"], topk_rerank=self._rl_cfg["rerank_topk"], only_own_memories=self.only_own_memories, + queries=retrieval_queries, + dynamic_topk=self._rl_cfg["dynamic_topk"], + dynamic_min_k=self._rl_cfg["dynamic_min_k"], + dynamic_max_k=self._rl_cfg["dynamic_max_k"], + dynamic_score_ratio=self._rl_cfg["dynamic_score_ratio"], ) else: beta_long = self._rl_cfg["beta_long"] beta_short = self._rl_cfg["beta_short"] - # eval_mode=True -> sample=False (greedy/deterministic) - # eval_mode=False -> sample=True (stochastic/exploration) candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_with_policy( user_id=user_id, query=query, embed_model=self._embed_model, reranker=self._reranker, - memory_cards=self._memory_cards, - memory_embeddings=self._memory_embeddings, + memory_cards=retrieval_cards, + memory_embeddings=retrieval_embeddings, user_store=self._user_store, - item_vectors=self._item_vectors, + item_vectors=retrieval_item_vectors, topk_dense=self._rl_cfg["dense_topk"], topk_rerank=self._rl_cfg["rerank_topk"], beta_long=beta_long, @@ -824,27 +1171,39 @@ class PersonalizedLLM: tau=self._rl_cfg["tau"], only_own_memories=self.only_own_memories, sample=not self.eval_mode, + queries=retrieval_queries, ) - + # Get selected memories memories_t = [candidates[int(i)] for i in chosen_indices] if chosen_indices else [] memory_notes = [m.note_text for m in memories_t] + # Apply preference rewrite if enabled + if self.enable_preference_rewrite and memory_notes: + memory_notes = self._rewrite_preferences(memory_notes, query) + # Debug: show retrieval info - if memories_t: + if memories_t or global_notes: print(f"[DEBUG-RETRIEVAL] User={user_id}, Query={query[:50]}...") - print(f"[DEBUG-RETRIEVAL] Candidates={len(candidates)}, Selected={len(memories_t)}") + print(f"[DEBUG-RETRIEVAL] Global={len(global_notes)}, Candidates={len(candidates)}, Retrieved={len(memories_t)}") for i, m in enumerate(memories_t[:3]): # Show top 3 score = probs[chosen_indices[i]] if i < len(chosen_indices) and chosen_indices[i] < len(probs) else 0 print(f"[DEBUG-RETRIEVAL] [{i+1}] score={score:.3f}: {m.note_text[:80]}...") - + + # Combine all notes for prompt (global + retrieved) + # For chat(), we combine all notes; chat_prepare() handles them separately + if self.mode != "vanilla": + all_memory_notes = (global_notes if global_notes else []) + memory_notes + else: + all_memory_notes = memory_notes + # Build prompt and count tokens prompt_tokens = self._count_tokens(query) for turn in session.history: prompt_tokens += self._count_tokens(turn.text) - for note in memory_notes: + for note in all_memory_notes: prompt_tokens += self._count_tokens(note) - + # Generate answer (with best-of-N if enabled) if self.best_of_n > 1: # Generate N responses and pick the best one @@ -852,7 +1211,7 @@ class PersonalizedLLM: for i in range(self.best_of_n): resp = self._chat_model.answer( history=session.history, - memory_notes=memory_notes, + memory_notes=all_memory_notes, max_new_tokens=self._rl_cfg["max_new_tokens"], temperature=0.8, # Slightly higher temp for diversity ) @@ -869,7 +1228,7 @@ class PersonalizedLLM: else: answer_t = self._chat_model.answer( history=session.history, - memory_notes=memory_notes, + memory_notes=all_memory_notes, max_new_tokens=self._rl_cfg["max_new_tokens"], ) @@ -920,7 +1279,7 @@ class PersonalizedLLM: debug=debug, ) - def chat_prepare(self, user_id: str, query: str) -> dict: + def chat_prepare(self, user_id: str, query: str, skip_extraction: bool = False, skip_auto_reward: bool = False) -> dict: """ Prepare for chat without calling the LLM. @@ -984,7 +1343,8 @@ class PersonalizedLLM: } # Auto-compute reward via LLM judge if enabled - if self._llm_reward_client is not None: + # skip_auto_reward=True when batch framework handles rewards externally + if self._llm_reward_client is not None and not skip_auto_reward: import asyncio try: reward, gating = asyncio.run(eval_step_llm( @@ -1006,7 +1366,7 @@ class PersonalizedLLM: # Extract preferences from conversation extracted_prefs = [] - if self.enable_preference_extraction: + if self.enable_preference_extraction and not skip_extraction: prefs = self._extractor.extract_turn(session.history) if prefs.preferences: print(f"[DEBUG] Extracted {len(prefs.preferences)} prefs from history (len={len(session.history)})") @@ -1016,6 +1376,30 @@ class PersonalizedLLM: if extracted_prefs: print(f"[DEBUG] Added {len(extracted_prefs)} to memory. Total cards: {len(self._memory_cards)}") + # Separate global preferences (bypass retrieval) from conditional ones + global_notes = [] + retrieval_cards = self._memory_cards + retrieval_embeddings = self._memory_embeddings + retrieval_item_vectors = self._item_vectors + if self.enable_global_preferences: + global_cards = [c for c in self._memory_cards if c.is_global and c.user_id == user_id] + global_notes = [c.note_text for c in global_cards[:10]] # Cap at 10 + cond_indices = [i for i, c in enumerate(self._memory_cards) if not c.is_global] + if cond_indices: + retrieval_cards = [self._memory_cards[i] for i in cond_indices] + retrieval_embeddings = self._memory_embeddings[cond_indices] + if len(self._item_vectors) > 0: + retrieval_item_vectors = self._item_vectors[cond_indices] + else: + retrieval_cards = [] + retrieval_embeddings = np.zeros((0, self._memory_embeddings.shape[1]), dtype=np.float32) if len(self._memory_embeddings) > 0 else self._memory_embeddings + retrieval_item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32) + + # Query transformation for better retrieval matching + retrieval_queries = None + if self.enable_query_transform: + retrieval_queries = self._transform_query_for_retrieval(query) + # Retrieve memories if self.mode == "nopersonal": candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_no_policy( @@ -1023,11 +1407,16 @@ class PersonalizedLLM: query=query, embed_model=self._embed_model, reranker=self._reranker, - memory_cards=self._memory_cards, - memory_embeddings=self._memory_embeddings, + memory_cards=retrieval_cards, + memory_embeddings=retrieval_embeddings, topk_dense=self._rl_cfg["dense_topk"], topk_rerank=self._rl_cfg["rerank_topk"], only_own_memories=self.only_own_memories, + queries=retrieval_queries, + dynamic_topk=self._rl_cfg["dynamic_topk"], + dynamic_min_k=self._rl_cfg["dynamic_min_k"], + dynamic_max_k=self._rl_cfg["dynamic_max_k"], + dynamic_score_ratio=self._rl_cfg["dynamic_score_ratio"], ) else: beta_long = self._rl_cfg["beta_long"] @@ -1037,10 +1426,10 @@ class PersonalizedLLM: query=query, embed_model=self._embed_model, reranker=self._reranker, - memory_cards=self._memory_cards, - memory_embeddings=self._memory_embeddings, + memory_cards=retrieval_cards, + memory_embeddings=retrieval_embeddings, user_store=self._user_store, - item_vectors=self._item_vectors, + item_vectors=retrieval_item_vectors, topk_dense=self._rl_cfg["dense_topk"], topk_rerank=self._rl_cfg["rerank_topk"], beta_long=beta_long, @@ -1048,14 +1437,19 @@ class PersonalizedLLM: tau=self._rl_cfg["tau"], only_own_memories=self.only_own_memories, sample=not self.eval_mode, + queries=retrieval_queries, ) memories_t = [candidates[int(i)] for i in chosen_indices] if chosen_indices else [] memory_notes = [m.note_text for m in memories_t] - if memories_t: + # Apply preference rewrite if enabled + if self.enable_preference_rewrite and memory_notes: + memory_notes = self._rewrite_preferences(memory_notes, query) + + if memories_t or global_notes: print(f"[DEBUG-RETRIEVAL] User={user_id}, Query={query[:50]}...") - print(f"[DEBUG-RETRIEVAL] Candidates={len(candidates)}, Selected={len(memories_t)}") + print(f"[DEBUG-RETRIEVAL] Global={len(global_notes)}, Candidates={len(candidates)}, Retrieved={len(memories_t)}") for i, m in enumerate(memories_t[:3]): score = probs[chosen_indices[i]] if i < len(chosen_indices) and chosen_indices[i] < len(probs) else 0 print(f"[DEBUG-RETRIEVAL] [{i+1}] score={score:.3f}: {m.note_text[:80]}...") @@ -1064,14 +1458,17 @@ class PersonalizedLLM: prompt_tokens = self._count_tokens(query) for turn in session.history: prompt_tokens += self._count_tokens(turn.text) - for note in memory_notes: + all_notes = memory_notes + (global_notes if self.mode != "vanilla" else []) + for note in all_notes: prompt_tokens += self._count_tokens(note) - # Build messages for LLM + # Build messages for LLM (pass global_notes separately for distinct prompt sections) + effective_global = global_notes if (self.enable_global_preferences and self.mode != "vanilla") else None messages = self._chat_model.build_messages( history=session.history, memory_notes=memory_notes, max_new_tokens=self._rl_cfg["max_new_tokens"], + global_notes=effective_global, ) # Return messages and context for chat_complete @@ -1176,22 +1573,47 @@ class PersonalizedLLM: debug=debug, ) + def apply_extracted_preferences(self, user_id: str, pref_dict: dict) -> list: + """Apply pre-computed extraction results (from batch extraction) to memory.""" + prefs = PreferenceList.model_validate(pref_dict) + if not prefs.preferences: + return [] + ctx = self._get_or_create_session(user_id) + query = ctx.session_state.history[-1].text if ctx.session_state.history else "" + extracted = self._add_preferences_as_memory(prefs, query, user_id, ctx.turn_counter) + if extracted: + print(f"[DEBUG] Batch-added {len(extracted)} to memory. Total cards: {len(self._memory_cards)}") + return extracted + + def get_last_user_query(self, user_id: str) -> str: + """Get the last user message text for this user's session.""" + ctx = self._sessions.get(user_id) + if ctx and ctx.session_state.history: + for t in reversed(ctx.session_state.history): + if t.role == "user": + return t.text + return "" + def reset_session(self, user_id: str) -> None: """ Reset session for a user (new chat window). - + This clears: - Session conversation history - Short-term user vector (z_short) - Pending RL update info - + This preserves: - Long-term user vector (z_long) - - User's memory cards - + - User's memory cards (may be consolidated if enabled) + Args: user_id: The user whose session to reset. """ + # Consolidate preferences at session end (before clearing session) + if self.enable_preference_consolidation: + self.consolidate_user_preferences(user_id) + # Clear session context if user_id in self._sessions: del self._sessions[user_id] @@ -1270,14 +1692,14 @@ class PersonalizedLLM: """ if not self.enable_rl_updates: return - + # In "nopersonal" or "vanilla" mode, skip RL updates entirely (baseline) if self.mode in ("nopersonal", "vanilla"): return - + user_id = feedback.user_id ctx = self._sessions.get(user_id) - + if ctx is None or ctx.pending_rl_update is None: return @@ -1289,12 +1711,15 @@ class PersonalizedLLM: pending.get("last_policy_probs") is not None and pending.get("last_chosen_indices") is not None and len(pending["last_chosen_indices"]) > 0): - + # Extract chosen vectors chosen_indices = pending["last_chosen_indices"] candidate_vectors = pending["last_candidate_item_vectors"] - + if len(candidate_vectors) > 0: + print(f"[DEBUG-REINFORCE] User={user_id} reward={feedback.reward:.2f} " + f"n_candidates={len(candidate_vectors)} chosen={chosen_indices} " + f"probs_shape={pending['last_policy_probs'].shape if hasattr(pending['last_policy_probs'], 'shape') else 'N/A'}") # REINFORCE expects: # - item_vectors: ALL candidate vectors [K, k] # - chosen_indices: indices into those candidates @@ -1313,6 +1738,7 @@ class PersonalizedLLM: short_decay=self._rl_cfg["short_decay"], ) + print(f"[DEBUG-REINFORCE] updated={updated} z_long_norm={np.linalg.norm(user_state.z_long):.15e}") if updated: self._user_store.save_state(user_state) |
