summaryrefslogtreecommitdiff
path: root/src/personalization
diff options
context:
space:
mode:
Diffstat (limited to 'src/personalization')
-rw-r--r--src/personalization/config/registry.py5
-rw-r--r--src/personalization/config/settings.py4
-rw-r--r--src/personalization/feedback/local_llm_reward.py36
-rw-r--r--src/personalization/models/llm/vllm_chat.py37
-rw-r--r--src/personalization/models/preference_extractor/rule_extractor.py53
-rw-r--r--src/personalization/retrieval/pipeline.py184
-rw-r--r--src/personalization/retrieval/preference_store/schemas.py1
-rw-r--r--src/personalization/serving/personalized_llm.py512
8 files changed, 750 insertions, 82 deletions
diff --git a/src/personalization/config/registry.py b/src/personalization/config/registry.py
index 6048044..c7a6a09 100644
--- a/src/personalization/config/registry.py
+++ b/src/personalization/config/registry.py
@@ -7,6 +7,9 @@ import yaml
from personalization.config import settings
+# Project root for resolving config paths
+_PROJECT_ROOT = Path(__file__).parent.parent.parent.parent
+
# Avoid circular imports by NOT importing extractors here at top level
# from personalization.models.preference_extractor.base import PreferenceExtractorBase
# from personalization.models.preference_extractor.rule_extractor import QwenRuleExtractor
@@ -54,7 +57,7 @@ def get_chat_model(name: str, device_override: Optional[str] = None):
cfg = settings.load_local_models_config()
# Try to load raw config to support multi-backend map
- with open("configs/local_models.yaml", "r") as f:
+ with open(_PROJECT_ROOT / "configs/local_models.yaml", "r") as f:
raw_cfg = yaml.safe_load(f)
models = raw_cfg.get("models", {}).get("llm", {})
diff --git a/src/personalization/config/settings.py b/src/personalization/config/settings.py
index 1bb1bbe..8f0cc8a 100644
--- a/src/personalization/config/settings.py
+++ b/src/personalization/config/settings.py
@@ -37,7 +37,9 @@ def _resolve_config_path(env_key: str, default_rel: str) -> Path:
value = os.getenv(env_key)
if value:
return Path(value).expanduser().resolve()
- return (Path.cwd() / default_rel).resolve()
+ # Use project root (parent of src/personalization/config) instead of cwd
+ project_root = Path(__file__).parent.parent.parent.parent
+ return (project_root / default_rel).resolve()
def load_local_models_config(path: Optional[str] = None) -> LocalModelsConfig:
diff --git a/src/personalization/feedback/local_llm_reward.py b/src/personalization/feedback/local_llm_reward.py
index 9837ff0..70bbeb8 100644
--- a/src/personalization/feedback/local_llm_reward.py
+++ b/src/personalization/feedback/local_llm_reward.py
@@ -307,11 +307,39 @@ class LocalLLMRewardClient:
This is the main entry point for batch reward estimation.
"""
- return asyncio.run(self.judge_batch_async(samples))
-
- def judge(self, sample: TurnSample) -> RewardResult:
+ try:
+ loop = asyncio.get_running_loop()
+ except RuntimeError:
+ loop = None
+
+ if loop is not None:
+ # Already in an event loop - create a new thread to run the coroutine
+ import concurrent.futures
+ with concurrent.futures.ThreadPoolExecutor() as executor:
+ future = executor.submit(asyncio.run, self.judge_batch_async(samples))
+ return future.result()
+ else:
+ return asyncio.run(self.judge_batch_async(samples))
+
+ async def judge(self, sample: TurnSample) -> RewardResult:
+ """Judge a single turn (async interface for compatibility with LLMRewardClient)."""
+ return await self.judge_async(sample)
+
+ def judge_sync(self, sample: TurnSample) -> RewardResult:
"""Judge a single turn (sync wrapper)."""
- return asyncio.run(self.judge_async(sample))
+ try:
+ loop = asyncio.get_running_loop()
+ except RuntimeError:
+ loop = None
+
+ if loop is not None:
+ # Already in an event loop - create a new thread to run the coroutine
+ import concurrent.futures
+ with concurrent.futures.ThreadPoolExecutor() as executor:
+ future = executor.submit(asyncio.run, self.judge_async(sample))
+ return future.result()
+ else:
+ return asyncio.run(self.judge_async(sample))
# --- Convenience Functions ---
diff --git a/src/personalization/models/llm/vllm_chat.py b/src/personalization/models/llm/vllm_chat.py
index b5c3a05..d577a30 100644
--- a/src/personalization/models/llm/vllm_chat.py
+++ b/src/personalization/models/llm/vllm_chat.py
@@ -78,27 +78,53 @@ class VLLMChatModel(ChatModel):
history: List[ChatTurn],
memory_notes: List[str],
max_new_tokens: int = 512,
+ global_notes: List[str] = None,
) -> List[dict]:
"""Build messages list for chat completion API with auto-truncation.
If the context exceeds max_context_length, older conversation turns
are removed to keep only the most recent context that fits.
+
+ Args:
+ global_notes: If provided, these are always-applicable preferences
+ displayed in a separate section from task-specific retrieved notes.
"""
# Use CollaborativeAgents-style system prompt
- if memory_notes:
- bullet = "\n".join(f"- {n}" for n in memory_notes)
+ has_any_notes = memory_notes or global_notes
+ if has_any_notes:
+ # Build preference sections
+ pref_sections = ""
+ if global_notes:
+ global_bullet = "\n".join(f"- {n}" for n in global_notes)
+ pref_sections += f"## General Preferences (always apply)\n{global_bullet}\n\n"
+ if memory_notes:
+ task_bullet = "\n".join(f"- {n}" for n in memory_notes)
+ if global_notes:
+ pref_sections += f"## Task-Specific Preferences\n{task_bullet}\n"
+ else:
+ pref_sections += f"{task_bullet}\n"
+
system_content = (
"You are a collaborative AI agent helping users solve writing, question answering, math, and coding problems.\n\n"
"# User Preferences\n"
"The user has a set of preferences for how you should behave. If you do not follow these preferences, "
"the user will be unable to learn from your response and you will need to adjust your response to adhere "
- "to these preferences (so it is best to follow them initially).\n"
+ "to these preferences (so it is best to follow them initially).\n\n"
+ "**IMPORTANT**: If the user explicitly requests something in THIS conversation (e.g., asks you to change "
+ "your format, style, or approach), that request takes PRIORITY over the remembered preferences below. "
+ "Always adapt to the user's direct feedback first.\n\n"
"Based on your past interactions with the user, you have maintained a set of notes about the user's preferences:\n"
- f"{bullet}\n\n"
+ f"{pref_sections}\n"
+ "# Before Responding\n"
+ "Before writing your response, briefly consider:\n"
+ "1. Which preferences above are relevant to this specific request?\n"
+ "2. How will you satisfy each relevant preference in your response?\n\n"
"# Conversation Guidelines:\n"
+ "- If the user asks you to adjust your response (e.g., 'be more concise', 'focus on intuition'), you MUST change your approach accordingly. Do NOT repeat the same response.\n"
"- If the user's message is unclear, lacks details, or is ambiguous (e.g. length of an essay, format requirements, "
"specific constraints), do not make assumptions. Ask for clarification and ensure you have enough information before providing an answer.\n"
"- Your goal is to help the user solve their problem. Adhere to their preferences and do your best to help them solve their problem.\n"
+ "- **Verify**: Before finalizing, check that your response satisfies the relevant preferences listed above.\n"
)
else:
# Vanilla mode - no preferences
@@ -152,13 +178,14 @@ class VLLMChatModel(ChatModel):
history: List[ChatTurn],
memory_notes: List[str],
max_new_tokens: int = 512,
+ global_notes: List[str] = None,
) -> List[dict]:
"""Public method to build messages without calling the API.
Used for batch processing where messages are collected first,
then sent in batch to vLLM for concurrent processing.
"""
- return self._build_messages(history, memory_notes, max_new_tokens)
+ return self._build_messages(history, memory_notes, max_new_tokens, global_notes=global_notes)
def answer(
self,
diff --git a/src/personalization/models/preference_extractor/rule_extractor.py b/src/personalization/models/preference_extractor/rule_extractor.py
index 0f743d9..42f43ed 100644
--- a/src/personalization/models/preference_extractor/rule_extractor.py
+++ b/src/personalization/models/preference_extractor/rule_extractor.py
@@ -119,6 +119,59 @@ class QwenRuleExtractor(PreferenceExtractor):
return text[start : end + 1]
return None
+ @torch.inference_mode()
+ def batch_extract_preferences(self, queries: List[str], batch_size: int = 64) -> List[Dict[str, Any]]:
+ """
+ Batch extract preferences from multiple queries using left-padded batching.
+ """
+ if not queries:
+ return []
+
+ # Save and set padding side for decoder-only batched generation
+ orig_padding_side = self.tokenizer.padding_side
+ self.tokenizer.padding_side = "left"
+
+ all_results = []
+ prompts = [self.build_preference_prompt(q) for q in queries]
+
+ for start in range(0, len(prompts), batch_size):
+ batch_prompts = prompts[start:start + batch_size]
+ inputs = self.tokenizer(
+ batch_prompts, return_tensors="pt", padding=True, truncation=True
+ ).to(self.model.device)
+
+ outputs = self.model.generate(
+ **inputs,
+ do_sample=False,
+ max_new_tokens=512,
+ pad_token_id=self.tokenizer.pad_token_id,
+ eos_token_id=self.tokenizer.eos_token_id,
+ )
+
+ for i in range(len(batch_prompts)):
+ input_len = (inputs["attention_mask"][i] == 1).sum().item()
+ gen_ids = outputs[i][input_len:]
+ text = self.tokenizer.decode(gen_ids, skip_special_tokens=True)
+
+ try:
+ data = json.loads(text)
+ validated = PreferenceList.model_validate(data)
+ all_results.append(validated.model_dump())
+ except Exception:
+ extracted_json = self._extract_json_substring(text)
+ if extracted_json:
+ try:
+ data = json.loads(extracted_json)
+ validated = PreferenceList.model_validate(data)
+ all_results.append(validated.model_dump())
+ continue
+ except Exception:
+ pass
+ all_results.append({"preferences": []})
+
+ self.tokenizer.padding_side = orig_padding_side
+ return all_results
+
def extract_turn(self, turns: List[ChatTurn]) -> PreferenceList:
"""
Extract preferences from the LAST user turn in the history.
diff --git a/src/personalization/retrieval/pipeline.py b/src/personalization/retrieval/pipeline.py
index e83940d..6cc7f3e 100644
--- a/src/personalization/retrieval/pipeline.py
+++ b/src/personalization/retrieval/pipeline.py
@@ -12,6 +12,51 @@ def cosine_similarity_matrix(E: np.ndarray, e_q: np.ndarray) -> np.ndarray:
# E: [M, d], e_q: [d]
return np.dot(E, e_q)
+
+def dynamic_topk_selection(
+ scores: np.ndarray,
+ min_k: int = 3,
+ max_k: int = 8,
+ score_ratio: float = 0.5,
+) -> List[int]:
+ """
+ Dynamically select top-k indices based on score distribution.
+
+ Strategy:
+ 1. Sort by score descending
+ 2. Compute threshold = top_score * score_ratio
+ 3. Select all indices with score > threshold
+ 4. Clamp to [min_k, max_k] range
+
+ Args:
+ scores: Array of scores (higher = better)
+ min_k: Minimum number of items to select
+ max_k: Maximum number of items to select
+ score_ratio: Threshold ratio relative to top score
+
+ Returns:
+ List of selected indices (in descending score order)
+ """
+ if len(scores) == 0:
+ return []
+
+ # Sort indices by score descending
+ sorted_indices = np.argsort(scores)[::-1]
+ sorted_scores = scores[sorted_indices]
+
+ # Compute threshold
+ top_score = sorted_scores[0]
+ threshold = top_score * score_ratio
+
+ # Find how many pass threshold
+ n_above_threshold = np.sum(sorted_scores > threshold)
+
+ # Clamp to [min_k, max_k]
+ n_select = max(min_k, min(max_k, n_above_threshold))
+ n_select = min(n_select, len(scores)) # Don't exceed available
+
+ return sorted_indices[:n_select].tolist()
+
def dense_topk_indices(
query: str,
embed_model: EmbeddingModel,
@@ -58,6 +103,49 @@ def dense_topk_indices(
idx = np.argsort(sims)[-k:][::-1]
return idx.tolist()
+def dense_topk_indices_multi_query(
+ queries: List[str],
+ embed_model: EmbeddingModel,
+ memory_embeddings: np.ndarray,
+ valid_indices: List[int] = None,
+ topk: int = 64
+) -> List[int]:
+ """
+ Multi-query dense retrieval: embed all queries, take max similarity per memory,
+ return top-k by max similarity (union effect).
+ """
+ if len(memory_embeddings) == 0:
+ return []
+
+ # Embed all queries at once
+ e_qs = embed_model.encode(queries, normalize=True, return_tensor=False)
+ e_qs = np.array(e_qs, dtype=np.float32) # [Q, d]
+
+ if valid_indices is not None:
+ if len(valid_indices) == 0:
+ return []
+ E_sub = memory_embeddings[valid_indices]
+ # sims: [Q, M_sub]
+ sims = np.dot(e_qs, E_sub.T)
+ # max across queries per memory
+ max_sims = sims.max(axis=0) # [M_sub]
+ k = min(topk, len(max_sims))
+ if k == 0:
+ return []
+ idx_sub = np.argsort(max_sims)[-k:][::-1]
+ return [valid_indices[i] for i in idx_sub]
+
+ # Global search
+ # sims: [Q, M]
+ sims = np.dot(e_qs, memory_embeddings.T)
+ max_sims = sims.max(axis=0) # [M]
+ k = min(topk, len(max_sims))
+ if k == 0:
+ return []
+ idx = np.argsort(max_sims)[-k:][::-1]
+ return idx.tolist()
+
+
def retrieve_with_policy(
user_id: str,
query: str,
@@ -74,6 +162,7 @@ def retrieve_with_policy(
tau: float = 1.0,
only_own_memories: bool = False,
sample: bool = False,
+ queries: List[str] = None,
) -> Tuple[List[MemoryCard], np.ndarray, np.ndarray, List[int], np.ndarray]:
"""
Returns extended info for policy update:
@@ -90,28 +179,37 @@ def retrieve_with_policy(
if not valid_indices:
return [], np.array([]), np.array([]), [], np.array([])
- # 1. Dense retrieval
- dense_idx = dense_topk_indices(
- query,
- embed_model,
- memory_embeddings,
- valid_indices=valid_indices,
- topk=topk_dense
- )
+ # 1. Dense retrieval (multi-query if available)
+ if queries and len(queries) > 1:
+ dense_idx = dense_topk_indices_multi_query(
+ queries,
+ embed_model,
+ memory_embeddings,
+ valid_indices=valid_indices,
+ topk=topk_dense
+ )
+ else:
+ dense_idx = dense_topk_indices(
+ query,
+ embed_model,
+ memory_embeddings,
+ valid_indices=valid_indices,
+ topk=topk_dense
+ )
# DEBUG: Check for duplicates or out of bounds
if len(dense_idx) > 0:
import os
if os.getenv("RETRIEVAL_DEBUG") == "1":
print(f" [Pipeline] Dense Indices (Top {len(dense_idx)}): {dense_idx[:10]}...")
print(f" [Pipeline] Max Index: {max(dense_idx)} | Memory Size: {len(memory_cards)}")
-
+
if not dense_idx:
return [], np.array([]), np.array([]), [], np.array([])
candidates = [memory_cards[i] for i in dense_idx]
candidate_docs = [c.note_text for c in candidates]
- # 2. Rerank base score (P(yes|q,m))
+ # 2. Rerank base score (P(yes|q,m)) - always use original query for reranking
# Skip reranking if we have fewer candidates than topk_rerank (saves GPU memory)
if len(candidates) <= topk_rerank:
base_scores = np.ones(len(candidates)) # Uniform scores
@@ -165,14 +263,25 @@ def retrieve_no_policy(
topk_dense: int = 64,
topk_rerank: int = 8,
only_own_memories: bool = False,
+ queries: List[str] = None,
+ dynamic_topk: bool = False,
+ dynamic_min_k: int = 3,
+ dynamic_max_k: int = 8,
+ dynamic_score_ratio: float = 0.5,
) -> Tuple[List[MemoryCard], np.ndarray, np.ndarray, List[int], np.ndarray]:
"""
Deterministic retrieval baseline (NoPersonal mode):
- Dense retrieval -> Rerank -> Top-K (no policy sampling, no user vector influence)
-
+
+ Args:
+ dynamic_topk: If True, use dynamic selection based on score distribution
+ dynamic_min_k: Minimum items to select (when dynamic_topk=True)
+ dynamic_max_k: Maximum items to select (when dynamic_topk=True)
+ dynamic_score_ratio: Threshold = top_score * ratio (when dynamic_topk=True)
+
Returns same structure as retrieve_with_policy for compatibility:
(candidates, candidate_item_vectors, base_scores, chosen_indices, rerank_scores_for_chosen)
-
+
Note: candidate_item_vectors is empty array (not used in NoPersonal mode)
The last return value is rerank scores instead of policy probs
"""
@@ -183,14 +292,23 @@ def retrieve_no_policy(
if not valid_indices:
return [], np.array([]), np.array([]), [], np.array([])
- # 1. Dense retrieval
- dense_idx = dense_topk_indices(
- query,
- embed_model,
- memory_embeddings,
- valid_indices=valid_indices,
- topk=topk_dense
- )
+ # 1. Dense retrieval (multi-query if available)
+ if queries and len(queries) > 1:
+ dense_idx = dense_topk_indices_multi_query(
+ queries,
+ embed_model,
+ memory_embeddings,
+ valid_indices=valid_indices,
+ topk=topk_dense
+ )
+ else:
+ dense_idx = dense_topk_indices(
+ query,
+ embed_model,
+ memory_embeddings,
+ valid_indices=valid_indices,
+ topk=topk_dense
+ )
if not dense_idx:
return [], np.array([]), np.array([]), [], np.array([])
@@ -198,23 +316,33 @@ def retrieve_no_policy(
candidates = [memory_cards[i] for i in dense_idx]
candidate_docs = [c.note_text for c in candidates]
- # 2. Rerank base score (P(yes|q,m))
- # Skip reranking if we have fewer candidates than topk_rerank (saves GPU memory)
- if len(candidates) <= topk_rerank:
+ # 2. Rerank base score (P(yes|q,m)) - always use original query for reranking
+ max_k = dynamic_max_k if dynamic_topk else topk_rerank
+
+ # Skip reranking if we have fewer candidates than needed
+ if len(candidates) <= max_k:
# Just return all candidates without reranking
base_scores = np.ones(len(candidates)) # Uniform scores
chosen_indices = list(range(len(candidates)))
else:
base_scores = np.array(reranker.score(query, candidate_docs))
- # 3. Deterministic Top-K selection based on rerank scores ONLY (no policy)
- k = min(topk_rerank, len(base_scores))
- top_indices_local = base_scores.argsort()[-k:][::-1]
- chosen_indices = top_indices_local.tolist()
+ # 3. Selection: dynamic or fixed top-K
+ if dynamic_topk:
+ chosen_indices = dynamic_topk_selection(
+ base_scores,
+ min_k=dynamic_min_k,
+ max_k=dynamic_max_k,
+ score_ratio=dynamic_score_ratio,
+ )
+ else:
+ k = min(topk_rerank, len(base_scores))
+ top_indices_local = base_scores.argsort()[-k:][::-1]
+ chosen_indices = top_indices_local.tolist()
# Get scores for chosen items (for logging compatibility)
chosen_scores = base_scores[chosen_indices]
-
+
# Return empty item vectors (not used in NoPersonal mode)
# Return rerank scores as the "probs" field for logging compatibility
return candidates, np.array([]), base_scores, chosen_indices, chosen_scores
diff --git a/src/personalization/retrieval/preference_store/schemas.py b/src/personalization/retrieval/preference_store/schemas.py
index eb82558..5245025 100644
--- a/src/personalization/retrieval/preference_store/schemas.py
+++ b/src/personalization/retrieval/preference_store/schemas.py
@@ -45,3 +45,4 @@ class MemoryCard(BaseModel):
note_text: str # Summarized "condition: action" text
embedding_e: List[float] # The embedding vector
kind: Literal["pref", "fact"] = "pref"
+ is_global: bool = False # True = always include in prompt, bypass retrieval
diff --git a/src/personalization/serving/personalized_llm.py b/src/personalization/serving/personalized_llm.py
index 45d002b..785995b 100644
--- a/src/personalization/serving/personalized_llm.py
+++ b/src/personalization/serving/personalized_llm.py
@@ -285,6 +285,17 @@ class PersonalizedLLM:
reward_mode: str = "keyword", # "keyword", "llm" (GPT-4o-mini), or "llm_local" (local vLLM)
llm_reward_config: Optional["LLMRewardConfig"] = None, # Config for LLM judge
reward_vllm_url: Optional[str] = None, # vLLM URL for local reward model (when reward_mode="llm_local")
+ enable_query_transform: bool = False, # Transform queries for better retrieval matching
+ enable_global_preferences: bool = False, # Separate global prefs that bypass retrieval
+ dynamic_topk: bool = False, # Use dynamic topk based on rerank scores
+ dynamic_min_k: int = 3, # Min preferences for dynamic topk
+ dynamic_max_k: int = 8, # Max preferences for dynamic topk
+ dynamic_score_ratio: float = 0.5, # Threshold = top_score * ratio
+ eta_long: float = None, # Override RL learning rate for z_long
+ eta_short: float = None, # Override RL learning rate for z_short
+ enable_preference_consolidation: bool = False, # Consolidate preferences at session end
+ consolidation_threshold: int = 5, # Min preferences before consolidation
+ enable_preference_rewrite: bool = False, # Use LLM to rewrite/merge retrieved preferences
):
"""
Initialize the PersonalizedLLM.
@@ -319,6 +330,11 @@ class PersonalizedLLM:
self.reranker_type = reranker_type # "qwen3" or "bge"
self.best_of_n = best_of_n # Generate N responses and pick best
self.reward_mode = reward_mode # "keyword", "llm", or "llm_local"
+ self.enable_query_transform = enable_query_transform
+ self.enable_global_preferences = enable_global_preferences
+ self.enable_preference_consolidation = enable_preference_consolidation
+ self.consolidation_threshold = consolidation_threshold
+ self.enable_preference_rewrite = enable_preference_rewrite
# Initialize LLM reward client if using LLM judge
self._llm_reward_client = None # Can be LLMRewardClient or LocalLLMRewardClient
@@ -354,13 +370,18 @@ class PersonalizedLLM:
"beta_long": 2.0, # Increased from 0.1 for stronger personalization
"beta_short": 5.0, # Increased from 0.3
"tau": 1.0,
- "eta_long": 0.01, # Increased from 1e-3 for faster learning
- "eta_short": 0.05, # Increased from 5e-3
+ "eta_long": eta_long if eta_long is not None else 0.01,
+ "eta_short": eta_short if eta_short is not None else 0.05,
"ema_alpha": 0.05,
"short_decay": 0.1,
"dense_topk": 64,
- "rerank_topk": 3,
+ "rerank_topk": 5,
"max_new_tokens": 512,
+ # Dynamic topk settings
+ "dynamic_topk": dynamic_topk,
+ "dynamic_min_k": dynamic_min_k,
+ "dynamic_max_k": dynamic_max_k,
+ "dynamic_score_ratio": dynamic_score_ratio,
}
# Store llm_name before loading config (needed in _load_config)
@@ -528,7 +549,13 @@ class PersonalizedLLM:
self._memory_cards: List[MemoryCard] = []
self._memory_embeddings = np.zeros((0, 4096), dtype=np.float32)
self._item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32)
- self._projection = None
+ # Create default projection (truncation to first k dims) so preferences can be added
+ k = self._rl_cfg["item_dim"]
+ d = 4096
+ P = np.zeros((k, d), dtype=np.float32)
+ P[:, :k] = np.eye(k, dtype=np.float32)
+ self._projection = ItemProjection(P=P, mean=np.zeros(d, dtype=np.float32))
+ print(f"[PersonalizedLLM] Created default projection (truncation, k={k})")
return
# Load cards
@@ -551,8 +578,14 @@ class PersonalizedLLM:
self._projection = ItemProjection(P=proj_data["P"], mean=proj_data["mean"])
self._item_vectors = proj_data["V"]
else:
- self._projection = None
+ # Create default projection so preferences can still be added
+ k = self._rl_cfg["item_dim"]
+ d = 4096
+ P = np.zeros((k, d), dtype=np.float32)
+ P[:, :k] = np.eye(k, dtype=np.float32)
+ self._projection = ItemProjection(P=P, mean=np.zeros(d, dtype=np.float32))
self._item_vectors = np.zeros((len(self._memory_cards), self._rl_cfg["item_dim"]), dtype=np.float32)
+ print(f"[PersonalizedLLM] Created default projection (truncation, k={k})")
print(f"[PersonalizedLLM] Loaded {len(self._memory_cards)} memory cards.")
@@ -588,6 +621,290 @@ class PersonalizedLLM:
except Exception:
return len(text) // 4
+ # Task type keywords for query transformation
+ _TASK_KEYWORDS = {
+ "math": ["solve", "calculate", "integral", "equation", "proof", "derivative",
+ "math", "algebra", "geometry", "trigonometry", "calculus", "arithmetic",
+ "formula", "compute", "evaluate", "simplify", "factor", "graph"],
+ "coding": ["code", "program", "function", "implement", "debug", "python", "java",
+ "javascript", "algorithm", "class", "method", "bug", "error", "compile",
+ "script", "html", "css", "sql", "api", "library", "framework"],
+ "writing": ["write", "essay", "paragraph", "summarize", "draft", "compose",
+ "article", "story", "letter", "email", "report", "review", "edit",
+ "rewrite", "paraphrase", "outline"],
+ "explanation": ["explain", "what is", "how does", "why", "describe", "define",
+ "meaning", "concept", "difference between", "compare", "contrast"],
+ }
+
+ def _transform_query_for_retrieval(self, query: str) -> List[str]:
+ """
+ Transform raw user query into multiple retrieval queries to bridge
+ the semantic gap between task queries and preference descriptions.
+
+ Returns [original_query, transformed_query] or [original_query] if
+ no task type detected.
+ """
+ import re
+ query_lower = query.lower()
+ detected_types = []
+ for task_type, keywords in self._TASK_KEYWORDS.items():
+ for kw in keywords:
+ # Use word boundary matching to avoid false positives
+ # e.g., "api" should not match "capital"
+ if re.search(r'\b' + re.escape(kw) + r'\b', query_lower):
+ detected_types.append(task_type)
+ break
+
+ if not detected_types:
+ return [query]
+
+ # Use first detected type (most specific match)
+ task_type = detected_types[0]
+ transformed = f"user preferences for {task_type} tasks: {query}"
+ return [query, transformed]
+
+ # Patterns indicating a global/universal preference condition
+ _GLOBAL_PATTERNS = ["general", "any", "always", "all ", "every", "regardless",
+ "any task", "any topic", "any question", "all tasks", "all topics"]
+
+ # Domain-specific terms that indicate a conditional preference
+ _DOMAIN_TERMS = ["math", "code", "coding", "program", "writing", "essay", "science",
+ "history", "language", "physics", "chemistry", "biology", "literature",
+ "creative", "technical", "formal", "informal", "academic", "casual"]
+
+ def _classify_preference_scope(self, condition: str) -> bool:
+ """
+ Classify whether a preference condition is global (always applicable)
+ or conditional (task-specific).
+
+ Returns True if global, False if conditional.
+ """
+ cond_lower = condition.lower().strip()
+
+ # Check for explicit global patterns
+ for pattern in self._GLOBAL_PATTERNS:
+ if pattern in cond_lower:
+ return True
+
+ # Very short/vague conditions with no domain terms are likely global
+ words = cond_lower.split()
+ if len(words) <= 2:
+ has_domain = any(term in cond_lower for term in self._DOMAIN_TERMS)
+ if not has_domain:
+ return True
+
+ return False
+
+ # Rewrite prompt for merging retrieved preferences
+ _REWRITE_PROMPT = """You are helping to prepare user preferences for an AI assistant.
+
+The user is asking: {query}
+
+Retrieved preferences about this user:
+{preferences}
+
+Task: Create a concise preference summary that the assistant MUST follow.
+
+Rules:
+1. PRESERVE all specific formatting requirements exactly (e.g., "type hints", "snake_case", "code fence with language")
+2. PRESERVE all structural requirements (e.g., "numbered steps", "bullet points", "answer first then explanation")
+3. Only MERGE preferences that are truly redundant (saying the same thing differently)
+4. Output as a short bulleted list if there are multiple distinct requirements
+5. Keep each point actionable and specific - NO vague generalizations like "follow best practices"
+
+Example input:
+- Include type hints in Python code
+- Use snake_case for variable names
+- When explaining, use numbered steps
+
+Example output:
+- Include type hints
+- Use snake_case for variables
+- Use numbered steps for explanations
+
+If no preferences are relevant to this query type, output: "No specific preferences apply."
+
+Preference summary:"""
+
+ def _rewrite_preferences(self, memory_notes: List[str], query: str) -> List[str]:
+ """
+ Use LLM to rewrite/merge multiple retrieved preferences into concise instructions.
+
+ This is similar to Reflection's proper_scaffolding but focuses on merging
+ rather than just filtering.
+
+ Args:
+ memory_notes: List of retrieved preference notes
+ query: Current user query
+
+ Returns:
+ List with single rewritten instruction (or original if rewrite fails/disabled)
+ """
+ if not memory_notes or len(memory_notes) <= 1:
+ return memory_notes
+
+ try:
+ import requests
+
+ # Format preferences for prompt
+ prefs_text = "\n".join(f"- {note}" for note in memory_notes)
+ prompt = self._REWRITE_PROMPT.format(query=query[:200], preferences=prefs_text)
+
+ # Direct vLLM API call (simpler than going through chat model)
+ messages = [{"role": "user", "content": prompt}]
+ payload = {
+ "model": self._chat_model.model_name,
+ "messages": messages,
+ "max_tokens": 150,
+ "temperature": 0.3, # Lower temperature for more consistent output
+ }
+
+ response = requests.post(
+ f"{self._chat_model.vllm_url}/chat/completions",
+ json=payload,
+ timeout=30
+ )
+
+ if response.status_code != 200:
+ print(f"[REWRITE] API error {response.status_code}, keeping original notes")
+ return memory_notes
+
+ result = response.json()
+ rewritten = result["choices"][0]["message"]["content"].strip().strip('"')
+
+ # Validate response
+ if rewritten and len(rewritten) > 10 and "No specific preferences" not in rewritten:
+ print(f"[REWRITE] {len(memory_notes)} notes → 1 merged instruction")
+ return [rewritten]
+ else:
+ print(f"[REWRITE] Kept original {len(memory_notes)} notes (no valid merge)")
+ return memory_notes
+
+ except Exception as e:
+ print(f"[REWRITE] Failed: {e}, keeping original notes")
+ return memory_notes
+
+ # Consolidation prompt for session-end preference merging
+ _CONSOLIDATION_PROMPT = """You are analyzing user preferences extracted from conversations.
+
+Current preferences for this user:
+{preferences}
+
+Task: Consolidate these preferences into a cleaner, more organized set by:
+1. MERGE similar preferences (e.g., "use bullet points" + "format with bullets" → single preference)
+2. REMOVE redundant or contradictory preferences (keep the more specific one)
+3. PRESERVE all unique, meaningful preferences
+4. Keep the same "When [condition], [action]." format
+
+Output ONLY the consolidated preferences, one per line, in this exact format:
+When [condition], [action].
+
+Do not add explanations or commentary. Just output the preference lines."""
+
+ def consolidate_user_preferences(self, user_id: str) -> int:
+ """
+ Consolidate user preferences at session end using LLM.
+
+ Merges similar preferences, removes redundancy, and creates cleaner
+ preference descriptions. Only runs if user has enough preferences.
+
+ Args:
+ user_id: The user whose preferences to consolidate.
+
+ Returns:
+ Number of preferences after consolidation (0 if skipped).
+ """
+ if not self.enable_preference_consolidation:
+ return 0
+
+ # Get user's memory cards
+ user_cards = [c for c in self._memory_cards if c.user_id == user_id]
+
+ if len(user_cards) < self.consolidation_threshold:
+ return len(user_cards)
+
+ # Build preference list for prompt
+ pref_lines = [card.note_text for card in user_cards]
+ preferences_text = "\n".join(f"- {p}" for p in pref_lines)
+
+ # Call LLM for consolidation
+ prompt = self._CONSOLIDATION_PROMPT.format(preferences=preferences_text)
+ messages = [{"role": "user", "content": prompt}]
+
+ try:
+ result = self._chat_model.answer(messages, max_new_tokens=512)
+ consolidated_text = result.get("content", "").strip()
+
+ if not consolidated_text:
+ return len(user_cards)
+
+ # Parse consolidated preferences
+ new_prefs = []
+ for line in consolidated_text.split("\n"):
+ line = line.strip()
+ if not line or not line.startswith("When "):
+ continue
+ # Parse "When [condition], [action]."
+ if ", " in line:
+ parts = line.split(", ", 1)
+ condition = parts[0].replace("When ", "").strip()
+ action = parts[1].rstrip(".").strip()
+ if condition and action:
+ new_prefs.append({
+ "condition": condition,
+ "action": action,
+ "is_global": self._classify_preference_scope(condition) if self.enable_global_preferences else False,
+ })
+
+ if not new_prefs:
+ return len(user_cards)
+
+ # Remove old cards for this user
+ keep_indices = [i for i, c in enumerate(self._memory_cards) if c.user_id != user_id]
+ self._memory_cards = [self._memory_cards[i] for i in keep_indices]
+ if len(keep_indices) > 0 and len(self._memory_embeddings) > 0:
+ self._memory_embeddings = self._memory_embeddings[keep_indices]
+ self._item_vectors = self._item_vectors[keep_indices]
+ else:
+ embed_dim = self._memory_embeddings.shape[1] if len(self._memory_embeddings) > 0 else 4096
+ self._memory_embeddings = np.zeros((0, embed_dim), dtype=np.float32)
+ self._item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32)
+
+ # Add consolidated preferences
+ for pref in new_prefs:
+ note_text = f"When {pref['condition']}, {pref['action']}."
+
+ # Compute embedding
+ e_note = self._embed_model.encode([note_text], normalize=True, return_tensor=False)[0]
+ v_note = self._projection.transform_vector(np.array(e_note))
+
+ # Create card
+ card = MemoryCard(
+ card_id=str(uuid.uuid4()),
+ user_id=user_id,
+ source_session_id=f"consolidated_{user_id}",
+ source_turn_ids=[],
+ raw_queries=[],
+ preference_list=PreferenceList(preferences=[
+ Preference(condition=pref["condition"], action=pref["action"], confidence=1.0)
+ ]),
+ note_text=note_text,
+ embedding_e=list(e_note),
+ kind="pref",
+ is_global=pref["is_global"],
+ )
+
+ self._memory_cards.append(card)
+ self._memory_embeddings = np.vstack([self._memory_embeddings, np.array([e_note])])
+ self._item_vectors = np.vstack([self._item_vectors, np.array([v_note])])
+
+ print(f"[PersonalizedLLM] Consolidated {len(user_cards)} → {len(new_prefs)} preferences for user {user_id}")
+ return len(new_prefs)
+
+ except Exception as e:
+ print(f"[PersonalizedLLM] Consolidation failed for user {user_id}: {e}")
+ return len(user_cards)
+
def _add_preferences_as_memory(
self,
prefs: PreferenceList,
@@ -628,6 +945,9 @@ class PersonalizedLLM:
e_note = self._embed_model.encode([note_text], normalize=True, return_tensor=False)[0]
v_note = self._projection.transform_vector(np.array(e_note))
+ # Classify as global or conditional
+ is_global = self._classify_preference_scope(pref.condition) if self.enable_global_preferences else False
+
# Create new memory card
card = MemoryCard(
card_id=str(uuid.uuid4()),
@@ -639,6 +959,7 @@ class PersonalizedLLM:
note_text=note_text,
embedding_e=list(e_note),
kind="pref",
+ is_global=is_global,
)
# Add to memory store
@@ -788,35 +1109,61 @@ class PersonalizedLLM:
if extracted_prefs:
print(f"[DEBUG] Added {len(extracted_prefs)} to memory. Total cards: {len(self._memory_cards)}")
+ # Separate global preferences (bypass retrieval) from conditional ones
+ global_notes = []
+ retrieval_cards = self._memory_cards
+ retrieval_embeddings = self._memory_embeddings
+ retrieval_item_vectors = self._item_vectors
+ if self.enable_global_preferences:
+ global_cards = [c for c in self._memory_cards if c.is_global and c.user_id == user_id]
+ global_notes = [c.note_text for c in global_cards[:10]] # Cap at 10
+ # Filter out global cards for retrieval
+ cond_indices = [i for i, c in enumerate(self._memory_cards) if not c.is_global]
+ if cond_indices:
+ retrieval_cards = [self._memory_cards[i] for i in cond_indices]
+ retrieval_embeddings = self._memory_embeddings[cond_indices]
+ if len(self._item_vectors) > 0:
+ retrieval_item_vectors = self._item_vectors[cond_indices]
+ else:
+ retrieval_cards = []
+ retrieval_embeddings = np.zeros((0, self._memory_embeddings.shape[1]), dtype=np.float32) if len(self._memory_embeddings) > 0 else self._memory_embeddings
+ retrieval_item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32)
+
+ # Query transformation for better retrieval matching
+ retrieval_queries = None
+ if self.enable_query_transform:
+ retrieval_queries = self._transform_query_for_retrieval(query)
+
# Retrieve memories
- # In "nopersonal" mode: deterministic retrieval (dense + rerank + topk), no policy/user vector
- # In "full" mode: policy-based retrieval with user vector influence
if self.mode == "nopersonal":
candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_no_policy(
user_id=user_id,
query=query,
embed_model=self._embed_model,
reranker=self._reranker,
- memory_cards=self._memory_cards,
- memory_embeddings=self._memory_embeddings,
+ memory_cards=retrieval_cards,
+ memory_embeddings=retrieval_embeddings,
topk_dense=self._rl_cfg["dense_topk"],
topk_rerank=self._rl_cfg["rerank_topk"],
only_own_memories=self.only_own_memories,
+ queries=retrieval_queries,
+ dynamic_topk=self._rl_cfg["dynamic_topk"],
+ dynamic_min_k=self._rl_cfg["dynamic_min_k"],
+ dynamic_max_k=self._rl_cfg["dynamic_max_k"],
+ dynamic_score_ratio=self._rl_cfg["dynamic_score_ratio"],
)
else:
beta_long = self._rl_cfg["beta_long"]
beta_short = self._rl_cfg["beta_short"]
- # eval_mode=True -> sample=False (greedy/deterministic)
- # eval_mode=False -> sample=True (stochastic/exploration)
candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_with_policy(
user_id=user_id,
query=query,
embed_model=self._embed_model,
reranker=self._reranker,
- memory_cards=self._memory_cards,
- memory_embeddings=self._memory_embeddings,
+ memory_cards=retrieval_cards,
+ memory_embeddings=retrieval_embeddings,
user_store=self._user_store,
- item_vectors=self._item_vectors,
+ item_vectors=retrieval_item_vectors,
topk_dense=self._rl_cfg["dense_topk"],
topk_rerank=self._rl_cfg["rerank_topk"],
beta_long=beta_long,
@@ -824,27 +1171,39 @@ class PersonalizedLLM:
tau=self._rl_cfg["tau"],
only_own_memories=self.only_own_memories,
sample=not self.eval_mode,
+ queries=retrieval_queries,
)
-
+
# Get selected memories
memories_t = [candidates[int(i)] for i in chosen_indices] if chosen_indices else []
memory_notes = [m.note_text for m in memories_t]
+ # Apply preference rewrite if enabled
+ if self.enable_preference_rewrite and memory_notes:
+ memory_notes = self._rewrite_preferences(memory_notes, query)
+
# Debug: show retrieval info
- if memories_t:
+ if memories_t or global_notes:
print(f"[DEBUG-RETRIEVAL] User={user_id}, Query={query[:50]}...")
- print(f"[DEBUG-RETRIEVAL] Candidates={len(candidates)}, Selected={len(memories_t)}")
+ print(f"[DEBUG-RETRIEVAL] Global={len(global_notes)}, Candidates={len(candidates)}, Retrieved={len(memories_t)}")
for i, m in enumerate(memories_t[:3]): # Show top 3
score = probs[chosen_indices[i]] if i < len(chosen_indices) and chosen_indices[i] < len(probs) else 0
print(f"[DEBUG-RETRIEVAL] [{i+1}] score={score:.3f}: {m.note_text[:80]}...")
-
+
+ # Combine all notes for prompt (global + retrieved)
+ # For chat(), we combine all notes; chat_prepare() handles them separately
+ if self.mode != "vanilla":
+ all_memory_notes = (global_notes if global_notes else []) + memory_notes
+ else:
+ all_memory_notes = memory_notes
+
# Build prompt and count tokens
prompt_tokens = self._count_tokens(query)
for turn in session.history:
prompt_tokens += self._count_tokens(turn.text)
- for note in memory_notes:
+ for note in all_memory_notes:
prompt_tokens += self._count_tokens(note)
-
+
# Generate answer (with best-of-N if enabled)
if self.best_of_n > 1:
# Generate N responses and pick the best one
@@ -852,7 +1211,7 @@ class PersonalizedLLM:
for i in range(self.best_of_n):
resp = self._chat_model.answer(
history=session.history,
- memory_notes=memory_notes,
+ memory_notes=all_memory_notes,
max_new_tokens=self._rl_cfg["max_new_tokens"],
temperature=0.8, # Slightly higher temp for diversity
)
@@ -869,7 +1228,7 @@ class PersonalizedLLM:
else:
answer_t = self._chat_model.answer(
history=session.history,
- memory_notes=memory_notes,
+ memory_notes=all_memory_notes,
max_new_tokens=self._rl_cfg["max_new_tokens"],
)
@@ -920,7 +1279,7 @@ class PersonalizedLLM:
debug=debug,
)
- def chat_prepare(self, user_id: str, query: str) -> dict:
+ def chat_prepare(self, user_id: str, query: str, skip_extraction: bool = False, skip_auto_reward: bool = False) -> dict:
"""
Prepare for chat without calling the LLM.
@@ -984,7 +1343,8 @@ class PersonalizedLLM:
}
# Auto-compute reward via LLM judge if enabled
- if self._llm_reward_client is not None:
+ # skip_auto_reward=True when batch framework handles rewards externally
+ if self._llm_reward_client is not None and not skip_auto_reward:
import asyncio
try:
reward, gating = asyncio.run(eval_step_llm(
@@ -1006,7 +1366,7 @@ class PersonalizedLLM:
# Extract preferences from conversation
extracted_prefs = []
- if self.enable_preference_extraction:
+ if self.enable_preference_extraction and not skip_extraction:
prefs = self._extractor.extract_turn(session.history)
if prefs.preferences:
print(f"[DEBUG] Extracted {len(prefs.preferences)} prefs from history (len={len(session.history)})")
@@ -1016,6 +1376,30 @@ class PersonalizedLLM:
if extracted_prefs:
print(f"[DEBUG] Added {len(extracted_prefs)} to memory. Total cards: {len(self._memory_cards)}")
+ # Separate global preferences (bypass retrieval) from conditional ones
+ global_notes = []
+ retrieval_cards = self._memory_cards
+ retrieval_embeddings = self._memory_embeddings
+ retrieval_item_vectors = self._item_vectors
+ if self.enable_global_preferences:
+ global_cards = [c for c in self._memory_cards if c.is_global and c.user_id == user_id]
+ global_notes = [c.note_text for c in global_cards[:10]] # Cap at 10
+ cond_indices = [i for i, c in enumerate(self._memory_cards) if not c.is_global]
+ if cond_indices:
+ retrieval_cards = [self._memory_cards[i] for i in cond_indices]
+ retrieval_embeddings = self._memory_embeddings[cond_indices]
+ if len(self._item_vectors) > 0:
+ retrieval_item_vectors = self._item_vectors[cond_indices]
+ else:
+ retrieval_cards = []
+ retrieval_embeddings = np.zeros((0, self._memory_embeddings.shape[1]), dtype=np.float32) if len(self._memory_embeddings) > 0 else self._memory_embeddings
+ retrieval_item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32)
+
+ # Query transformation for better retrieval matching
+ retrieval_queries = None
+ if self.enable_query_transform:
+ retrieval_queries = self._transform_query_for_retrieval(query)
+
# Retrieve memories
if self.mode == "nopersonal":
candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_no_policy(
@@ -1023,11 +1407,16 @@ class PersonalizedLLM:
query=query,
embed_model=self._embed_model,
reranker=self._reranker,
- memory_cards=self._memory_cards,
- memory_embeddings=self._memory_embeddings,
+ memory_cards=retrieval_cards,
+ memory_embeddings=retrieval_embeddings,
topk_dense=self._rl_cfg["dense_topk"],
topk_rerank=self._rl_cfg["rerank_topk"],
only_own_memories=self.only_own_memories,
+ queries=retrieval_queries,
+ dynamic_topk=self._rl_cfg["dynamic_topk"],
+ dynamic_min_k=self._rl_cfg["dynamic_min_k"],
+ dynamic_max_k=self._rl_cfg["dynamic_max_k"],
+ dynamic_score_ratio=self._rl_cfg["dynamic_score_ratio"],
)
else:
beta_long = self._rl_cfg["beta_long"]
@@ -1037,10 +1426,10 @@ class PersonalizedLLM:
query=query,
embed_model=self._embed_model,
reranker=self._reranker,
- memory_cards=self._memory_cards,
- memory_embeddings=self._memory_embeddings,
+ memory_cards=retrieval_cards,
+ memory_embeddings=retrieval_embeddings,
user_store=self._user_store,
- item_vectors=self._item_vectors,
+ item_vectors=retrieval_item_vectors,
topk_dense=self._rl_cfg["dense_topk"],
topk_rerank=self._rl_cfg["rerank_topk"],
beta_long=beta_long,
@@ -1048,14 +1437,19 @@ class PersonalizedLLM:
tau=self._rl_cfg["tau"],
only_own_memories=self.only_own_memories,
sample=not self.eval_mode,
+ queries=retrieval_queries,
)
memories_t = [candidates[int(i)] for i in chosen_indices] if chosen_indices else []
memory_notes = [m.note_text for m in memories_t]
- if memories_t:
+ # Apply preference rewrite if enabled
+ if self.enable_preference_rewrite and memory_notes:
+ memory_notes = self._rewrite_preferences(memory_notes, query)
+
+ if memories_t or global_notes:
print(f"[DEBUG-RETRIEVAL] User={user_id}, Query={query[:50]}...")
- print(f"[DEBUG-RETRIEVAL] Candidates={len(candidates)}, Selected={len(memories_t)}")
+ print(f"[DEBUG-RETRIEVAL] Global={len(global_notes)}, Candidates={len(candidates)}, Retrieved={len(memories_t)}")
for i, m in enumerate(memories_t[:3]):
score = probs[chosen_indices[i]] if i < len(chosen_indices) and chosen_indices[i] < len(probs) else 0
print(f"[DEBUG-RETRIEVAL] [{i+1}] score={score:.3f}: {m.note_text[:80]}...")
@@ -1064,14 +1458,17 @@ class PersonalizedLLM:
prompt_tokens = self._count_tokens(query)
for turn in session.history:
prompt_tokens += self._count_tokens(turn.text)
- for note in memory_notes:
+ all_notes = memory_notes + (global_notes if self.mode != "vanilla" else [])
+ for note in all_notes:
prompt_tokens += self._count_tokens(note)
- # Build messages for LLM
+ # Build messages for LLM (pass global_notes separately for distinct prompt sections)
+ effective_global = global_notes if (self.enable_global_preferences and self.mode != "vanilla") else None
messages = self._chat_model.build_messages(
history=session.history,
memory_notes=memory_notes,
max_new_tokens=self._rl_cfg["max_new_tokens"],
+ global_notes=effective_global,
)
# Return messages and context for chat_complete
@@ -1176,22 +1573,47 @@ class PersonalizedLLM:
debug=debug,
)
+ def apply_extracted_preferences(self, user_id: str, pref_dict: dict) -> list:
+ """Apply pre-computed extraction results (from batch extraction) to memory."""
+ prefs = PreferenceList.model_validate(pref_dict)
+ if not prefs.preferences:
+ return []
+ ctx = self._get_or_create_session(user_id)
+ query = ctx.session_state.history[-1].text if ctx.session_state.history else ""
+ extracted = self._add_preferences_as_memory(prefs, query, user_id, ctx.turn_counter)
+ if extracted:
+ print(f"[DEBUG] Batch-added {len(extracted)} to memory. Total cards: {len(self._memory_cards)}")
+ return extracted
+
+ def get_last_user_query(self, user_id: str) -> str:
+ """Get the last user message text for this user's session."""
+ ctx = self._sessions.get(user_id)
+ if ctx and ctx.session_state.history:
+ for t in reversed(ctx.session_state.history):
+ if t.role == "user":
+ return t.text
+ return ""
+
def reset_session(self, user_id: str) -> None:
"""
Reset session for a user (new chat window).
-
+
This clears:
- Session conversation history
- Short-term user vector (z_short)
- Pending RL update info
-
+
This preserves:
- Long-term user vector (z_long)
- - User's memory cards
-
+ - User's memory cards (may be consolidated if enabled)
+
Args:
user_id: The user whose session to reset.
"""
+ # Consolidate preferences at session end (before clearing session)
+ if self.enable_preference_consolidation:
+ self.consolidate_user_preferences(user_id)
+
# Clear session context
if user_id in self._sessions:
del self._sessions[user_id]
@@ -1270,14 +1692,14 @@ class PersonalizedLLM:
"""
if not self.enable_rl_updates:
return
-
+
# In "nopersonal" or "vanilla" mode, skip RL updates entirely (baseline)
if self.mode in ("nopersonal", "vanilla"):
return
-
+
user_id = feedback.user_id
ctx = self._sessions.get(user_id)
-
+
if ctx is None or ctx.pending_rl_update is None:
return
@@ -1289,12 +1711,15 @@ class PersonalizedLLM:
pending.get("last_policy_probs") is not None and
pending.get("last_chosen_indices") is not None and
len(pending["last_chosen_indices"]) > 0):
-
+
# Extract chosen vectors
chosen_indices = pending["last_chosen_indices"]
candidate_vectors = pending["last_candidate_item_vectors"]
-
+
if len(candidate_vectors) > 0:
+ print(f"[DEBUG-REINFORCE] User={user_id} reward={feedback.reward:.2f} "
+ f"n_candidates={len(candidate_vectors)} chosen={chosen_indices} "
+ f"probs_shape={pending['last_policy_probs'].shape if hasattr(pending['last_policy_probs'], 'shape') else 'N/A'}")
# REINFORCE expects:
# - item_vectors: ALL candidate vectors [K, k]
# - chosen_indices: indices into those candidates
@@ -1313,6 +1738,7 @@ class PersonalizedLLM:
short_decay=self._rl_cfg["short_decay"],
)
+ print(f"[DEBUG-REINFORCE] updated={updated} z_long_norm={np.linalg.norm(user_state.z_long):.15e}")
if updated:
self._user_store.save_state(user_state)