summaryrefslogtreecommitdiff
path: root/src/personalization/serving
diff options
context:
space:
mode:
Diffstat (limited to 'src/personalization/serving')
-rw-r--r--src/personalization/serving/personalized_llm.py721
1 files changed, 637 insertions, 84 deletions
diff --git a/src/personalization/serving/personalized_llm.py b/src/personalization/serving/personalized_llm.py
index 2c4d5a8..733ff87 100644
--- a/src/personalization/serving/personalized_llm.py
+++ b/src/personalization/serving/personalized_llm.py
@@ -33,6 +33,7 @@ from personalization.config.settings import load_local_models_config
from personalization.config.registry import get_preference_extractor, get_chat_model
from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B
from personalization.models.reranker.qwen3_reranker import Qwen3Reranker
+from personalization.models.reranker.bge_reranker import BGEReranker
from personalization.user_model.tensor_store import UserTensorStore, UserState
from personalization.user_model.session_state import OnlineSessionState
from personalization.user_model.features import ItemProjection
@@ -40,7 +41,8 @@ from personalization.retrieval.preference_store.schemas import (
MemoryCard, ChatTurn, PreferenceList, Preference
)
from personalization.retrieval.pipeline import retrieve_with_policy, retrieve_no_policy
-from personalization.feedback.handlers import eval_step
+from personalization.feedback.handlers import eval_step, eval_step_llm
+from personalization.feedback.llm_reward import LLMRewardClient, LLMRewardConfig
from personalization.user_model.policy.reinforce import reinforce_update_user_state
@@ -113,6 +115,119 @@ class _SessionContext:
# =============================================================================
+# Shared Model Singletons for Multi-threaded Efficiency
+# =============================================================================
+
+_shared_embed_model = None
+_shared_reranker = None
+_shared_extractor = None
+_shared_models_lock = None # Will be initialized on first use
+
+
+def _get_shared_models_lock():
+ """Get or create the threading lock for shared models."""
+ global _shared_models_lock
+ if _shared_models_lock is None:
+ import threading
+ _shared_models_lock = threading.Lock()
+ return _shared_models_lock
+
+
+def get_shared_embedding_model(model_path: str, device_map: str = "auto"):
+ """Get or create shared embedding model (thread-safe singleton)."""
+ global _shared_embed_model
+ import torch
+
+ lock = _get_shared_models_lock()
+ with lock:
+ if _shared_embed_model is None:
+ print(f"[SharedModels] Loading shared embedding model on {device_map}...")
+ _shared_embed_model = Qwen3Embedding8B(
+ model_path=model_path,
+ dtype=torch.bfloat16,
+ device_map=device_map,
+ )
+ print("[SharedModels] Shared embedding model loaded.")
+ return _shared_embed_model
+
+
+def get_shared_reranker(model_path: str, device_map: str = "auto", reranker_type: str = "qwen3"):
+ """Get or create shared reranker model (thread-safe singleton)."""
+ global _shared_reranker
+ import torch
+
+ lock = _get_shared_models_lock()
+ with lock:
+ if _shared_reranker is None:
+ print(f"[SharedModels] Loading shared reranker ({reranker_type}) on {device_map}...")
+ if reranker_type == "bge":
+ _shared_reranker = BGEReranker(
+ model_path=model_path,
+ device_map=device_map,
+ dtype=torch.float16,
+ )
+ else:
+ _shared_reranker = Qwen3Reranker(
+ model_path=model_path,
+ device_map=device_map,
+ dtype=torch.bfloat16,
+ )
+ print("[SharedModels] Shared reranker model loaded.")
+ return _shared_reranker
+
+
+def get_shared_extractor(model_path: str, device_map: str = "auto"):
+ """Get or create shared preference extractor model (thread-safe singleton)."""
+ global _shared_extractor
+ import torch
+ from personalization.models.preference_extractor.rule_extractor import QwenRuleExtractor
+
+ lock = _get_shared_models_lock()
+ with lock:
+ if _shared_extractor is None:
+ print(f"[SharedModels] Loading shared preference extractor on {device_map}...")
+ _shared_extractor = QwenRuleExtractor(
+ model_path=model_path,
+ dtype=torch.bfloat16,
+ device_map=device_map,
+ )
+ print("[SharedModels] Shared preference extractor loaded.")
+ return _shared_extractor
+
+
+def clear_shared_models():
+ """Free all shared singleton models to reclaim GPU memory between methods."""
+ global _shared_embed_model, _shared_reranker, _shared_extractor
+ import gc
+
+ lock = _get_shared_models_lock()
+ with lock:
+ freed = []
+ if _shared_embed_model is not None:
+ freed.append("embedding")
+ del _shared_embed_model
+ _shared_embed_model = None
+ if _shared_reranker is not None:
+ freed.append("reranker")
+ del _shared_reranker
+ _shared_reranker = None
+ if _shared_extractor is not None:
+ freed.append("extractor")
+ del _shared_extractor
+ _shared_extractor = None
+
+ if freed:
+ gc.collect()
+ try:
+ import torch
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ except ImportError:
+ pass
+ print(f"[SharedModels] Cleared: {', '.join(freed)}")
+
+
+# =============================================================================
# PersonalizedLLM Class
# =============================================================================
@@ -163,6 +278,12 @@ class PersonalizedLLM:
mode: str = "full", # "full", "nopersonal", or "vanilla"
eval_mode: bool = True, # True = greedy selection, False = stochastic sampling
device_assignment: Optional[Dict[str, str]] = None, # Multi-GPU support
+ llm_name: Optional[str] = None, # Override LLM name (e.g., "llama_8b_vllm" for vLLM)
+ use_shared_models: bool = False, # Use shared singleton models for multi-threaded efficiency
+ reranker_type: str = "qwen3", # "qwen3" (8B) or "bge" (278M)
+ best_of_n: int = 1, # Generate N responses and pick best (for RAG methods)
+ reward_mode: str = "keyword", # "keyword" (legacy heuristic) or "llm" (GPT-5-nano judge)
+ llm_reward_config: Optional["LLMRewardConfig"] = None, # Config for LLM judge
):
"""
Initialize the PersonalizedLLM.
@@ -183,12 +304,25 @@ class PersonalizedLLM:
device_assignment: Optional dict to assign models to specific GPUs.
Example: {"embed": "cuda:0", "reranker": "cuda:1", "chat": "cuda:2", "extractor": "cuda:3"}
If None, uses "auto" for all models.
+ use_shared_models: If True, use shared singleton models for embedding and reranker.
+ This is essential for multi-threaded/parallel profile processing to avoid
+ loading duplicate models. When enabled, the first thread loads the models,
+ and subsequent threads reuse the shared instances.
"""
self.only_own_memories = only_own_memories
+ self.use_shared_models = use_shared_models
self.enable_preference_extraction = enable_preference_extraction
self.enable_rl_updates = enable_rl_updates
self.mode = mode # "full" or "nopersonal"
self.eval_mode = eval_mode # True = greedy, False = sample
+ self.reranker_type = reranker_type # "qwen3" or "bge"
+ self.best_of_n = best_of_n # Generate N responses and pick best
+ self.reward_mode = reward_mode # "keyword" or "llm"
+
+ # Initialize LLM reward client if using LLM judge
+ self._llm_reward_client: Optional[LLMRewardClient] = None
+ if reward_mode == "llm":
+ self._llm_reward_client = LLMRewardClient(llm_reward_config or LLMRewardConfig())
# Multi-GPU device assignment
self._device_assignment = device_assignment or {
@@ -219,6 +353,9 @@ class PersonalizedLLM:
"max_new_tokens": 512,
}
+ # Store llm_name before loading config (needed in _load_config)
+ self._llm_name_override = llm_name
+
# Load config and override RL params if available
self._load_config(config_path)
@@ -249,8 +386,8 @@ class PersonalizedLLM:
if config_path is None:
config_path = "configs/user_model.yaml"
- self._llm_name = "qwen_1_5b" # Default
-
+ self._llm_name = self._llm_name_override or "qwen_1_5b" # Default, can be overridden
+
try:
if os.path.exists(config_path):
with open(config_path, "r") as f:
@@ -260,8 +397,8 @@ class PersonalizedLLM:
for key in self._rl_cfg:
if key in user_cfg:
self._rl_cfg[key] = user_cfg[key]
- # LLM name
- if "llm_name" in user_cfg:
+ # LLM name (only from config if not already set via parameter)
+ if self._llm_name_override is None and "llm_name" in user_cfg:
self._llm_name = user_cfg["llm_name"]
except Exception as e:
print(f"[PersonalizedLLM] Warning: Failed to load config: {e}")
@@ -269,53 +406,110 @@ class PersonalizedLLM:
def _load_models(self):
"""Load all ML models with optional multi-GPU assignment."""
import torch
-
- # Report GPU availability
- num_gpus = torch.cuda.device_count()
- print(f"[PersonalizedLLM] Available GPUs: {num_gpus}")
- for i in range(num_gpus):
- mem = torch.cuda.get_device_properties(i).total_memory / 1e9
- print(f" GPU {i}: {torch.cuda.get_device_name(i)} ({mem:.1f}GB)")
-
+
+ # Report GPU availability (only once, not for shared model instances)
+ if not self.use_shared_models:
+ num_gpus = torch.cuda.device_count()
+ print(f"[PersonalizedLLM] Available GPUs: {num_gpus}")
+ for i in range(num_gpus):
+ mem = torch.cuda.get_device_properties(i).total_memory / 1e9
+ print(f" GPU {i}: {torch.cuda.get_device_name(i)} ({mem:.1f}GB)")
+
embed_device = self._device_assignment.get("embed", "auto")
reranker_device = self._device_assignment.get("reranker", "auto")
chat_device = self._device_assignment.get("chat", "auto")
extractor_device = self._device_assignment.get("extractor", "auto")
-
- # Embedding model
- print(f"[PersonalizedLLM] Loading Embedding model on {embed_device}...")
- self._embed_model = Qwen3Embedding8B(
- model_path=self._cfg.embedding.qwen3.local_path,
- dtype=torch.bfloat16,
- device_map=embed_device,
- )
-
- # Reranker
- print(f"[PersonalizedLLM] Loading Reranker on {reranker_device}...")
- self._reranker = Qwen3Reranker(
- model_path=self._cfg.reranker.qwen3_8b.local_path,
- device_map=reranker_device,
- dtype=torch.bfloat16,
- )
-
+
+ # Embedding model - only load for modes that use RAG retrieval
+ # Vanilla and contextual modes don't need embedding/reranker
+ needs_retrieval = self.mode not in ("vanilla", "contextual")
+
+ if needs_retrieval:
+ if self.use_shared_models:
+ print(f"[PersonalizedLLM] Using shared embedding model...")
+ self._embed_model = get_shared_embedding_model(
+ model_path=self._cfg.embedding.qwen3.local_path,
+ device_map=embed_device,
+ )
+ else:
+ print(f"[PersonalizedLLM] Loading Embedding model on {embed_device}...")
+ self._embed_model = Qwen3Embedding8B(
+ model_path=self._cfg.embedding.qwen3.local_path,
+ dtype=torch.bfloat16,
+ device_map=embed_device,
+ )
+ else:
+ print(f"[PersonalizedLLM] Skipping embedding model (not needed for {self.mode} mode)")
+ self._embed_model = None
+
+ # Reranker - only load for modes that use RAG retrieval
+ # Support both qwen3 (8B) and bge (278M) rerankers
+ if needs_retrieval:
+ if self.reranker_type == "bge":
+ reranker_path = getattr(self._cfg.reranker, "bge_base", None)
+ reranker_path = reranker_path.local_path if reranker_path else "BAAI/bge-reranker-base"
+ else:
+ reranker_path = self._cfg.reranker.qwen3_8b.local_path
+
+ if self.use_shared_models:
+ print(f"[PersonalizedLLM] Using shared reranker model ({self.reranker_type})...")
+ self._reranker = get_shared_reranker(
+ model_path=reranker_path,
+ device_map=reranker_device,
+ reranker_type=self.reranker_type,
+ )
+ else:
+ print(f"[PersonalizedLLM] Loading Reranker ({self.reranker_type}) on {reranker_device}...")
+ if self.reranker_type == "bge":
+ self._reranker = BGEReranker(
+ model_path=reranker_path,
+ device_map=reranker_device,
+ dtype=torch.float16,
+ )
+ else:
+ self._reranker = Qwen3Reranker(
+ model_path=reranker_path,
+ device_map=reranker_device,
+ dtype=torch.bfloat16,
+ )
+ else:
+ print(f"[PersonalizedLLM] Skipping reranker (not needed for {self.mode} mode)")
+ self._reranker = None
+
# Chat model (via registry for backend switching)
print(f"[PersonalizedLLM] Loading ChatModel: {self._llm_name} on {chat_device}...")
# Pass device override if specified (not "auto")
device_for_chat = chat_device if chat_device != "auto" else None
self._chat_model = get_chat_model(self._llm_name, device_override=device_for_chat)
-
- # Preference extractor
+
+ # Preference extractor - use shared singleton if enabled
if self.enable_preference_extraction:
extractor_name = "qwen3_0_6b_sft"
- print(f"[PersonalizedLLM] Loading extractor: {extractor_name} on {extractor_device}...")
- try:
- self._extractor = get_preference_extractor(extractor_name)
- except Exception as e:
- print(f"[PersonalizedLLM] Warning: Failed to load {extractor_name}: {e}. Using rule-based.")
- self._extractor = get_preference_extractor("rule")
+ if self.use_shared_models:
+ print(f"[PersonalizedLLM] Using shared preference extractor...")
+ try:
+ extractor_path = self._cfg.preference_extractor.get("qwen3_0_6b_sft", {}).get("path", None)
+ if extractor_path:
+ self._extractor = get_shared_extractor(
+ model_path=extractor_path,
+ device_map=extractor_device,
+ )
+ else:
+ print(f"[PersonalizedLLM] Extractor path not found, using rule-based.")
+ self._extractor = get_preference_extractor("rule")
+ except Exception as e:
+ print(f"[PersonalizedLLM] Warning: Failed to load shared extractor: {e}. Using rule-based.")
+ self._extractor = get_preference_extractor("rule")
+ else:
+ print(f"[PersonalizedLLM] Loading extractor: {extractor_name} on {extractor_device}...")
+ try:
+ self._extractor = get_preference_extractor(extractor_name)
+ except Exception as e:
+ print(f"[PersonalizedLLM] Warning: Failed to load {extractor_name}: {e}. Using rule-based.")
+ self._extractor = get_preference_extractor("rule")
else:
- print("[PersonalizedLLM] Preference extraction disabled, using rule-based extractor.")
- self._extractor = get_preference_extractor("rule")
+ print("[PersonalizedLLM] Preference extraction disabled, skipping extractor.")
+ self._extractor = None
def _load_memory_store(self):
"""Load memory cards and embeddings."""
@@ -396,33 +590,34 @@ class PersonalizedLLM:
Returns list of preference dicts for debug info.
"""
extracted = []
-
+
if not prefs.preferences or self._projection is None:
return extracted
-
- # Compute embedding for the query
- e_q = self._embed_model.encode([query], return_tensor=False)[0]
- v_q = self._projection.transform_vector(np.array(e_q))
-
+
for pref in prefs.preferences:
note_text = f"When {pref.condition}, {pref.action}."
-
+
# Record for debug
extracted.append({
"condition": pref.condition,
"action": pref.action,
"confidence": pref.confidence,
})
-
+
# Deduplication check
is_duplicate = any(
card.user_id == user_id and card.note_text == note_text
for card in self._memory_cards
)
-
+
if is_duplicate:
continue
-
+
+ # Compute embedding from note_text (NOT query) for proper semantic retrieval
+ # This ensures retrieval query "solve math problem" matches stored "When math problems..."
+ e_note = self._embed_model.encode([note_text], normalize=True, return_tensor=False)[0]
+ v_note = self._projection.transform_vector(np.array(e_note))
+
# Create new memory card
card = MemoryCard(
card_id=str(uuid.uuid4()),
@@ -432,21 +627,61 @@ class PersonalizedLLM:
raw_queries=[query],
preference_list=PreferenceList(preferences=[pref]),
note_text=note_text,
- embedding_e=list(e_q),
+ embedding_e=list(e_note),
kind="pref",
)
-
+
# Add to memory store
self._memory_cards.append(card)
- self._memory_embeddings = np.vstack([self._memory_embeddings, np.array([e_q])])
- self._item_vectors = np.vstack([self._item_vectors, np.array([v_q])])
+ self._memory_embeddings = np.vstack([self._memory_embeddings, np.array([e_note])])
+ self._item_vectors = np.vstack([self._item_vectors, np.array([v_note])])
return extracted
-
+
+ def _score_response(self, response: str) -> float:
+ """
+ Score a response for best-of-N selection.
+
+ Higher score = better response. Scoring heuristics:
+ 1. Length: Longer responses typically have more substance
+ 2. Solution indicators: Contains formulas, steps, answers
+ 3. Proactivity: Doesn't end with just a question
+
+ Returns:
+ Float score (higher is better)
+ """
+ score = 0.0
+ response_lower = response.lower()
+
+ # Length score (normalized, cap at 1000 chars)
+ score += min(len(response), 1000) / 1000 * 3.0
+
+ # Solution indicators (+1 each, max 5)
+ solution_indicators = ['=', 'step', 'answer', 'formula', 'result', 'therefore', 'solution']
+ indicator_count = sum(1 for ind in solution_indicators if ind in response_lower)
+ score += min(indicator_count, 5) * 0.5
+
+ # Structured content (+1 for numbered/bulleted lists)
+ if any(marker in response for marker in ['1.', '2.', '- ', '* ', '##']):
+ score += 1.0
+
+ # Penalty for ending with question (passive behavior)
+ # Check last 100 chars for question marks
+ if '?' in response[-100:]:
+ score -= 1.5
+
+ # Bonus for providing concrete values/numbers
+ import re
+ numbers = re.findall(r'\d+\.?\d*', response)
+ if len(numbers) >= 3:
+ score += 1.0
+
+ return score
+
# =========================================================================
# Public Interface
# =========================================================================
-
+
def chat(self, user_id: str, query: str) -> AssistantResponse:
"""
Main online chat interface.
@@ -465,34 +700,19 @@ class PersonalizedLLM:
ctx = self._get_or_create_session(user_id)
session = ctx.session_state
user_state = self._user_store.get_state(user_id)
-
+
# Record user vector before for debug
z_long_before = user_state.z_long.copy().tolist()
z_short_before = user_state.z_short.copy().tolist()
-
- # Compute query embedding
- e_q_t = np.array(self._embed_model.encode([query], return_tensor=False)[0])
-
- # Store pending RL update info from last turn (for apply_feedback)
- if session.last_query is not None and self.enable_rl_updates:
- ctx.pending_rl_update = {
- "last_query": session.last_query,
- "last_answer": session.last_answer,
- "last_memories": session.last_memories,
- "last_query_embedding": session.last_query_embedding,
- "current_query_embedding": e_q_t,
- "last_candidate_item_vectors": session.last_candidate_item_vectors,
- "last_policy_probs": session.last_policy_probs,
- "last_chosen_indices": session.last_chosen_indices,
- }
-
+
# Add user turn to history
user_turn = self._build_chat_turn(user_id, query, "user", ctx.turn_counter)
session.history.append(user_turn)
-
+
# Vanilla mode: pure LLM without any memory or preference extraction
if self.mode == "vanilla":
- # Skip preference extraction and memory retrieval entirely
+ # Skip embedding, preference extraction, and memory retrieval entirely
+ e_q_t = np.zeros(4096, dtype=np.float32) # Placeholder for vanilla mode
extracted_prefs = []
candidates = []
cand_item_vecs = np.array([])
@@ -502,13 +722,61 @@ class PersonalizedLLM:
memories_t = []
memory_notes = []
else:
+ # Compute query embedding (only needed for non-vanilla modes)
+ # Explicitly normalize for consistent cosine similarity with stored embeddings
+ embed_result = self._embed_model.encode([query], normalize=True, return_tensor=False)
+ if embed_result is None or len(embed_result) == 0:
+ raise RuntimeError(f"Embedding model returned empty result for query: {query[:100]}")
+ e_q_t = np.array(embed_result[0])
+
+ # Store pending RL update info from last turn (for apply_feedback)
+ if session.last_query is not None and self.enable_rl_updates:
+ ctx.pending_rl_update = {
+ "last_query": session.last_query,
+ "last_answer": session.last_answer,
+ "last_memories": session.last_memories,
+ "last_query_embedding": session.last_query_embedding,
+ "current_query_embedding": e_q_t,
+ "last_candidate_item_vectors": session.last_candidate_item_vectors,
+ "last_policy_probs": session.last_policy_probs,
+ "last_chosen_indices": session.last_chosen_indices,
+ }
+
+ # Auto-compute reward via LLM judge if enabled
+ if self.reward_mode == "llm" and self._llm_reward_client is not None:
+ import asyncio
+ try:
+ reward, gating = asyncio.run(eval_step_llm(
+ q_t=session.last_query,
+ answer_t=session.last_answer,
+ q_t1=query,
+ memories_t=session.last_memories or [],
+ client=self._llm_reward_client,
+ ))
+ if gating > 0.0:
+ self.apply_feedback(Feedback(
+ user_id=user_id,
+ turn_id=ctx.turn_counter - 1,
+ reward=reward,
+ gating=gating,
+ ))
+ except Exception as e:
+ # Graceful fallback: skip RL update if judge fails
+ print(f"[LLM-Reward] Judge call failed, skipping update: {e}")
+
# Extract preferences from conversation (if enabled)
+ # extract_turn processes only the last user turn - efficient since called each turn
+ # Preferences accumulate in _memory_cards across turns (dedup prevents duplicates)
extracted_prefs = []
if self.enable_preference_extraction:
prefs = self._extractor.extract_turn(session.history)
+ if prefs.preferences:
+ print(f"[DEBUG] Extracted {len(prefs.preferences)} prefs from history (len={len(session.history)})")
extracted_prefs = self._add_preferences_as_memory(
prefs, query, user_id, ctx.turn_counter
)
+ if extracted_prefs:
+ print(f"[DEBUG] Added {len(extracted_prefs)} to memory. Total cards: {len(self._memory_cards)}")
# Retrieve memories
# In "nopersonal" mode: deterministic retrieval (dense + rerank + topk), no policy/user vector
@@ -551,6 +819,14 @@ class PersonalizedLLM:
# Get selected memories
memories_t = [candidates[int(i)] for i in chosen_indices] if chosen_indices else []
memory_notes = [m.note_text for m in memories_t]
+
+ # Debug: show retrieval info
+ if memories_t:
+ print(f"[DEBUG-RETRIEVAL] User={user_id}, Query={query[:50]}...")
+ print(f"[DEBUG-RETRIEVAL] Candidates={len(candidates)}, Selected={len(memories_t)}")
+ for i, m in enumerate(memories_t[:3]): # Show top 3
+ score = probs[chosen_indices[i]] if i < len(chosen_indices) and chosen_indices[i] < len(probs) else 0
+ print(f"[DEBUG-RETRIEVAL] [{i+1}] score={score:.3f}: {m.note_text[:80]}...")
# Build prompt and count tokens
prompt_tokens = self._count_tokens(query)
@@ -559,13 +835,34 @@ class PersonalizedLLM:
for note in memory_notes:
prompt_tokens += self._count_tokens(note)
- # Generate answer
- answer_t = self._chat_model.answer(
- history=session.history,
- memory_notes=memory_notes,
- max_new_tokens=self._rl_cfg["max_new_tokens"],
- )
-
+ # Generate answer (with best-of-N if enabled)
+ if self.best_of_n > 1:
+ # Generate N responses and pick the best one
+ candidates_responses = []
+ for i in range(self.best_of_n):
+ resp = self._chat_model.answer(
+ history=session.history,
+ memory_notes=memory_notes,
+ max_new_tokens=self._rl_cfg["max_new_tokens"],
+ temperature=0.8, # Slightly higher temp for diversity
+ )
+ score = self._score_response(resp)
+ candidates_responses.append((resp, score))
+
+ # Sort by score (descending) and pick best
+ candidates_responses.sort(key=lambda x: x[1], reverse=True)
+ answer_t = candidates_responses[0][0]
+ best_score = candidates_responses[0][1]
+
+ if len(candidates_responses) > 1:
+ print(f"[BEST-OF-{self.best_of_n}] Scores: {[f'{s:.2f}' for _, s in candidates_responses]}, picked score={best_score:.2f}")
+ else:
+ answer_t = self._chat_model.answer(
+ history=session.history,
+ memory_notes=memory_notes,
+ max_new_tokens=self._rl_cfg["max_new_tokens"],
+ )
+
completion_tokens = self._count_tokens(answer_t)
# Add assistant turn to history
@@ -612,7 +909,263 @@ class PersonalizedLLM:
usage=usage,
debug=debug,
)
-
+
+ def chat_prepare(self, user_id: str, query: str) -> dict:
+ """
+ Prepare for chat without calling the LLM.
+
+ This does all the preparation work (embedding, memory retrieval, etc.)
+ and returns the messages to send to the LLM along with context needed
+ for post-processing.
+
+ Used for batch processing where messages are collected first, then
+ sent in batch to vLLM for concurrent processing.
+
+ Args:
+ user_id: Unique identifier for the user.
+ query: Current user query/message.
+
+ Returns:
+ Dict containing:
+ - messages: List of messages to send to LLM
+ - context: Dict with all state needed for chat_complete()
+ """
+ ctx = self._get_or_create_session(user_id)
+ session = ctx.session_state
+ user_state = self._user_store.get_state(user_id)
+
+ # Record user vector before for debug
+ z_long_before = user_state.z_long.copy().tolist()
+ z_short_before = user_state.z_short.copy().tolist()
+
+ # Add user turn to history
+ user_turn = self._build_chat_turn(user_id, query, "user", ctx.turn_counter)
+ session.history.append(user_turn)
+
+ # Vanilla mode: pure LLM without any memory or preference extraction
+ if self.mode == "vanilla":
+ e_q_t = np.zeros(4096, dtype=np.float32)
+ extracted_prefs = []
+ candidates = []
+ cand_item_vecs = np.array([])
+ base_scores = np.array([])
+ chosen_indices = []
+ probs = np.array([])
+ memories_t = []
+ memory_notes = []
+ else:
+ # Compute query embedding
+ embed_result = self._embed_model.encode([query], normalize=True, return_tensor=False)
+ if embed_result is None or len(embed_result) == 0:
+ raise RuntimeError(f"Embedding model returned empty result for query: {query[:100]}")
+ e_q_t = np.array(embed_result[0])
+
+ # Store pending RL update info from last turn
+ if session.last_query is not None and self.enable_rl_updates:
+ ctx.pending_rl_update = {
+ "last_query": session.last_query,
+ "last_answer": session.last_answer,
+ "last_memories": session.last_memories,
+ "last_query_embedding": session.last_query_embedding,
+ "current_query_embedding": e_q_t,
+ "last_candidate_item_vectors": session.last_candidate_item_vectors,
+ "last_policy_probs": session.last_policy_probs,
+ "last_chosen_indices": session.last_chosen_indices,
+ }
+
+ # Auto-compute reward via LLM judge if enabled
+ if self.reward_mode == "llm" and self._llm_reward_client is not None:
+ import asyncio
+ try:
+ reward, gating = asyncio.run(eval_step_llm(
+ q_t=session.last_query,
+ answer_t=session.last_answer,
+ q_t1=query,
+ memories_t=session.last_memories or [],
+ client=self._llm_reward_client,
+ ))
+ if gating > 0.0:
+ self.apply_feedback(Feedback(
+ user_id=user_id,
+ turn_id=ctx.turn_counter - 1,
+ reward=reward,
+ gating=gating,
+ ))
+ except Exception as e:
+ print(f"[LLM-Reward] Judge call failed, skipping update: {e}")
+
+ # Extract preferences from conversation
+ extracted_prefs = []
+ if self.enable_preference_extraction:
+ prefs = self._extractor.extract_turn(session.history)
+ if prefs.preferences:
+ print(f"[DEBUG] Extracted {len(prefs.preferences)} prefs from history (len={len(session.history)})")
+ extracted_prefs = self._add_preferences_as_memory(
+ prefs, query, user_id, ctx.turn_counter
+ )
+ if extracted_prefs:
+ print(f"[DEBUG] Added {len(extracted_prefs)} to memory. Total cards: {len(self._memory_cards)}")
+
+ # Retrieve memories
+ if self.mode == "nopersonal":
+ candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_no_policy(
+ user_id=user_id,
+ query=query,
+ embed_model=self._embed_model,
+ reranker=self._reranker,
+ memory_cards=self._memory_cards,
+ memory_embeddings=self._memory_embeddings,
+ topk_dense=self._rl_cfg["dense_topk"],
+ topk_rerank=self._rl_cfg["rerank_topk"],
+ only_own_memories=self.only_own_memories,
+ )
+ else:
+ beta_long = self._rl_cfg["beta_long"]
+ beta_short = self._rl_cfg["beta_short"]
+ candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_with_policy(
+ user_id=user_id,
+ query=query,
+ embed_model=self._embed_model,
+ reranker=self._reranker,
+ memory_cards=self._memory_cards,
+ memory_embeddings=self._memory_embeddings,
+ user_store=self._user_store,
+ item_vectors=self._item_vectors,
+ topk_dense=self._rl_cfg["dense_topk"],
+ topk_rerank=self._rl_cfg["rerank_topk"],
+ beta_long=beta_long,
+ beta_short=beta_short,
+ tau=self._rl_cfg["tau"],
+ only_own_memories=self.only_own_memories,
+ sample=not self.eval_mode,
+ )
+
+ memories_t = [candidates[int(i)] for i in chosen_indices] if chosen_indices else []
+ memory_notes = [m.note_text for m in memories_t]
+
+ if memories_t:
+ print(f"[DEBUG-RETRIEVAL] User={user_id}, Query={query[:50]}...")
+ print(f"[DEBUG-RETRIEVAL] Candidates={len(candidates)}, Selected={len(memories_t)}")
+ for i, m in enumerate(memories_t[:3]):
+ score = probs[chosen_indices[i]] if i < len(chosen_indices) and chosen_indices[i] < len(probs) else 0
+ print(f"[DEBUG-RETRIEVAL] [{i+1}] score={score:.3f}: {m.note_text[:80]}...")
+
+ # Build prompt token count
+ prompt_tokens = self._count_tokens(query)
+ for turn in session.history:
+ prompt_tokens += self._count_tokens(turn.text)
+ for note in memory_notes:
+ prompt_tokens += self._count_tokens(note)
+
+ # Build messages for LLM
+ messages = self._chat_model.build_messages(
+ history=session.history,
+ memory_notes=memory_notes,
+ max_new_tokens=self._rl_cfg["max_new_tokens"],
+ )
+
+ # Return messages and context for chat_complete
+ return {
+ "messages": messages,
+ "context": {
+ "user_id": user_id,
+ "query": query,
+ "ctx": ctx,
+ "session": session,
+ "user_state": user_state,
+ "z_long_before": z_long_before,
+ "z_short_before": z_short_before,
+ "e_q_t": e_q_t,
+ "extracted_prefs": extracted_prefs,
+ "candidates": candidates,
+ "cand_item_vecs": cand_item_vecs,
+ "chosen_indices": chosen_indices,
+ "probs": probs,
+ "memories_t": memories_t,
+ "memory_notes": memory_notes,
+ "prompt_tokens": prompt_tokens,
+ }
+ }
+
+ def chat_complete(self, answer_t: str, context: dict) -> AssistantResponse:
+ """
+ Complete chat with LLM response.
+
+ This takes the LLM response and context from chat_prepare(), and
+ does all post-processing (add to history, debug info, etc.).
+
+ Args:
+ answer_t: The LLM response text.
+ context: Context dict from chat_prepare().
+
+ Returns:
+ AssistantResponse containing the answer, usage stats, and debug info.
+ """
+ # Unpack context
+ user_id = context["user_id"]
+ query = context["query"]
+ ctx = context["ctx"]
+ session = context["session"]
+ user_state = context["user_state"]
+ z_long_before = context["z_long_before"]
+ z_short_before = context["z_short_before"]
+ e_q_t = context["e_q_t"]
+ extracted_prefs = context["extracted_prefs"]
+ candidates = context["candidates"]
+ cand_item_vecs = context["cand_item_vecs"]
+ chosen_indices = context["chosen_indices"]
+ probs = context["probs"]
+ memories_t = context["memories_t"]
+ memory_notes = context["memory_notes"]
+ prompt_tokens = context["prompt_tokens"]
+
+ completion_tokens = self._count_tokens(answer_t)
+
+ # Add assistant turn to history
+ assist_turn = self._build_chat_turn(user_id, answer_t, "assistant", ctx.turn_counter)
+ session.history.append(assist_turn)
+
+ # Update session state for next turn
+ session.last_query = query
+ session.last_answer = answer_t
+ session.last_memories = memories_t
+ session.last_query_embedding = e_q_t
+ session.last_candidate_item_vectors = cand_item_vecs
+ session.last_policy_probs = probs
+ session.last_chosen_indices = list(chosen_indices) if len(chosen_indices) > 0 else []
+
+ ctx.turn_counter += 1
+
+ # Build debug info
+ debug = DebugInfo(
+ selected_memory_ids=[m.card_id for m in memories_t],
+ selected_memory_notes=[m.note_text for m in memories_t],
+ selected_memory_scores=[float(probs[i]) if i < len(probs) else 0.0 for i in chosen_indices] if len(chosen_indices) > 0 else [],
+ user_vector_before=z_long_before + z_short_before,
+ user_vector_after=user_state.z_long.tolist() + user_state.z_short.tolist(),
+ extracted_preferences=extracted_prefs,
+ extra={
+ "num_candidates": len(candidates),
+ "num_total_memories": len(self._memory_cards),
+ "z_long_norm": float(np.linalg.norm(user_state.z_long)),
+ "z_short_norm": float(np.linalg.norm(user_state.z_short)),
+ }
+ )
+
+ # Build usage stats
+ usage = UsageStats(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=prompt_tokens + completion_tokens,
+ model=self._llm_name,
+ )
+
+ return AssistantResponse(
+ answer=answer_t,
+ usage=usage,
+ debug=debug,
+ )
+
def reset_session(self, user_id: str) -> None:
"""
Reset session for a user (new chat window).