summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-03-18 18:25:09 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-03-18 18:25:09 -0500
commitb6c3e4e51eeab703b40284459c6e9fff2151216c (patch)
tree221410886f23214575f93b9ef44fa8431c9a6dfc
Initial release: VARS - personalized LLM with RAG and user vector learning
-rw-r--r--configs/base.yaml0
-rw-r--r--configs/local_models.yaml66
-rw-r--r--configs/reranker.yaml3
-rw-r--r--configs/retrieval.yaml5
-rw-r--r--configs/user_model.yaml14
-rw-r--r--pyproject.toml38
-rw-r--r--requirements.txt9
-rw-r--r--src/personalization/__init__.py0
-rw-r--r--src/personalization/config/__init__.py0
-rw-r--r--src/personalization/config/registry.py146
-rw-r--r--src/personalization/config/settings.py75
-rw-r--r--src/personalization/feedback/__init__.py0
-rw-r--r--src/personalization/feedback/gating.py72
-rw-r--r--src/personalization/feedback/handlers.py87
-rw-r--r--src/personalization/feedback/llm_reward.py253
-rw-r--r--src/personalization/feedback/local_llm_reward.py370
-rw-r--r--src/personalization/feedback/online_update.py0
-rw-r--r--src/personalization/feedback/reward_model.py64
-rw-r--r--src/personalization/feedback/sampler.py109
-rw-r--r--src/personalization/feedback/schemas.py23
-rw-r--r--src/personalization/models/__init__.py0
-rw-r--r--src/personalization/models/embedding/__init__.py11
-rw-r--r--src/personalization/models/embedding/base.py37
-rw-r--r--src/personalization/models/embedding/qwen3_8b.py89
-rw-r--r--src/personalization/models/llm/__init__.py4
-rw-r--r--src/personalization/models/llm/base.py29
-rw-r--r--src/personalization/models/llm/prompt_builder.py0
-rw-r--r--src/personalization/models/llm/vllm_chat.py244
-rw-r--r--src/personalization/models/preference_extractor/__init__.py5
-rw-r--r--src/personalization/models/preference_extractor/base.py17
-rw-r--r--src/personalization/models/preference_extractor/gpt4o_extractor.py165
-rw-r--r--src/personalization/models/preference_extractor/llm_extractor.py153
-rw-r--r--src/personalization/models/preference_extractor/rule_extractor.py205
-rw-r--r--src/personalization/models/reranker/__init__.py0
-rw-r--r--src/personalization/models/reranker/base.py16
-rw-r--r--src/personalization/models/reranker/qwen3_reranker.py96
-rw-r--r--src/personalization/retrieval/__init__.py0
-rw-r--r--src/personalization/retrieval/chunking/__init__.py0
-rw-r--r--src/personalization/retrieval/chunking/rules.py0
-rw-r--r--src/personalization/retrieval/pipeline.py388
-rw-r--r--src/personalization/retrieval/preference_store/__init__.py0
-rw-r--r--src/personalization/retrieval/preference_store/base.py0
-rw-r--r--src/personalization/retrieval/preference_store/schemas.py48
-rw-r--r--src/personalization/retrieval/preference_store/vector_kv.py0
-rw-r--r--src/personalization/retrieval/rerank.py0
-rw-r--r--src/personalization/retrieval/store/__init__.py0
-rw-r--r--src/personalization/retrieval/store/base.py0
-rw-r--r--src/personalization/retrieval/store/faiss_store.py0
-rw-r--r--src/personalization/serving/__init__.py22
-rw-r--r--src/personalization/serving/personalized_llm.py1835
-rw-r--r--src/personalization/types.py4
-rw-r--r--src/personalization/user_model/__init__.py0
-rw-r--r--src/personalization/user_model/features.py49
-rw-r--r--src/personalization/user_model/policy/__init__.py0
-rw-r--r--src/personalization/user_model/policy/optimizer.py0
-rw-r--r--src/personalization/user_model/policy/reinforce.py104
-rw-r--r--src/personalization/user_model/scoring.py25
-rw-r--r--src/personalization/user_model/session_state.py19
-rw-r--r--src/personalization/user_model/tensor_store.py80
-rw-r--r--src/personalization/utils/__init__.py0
-rw-r--r--src/personalization/utils/ids.py0
-rw-r--r--src/personalization/utils/io.py0
-rw-r--r--src/personalization/utils/logging.py0
-rw-r--r--src/personalization/utils/timing.py0
64 files changed, 4979 insertions, 0 deletions
diff --git a/configs/base.yaml b/configs/base.yaml
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/configs/base.yaml
diff --git a/configs/local_models.yaml b/configs/local_models.yaml
new file mode 100644
index 0000000..ea001ea
--- /dev/null
+++ b/configs/local_models.yaml
@@ -0,0 +1,66 @@
+# Base path for all models
+_base_path: &base ./
+
+models:
+ llm:
+ # New Multi-Backend Config
+ qwen_1_5b:
+ backend: qwen
+ path: .//models/qwen2.5-1.5b-instruct
+ device: auto
+ dtype: bfloat16
+ max_context_length: 4096
+
+ llama_8b:
+ backend: llama
+ path: .//models/llama-3.1-8b-instruct
+ device: auto
+ dtype: bfloat16
+ max_context_length: 8192
+
+ # vLLM backend for high-throughput experiments
+ llama_8b_vllm:
+ backend: vllm
+ path: .//models/llama-3.1-8b-instruct
+ vllm_url: http://localhost:8003/v1
+ model_name: meta-llama/Llama-3.1-8B-Instruct
+ max_context_length: 8192
+
+ # Legacy fallback (needed if from_config is called directly without name)
+ hf_id: Qwen/Qwen2.5-1.5B-Instruct
+ local_path: .//models/qwen2.5-1.5b-instruct
+ dtype: bfloat16
+ device_map: auto
+
+ preference_extractor:
+ # Default/Legacy
+ default:
+ hf_id: Qwen/Qwen2.5-0.5B-Instruct
+ local_path: .//models/qwen2.5-0.5b-instruct
+ dtype: bfloat16
+ device_map: auto
+ # New SFT Extractor
+ qwen3_0_6b_sft:
+ path: .//models/pref-extractor-qwen3-0.6b-sft
+ prompt_template_path: fine_tuning_prompt_template.txt
+ device: auto
+ dtype: bfloat16
+ max_new_tokens: 512
+ embedding:
+ qwen3:
+ hf_id: Qwen/Qwen3-Embedding-8B
+ local_path: .//models/qwen3-embedding-8b
+ nemotron:
+ hf_id: nvidia/llama-embed-nemotron-8b
+ local_path: .//models/llama-embed-nemotron-8b
+ reranker:
+ qwen3_8b:
+ hf_id: Qwen/Qwen3-Reranker-8B
+ local_path: .//models/rerankers/qwen3-reranker-8b
+ dtype: bfloat16
+ device_map: auto
+ bge_base:
+ hf_id: BAAI/bge-reranker-base
+ local_path: .//models/rerankers/bge-reranker-base
+ dtype: float16
+ device_map: auto
diff --git a/configs/reranker.yaml b/configs/reranker.yaml
new file mode 100644
index 0000000..c376fc7
--- /dev/null
+++ b/configs/reranker.yaml
@@ -0,0 +1,3 @@
+reranker:
+ default: qwen3_8b
+
diff --git a/configs/retrieval.yaml b/configs/retrieval.yaml
new file mode 100644
index 0000000..d2e100e
--- /dev/null
+++ b/configs/retrieval.yaml
@@ -0,0 +1,5 @@
+retrieval:
+ dense_topk: 64 # Initial recall count
+ rerank_topk: 8 # Count fed to LLM after rerank
+ pca_dim: 256
+
diff --git a/configs/user_model.yaml b/configs/user_model.yaml
new file mode 100644
index 0000000..7b8e230
--- /dev/null
+++ b/configs/user_model.yaml
@@ -0,0 +1,14 @@
+user_model:
+ item_dim: 256
+ user_dim: 256
+ beta_long: 0.1 # Enable personalization for Day 4
+ beta_short: 0.3
+ tau: 1.0
+ preference_extractor_name: qwen3_0_6b_sft # Switch to new extractor
+ rl:
+ eta_long: 1.0e-3
+ eta_short: 5.0e-3
+ ema_alpha: 0.05
+ short_decay: 0.1
+
+llm_name: llama_8b # Switch backend to Llama 3.1
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..a76f9bb
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,38 @@
+[
+tool.black
+]
+
+[build-system]
+requires = ["hatchling>=1.18"]
+build-backend = "hatchling.build"
+
+[project]
+name = "personalization-user-model"
+version = "0.1.0"
+description = "Personalized memory RAG system with online user modeling"
+readme = "README.md"
+requires-python = ">=3.10"
+license = { text = "Apache-2.0" }
+authors = [
+ { name = "Anonymous" }
+]
+dependencies = [
+ "torch>=2.3.0",
+ "transformers>=4.44.0",
+ "accelerate>=0.33.0",
+ "huggingface_hub>=0.24.0",
+ "pydantic>=2.7.0",
+ "pyyaml>=6.0.0",
+ "safetensors>=0.4.2"
+]
+
+[project.urls]
+homepage = "https://example.com"
+
+[tool.hatch.build.targets.wheel]
+packages = ["src/personalization"]
+
+[tool.hatch.metadata]
+allow-direct-references = true
+
+
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..1e227de
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,9 @@
+torch>=2.3.0
+transformers>=4.44.0
+accelerate>=0.33.0
+huggingface_hub>=0.24.0
+pydantic>=2.7.0
+PyYAML>=6.0.0
+safetensors>=0.4.2
+
+
diff --git a/src/personalization/__init__.py b/src/personalization/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/__init__.py
diff --git a/src/personalization/config/__init__.py b/src/personalization/config/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/config/__init__.py
diff --git a/src/personalization/config/registry.py b/src/personalization/config/registry.py
new file mode 100644
index 0000000..6badeae
--- /dev/null
+++ b/src/personalization/config/registry.py
@@ -0,0 +1,146 @@
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Any, Dict, Optional
+import torch
+import yaml
+
+from personalization.config import settings
+
+# Project root for resolving config paths
+_PROJECT_ROOT = Path(__file__).parent.parent.parent.parent
+
+# Avoid circular imports by NOT importing extractors here at top level
+# from personalization.models.preference_extractor.base import PreferenceExtractorBase
+# from personalization.models.preference_extractor.rule_extractor import QwenRuleExtractor
+# from personalization.models.preference_extractor.gpt4o_extractor import GPT4OExtractor
+# from personalization.models.preference_extractor.llm_extractor import PreferenceExtractorLLM
+
+_DTYPE_MAP: Dict[str, torch.dtype] = {
+ "bfloat16": torch.bfloat16,
+ "float16": torch.float16,
+ "float32": torch.float32,
+}
+
+def choose_dtype(preferred: Optional[str] = None) -> torch.dtype:
+ if preferred and preferred.lower() in _DTYPE_MAP:
+ dt = _DTYPE_MAP[preferred.lower()]
+ else:
+ dt = torch.bfloat16 if torch.cuda.is_available() else torch.float32
+ if dt is torch.bfloat16 and not torch.cuda.is_available():
+ return torch.float32
+ return dt
+
+def choose_device_map(spec: Optional[str] = "auto") -> Any:
+ return spec or "auto"
+
+def ensure_local_path(path_str: str) -> str:
+ path = Path(path_str)
+ if not path.exists():
+ path.mkdir(parents=True, exist_ok=True)
+ return str(path)
+
+# --- Chat Model Factory ---
+def get_chat_model(name: str, device_override: Optional[str] = None):
+ """
+ Get a chat model by name.
+
+ Args:
+ name: Model name (e.g., "qwen_1_5b", "llama_8b")
+ device_override: Optional device override (e.g., "cuda:2"). If None, uses config default.
+ """
+ from personalization.models.llm.base import ChatModel
+ from personalization.models.llm.qwen_instruct import QwenInstruct
+ from personalization.models.llm.llama_instruct import LlamaChatModel
+ from personalization.models.llm.vllm_chat import VLLMChatModel
+
+ cfg = settings.load_local_models_config()
+
+ # Try to load raw config to support multi-backend map
+ with open(_PROJECT_ROOT / "configs/local_models.yaml", "r") as f:
+ raw_cfg = yaml.safe_load(f)
+
+ models = raw_cfg.get("models", {}).get("llm", {})
+
+ # If models['llm'] is a dict of configs (new style)
+ if isinstance(models, dict) and "backend" in models.get(name, {}):
+ spec = models[name]
+ backend = spec.get("backend", "qwen")
+ path = spec["path"]
+ device = device_override or spec.get("device", "cuda") # Use override if provided
+ dtype = spec.get("dtype", "bfloat16")
+ max_len = spec.get("max_context_length", 4096)
+
+ if backend == "qwen":
+ return QwenInstruct(
+ model_path=path,
+ device=device,
+ dtype=choose_dtype(dtype), # Converts string to torch.dtype
+ max_context_length=max_len
+ )
+ elif backend == "llama":
+ return LlamaChatModel(
+ model_path=path,
+ device=device,
+ dtype=choose_dtype(dtype), # Converts string to torch.dtype
+ max_context_length=max_len
+ )
+ elif backend == "vllm":
+ # Use vLLM HTTP API for high-throughput inference
+ vllm_url = spec.get("vllm_url", "http://localhost:8003/v1")
+ return VLLMChatModel(
+ vllm_url=vllm_url,
+ model_name=spec.get("model_name"),
+ max_context_length=max_len
+ )
+
+ # Fallback to legacy single config
+ return QwenInstruct.from_config(cfg)
+
+def get_preference_extractor(name: Optional[str] = None):
+ # Deferred imports to break circular dependency
+ from personalization.models.preference_extractor.rule_extractor import QwenRuleExtractor
+ from personalization.models.preference_extractor.gpt4o_extractor import GPT4OExtractor
+ from personalization.models.preference_extractor.llm_extractor import PreferenceExtractorLLM
+
+ cfg = settings.load_local_models_config()
+ pref_cfg = cfg.preference_extractor
+
+ if name is None:
+ if isinstance(pref_cfg, dict) and "qwen3_0_6b_sft" in pref_cfg:
+ name = "qwen3_0_6b_sft"
+ else:
+ name = "rule"
+
+ if isinstance(pref_cfg, dict) and name in pref_cfg:
+ spec = pref_cfg[name]
+ if name == "qwen3_0_6b_sft":
+ # Use QwenRuleExtractor which we have updated for SFT End-to-End logic
+ return QwenRuleExtractor(
+ model_path=spec["path"],
+ device_map=spec.get("device", "auto"),
+ dtype=choose_dtype(spec.get("dtype", "bfloat16")),
+ )
+ # Add 'default' handling if mapped to rule/gpt
+ if name == "default":
+ pass
+
+ if name == "gpt4o":
+ return GPT4OExtractor.from_config(cfg)
+ elif name == "gpt5_mini":
+ import os
+ return GPT4OExtractor(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-5-mini")
+ elif name == "rule":
+ if isinstance(pref_cfg, dict):
+ if "default" in pref_cfg:
+ # Manually construct to bypass ModelSpec mismatch if needed
+ spec_dict = pref_cfg["default"]
+ return QwenRuleExtractor(
+ model_path=spec_dict["local_path"],
+ dtype=choose_dtype(spec_dict.get("dtype")),
+ device_map=choose_device_map(spec_dict.get("device_map"))
+ )
+ else:
+ return QwenRuleExtractor.from_config(cfg)
+
+ raise ValueError(f"Could not load preference extractor: {name}")
diff --git a/src/personalization/config/settings.py b/src/personalization/config/settings.py
new file mode 100644
index 0000000..8f0cc8a
--- /dev/null
+++ b/src/personalization/config/settings.py
@@ -0,0 +1,75 @@
+from __future__ import annotations
+
+import os
+from pathlib import Path
+from typing import Optional, Any, Dict
+
+import yaml
+from pydantic import BaseModel, Field
+
+
+class ModelSpec(BaseModel):
+ hf_id: str = Field(..., description="Hugging Face repository id")
+ local_path: str = Field(..., description="Local directory for model weights")
+ dtype: Optional[str] = Field(
+ default="bfloat16", description="Preferred torch dtype: bfloat16|float16|float32"
+ )
+ device_map: Optional[str] = Field(default="auto", description="Device map policy")
+
+
+class EmbeddingModelsConfig(BaseModel):
+ qwen3: Optional[ModelSpec] = None
+ nemotron: Optional[ModelSpec] = None
+
+
+class RerankerModelsConfig(BaseModel):
+ qwen3_8b: Optional[ModelSpec] = None
+
+
+class LocalModelsConfig(BaseModel):
+ llm: ModelSpec
+ preference_extractor: Any # Allow flexible dict or ModelSpec for now to support map
+ embedding: Optional[EmbeddingModelsConfig] = None
+ reranker: Optional[RerankerModelsConfig] = None
+
+
+def _resolve_config_path(env_key: str, default_rel: str) -> Path:
+ value = os.getenv(env_key)
+ if value:
+ return Path(value).expanduser().resolve()
+ # Use project root (parent of src/personalization/config) instead of cwd
+ project_root = Path(__file__).parent.parent.parent.parent
+ return (project_root / default_rel).resolve()
+
+
+def load_local_models_config(path: Optional[str] = None) -> LocalModelsConfig:
+ config_path = Path(path) if path else _resolve_config_path(
+ "LOCAL_MODELS_CONFIG", "configs/local_models.yaml"
+ )
+ with open(config_path, "r", encoding="utf-8") as f:
+ raw = yaml.safe_load(f) or {}
+ models = raw.get("models", {})
+ embedding_cfg = None
+ if "embedding" in models:
+ emb = models["embedding"] or {}
+ # dtype/device_map are not necessary for embedders; ModelSpec still accepts them
+ embedding_cfg = EmbeddingModelsConfig(
+ qwen3=ModelSpec(**emb["qwen3"]) if "qwen3" in emb else None,
+ nemotron=ModelSpec(**emb["nemotron"]) if "nemotron" in emb else None,
+ )
+
+ reranker_cfg = None
+ if "reranker" in models:
+ rer = models["reranker"] or {}
+ reranker_cfg = RerankerModelsConfig(
+ qwen3_8b=ModelSpec(**rer["qwen3_8b"]) if "qwen3_8b" in rer else None
+ )
+
+ return LocalModelsConfig(
+ llm=ModelSpec(**models["llm"]),
+ preference_extractor=models["preference_extractor"], # Pass raw dict/value
+ embedding=embedding_cfg,
+ reranker=reranker_cfg,
+ )
+
+
diff --git a/src/personalization/feedback/__init__.py b/src/personalization/feedback/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/feedback/__init__.py
diff --git a/src/personalization/feedback/gating.py b/src/personalization/feedback/gating.py
new file mode 100644
index 0000000..d741874
--- /dev/null
+++ b/src/personalization/feedback/gating.py
@@ -0,0 +1,72 @@
+import numpy as np
+from personalization.feedback.schemas import TurnSample
+
+def cosine_sim_batch(matrix: np.ndarray, vector: np.ndarray) -> np.ndarray:
+ # matrix: [N, d], vector: [d]
+ # return: [N]
+ norm_m = np.linalg.norm(matrix, axis=1)
+ norm_v = np.linalg.norm(vector)
+
+ # Avoid div by zero
+ den = norm_m * norm_v
+ den[den == 0] = 1e-9
+
+ return np.dot(matrix, vector) / den
+
+def estimate_retrieval_gating(sample: TurnSample, reward_hat: float) -> float:
+ """
+ Return g_t in [0,1], representing how much the reward is due to retrieval.
+ """
+ e_q = sample.query_embedding_t
+ e_q1 = sample.query_embedding_t1
+
+ if e_q is None or e_q1 is None or not sample.memories:
+ return 0.5 # Neutral
+
+ # We need embeddings of the memories.
+ # In a real pipeline, we might pass them in sample.memory_embeddings.
+ # If missing, we can't compute sim.
+ if sample.memory_embeddings is None:
+ # Try to use embedding_e from memory cards if available
+ # But MemoryCard.embedding_e is List[float]
+ try:
+ mem_embs = np.array([m.embedding_e for m in sample.memories])
+ if mem_embs.shape[1] == 0: # Empty embeddings
+ return 0.5
+ except:
+ return 0.5
+ else:
+ mem_embs = sample.memory_embeddings
+
+ # Compute similarities
+ # shape: [K]
+ sims_q = cosine_sim_batch(mem_embs, e_q)
+ sims_q1 = cosine_sim_batch(mem_embs, e_q1)
+
+ s_q_max = sims_q.max() if len(sims_q) > 0 else 0
+ s_q1_max = sims_q1.max() if len(sims_q1) > 0 else 0
+
+ g = 0.5
+
+ # Heuristics
+
+ # Case A: Retrieval clearly irrelevant + bad reward
+ # q_t / q_{t+1} have low similarity to memories -> likely retrieval failure (or no relevant memories)
+ if reward_hat < -0.5 and s_q_max < 0.2 and s_q1_max < 0.2:
+ g = 0.9 # Blame retrieval (for failing to find anything, or nothing exists)
+
+ # Case B: Retrieval looks good but reward is bad
+ # Memories are relevant to query, but user still unhappy -> LLM didn't use them well?
+ elif reward_hat < -0.5 and s_q_max > 0.5:
+ g = 0.2 # Likely LLM fault
+
+ # Case C: Good reward
+ # If reward is high, we assume both did okay.
+ elif reward_hat > 0.5:
+ if s_q_max > 0.4:
+ g = 0.6 # Retrieval helped
+ else:
+ g = 0.3 # LLM handled it without strong retrieval help
+
+ return float(g)
+
diff --git a/src/personalization/feedback/handlers.py b/src/personalization/feedback/handlers.py
new file mode 100644
index 0000000..f0468b6
--- /dev/null
+++ b/src/personalization/feedback/handlers.py
@@ -0,0 +1,87 @@
+from typing import Tuple, List, Optional
+import numpy as np
+
+from personalization.retrieval.preference_store.schemas import MemoryCard
+from personalization.feedback.schemas import TurnSample
+from personalization.feedback.reward_model import estimate_reward
+from personalization.feedback.gating import estimate_retrieval_gating
+from personalization.feedback.llm_reward import (
+ LLMRewardClient, LLMRewardConfig, RewardResult
+)
+
+
+def eval_step(
+ q_t: str,
+ answer_t: str,
+ q_t1: str,
+ memories_t: List[MemoryCard],
+ query_embedding_t: Optional[np.ndarray] = None,
+ query_embedding_t1: Optional[np.ndarray] = None,
+) -> Tuple[float, float]:
+ """
+ Keyword-based evaluation (legacy).
+ Given (q_t, a_t, q_{t+1}, memories), returns (reward_hat, gating_hat).
+ """
+ mem_embs = None
+ if memories_t and memories_t[0].embedding_e:
+ try:
+ mem_embs = np.array([m.embedding_e for m in memories_t])
+ except:
+ pass
+
+ sample = TurnSample(
+ user_id="",
+ session_id="",
+ turn_id=0,
+ query_t=q_t,
+ answer_t=answer_t,
+ query_t1=q_t1,
+ memories=memories_t,
+ query_embedding_t=query_embedding_t,
+ query_embedding_t1=query_embedding_t1,
+ memory_embeddings=mem_embs,
+ )
+
+ r_hat = estimate_reward(sample)
+ g_hat = estimate_retrieval_gating(sample, r_hat)
+
+ return r_hat, g_hat
+
+
+async def eval_step_llm(
+ q_t: str,
+ answer_t: str,
+ q_t1: str,
+ memories_t: List[MemoryCard],
+ client: LLMRewardClient,
+ query_embedding_t: Optional[np.ndarray] = None,
+ query_embedding_t1: Optional[np.ndarray] = None,
+) -> Tuple[float, float]:
+ """
+ LLM-as-judge evaluation (async).
+ Returns (reward, gating) where gating=0.0 if update should be skipped.
+
+ The gating signal is derived from the judge's confidence and label:
+ - If confidence < tau_c or label == topic_shift: gating = 0.0
+ - Otherwise: gating = confidence (continuous, in [tau_c, 1.0])
+
+ This replaces the old heuristic gating with the judge's own confidence.
+ """
+ sample = TurnSample(
+ user_id="",
+ session_id="",
+ turn_id=0,
+ query_t=q_t,
+ answer_t=answer_t,
+ query_t1=q_t1,
+ memories=memories_t,
+ query_embedding_t=query_embedding_t,
+ query_embedding_t1=query_embedding_t1,
+ )
+
+ result: RewardResult = await client.judge(sample)
+
+ if result.should_update:
+ return result.reward, result.confidence
+ else:
+ return 0.0, 0.0
diff --git a/src/personalization/feedback/llm_reward.py b/src/personalization/feedback/llm_reward.py
new file mode 100644
index 0000000..6adcf98
--- /dev/null
+++ b/src/personalization/feedback/llm_reward.py
@@ -0,0 +1,253 @@
+"""
+LLM-as-Judge reward model using OpenAI GPT-5-nano (async for parallelism).
+
+Replaces keyword-based heuristic reward with structured LLM judgement.
+Judge receives only (q_t, a_t, q_{t+1}) — no oracle preference cards, no history.
+
+Label taxonomy → scalar reward mapping:
+ neg_constraint_restate → -1.0
+ neg_correction → -0.8
+ neg_confusion → -0.6
+ pos_praise → +0.8
+ pos_progress → +0.1
+ neutral → 0.0
+ topic_shift → 0.0 (update skipped)
+
+Confidence gating: if confidence < tau_c, reward is set to 0 and update is skipped.
+"""
+from __future__ import annotations
+
+import asyncio
+import hashlib
+import json
+import os
+from dataclasses import dataclass, field
+from typing import Dict, List, Optional, Tuple
+
+from openai import AsyncOpenAI, RateLimitError, APITimeoutError, APIConnectionError
+
+from personalization.feedback.schemas import TurnSample
+
+
+# --- Label → Reward Mapping ---
+
+REWARD_MAP: Dict[str, float] = {
+ "neg_constraint_restate": -1.0,
+ "neg_correction": -0.8,
+ "neg_confusion": -0.6,
+ "pos_praise": +0.8,
+ "pos_progress": +0.1,
+ "neutral": 0.0,
+ "topic_shift": 0.0,
+}
+
+VALID_LABELS = set(REWARD_MAP.keys())
+
+
+# --- Configuration ---
+
+@dataclass
+class LLMRewardConfig:
+ model: str = "gpt-5-nano"
+ api_key: Optional[str] = None # Falls back to OPENAI_API_KEY env var
+ base_url: Optional[str] = None # For custom endpoints
+ max_concurrent: int = 32 # Semaphore limit for parallel requests
+ max_retries: int = 3
+ retry_base_delay: float = 1.0 # Exponential backoff base (seconds)
+ timeout: float = 60.0 # Per-request timeout (reasoning models are slower)
+ max_completion_tokens: int = 2048 # Must be high — reasoning models use internal tokens
+ confidence_threshold: float = 0.6 # tau_c: skip update if confidence < this
+ enable_cache: bool = True # Cache by hash of (q_t, a_t, q_{t+1})
+
+
+# --- Prompt ---
+
+JUDGE_SYSTEM_PROMPT = """\
+You are a feedback classifier. Given a user query (q_t), the assistant's response (a_t), \
+and the user's next message (q_{t+1}), classify the user's follow-up into exactly one label.
+
+Labels (mutually exclusive):
+- neg_constraint_restate: User reasserts constraints/preferences as correction (e.g., "as I said…", "remember…", "按我说的…").
+- neg_correction: User indicates the content is wrong or the assistant failed to answer.
+- neg_confusion: User indicates confusion or requests re-explanation.
+- pos_praise: Explicit praise or satisfaction with the response.
+- pos_progress: Constructive continuation (examples, extensions, what-if, next steps) without complaint.
+- neutral: Ambiguous or minimal feedback, not clearly positive or negative.
+- topic_shift: User switches to a new unrelated task/topic.
+
+Output a JSON object with fields: label, confidence (0-1), rationale (one short sentence)."""
+
+JUDGE_USER_TEMPLATE = """\
+q_t: {query_t}
+
+a_t: {answer_t}
+
+q_{{t+1}}: {query_t1}"""
+
+
+# --- Result Dataclass ---
+
+@dataclass
+class RewardResult:
+ label: str
+ confidence: float
+ rationale: str
+ reward: float
+ should_update: bool # False if gated by confidence or topic_shift
+
+
+# --- Async Client ---
+
+class LLMRewardClient:
+ """Async OpenAI client for LLM-as-judge reward estimation."""
+
+ def __init__(self, config: Optional[LLMRewardConfig] = None):
+ self.config = config or LLMRewardConfig()
+ self._client = AsyncOpenAI(
+ api_key=self.config.api_key or os.getenv("OPENAI_API_KEY"),
+ base_url=self.config.base_url,
+ timeout=self.config.timeout,
+ )
+ self._semaphore = asyncio.Semaphore(self.config.max_concurrent)
+ self._cache: Dict[str, RewardResult] = {}
+
+ def _cache_key(self, query_t: str, answer_t: str, query_t1: str) -> str:
+ """Deterministic hash of the judge input triple."""
+ content = f"{query_t}\x00{answer_t}\x00{query_t1}"
+ return hashlib.sha256(content.encode("utf-8")).hexdigest()
+
+ async def _call_with_retry(self, messages: List[dict]) -> str:
+ """Single LLM call with exponential backoff retry."""
+ for attempt in range(self.config.max_retries):
+ try:
+ async with self._semaphore:
+ response = await self._client.chat.completions.create(
+ model=self.config.model,
+ messages=messages,
+ max_completion_tokens=self.config.max_completion_tokens,
+ response_format={"type": "json_object"},
+ )
+ content = response.choices[0].message.content
+ if content:
+ return content.strip()
+ # Reasoning model may exhaust tokens on thinking — retry
+ if response.choices[0].finish_reason == "length":
+ continue
+ return ""
+ except (RateLimitError, APITimeoutError, APIConnectionError) as e:
+ if attempt == self.config.max_retries - 1:
+ raise
+ delay = self.config.retry_base_delay * (2 ** attempt)
+ await asyncio.sleep(delay)
+ return ""
+
+ def _build_messages(self, sample: TurnSample) -> List[dict]:
+ """Construct the judge prompt from (q_t, a_t, q_{t+1}) only."""
+ user_content = JUDGE_USER_TEMPLATE.format(
+ query_t=sample.query_t,
+ answer_t=sample.answer_t,
+ query_t1=sample.query_t1,
+ )
+ return [
+ {"role": "system", "content": JUDGE_SYSTEM_PROMPT},
+ {"role": "user", "content": user_content},
+ ]
+
+ def _parse_result(self, raw: str) -> RewardResult:
+ """Parse structured JSON output into RewardResult."""
+ try:
+ parsed = json.loads(raw)
+ label = parsed["label"]
+ confidence = float(parsed["confidence"])
+ rationale = parsed.get("rationale", "")
+
+ if label not in VALID_LABELS:
+ label = "neutral"
+ confidence = 0.0
+
+ reward = REWARD_MAP[label]
+
+ # Confidence gating and topic_shift skip
+ should_update = (
+ confidence >= self.config.confidence_threshold
+ and label != "topic_shift"
+ )
+ if not should_update:
+ reward = 0.0
+
+ return RewardResult(
+ label=label,
+ confidence=confidence,
+ rationale=rationale,
+ reward=reward,
+ should_update=should_update,
+ )
+ except (json.JSONDecodeError, KeyError, TypeError, ValueError):
+ return RewardResult(
+ label="neutral",
+ confidence=0.0,
+ rationale="parse_failure",
+ reward=0.0,
+ should_update=False,
+ )
+
+ async def judge(self, sample: TurnSample) -> RewardResult:
+ """Judge a single turn (async). Returns RewardResult with gating applied."""
+ # Cache lookup
+ if self.config.enable_cache:
+ key = self._cache_key(sample.query_t, sample.answer_t, sample.query_t1)
+ if key in self._cache:
+ return self._cache[key]
+
+ messages = self._build_messages(sample)
+ raw = await self._call_with_retry(messages)
+ result = self._parse_result(raw)
+
+ # Cache store
+ if self.config.enable_cache:
+ self._cache[key] = result
+
+ return result
+
+ async def judge_batch(self, samples: List[TurnSample]) -> List[RewardResult]:
+ """Judge a batch of turns in parallel. Returns list of RewardResult."""
+ tasks = [self.judge(s) for s in samples]
+ return await asyncio.gather(*tasks)
+
+ async def close(self):
+ """Close the underlying HTTP client."""
+ await self._client.close()
+
+
+# --- Synchronous Wrappers ---
+
+def estimate_reward_llm(
+ sample: TurnSample,
+ config: Optional[LLMRewardConfig] = None,
+) -> Tuple[float, bool]:
+ """
+ Synchronous single-sample reward estimation.
+ Returns (reward, should_update).
+ """
+ client = LLMRewardClient(config)
+ try:
+ result = asyncio.run(client.judge(sample))
+ return result.reward, result.should_update
+ finally:
+ asyncio.run(client.close())
+
+
+def estimate_rewards_batch(
+ samples: List[TurnSample],
+ config: Optional[LLMRewardConfig] = None,
+) -> List[Tuple[float, bool]]:
+ """
+ Synchronous batch reward estimation (runs async internally).
+ Returns list of (reward, should_update) tuples.
+ """
+ client = LLMRewardClient(config)
+ try:
+ results = asyncio.run(client.judge_batch(samples))
+ return [(r.reward, r.should_update) for r in results]
+ finally:
+ asyncio.run(client.close())
diff --git a/src/personalization/feedback/local_llm_reward.py b/src/personalization/feedback/local_llm_reward.py
new file mode 100644
index 0000000..70bbeb8
--- /dev/null
+++ b/src/personalization/feedback/local_llm_reward.py
@@ -0,0 +1,370 @@
+"""
+Local LLM reward model using vLLM server for batch inference.
+
+Drop-in replacement for LLMRewardClient when you want to use a local model
+(e.g., Llama-3.1-8B-Instruct) instead of OpenAI API.
+
+Uses BatchVLLMClient for efficient concurrent requests - vLLM's continuous
+batching will process them together for high throughput.
+"""
+from __future__ import annotations
+
+import asyncio
+import hashlib
+import json
+from dataclasses import dataclass
+from typing import Dict, List, Optional
+
+import aiohttp
+
+from personalization.feedback.schemas import TurnSample
+from personalization.feedback.llm_reward import (
+ REWARD_MAP,
+ VALID_LABELS,
+ JUDGE_SYSTEM_PROMPT,
+ JUDGE_USER_TEMPLATE,
+ RewardResult,
+)
+
+
+@dataclass
+class LocalLLMRewardConfig:
+ """Configuration for local LLM reward model."""
+ vllm_url: str = "http://localhost:8005/v1" # vLLM server URL
+ model_name: Optional[str] = None # Auto-discovered if None
+ max_tokens: int = 256
+ temperature: float = 0.1
+ max_concurrent: int = 100 # High concurrency for vLLM batching
+ timeout: Optional[int] = 60 # Per-request timeout in seconds
+ confidence_threshold: float = 0.6 # tau_c: skip update if confidence < this
+ enable_cache: bool = True # Cache by hash of (q_t, a_t, q_{t+1})
+
+
+class LocalLLMRewardClient:
+ """
+ Local LLM reward client using vLLM server.
+
+ Designed for batch processing - uses async HTTP requests that vLLM
+ batches together via continuous batching for high throughput.
+ """
+
+ def __init__(self, config: Optional[LocalLLMRewardConfig] = None):
+ self.config = config or LocalLLMRewardConfig()
+ self._model_name = self.config.model_name
+ self._cache: Dict[str, RewardResult] = {}
+
+ # Discover model name if not provided
+ if self._model_name is None:
+ self._discover_model_sync()
+
+ def _discover_model_sync(self):
+ """Synchronously discover model name from vLLM server."""
+ import requests
+ try:
+ response = requests.get(
+ f"{self.config.vllm_url}/models",
+ timeout=10
+ )
+ response.raise_for_status()
+ models = response.json()
+ if models.get("data") and len(models["data"]) > 0:
+ self._model_name = models["data"][0]["id"]
+ else:
+ self._model_name = "default"
+ except Exception as e:
+ print(f"[LocalLLMReward] Warning: Could not discover model ({e})")
+ self._model_name = "default"
+
+ def _cache_key(self, query_t: str, answer_t: str, query_t1: str) -> str:
+ """Deterministic hash of the judge input triple."""
+ content = f"{query_t}\x00{answer_t}\x00{query_t1}"
+ return hashlib.sha256(content.encode("utf-8")).hexdigest()
+
+ def _build_messages(self, sample: TurnSample) -> List[dict]:
+ """Construct the judge prompt from (q_t, a_t, q_{t+1})."""
+ user_content = JUDGE_USER_TEMPLATE.format(
+ query_t=sample.query_t,
+ answer_t=sample.answer_t,
+ query_t1=sample.query_t1,
+ )
+ return [
+ {"role": "system", "content": JUDGE_SYSTEM_PROMPT},
+ {"role": "user", "content": user_content},
+ ]
+
+ def _parse_result(self, raw: Optional[str]) -> RewardResult:
+ """Parse structured JSON output into RewardResult."""
+ if raw is None:
+ return RewardResult(
+ label="neutral",
+ confidence=0.0,
+ rationale="request_failed",
+ reward=0.0,
+ should_update=False,
+ )
+
+ try:
+ # Handle markdown code blocks
+ text = raw.strip()
+ if text.startswith("```"):
+ lines = text.split("\n")
+ text = "\n".join(
+ lines[1:-1] if lines[-1].strip() == "```" else lines[1:]
+ )
+
+ parsed = json.loads(text)
+ label = parsed.get("label", "neutral")
+ confidence = float(parsed.get("confidence", 0.0))
+ rationale = parsed.get("rationale", "")
+
+ if label not in VALID_LABELS:
+ label = "neutral"
+ confidence = 0.0
+
+ reward = REWARD_MAP[label]
+
+ # Confidence gating and topic_shift skip
+ should_update = (
+ confidence >= self.config.confidence_threshold
+ and label != "topic_shift"
+ )
+ if not should_update:
+ reward = 0.0
+
+ return RewardResult(
+ label=label,
+ confidence=confidence,
+ rationale=rationale,
+ reward=reward,
+ should_update=should_update,
+ )
+ except (json.JSONDecodeError, KeyError, TypeError, ValueError):
+ # Try to extract JSON from text
+ import re
+ match = re.search(r'\{[^}]+\}', raw, re.DOTALL)
+ if match:
+ try:
+ parsed = json.loads(match.group())
+ label = parsed.get("label", "neutral")
+ confidence = float(parsed.get("confidence", 0.0))
+ rationale = parsed.get("rationale", "")
+
+ if label not in VALID_LABELS:
+ label = "neutral"
+ confidence = 0.0
+
+ reward = REWARD_MAP[label]
+ should_update = (
+ confidence >= self.config.confidence_threshold
+ and label != "topic_shift"
+ )
+ if not should_update:
+ reward = 0.0
+
+ return RewardResult(
+ label=label,
+ confidence=confidence,
+ rationale=rationale,
+ reward=reward,
+ should_update=should_update,
+ )
+ except:
+ pass
+
+ return RewardResult(
+ label="neutral",
+ confidence=0.0,
+ rationale="parse_failure",
+ reward=0.0,
+ should_update=False,
+ )
+
+ async def _single_request(
+ self,
+ session: aiohttp.ClientSession,
+ messages: List[dict],
+ idx: int,
+ ) -> tuple:
+ """Make a single async request to vLLM server."""
+ payload = {
+ "model": self._model_name,
+ "messages": messages,
+ "max_tokens": self.config.max_tokens,
+ "temperature": self.config.temperature,
+ "response_format": {"type": "json_object"},
+ }
+
+ for attempt in range(3):
+ try:
+ timeout_config = (
+ aiohttp.ClientTimeout(total=self.config.timeout)
+ if self.config.timeout else None
+ )
+ async with session.post(
+ f"{self.config.vllm_url}/chat/completions",
+ json=payload,
+ timeout=timeout_config,
+ ) as response:
+ if response.status == 200:
+ result = await response.json()
+ content = result["choices"][0]["message"]["content"]
+ return (idx, content, None)
+ elif response.status == 429:
+ # Rate limit - wait and retry
+ await asyncio.sleep(2 ** attempt)
+ continue
+ else:
+ error_text = await response.text()
+ return (idx, None, f"HTTP {response.status}: {error_text[:200]}")
+ except asyncio.TimeoutError:
+ if attempt < 2:
+ continue
+ return (idx, None, "Timeout")
+ except Exception as e:
+ return (idx, None, str(e))
+
+ return (idx, None, "Max retries exceeded")
+
+ async def judge_batch_async(
+ self,
+ samples: List[TurnSample],
+ show_progress: bool = False,
+ ) -> List[RewardResult]:
+ """
+ Judge a batch of turns using concurrent vLLM requests.
+
+ vLLM's continuous batching will process these together for
+ high throughput.
+ """
+ n_samples = len(samples)
+ results = [None] * n_samples
+
+ # Check cache and build request list
+ to_request = [] # (original_idx, messages)
+ for i, sample in enumerate(samples):
+ if self.config.enable_cache:
+ key = self._cache_key(sample.query_t, sample.answer_t, sample.query_t1)
+ if key in self._cache:
+ results[i] = self._cache[key]
+ continue
+
+ messages = self._build_messages(sample)
+ to_request.append((i, messages))
+
+ if not to_request:
+ return results
+
+ # Make concurrent requests
+ semaphore = asyncio.Semaphore(self.config.max_concurrent)
+
+ async def limited_request(session, messages, idx):
+ async with semaphore:
+ return await self._single_request(session, messages, idx)
+
+ connector = aiohttp.TCPConnector(limit=self.config.max_concurrent)
+ headers = {"Content-Type": "application/json"}
+
+ async with aiohttp.ClientSession(
+ connector=connector, headers=headers
+ ) as session:
+ tasks = [
+ limited_request(session, messages, orig_idx)
+ for orig_idx, messages in to_request
+ ]
+
+ completed = 0
+ for coro in asyncio.as_completed(tasks):
+ orig_idx, content, error = await coro
+ completed += 1
+
+ if error:
+ print(f"[LocalLLMReward] Request {orig_idx} failed: {error}")
+
+ result = self._parse_result(content)
+ results[orig_idx] = result
+
+ # Cache the result
+ if self.config.enable_cache:
+ sample = samples[orig_idx]
+ key = self._cache_key(
+ sample.query_t, sample.answer_t, sample.query_t1
+ )
+ self._cache[key] = result
+
+ if show_progress and completed % 10 == 0:
+ print(f" [LocalLLMReward {completed}/{len(to_request)}] completed")
+
+ return results
+
+ async def judge_async(self, sample: TurnSample) -> RewardResult:
+ """Judge a single turn (async)."""
+ results = await self.judge_batch_async([sample])
+ return results[0]
+
+ def judge_batch(self, samples: List[TurnSample]) -> List[RewardResult]:
+ """
+ Judge a batch of turns (sync wrapper).
+
+ This is the main entry point for batch reward estimation.
+ """
+ try:
+ loop = asyncio.get_running_loop()
+ except RuntimeError:
+ loop = None
+
+ if loop is not None:
+ # Already in an event loop - create a new thread to run the coroutine
+ import concurrent.futures
+ with concurrent.futures.ThreadPoolExecutor() as executor:
+ future = executor.submit(asyncio.run, self.judge_batch_async(samples))
+ return future.result()
+ else:
+ return asyncio.run(self.judge_batch_async(samples))
+
+ async def judge(self, sample: TurnSample) -> RewardResult:
+ """Judge a single turn (async interface for compatibility with LLMRewardClient)."""
+ return await self.judge_async(sample)
+
+ def judge_sync(self, sample: TurnSample) -> RewardResult:
+ """Judge a single turn (sync wrapper)."""
+ try:
+ loop = asyncio.get_running_loop()
+ except RuntimeError:
+ loop = None
+
+ if loop is not None:
+ # Already in an event loop - create a new thread to run the coroutine
+ import concurrent.futures
+ with concurrent.futures.ThreadPoolExecutor() as executor:
+ future = executor.submit(asyncio.run, self.judge_async(sample))
+ return future.result()
+ else:
+ return asyncio.run(self.judge_async(sample))
+
+
+# --- Convenience Functions ---
+
+def estimate_reward_local(
+ sample: TurnSample,
+ config: Optional[LocalLLMRewardConfig] = None,
+) -> tuple:
+ """
+ Synchronous single-sample reward estimation using local LLM.
+ Returns (reward, should_update).
+ """
+ client = LocalLLMRewardClient(config)
+ result = client.judge(sample)
+ return result.reward, result.should_update
+
+
+def estimate_rewards_batch_local(
+ samples: List[TurnSample],
+ config: Optional[LocalLLMRewardConfig] = None,
+) -> List[tuple]:
+ """
+ Synchronous batch reward estimation using local LLM.
+ Returns list of (reward, should_update) tuples.
+ """
+ client = LocalLLMRewardClient(config)
+ results = client.judge_batch(samples)
+ return [(r.reward, r.should_update) for r in results]
diff --git a/src/personalization/feedback/online_update.py b/src/personalization/feedback/online_update.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/feedback/online_update.py
diff --git a/src/personalization/feedback/reward_model.py b/src/personalization/feedback/reward_model.py
new file mode 100644
index 0000000..3584b43
--- /dev/null
+++ b/src/personalization/feedback/reward_model.py
@@ -0,0 +1,64 @@
+import numpy as np
+from personalization.feedback.schemas import TurnSample
+
+def cosine_sim(a: np.ndarray, b: np.ndarray) -> float:
+ norm_a = np.linalg.norm(a)
+ norm_b = np.linalg.norm(b)
+ if norm_a == 0 or norm_b == 0:
+ return 0.0
+ return float(np.dot(a, b) / (norm_a * norm_b))
+
+def estimate_reward(sample: TurnSample) -> float:
+ """
+ Return a scalar reward_hat, indicating if the previous answer was helpful.
+ Range: [-1.0, 1.0] (approx)
+ """
+
+ # 1. Language/Topic Coherence
+ if sample.query_embedding_t is None or sample.query_embedding_t1 is None:
+ topic_sim = 0.5
+ else:
+ topic_sim = cosine_sim(sample.query_embedding_t, sample.query_embedding_t1)
+
+ # 2. Negative Keywords (Complaint/Correction)
+ negative_keywords = [
+ "you didn't", "that's not", "incorrect", "redo", "again", "explain more",
+ "doesn't help", "wrong", "no", "not what i asked",
+ "你没", "不是", "这不是", "重来", "重新", "不对", "错了", "没说清楚"
+ ]
+
+ # 3. Positive Keywords (Follow-up/Elaboration)
+ positive_keywords = [
+ "can you elaborate", "give an example", "continue", "what if", "based on that",
+ "thanks", "good", "great", "cool",
+ "能不能详细一点", "举个例子", "再继续", "那如果", "接下来", "在这个基础上", "谢谢", "不错"
+ ]
+
+ q1_lower = sample.query_t1.lower()
+
+ has_negative = any(kw in q1_lower for kw in negative_keywords)
+ has_positive = any(kw in q1_lower for kw in positive_keywords)
+
+ reward = 0.0
+
+ if has_negative:
+ reward -= 1.0
+
+ if has_positive:
+ # Only reward if topic similarity is decent, otherwise might be "thanks, bye" (end of session)
+ # But "thanks" is good.
+ reward += 0.5
+ if topic_sim > 0.3:
+ reward += 0.5
+
+ if topic_sim < 0.2:
+ # Topic shift -> previous interaction likely finished or failed.
+ # If no explicit positive/negative, assume neutral/slightly decayed.
+ # If user changes topic, it often means the previous task is done (neutral/positive)
+ # OR they gave up (negative). Hard to tell.
+ # Let's dampen the reward towards 0.
+ reward *= 0.5
+
+ # Clip
+ return max(-1.0, min(1.0, reward))
+
diff --git a/src/personalization/feedback/sampler.py b/src/personalization/feedback/sampler.py
new file mode 100644
index 0000000..9e26912
--- /dev/null
+++ b/src/personalization/feedback/sampler.py
@@ -0,0 +1,109 @@
+from typing import Iterable, List, Optional
+import numpy as np
+from tqdm import tqdm
+
+from personalization.retrieval.preference_store.schemas import ChatTurn, MemoryCard
+from personalization.feedback.schemas import TurnSample
+from personalization.retrieval.pipeline import retrieve_with_rerank
+from personalization.models.llm.qwen_instruct import QwenInstruct
+from personalization.models.embedding.base import EmbeddingModel
+from personalization.models.reranker.base import Reranker
+from personalization.user_model.tensor_store import UserTensorStore
+
+def build_turn_samples_from_sessions(
+ sessions: Iterable[List[ChatTurn]],
+ embed_model: EmbeddingModel,
+ llm: QwenInstruct,
+ reranker: Reranker,
+ memory_cards: List[MemoryCard],
+ memory_embeddings: np.ndarray,
+ user_store: UserTensorStore,
+ item_vectors: np.ndarray,
+ max_samples: Optional[int] = None,
+ topk_dense: int = 64,
+ topk_rerank: int = 3,
+) -> List[TurnSample]:
+ samples = []
+
+ for turns in tqdm(sessions, desc="Building TurnSamples"):
+ if max_samples and len(samples) >= max_samples:
+ break
+
+ # Ensure sorted by turn_id
+ sorted_turns = sorted(turns, key=lambda x: x.turn_id)
+
+ # Iterate to find (q_t, a_t, q_{t+1})
+ for i in range(len(sorted_turns)):
+ if max_samples and len(samples) >= max_samples:
+ break
+
+ q_t = sorted_turns[i]
+ if q_t.role != "user":
+ continue
+
+ # Find next user turn
+ # Also try to find assistant response in between
+ a_t_text = ""
+ q_t1 = None
+
+ # Look ahead
+ for j in range(i + 1, len(sorted_turns)):
+ next_turn = sorted_turns[j]
+ if next_turn.role == "assistant" and not a_t_text:
+ a_t_text = next_turn.text
+ elif next_turn.role == "user":
+ q_t1 = next_turn
+ break
+
+ if not q_t1:
+ # End of session or no subsequent user query
+ continue
+
+ # We have q_t, a_t (optional but preferred), q_t1
+ # If a_t is missing, we might skip or use empty string.
+ # For RL, we usually need the answer to evaluate quality.
+ # If dataset doesn't have assistant turns, we might need to generate one?
+ # For now, let's proceed even if a_t is empty, or maybe require it.
+ if not a_t_text:
+ # Try to use LLM to generate if needed, but for offline sampling
+ # from existing chats, we prefer existing answers.
+ # If using OASST1, it should have assistant turns.
+ pass
+
+ # 3. Retrieve memories for q_t
+ memories_t = retrieve_with_rerank(
+ user_id=q_t.user_id,
+ query=q_t.text,
+ embed_model=embed_model,
+ reranker=reranker,
+ memory_cards=memory_cards,
+ memory_embeddings=memory_embeddings,
+ user_store=user_store,
+ item_vectors=item_vectors,
+ topk_dense=topk_dense,
+ topk_rerank=topk_rerank,
+ beta_long=0.0,
+ beta_short=0.0,
+ only_own_memories=True # Assume we want user specific memories
+ )
+
+ # 4. Precompute embeddings
+ # We can do this efficiently later or batch, but here per sample
+ e_q_t = embed_model.encode([q_t.text], return_tensor=False)[0]
+ e_q_t1 = embed_model.encode([q_t1.text], return_tensor=False)[0]
+
+ sample = TurnSample(
+ user_id=q_t.user_id,
+ session_id=q_t.session_id,
+ turn_id=q_t.turn_id,
+ query_t=q_t.text,
+ answer_t=a_t_text,
+ query_t1=q_t1.text,
+ memories=memories_t,
+ query_embedding_t=np.array(e_q_t),
+ query_embedding_t1=np.array(e_q_t1)
+ )
+ samples.append(sample)
+
+ return samples
+
diff --git a/src/personalization/feedback/schemas.py b/src/personalization/feedback/schemas.py
new file mode 100644
index 0000000..b15db80
--- /dev/null
+++ b/src/personalization/feedback/schemas.py
@@ -0,0 +1,23 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import List, Optional, Any
+import numpy as np
+
+from personalization.retrieval.preference_store.schemas import MemoryCard
+
+@dataclass
+class TurnSample:
+ user_id: str
+ session_id: str
+ turn_id: int # index of q_t within the session
+ query_t: str # q_t
+ answer_t: str # a_t
+ query_t1: str # q_{t+1}
+ memories: List[MemoryCard] # A_t
+
+ # Optional pre-computed vectors and features
+ query_embedding_t: Optional[np.ndarray] = None
+ query_embedding_t1: Optional[np.ndarray] = None
+ memory_embeddings: Optional[np.ndarray] = None # corresponding e_m or v_m for memories
+
diff --git a/src/personalization/models/__init__.py b/src/personalization/models/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/models/__init__.py
diff --git a/src/personalization/models/embedding/__init__.py b/src/personalization/models/embedding/__init__.py
new file mode 100644
index 0000000..05221aa
--- /dev/null
+++ b/src/personalization/models/embedding/__init__.py
@@ -0,0 +1,11 @@
+from .base import EmbeddingModel
+from .qwen3_8b import Qwen3Embedding8B
+from .nemotron_8b import LlamaEmbedNemotron8B
+
+__all__ = [
+ "EmbeddingModel",
+ "Qwen3Embedding8B",
+ "LlamaEmbedNemotron8B",
+]
+
+
diff --git a/src/personalization/models/embedding/base.py b/src/personalization/models/embedding/base.py
new file mode 100644
index 0000000..9f9d4d1
--- /dev/null
+++ b/src/personalization/models/embedding/base.py
@@ -0,0 +1,37 @@
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from typing import Iterable, List, Sequence
+
+import torch
+
+
+class EmbeddingModel(ABC):
+ @abstractmethod
+ def encode(
+ self,
+ texts: Sequence[str],
+ batch_size: int = 8,
+ max_length: int = 512,
+ normalize: bool = True,
+ return_tensor: bool = False,
+ ) -> List[List[float]] | torch.Tensor:
+ """Encode a batch of texts into dense embeddings."""
+ raise NotImplementedError
+
+
+def _mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
+ # last_hidden_state: [batch, seq_len, hidden]
+ # attention_mask: [batch, seq_len]
+ mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state) # [b, s, 1]
+ summed = (last_hidden_state * mask).sum(dim=1)
+ counts = mask.sum(dim=1).clamp_min(1e-6)
+ return summed / counts
+
+
+def _maybe_normalize(x: torch.Tensor, normalize: bool) -> torch.Tensor:
+ if not normalize:
+ return x
+ return torch.nn.functional.normalize(x, p=2, dim=-1)
+
+
diff --git a/src/personalization/models/embedding/qwen3_8b.py b/src/personalization/models/embedding/qwen3_8b.py
new file mode 100644
index 0000000..fb02e67
--- /dev/null
+++ b/src/personalization/models/embedding/qwen3_8b.py
@@ -0,0 +1,89 @@
+from __future__ import annotations
+
+from typing import List, Sequence
+
+import torch
+from transformers import AutoModel, AutoTokenizer
+
+from personalization.config.registry import choose_dtype, choose_device_map
+from personalization.config.settings import LocalModelsConfig
+from .base import EmbeddingModel, _mean_pool, _maybe_normalize
+
+
+class Qwen3Embedding8B(EmbeddingModel):
+ def __init__(
+ self,
+ model_path: str,
+ dtype: torch.dtype,
+ device_map: str = "auto",
+ trust_remote_code: bool = True,
+ ) -> None:
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ model_path, use_fast=True, trust_remote_code=trust_remote_code
+ )
+
+ # Handle specific device assignment (e.g., "cuda:0", "cuda:1")
+ if device_map and device_map.startswith("cuda:"):
+ # Load to CPU first, then move to specific GPU
+ self.model = AutoModel.from_pretrained(
+ model_path,
+ torch_dtype=dtype,
+ device_map=None, # Don't use accelerate's device_map
+ trust_remote_code=trust_remote_code,
+ low_cpu_mem_usage=True,
+ )
+ self.model = self.model.to(device_map)
+ else:
+ # Use accelerate's auto device mapping
+ self.model = AutoModel.from_pretrained(
+ model_path,
+ torch_dtype=dtype,
+ device_map=device_map,
+ trust_remote_code=trust_remote_code,
+ low_cpu_mem_usage=True,
+ )
+
+ @classmethod
+ def from_config(cls, cfg: LocalModelsConfig) -> "Qwen3Embedding8B":
+ if not cfg.embedding or not cfg.embedding.qwen3:
+ raise ValueError("Embedding config for qwen3 is missing")
+ spec = cfg.embedding.qwen3
+ dtype = choose_dtype(spec.dtype)
+ device_map = choose_device_map(spec.device_map)
+ return cls(
+ spec.local_path,
+ dtype=dtype,
+ device_map=device_map,
+ trust_remote_code=True,
+ )
+
+ @torch.inference_mode()
+ def encode(
+ self,
+ texts: Sequence[str],
+ batch_size: int = 8,
+ max_length: int = 512,
+ normalize: bool = True,
+ return_tensor: bool = False,
+ ) -> List[List[float]] | torch.Tensor:
+ device = next(self.model.parameters()).device
+ outputs: List[torch.Tensor] = []
+ for i in range(0, len(texts), batch_size):
+ batch = list(texts[i : i + batch_size])
+ enc = self.tokenizer(
+ batch,
+ padding=True,
+ truncation=True,
+ max_length=max_length,
+ return_tensors="pt",
+ ).to(device)
+ model_out = self.model(**enc, output_hidden_states=False, return_dict=True)
+ pooled = _mean_pool(model_out.last_hidden_state, enc["attention_mask"]) # type: ignore[attr-defined]
+ pooled = _maybe_normalize(pooled, normalize)
+ outputs.append(pooled)
+ emb = torch.cat(outputs, dim=0)
+ if return_tensor:
+ return emb
+ return emb.cpu().to(torch.float32).tolist()
+
+
diff --git a/src/personalization/models/llm/__init__.py b/src/personalization/models/llm/__init__.py
new file mode 100644
index 0000000..3f1af81
--- /dev/null
+++ b/src/personalization/models/llm/__init__.py
@@ -0,0 +1,4 @@
+from .qwen_instruct import QwenInstruct
+
+__all__ = ["QwenInstruct"]
+
diff --git a/src/personalization/models/llm/base.py b/src/personalization/models/llm/base.py
new file mode 100644
index 0000000..72b6ca8
--- /dev/null
+++ b/src/personalization/models/llm/base.py
@@ -0,0 +1,29 @@
+from typing import List, Protocol, Optional
+from personalization.types import ChatTurn
+
+class ChatModel(Protocol):
+ def answer(
+ self,
+ history: List[ChatTurn],
+ memory_notes: List[str],
+ max_new_tokens: int = 512,
+ temperature: float = 0.7,
+ top_p: float = 0.9,
+ top_k: Optional[int] = None,
+ ) -> str:
+ """
+ Generate an assistant response given conversation history and memory notes.
+
+ Args:
+ history: The conversation history ending with the current user turn.
+ memory_notes: List of retrieved memory content strings.
+ max_new_tokens: Max tokens to generate.
+ temperature: Sampling temperature.
+ top_p: Top-p sampling.
+ top_k: Top-k sampling.
+
+ Returns:
+ The generated assistant response text.
+ """
+ ...
+
diff --git a/src/personalization/models/llm/prompt_builder.py b/src/personalization/models/llm/prompt_builder.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/models/llm/prompt_builder.py
diff --git a/src/personalization/models/llm/vllm_chat.py b/src/personalization/models/llm/vllm_chat.py
new file mode 100644
index 0000000..d577a30
--- /dev/null
+++ b/src/personalization/models/llm/vllm_chat.py
@@ -0,0 +1,244 @@
+"""
+vLLM-based ChatModel implementation for high-throughput inference.
+
+This provides the same interface as LlamaChatModel but uses vLLM HTTP API
+for much faster inference (3000+ sessions/hr vs 20 sessions/hr).
+"""
+
+from typing import List, Optional
+import time
+import requests
+
+from personalization.models.llm.base import ChatModel
+from personalization.types import ChatTurn
+
+
+class VLLMChatModel(ChatModel):
+ """
+ ChatModel implementation using vLLM HTTP API.
+
+ This is a drop-in replacement for LlamaChatModel that uses vLLM
+ for much faster inference.
+ """
+
+ def __init__(
+ self,
+ vllm_url: str = "http://localhost:8003/v1",
+ model_name: str = None,
+ max_context_length: int = 8192,
+ timeout: int = 120,
+ ):
+ self.vllm_url = vllm_url.rstrip('/')
+ self.model_name = model_name
+ self.max_context_length = max_context_length
+ self.timeout = timeout
+
+ # Discover model name if not provided
+ if self.model_name is None:
+ self._discover_model()
+
+ def _discover_model(self):
+ """Discover the model name from the vLLM server."""
+ max_retries = 30
+ for attempt in range(max_retries):
+ try:
+ response = requests.get(f"{self.vllm_url}/models", timeout=10)
+ response.raise_for_status()
+ models = response.json()
+ if models.get("data") and len(models["data"]) > 0:
+ self.model_name = models["data"][0]["id"]
+ return
+ except Exception as e:
+ if attempt < max_retries - 1:
+ wait_time = min(2 ** attempt * 0.5, 10)
+ time.sleep(wait_time)
+
+ # Fallback
+ self.model_name = "default"
+ print(f"[VLLMChatModel] Warning: Could not discover model, using '{self.model_name}'")
+
+ def health_check(self) -> bool:
+ """Check if the vLLM server is healthy."""
+ try:
+ response = requests.get(f"{self.vllm_url.replace('/v1', '')}/health", timeout=5)
+ return response.status_code == 200
+ except:
+ return False
+
+ def _estimate_tokens(self, text: str) -> int:
+ """Estimate token count using character-based heuristic.
+
+ For Llama models, ~4 characters per token is a reasonable estimate.
+ We use 3.5 to be conservative (slightly overestimate tokens).
+ """
+ return int(len(text) / 3.5)
+
+ def _build_messages(
+ self,
+ history: List[ChatTurn],
+ memory_notes: List[str],
+ max_new_tokens: int = 512,
+ global_notes: List[str] = None,
+ ) -> List[dict]:
+ """Build messages list for chat completion API with auto-truncation.
+
+ If the context exceeds max_context_length, older conversation turns
+ are removed to keep only the most recent context that fits.
+
+ Args:
+ global_notes: If provided, these are always-applicable preferences
+ displayed in a separate section from task-specific retrieved notes.
+ """
+ # Use CollaborativeAgents-style system prompt
+ has_any_notes = memory_notes or global_notes
+ if has_any_notes:
+ # Build preference sections
+ pref_sections = ""
+ if global_notes:
+ global_bullet = "\n".join(f"- {n}" for n in global_notes)
+ pref_sections += f"## General Preferences (always apply)\n{global_bullet}\n\n"
+ if memory_notes:
+ task_bullet = "\n".join(f"- {n}" for n in memory_notes)
+ if global_notes:
+ pref_sections += f"## Task-Specific Preferences\n{task_bullet}\n"
+ else:
+ pref_sections += f"{task_bullet}\n"
+
+ system_content = (
+ "You are a collaborative AI agent helping users solve writing, question answering, math, and coding problems.\n\n"
+ "# User Preferences\n"
+ "The user has a set of preferences for how you should behave. If you do not follow these preferences, "
+ "the user will be unable to learn from your response and you will need to adjust your response to adhere "
+ "to these preferences (so it is best to follow them initially).\n\n"
+ "**IMPORTANT**: If the user explicitly requests something in THIS conversation (e.g., asks you to change "
+ "your format, style, or approach), that request takes PRIORITY over the remembered preferences below. "
+ "Always adapt to the user's direct feedback first.\n\n"
+ "Based on your past interactions with the user, you have maintained a set of notes about the user's preferences:\n"
+ f"{pref_sections}\n"
+ "# Before Responding\n"
+ "Before writing your response, briefly consider:\n"
+ "1. Which preferences above are relevant to this specific request?\n"
+ "2. How will you satisfy each relevant preference in your response?\n\n"
+ "# Conversation Guidelines:\n"
+ "- If the user asks you to adjust your response (e.g., 'be more concise', 'focus on intuition'), you MUST change your approach accordingly. Do NOT repeat the same response.\n"
+ "- If the user's message is unclear, lacks details, or is ambiguous (e.g. length of an essay, format requirements, "
+ "specific constraints), do not make assumptions. Ask for clarification and ensure you have enough information before providing an answer.\n"
+ "- Your goal is to help the user solve their problem. Adhere to their preferences and do your best to help them solve their problem.\n"
+ "- **Verify**: Before finalizing, check that your response satisfies the relevant preferences listed above.\n"
+ )
+ else:
+ # Vanilla mode - no preferences
+ system_content = (
+ "You are a collaborative AI agent helping users solve writing, question answering, math, and coding problems.\n\n"
+ "# Conversation Guidelines:\n"
+ "- If the user's message is unclear, lacks details, or is ambiguous (e.g. length of an essay, format requirements, "
+ "specific constraints), do not make assumptions. Ask for clarification and ensure you have enough information before providing an answer.\n"
+ "- Your goal is to help the user solve their problem. Do your best to help them.\n"
+ )
+ system_message = {"role": "system", "content": system_content}
+
+ # Calculate available tokens for conversation history
+ # Reserve space for: system prompt + max_new_tokens + safety margin
+ system_tokens = self._estimate_tokens(system_content)
+ available_tokens = self.max_context_length - system_tokens - max_new_tokens - 100 # 100 token safety margin
+
+ # Build conversation messages from history
+ conversation_messages = []
+ for turn in history:
+ conversation_messages.append({"role": turn.role, "content": turn.text})
+
+ # Check if truncation is needed
+ total_conv_tokens = sum(self._estimate_tokens(m["content"]) for m in conversation_messages)
+
+ if total_conv_tokens > available_tokens:
+ # Truncate from the beginning (keep recent messages)
+ truncated_messages = []
+ current_tokens = 0
+
+ # Iterate from most recent to oldest
+ for msg in reversed(conversation_messages):
+ msg_tokens = self._estimate_tokens(msg["content"])
+ if current_tokens + msg_tokens <= available_tokens:
+ truncated_messages.insert(0, msg)
+ current_tokens += msg_tokens
+ else:
+ # Stop adding older messages
+ break
+
+ conversation_messages = truncated_messages
+ if len(truncated_messages) < len(history):
+ print(f"[VLLMChatModel] Truncated context: kept {len(truncated_messages)}/{len(history)} turns "
+ f"({current_tokens}/{total_conv_tokens} estimated tokens)")
+
+ messages = [system_message] + conversation_messages
+ return messages
+
+ def build_messages(
+ self,
+ history: List[ChatTurn],
+ memory_notes: List[str],
+ max_new_tokens: int = 512,
+ global_notes: List[str] = None,
+ ) -> List[dict]:
+ """Public method to build messages without calling the API.
+
+ Used for batch processing where messages are collected first,
+ then sent in batch to vLLM for concurrent processing.
+ """
+ return self._build_messages(history, memory_notes, max_new_tokens, global_notes=global_notes)
+
+ def answer(
+ self,
+ history: List[ChatTurn],
+ memory_notes: List[str],
+ max_new_tokens: int = 512,
+ temperature: float = 0.7,
+ top_p: float = 0.9,
+ top_k: Optional[int] = None,
+ ) -> str:
+ """Generate a response using vLLM HTTP API."""
+ messages = self._build_messages(history, memory_notes, max_new_tokens)
+
+ payload = {
+ "model": self.model_name,
+ "messages": messages,
+ "max_tokens": max_new_tokens,
+ "temperature": temperature,
+ "top_p": top_p,
+ }
+
+ # Retry with exponential backoff
+ max_retries = 5
+ for attempt in range(max_retries):
+ try:
+ response = requests.post(
+ f"{self.vllm_url}/chat/completions",
+ json=payload,
+ timeout=self.timeout
+ )
+
+ if response.status_code == 200:
+ result = response.json()
+ return result["choices"][0]["message"]["content"]
+ elif response.status_code == 400:
+ error_text = response.text
+ # Handle context length error
+ if "max_tokens" in error_text and max_new_tokens > 64:
+ payload["max_tokens"] = max(64, max_new_tokens // 2)
+ continue
+ raise RuntimeError(f"vLLM error: {error_text[:200]}")
+ else:
+ raise RuntimeError(f"vLLM HTTP {response.status_code}: {response.text[:200]}")
+
+ except requests.exceptions.Timeout:
+ if attempt < max_retries - 1:
+ time.sleep(2 ** attempt)
+ continue
+ raise RuntimeError("vLLM request timeout")
+ except requests.exceptions.ConnectionError as e:
+ if attempt < max_retries - 1:
+ time.sleep(2 ** attempt)
+ continue
+ raise RuntimeError(f"vLLM connection error: {e}")
+
+ raise RuntimeError("Max retries exceeded for vLLM request")
diff --git a/src/personalization/models/preference_extractor/__init__.py b/src/personalization/models/preference_extractor/__init__.py
new file mode 100644
index 0000000..65e2595
--- /dev/null
+++ b/src/personalization/models/preference_extractor/__init__.py
@@ -0,0 +1,5 @@
+from .rule_extractor import QwenRuleExtractor
+from .gpt4o_extractor import GPT4OExtractor
+from .base import PreferenceExtractor
+
+__all__ = ["QwenRuleExtractor", "GPT4OExtractor", "PreferenceExtractor"]
diff --git a/src/personalization/models/preference_extractor/base.py b/src/personalization/models/preference_extractor/base.py
new file mode 100644
index 0000000..850292f
--- /dev/null
+++ b/src/personalization/models/preference_extractor/base.py
@@ -0,0 +1,17 @@
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List
+from personalization.retrieval.preference_store.schemas import ChatTurn, PreferenceList
+
+class PreferenceExtractorBase(ABC):
+ @abstractmethod
+ def extract_turn(self, turns: List[ChatTurn]) -> PreferenceList:
+ """
+ Extract preferences from a window of chat turns (history + current query).
+ """
+ raise NotImplementedError
+
+# Alias for backward compatibility if needed,
+# though specific extractors should inherit from PreferenceExtractorBase now.
+PreferenceExtractor = PreferenceExtractorBase
diff --git a/src/personalization/models/preference_extractor/gpt4o_extractor.py b/src/personalization/models/preference_extractor/gpt4o_extractor.py
new file mode 100644
index 0000000..0f70522
--- /dev/null
+++ b/src/personalization/models/preference_extractor/gpt4o_extractor.py
@@ -0,0 +1,165 @@
+from __future__ import annotations
+
+import json
+import os
+from typing import Any, Dict, List
+
+from openai import OpenAI
+from personalization.config.settings import LocalModelsConfig
+from personalization.models.preference_extractor.base import PreferenceExtractorBase as PreferenceExtractor
+from personalization.retrieval.preference_store.schemas import (
+ ChatTurn,
+ PreferenceList,
+ preference_list_json_schema,
+)
+
+
+class GPT4OExtractor(PreferenceExtractor):
+ def __init__(self, api_key: str, model: str = "gpt-4o") -> None:
+ self.client = OpenAI(api_key=api_key)
+ self.model = model
+
+ # Load system prompt template
+ template_path = "fine_tuning_prompt_template.txt"
+ if os.path.exists(template_path):
+ with open(template_path, "r", encoding="utf-8") as f:
+ self.system_prompt = f.read()
+ else:
+ # Structured prompt that enforces the PreferenceList schema
+ self.system_prompt = (
+ "You are a preference extraction assistant. "
+ "Given a user message, extract any user preferences as condition-action rules.\n\n"
+ "Return a JSON object with exactly this structure:\n"
+ '{"preferences": [{"condition": "<when this applies>", "action": "<what to do>", "confidence": <0.0-1.0>}]}\n\n'
+ "Examples of preferences:\n"
+ '- {"condition": "general", "action": "respond in Chinese", "confidence": 0.9}\n'
+ '- {"condition": "when writing code", "action": "use Python with type hints", "confidence": 0.8}\n'
+ '- {"condition": "when explaining math", "action": "show step-by-step derivation", "confidence": 0.7}\n\n'
+ "If no preferences are found, return {\"preferences\": []}.\n"
+ "IMPORTANT: The output MUST be a JSON object with a \"preferences\" key containing a list."
+ )
+
+ @classmethod
+ def from_config(cls, cfg: LocalModelsConfig) -> "GPT4OExtractor":
+ # We rely on env var for API key, config for other potential settings if needed
+ api_key = os.getenv("OPENAI_API_KEY")
+ if not api_key:
+ raise ValueError("OPENAI_API_KEY environment variable not set")
+ return cls(api_key=api_key)
+
+ def build_preference_prompt(self, query: str) -> str:
+ # GPT4OExtractor uses the system prompt loaded in __init__
+ return self.system_prompt
+
+ def _call_kwargs(self, messages):
+ """Build kwargs for chat completion, skipping temperature for models that don't support it."""
+ kwargs = {
+ "model": self.model,
+ "messages": messages,
+ "response_format": {"type": "json_object"},
+ }
+ # GPT-5 series doesn't support temperature=0
+ if not self.model.startswith("gpt-5"):
+ kwargs["temperature"] = 0.0
+ return kwargs
+
+ def extract_preferences(self, query: str) -> Dict[str, Any]:
+ # Reuse logic but return raw dict
+ try:
+ messages = [
+ {"role": "system", "content": self.system_prompt},
+ {"role": "user", "content": query},
+ ]
+ response = self.client.chat.completions.create(**self._call_kwargs(messages))
+ content = response.choices[0].message.content
+ if content:
+ return json.loads(content)
+ except Exception as e:
+ print(f"Error calling GPT-4o: {e}")
+ return {"preferences": []}
+
+ def extract_turn(self, turns) -> PreferenceList:
+ # Accept both a single ChatTurn and a list of ChatTurns (history)
+ if isinstance(turns, list):
+ # Find the last user message in history
+ last_user_msg = None
+ for t in reversed(turns):
+ if hasattr(t, 'role') and t.role == "user":
+ last_user_msg = t.text
+ break
+ if not last_user_msg:
+ return PreferenceList(preferences=[])
+ else:
+ # Single ChatTurn
+ if turns.role != "user":
+ return PreferenceList(preferences=[])
+ last_user_msg = turns.text
+
+ try:
+ messages = [
+ {"role": "system", "content": self.system_prompt},
+ {"role": "user", "content": last_user_msg},
+ ]
+ response = self.client.chat.completions.create(**self._call_kwargs(messages))
+
+ content = response.choices[0].message.content
+ if not content:
+ return PreferenceList(preferences=[])
+
+ data = json.loads(content)
+ return self._parse_to_preference_list(data)
+
+ except Exception as e:
+ print(f"Error calling GPT-4o: {e}")
+ return PreferenceList(preferences=[])
+
+ @staticmethod
+ def _parse_to_preference_list(data: dict) -> PreferenceList:
+ """Robustly convert GPT output to PreferenceList, handling non-standard formats."""
+ # Best case: already matches schema
+ if "preferences" in data and isinstance(data["preferences"], list):
+ prefs = []
+ for item in data["preferences"]:
+ if isinstance(item, dict) and "condition" in item and "action" in item:
+ prefs.append({
+ "condition": str(item["condition"])[:128],
+ "action": str(item["action"])[:256],
+ "confidence": float(item.get("confidence", 0.7)),
+ })
+ return PreferenceList.model_validate({"preferences": prefs})
+
+ # GPT returned a flat dict of preferences - convert to condition/action pairs
+ prefs = []
+ for key, value in data.items():
+ if isinstance(value, str) and len(value) > 2:
+ prefs.append({
+ "condition": str(key)[:128] if len(str(key)) > 1 else "general",
+ "action": str(value)[:256],
+ "confidence": 0.7,
+ })
+ elif isinstance(value, dict):
+ # Nested dict: try to extract meaningful pairs
+ for sub_key, sub_val in value.items():
+ if isinstance(sub_val, str) and len(sub_val) > 2:
+ prefs.append({
+ "condition": str(sub_key)[:128],
+ "action": str(sub_val)[:256],
+ "confidence": 0.7,
+ })
+ elif isinstance(value, list):
+ for item in value:
+ if isinstance(item, str) and len(item) > 2:
+ prefs.append({
+ "condition": str(key)[:128],
+ "action": str(item)[:256],
+ "confidence": 0.7,
+ })
+
+ return PreferenceList.model_validate({"preferences": prefs[:20]})
+
+ def extract_session(self, turns: List[ChatTurn]) -> List[PreferenceList]:
+ results = []
+ for turn in turns:
+ results.append(self.extract_turn(turn))
+ return results
+
diff --git a/src/personalization/models/preference_extractor/llm_extractor.py b/src/personalization/models/preference_extractor/llm_extractor.py
new file mode 100644
index 0000000..8f7a6cb
--- /dev/null
+++ b/src/personalization/models/preference_extractor/llm_extractor.py
@@ -0,0 +1,153 @@
+from typing import List, Dict, Any
+import torch
+import json
+import os
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from personalization.models.preference_extractor.base import PreferenceExtractorBase
+from personalization.retrieval.preference_store.schemas import ChatTurn, PreferenceList
+from personalization.config.settings import LocalModelsConfig
+from personalization.config.registry import choose_dtype, choose_device_map
+
+class PreferenceExtractorLLM(PreferenceExtractorBase):
+ def __init__(
+ self,
+ model_path: str,
+ prompt_template_path: str = "fine_tuning_prompt_template.txt",
+ device_map: str = "auto",
+ dtype: torch.dtype = torch.bfloat16,
+ max_new_tokens: int = 512,
+ ) -> None:
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ torch_dtype=dtype,
+ device_map=device_map,
+ trust_remote_code=True,
+ )
+ self.max_new_tokens = max_new_tokens
+
+ if os.path.exists(prompt_template_path):
+ with open(prompt_template_path, "r", encoding="utf-8") as f:
+ self.prompt_template = f.read()
+ else:
+ print(f"Warning: Prompt template not found at {prompt_template_path}. Using fallback.")
+ self.prompt_template = "Extract user preferences from the following conversation."
+
+ @classmethod
+ def from_config(cls, cfg: LocalModelsConfig, name: str = "qwen3_0_6b_sft") -> "PreferenceExtractorLLM":
+ # We need to access the specific extractor config by name
+ # Assuming cfg has a way to access extra configs or we update LocalModelsConfig to support multiple extractors
+ # For now, let's look for it in the 'preference_extractor' dict if it was a Dict, but it is a ModelSpec.
+ # We need to update LocalModelsConfig to support a dictionary of extractors or a specific one.
+ # Based on user design doc:
+ # preference_extractor:
+ # qwen3_0_6b_sft: ...
+
+ # We might need to manually parse the raw config or update settings.py
+ # Let's assume settings.py will be updated to hold a map or specific fields.
+ # For now, if we use the existing ModelSpec for preference_extractor in cfg, we assume it points to this model.
+
+ # BUT the design doc says "preference_extractor" in local_models.yaml will have "qwen3_0_6b_sft" key.
+ # The current settings.py defines preference_extractor as a single ModelSpec.
+ # We will need to update settings.py first to support multiple extractors or a dict.
+ # I will proceed implementing this class assuming arguments are passed, and update settings/registry later.
+
+ # This from_config might change depending on how settings.py is refactored.
+ # For now I will implement it assuming a direct ModelSpec is passed, or we handle it in registry.
+ pass
+ return None
+
+ def _build_prompt(self, turns: List[ChatTurn]) -> str:
+ # Construct messages list for chat template
+ messages = [{"role": "system", "content": self.prompt_template}]
+
+ # Window size 6
+ window = turns[-6:]
+
+ # Add conversation history
+ # We need to format the conversation as input context.
+ # Since the task is to extract preferences from the *whole* context (or latest turn?),
+ # usually we provide the conversation and ask for extraction.
+ # But LLaMA-Factory SFT usually expects:
+ # System: <template>
+ # User: <input>
+ # Assistant: <output>
+
+ # We should pack the conversation history into the User message?
+ # Or if we trained with multi-turn chat format?
+ # Assuming "Input" column in dataset was the conversation history.
+
+ history_texts = []
+ for t in window:
+ role = "User" if t.role == "user" else "Assistant"
+ history_texts.append(f"{role}: {t.text}")
+
+ conversation_text = "\n".join(history_texts)
+
+ # Construct the User input
+ # We append a trigger instruction if it wasn't part of the training input implicitly.
+ # But based on your template, the User Input Example was just the query "I am a Python developer..."
+ # So likely we should just feed the conversation text as the user message.
+
+ messages.append({"role": "user", "content": conversation_text})
+
+ # Apply chat template
+ prompt = self.tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True
+ )
+
+ return prompt
+
+ def _generate(self, prompt: str) -> str:
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
+ with torch.no_grad():
+ outputs = self.model.generate(
+ **inputs,
+ max_new_tokens=self.max_new_tokens,
+ do_sample=False,
+ temperature=0.0,
+ eos_token_id=self.tokenizer.eos_token_id,
+ pad_token_id=self.tokenizer.pad_token_id,
+ )
+ full_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
+ return full_text[len(prompt):]
+
+ def _parse_preferences(self, raw_output: str) -> PreferenceList:
+ start = raw_output.find("{")
+ end = raw_output.rfind("}")
+
+ if start == -1 or end == -1 or end <= start:
+ return PreferenceList(preferences=[])
+
+ json_str = raw_output[start:end+1]
+ try:
+ data = json.loads(json_str)
+ return PreferenceList.model_validate(data)
+ except Exception:
+ return PreferenceList(preferences=[])
+
+ def extract_turn(self, turns: List[ChatTurn]) -> PreferenceList:
+ prompt = self._build_prompt(turns)
+ raw_output = self._generate(prompt)
+ return self._parse_preferences(raw_output)
+
+ # Legacy support
+ def build_preference_prompt(self, query: str) -> str:
+ # Wrap query in a dummy turn
+ turn = ChatTurn(
+ user_id="dummy", session_id="dummy", turn_id=0,
+ role="user", text=query
+ )
+ return self._build_prompt([turn])
+
+ def extract_preferences(self, query: str) -> Dict[str, Any]:
+ turn = ChatTurn(
+ user_id="dummy", session_id="dummy", turn_id=0,
+ role="user", text=query
+ )
+ prefs = self.extract_turn([turn])
+ return prefs.model_dump()
+
diff --git a/src/personalization/models/preference_extractor/rule_extractor.py b/src/personalization/models/preference_extractor/rule_extractor.py
new file mode 100644
index 0000000..42f43ed
--- /dev/null
+++ b/src/personalization/models/preference_extractor/rule_extractor.py
@@ -0,0 +1,205 @@
+from __future__ import annotations
+
+import json
+import re
+import os
+from typing import Any, Dict, List
+
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from personalization.config.registry import choose_dtype, choose_device_map
+from personalization.config.settings import LocalModelsConfig
+from .base import PreferenceExtractor
+from personalization.retrieval.preference_store.schemas import (
+ PreferenceList,
+ preference_list_json_schema,
+ ChatTurn,
+)
+
+# Hardcoded System Prompt to match SFT training
+# This MUST match what was used in training (scripts/split_train_test.py)
+SFT_SYSTEM_PROMPT = (
+ "Extract user preferences from the query into JSON format based on the PreferenceList schema. "
+ "If no preferences are found, return {\"preferences\": []}."
+)
+
+class QwenRuleExtractor(PreferenceExtractor):
+ """
+ Extractor using a Fine-Tuned (SFT) Qwen model.
+ Despite the name 'RuleExtractor' (legacy), this now performs direct End-to-End extraction.
+ """
+ def __init__(self, model_path: str, dtype: torch.dtype, device_map: str = "auto") -> None:
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ model_path, use_fast=True, trust_remote_code=True
+ )
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ dtype=dtype,
+ device_map=device_map,
+ trust_remote_code=True,
+ )
+ if self.tokenizer.pad_token_id is None:
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+
+ @classmethod
+ def from_config(cls, cfg: LocalModelsConfig) -> "QwenRuleExtractor":
+ spec = cfg.preference_extractor
+ dtype = choose_dtype(spec.dtype)
+ device_map = choose_device_map(spec.device_map)
+ return cls(spec.local_path, dtype=dtype, device_map=device_map)
+
+ def build_preference_prompt(self, query: str) -> str:
+ """
+ Construct the prompt string using the tokenizer's chat template.
+ Matches the format seen during SFT training.
+ """
+ messages = [
+ {"role": "system", "content": SFT_SYSTEM_PROMPT},
+ {"role": "user", "content": query}
+ ]
+ prompt = self.tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+ return prompt
+
+ @torch.inference_mode()
+ def extract_preferences(self, query: str) -> Dict[str, Any]:
+ """
+ Directly extract preferences from query using the SFT model.
+ Returns a dict compatible with PreferenceList model (key: 'preferences').
+ """
+ prompt = self.build_preference_prompt(query)
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
+
+ outputs = self.model.generate(
+ **inputs,
+ do_sample=False, # Deterministic greedy decoding
+ max_new_tokens=512, # Allow enough space for JSON
+ pad_token_id=self.tokenizer.pad_token_id,
+ eos_token_id=self.tokenizer.eos_token_id,
+ )
+
+ input_len = inputs["input_ids"].shape[1]
+ gen_ids = outputs[0][input_len:]
+ text = self.tokenizer.decode(gen_ids, skip_special_tokens=True)
+
+ if os.getenv("PREF_DEBUG") == "1":
+ print(f"[debug][extractor] Raw output: {text}")
+
+ # Try parsing JSON
+ try:
+ # 1. Direct parse
+ data = json.loads(text)
+
+ # 2. Validate against schema structure
+ validated = PreferenceList.model_validate(data)
+ return validated.model_dump()
+
+ except Exception:
+ # Fallback: Try to find JSON blob if model outputted extra text (rare for SFT but possible)
+ extracted_json = self._extract_json_substring(text)
+ if extracted_json:
+ try:
+ data = json.loads(extracted_json)
+ validated = PreferenceList.model_validate(data)
+ return validated.model_dump()
+ except:
+ pass
+
+ # If all fails, return empty
+ return {"preferences": []}
+
+ def _extract_json_substring(self, text: str) -> str | None:
+ """Helper to find { ... } block in text."""
+ # Find first '{' and last '}'
+ start = text.find('{')
+ end = text.rfind('}')
+ if start != -1 and end != -1 and end > start:
+ return text[start : end + 1]
+ return None
+
+ @torch.inference_mode()
+ def batch_extract_preferences(self, queries: List[str], batch_size: int = 64) -> List[Dict[str, Any]]:
+ """
+ Batch extract preferences from multiple queries using left-padded batching.
+ """
+ if not queries:
+ return []
+
+ # Save and set padding side for decoder-only batched generation
+ orig_padding_side = self.tokenizer.padding_side
+ self.tokenizer.padding_side = "left"
+
+ all_results = []
+ prompts = [self.build_preference_prompt(q) for q in queries]
+
+ for start in range(0, len(prompts), batch_size):
+ batch_prompts = prompts[start:start + batch_size]
+ inputs = self.tokenizer(
+ batch_prompts, return_tensors="pt", padding=True, truncation=True
+ ).to(self.model.device)
+
+ outputs = self.model.generate(
+ **inputs,
+ do_sample=False,
+ max_new_tokens=512,
+ pad_token_id=self.tokenizer.pad_token_id,
+ eos_token_id=self.tokenizer.eos_token_id,
+ )
+
+ for i in range(len(batch_prompts)):
+ input_len = (inputs["attention_mask"][i] == 1).sum().item()
+ gen_ids = outputs[i][input_len:]
+ text = self.tokenizer.decode(gen_ids, skip_special_tokens=True)
+
+ try:
+ data = json.loads(text)
+ validated = PreferenceList.model_validate(data)
+ all_results.append(validated.model_dump())
+ except Exception:
+ extracted_json = self._extract_json_substring(text)
+ if extracted_json:
+ try:
+ data = json.loads(extracted_json)
+ validated = PreferenceList.model_validate(data)
+ all_results.append(validated.model_dump())
+ continue
+ except Exception:
+ pass
+ all_results.append({"preferences": []})
+
+ self.tokenizer.padding_side = orig_padding_side
+ return all_results
+
+ def extract_turn(self, turns: List[ChatTurn]) -> PreferenceList:
+ """
+ Extract preferences from the LAST user turn in the history.
+ We don't concat history because our SFT model was trained on single-turn extraction.
+ Using context might confuse it unless we trained it that way.
+ """
+ # Find the last user message
+ last_user_msg = None
+ for t in reversed(turns):
+ if t.role == "user":
+ last_user_msg = t.text
+ break
+
+ if not last_user_msg:
+ return PreferenceList(preferences=[])
+
+ result_dict = self.extract_preferences(last_user_msg)
+ return PreferenceList.model_validate(result_dict)
+
+ def extract_session(self, turns: List[ChatTurn]) -> List[PreferenceList]:
+ """
+ Extract preferences from ALL user turns individually.
+ """
+ results = []
+ for turn in turns:
+ if turn.role == "user":
+ res = self.extract_preferences(turn.text)
+ results.append(PreferenceList.model_validate(res))
+ else:
+ results.append(PreferenceList(preferences=[]))
+ return results
diff --git a/src/personalization/models/reranker/__init__.py b/src/personalization/models/reranker/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/models/reranker/__init__.py
diff --git a/src/personalization/models/reranker/base.py b/src/personalization/models/reranker/base.py
new file mode 100644
index 0000000..34cf6ce
--- /dev/null
+++ b/src/personalization/models/reranker/base.py
@@ -0,0 +1,16 @@
+from typing import List, Protocol
+
+class Reranker(Protocol):
+ def score(
+ self,
+ query: str,
+ docs: List[str],
+ **kwargs,
+ ) -> List[float]:
+ """
+ Score multiple candidate documents for the same query.
+ Higher score indicates higher relevance.
+ Returns a list of floats with length equal to len(docs).
+ """
+ ...
+
diff --git a/src/personalization/models/reranker/qwen3_reranker.py b/src/personalization/models/reranker/qwen3_reranker.py
new file mode 100644
index 0000000..b648421
--- /dev/null
+++ b/src/personalization/models/reranker/qwen3_reranker.py
@@ -0,0 +1,96 @@
+from typing import List
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from .base import Reranker
+from personalization.config.settings import LocalModelsConfig
+from personalization.config.registry import choose_dtype, choose_device_map
+
+class Qwen3Reranker(Reranker):
+ def __init__(self, model_path: str, device_map: str = "auto", dtype: torch.dtype = torch.bfloat16):
+ # Ensure we pass trust_remote_code=True for Qwen models
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+
+ # Handle specific device assignment (e.g., "cuda:0", "cuda:1")
+ if device_map and device_map.startswith("cuda:"):
+ # Load to CPU first, then move to specific GPU
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ torch_dtype=dtype,
+ device_map=None,
+ trust_remote_code=True,
+ low_cpu_mem_usage=True,
+ )
+ self.model = self.model.to(device_map)
+ else:
+ # Use accelerate's auto device mapping
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ torch_dtype=dtype,
+ device_map=device_map,
+ trust_remote_code=True,
+ )
+
+ self.yes_token_id = self.tokenizer("yes", add_special_tokens=False).input_ids[0]
+
+ @classmethod
+ def from_config(cls, cfg: LocalModelsConfig) -> "Qwen3Reranker":
+ if not cfg.reranker or not cfg.reranker.qwen3_8b:
+ raise ValueError("Reranker config for qwen3_8b is missing")
+ spec = cfg.reranker.qwen3_8b
+ dtype = choose_dtype(spec.dtype)
+ device_map = choose_device_map(spec.device_map)
+ return cls(spec.local_path, device_map=device_map, dtype=dtype)
+
+ def _build_prompt(self, query: str, doc: str) -> str:
+ return (
+ "You are a reranker. "
+ "Given a user query and a memory note, answer 'yes' if the note is helpful "
+ "for answering the query, otherwise answer 'no'.\n\n"
+ f"Query: {query}\n"
+ f"Note: {doc}\n"
+ "Answer with a single token: yes or no."
+ )
+
+ @torch.inference_mode()
+ def score(self, query: str, docs: List[str], batch_size: int = 8, **kwargs) -> List[float]:
+ scores = []
+ for i in range(0, len(docs), batch_size):
+ batch_docs = docs[i : i + batch_size]
+ prompts = [self._build_prompt(query, d) for d in batch_docs]
+
+ inputs = self.tokenizer(
+ prompts,
+ return_tensors="pt",
+ padding=True,
+ truncation=True,
+ max_length=512
+ ).to(self.model.device)
+
+ outputs = self.model(**inputs)
+ # Take logits of the last token
+ # shape: [batch, seq_len, vocab_size]
+ logits = outputs.logits
+
+ # We want the logits for the token position immediately after the prompt ends.
+ # But since we generated inputs directly from tokenizer(prompts),
+ # we look at the last position of the input.
+ # For causal LM, we usually look at the logits of the last token
+ # to predict the *next* token (which we hope is 'yes' or 'no').
+
+ # Get logits for the next token prediction (last position)
+ # For each sequence in batch, select the last token's logits
+ # inputs['input_ids'] shape: [B, L]
+ # logits shape: [B, L, V]
+ # We want logits[:, -1, :]
+
+ last_token_logits = logits[:, -1, :]
+
+ # Calculate log prob of 'yes'
+ # We can use log_softmax over the vocab dimension
+ log_probs = torch.log_softmax(last_token_logits, dim=-1)
+ yes_log_probs = log_probs[:, self.yes_token_id]
+
+ scores.extend(yes_log_probs.float().cpu().numpy().tolist())
+
+ return scores
+
diff --git a/src/personalization/retrieval/__init__.py b/src/personalization/retrieval/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/retrieval/__init__.py
diff --git a/src/personalization/retrieval/chunking/__init__.py b/src/personalization/retrieval/chunking/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/retrieval/chunking/__init__.py
diff --git a/src/personalization/retrieval/chunking/rules.py b/src/personalization/retrieval/chunking/rules.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/retrieval/chunking/rules.py
diff --git a/src/personalization/retrieval/pipeline.py b/src/personalization/retrieval/pipeline.py
new file mode 100644
index 0000000..6cc7f3e
--- /dev/null
+++ b/src/personalization/retrieval/pipeline.py
@@ -0,0 +1,388 @@
+from typing import List, Tuple
+import numpy as np
+
+from personalization.models.embedding.base import EmbeddingModel
+from personalization.models.reranker.base import Reranker
+from personalization.retrieval.preference_store.schemas import MemoryCard
+from personalization.user_model.tensor_store import UserTensorStore, UserState
+from personalization.user_model.scoring import score_with_user
+from personalization.user_model.policy.reinforce import compute_policy_scores
+
+def cosine_similarity_matrix(E: np.ndarray, e_q: np.ndarray) -> np.ndarray:
+ # E: [M, d], e_q: [d]
+ return np.dot(E, e_q)
+
+
+def dynamic_topk_selection(
+ scores: np.ndarray,
+ min_k: int = 3,
+ max_k: int = 8,
+ score_ratio: float = 0.5,
+) -> List[int]:
+ """
+ Dynamically select top-k indices based on score distribution.
+
+ Strategy:
+ 1. Sort by score descending
+ 2. Compute threshold = top_score * score_ratio
+ 3. Select all indices with score > threshold
+ 4. Clamp to [min_k, max_k] range
+
+ Args:
+ scores: Array of scores (higher = better)
+ min_k: Minimum number of items to select
+ max_k: Maximum number of items to select
+ score_ratio: Threshold ratio relative to top score
+
+ Returns:
+ List of selected indices (in descending score order)
+ """
+ if len(scores) == 0:
+ return []
+
+ # Sort indices by score descending
+ sorted_indices = np.argsort(scores)[::-1]
+ sorted_scores = scores[sorted_indices]
+
+ # Compute threshold
+ top_score = sorted_scores[0]
+ threshold = top_score * score_ratio
+
+ # Find how many pass threshold
+ n_above_threshold = np.sum(sorted_scores > threshold)
+
+ # Clamp to [min_k, max_k]
+ n_select = max(min_k, min(max_k, n_above_threshold))
+ n_select = min(n_select, len(scores)) # Don't exceed available
+
+ return sorted_indices[:n_select].tolist()
+
+def dense_topk_indices(
+ query: str,
+ embed_model: EmbeddingModel,
+ memory_embeddings: np.ndarray,
+ valid_indices: List[int] = None,
+ topk: int = 64
+) -> List[int]:
+ """
+ Return indices of topk memories based on dense embedding similarity.
+ If valid_indices is provided, only search within that subset.
+ """
+ if valid_indices is not None and len(valid_indices) == 0:
+ return []
+
+ e_q_list = embed_model.encode([query], normalize=True, return_tensor=False)
+ e_q = np.array(e_q_list[0], dtype=np.float32)
+
+ # Select subset of embeddings if restricted
+ if valid_indices is not None:
+ # subset_embeddings = memory_embeddings[valid_indices]
+ # But valid_indices might be arbitrary.
+ # Efficient way: only dot product with subset
+ # E_sub: [M_sub, d]
+ E_sub = memory_embeddings[valid_indices]
+ sims_sub = np.dot(E_sub, e_q)
+
+ # Topk within subset
+ k = min(topk, len(sims_sub))
+ if k == 0:
+ return []
+
+ # argsort gives indices relative to E_sub (0..M_sub-1)
+ # We need to map back to original indices
+ idx_sub = np.argsort(sims_sub)[-k:][::-1]
+
+ return [valid_indices[i] for i in idx_sub]
+
+ # Global search
+ sims = np.dot(memory_embeddings, e_q)
+ k = min(topk, len(memory_embeddings))
+ if k == 0:
+ return []
+
+ idx = np.argsort(sims)[-k:][::-1]
+ return idx.tolist()
+
+def dense_topk_indices_multi_query(
+ queries: List[str],
+ embed_model: EmbeddingModel,
+ memory_embeddings: np.ndarray,
+ valid_indices: List[int] = None,
+ topk: int = 64
+) -> List[int]:
+ """
+ Multi-query dense retrieval: embed all queries, take max similarity per memory,
+ return top-k by max similarity (union effect).
+ """
+ if len(memory_embeddings) == 0:
+ return []
+
+ # Embed all queries at once
+ e_qs = embed_model.encode(queries, normalize=True, return_tensor=False)
+ e_qs = np.array(e_qs, dtype=np.float32) # [Q, d]
+
+ if valid_indices is not None:
+ if len(valid_indices) == 0:
+ return []
+ E_sub = memory_embeddings[valid_indices]
+ # sims: [Q, M_sub]
+ sims = np.dot(e_qs, E_sub.T)
+ # max across queries per memory
+ max_sims = sims.max(axis=0) # [M_sub]
+ k = min(topk, len(max_sims))
+ if k == 0:
+ return []
+ idx_sub = np.argsort(max_sims)[-k:][::-1]
+ return [valid_indices[i] for i in idx_sub]
+
+ # Global search
+ # sims: [Q, M]
+ sims = np.dot(e_qs, memory_embeddings.T)
+ max_sims = sims.max(axis=0) # [M]
+ k = min(topk, len(max_sims))
+ if k == 0:
+ return []
+ idx = np.argsort(max_sims)[-k:][::-1]
+ return idx.tolist()
+
+
+def retrieve_with_policy(
+ user_id: str,
+ query: str,
+ embed_model: EmbeddingModel,
+ reranker: Reranker,
+ memory_cards: List[MemoryCard],
+ memory_embeddings: np.ndarray, # shape: [M, d]
+ user_store: UserTensorStore,
+ item_vectors: np.ndarray, # shape: [M, k], v_m
+ topk_dense: int = 64,
+ topk_rerank: int = 8,
+ beta_long: float = 0.0,
+ beta_short: float = 0.0,
+ tau: float = 1.0,
+ only_own_memories: bool = False,
+ sample: bool = False,
+ queries: List[str] = None,
+) -> Tuple[List[MemoryCard], np.ndarray, np.ndarray, List[int], np.ndarray]:
+ """
+ Returns extended info for policy update:
+ (candidates, candidate_item_vectors, base_scores, chosen_indices, policy_probs)
+
+ Args:
+ sample: If True, use stochastic sampling from policy distribution (for training/exploration).
+ If False, use deterministic top-k by policy scores (for evaluation).
+ """
+ # 0. Filter indices if needed
+ valid_indices = None
+ if only_own_memories:
+ valid_indices = [i for i, card in enumerate(memory_cards) if card.user_id == user_id]
+ if not valid_indices:
+ return [], np.array([]), np.array([]), [], np.array([])
+
+ # 1. Dense retrieval (multi-query if available)
+ if queries and len(queries) > 1:
+ dense_idx = dense_topk_indices_multi_query(
+ queries,
+ embed_model,
+ memory_embeddings,
+ valid_indices=valid_indices,
+ topk=topk_dense
+ )
+ else:
+ dense_idx = dense_topk_indices(
+ query,
+ embed_model,
+ memory_embeddings,
+ valid_indices=valid_indices,
+ topk=topk_dense
+ )
+ # DEBUG: Check for duplicates or out of bounds
+ if len(dense_idx) > 0:
+ import os
+ if os.getenv("RETRIEVAL_DEBUG") == "1":
+ print(f" [Pipeline] Dense Indices (Top {len(dense_idx)}): {dense_idx[:10]}...")
+ print(f" [Pipeline] Max Index: {max(dense_idx)} | Memory Size: {len(memory_cards)}")
+
+ if not dense_idx:
+ return [], np.array([]), np.array([]), [], np.array([])
+
+ candidates = [memory_cards[i] for i in dense_idx]
+ candidate_docs = [c.note_text for c in candidates]
+
+ # 2. Rerank base score (P(yes|q,m)) - always use original query for reranking
+ # Skip reranking if we have fewer candidates than topk_rerank (saves GPU memory)
+ if len(candidates) <= topk_rerank:
+ base_scores = np.ones(len(candidates)) # Uniform scores
+ else:
+ base_scores = np.array(reranker.score(query, candidate_docs))
+
+ # 3. Policy Scoring (Softmax)
+ user_state: UserState = user_store.get_state(user_id)
+ candidate_vectors = item_vectors[dense_idx] # [K, k]
+
+ policy_out = compute_policy_scores(
+ base_scores=base_scores,
+ user_state=user_state,
+ item_vectors=candidate_vectors,
+ beta_long=beta_long,
+ beta_short=beta_short,
+ tau=tau
+ )
+
+ # 4. Selection: Greedy (eval) or Stochastic (training)
+ k = min(topk_rerank, len(policy_out.scores))
+
+ if sample:
+ # Stochastic sampling from policy distribution (for training/exploration)
+ # Sample k indices without replacement, weighted by policy probs
+ probs = policy_out.probs
+ # Normalize to ensure sum to 1 (handle numerical issues)
+ probs = probs / (probs.sum() + 1e-10)
+ # Sample without replacement
+ chosen_indices = np.random.choice(
+ len(probs), size=k, replace=False, p=probs
+ ).tolist()
+ else:
+ # Deterministic top-k by policy scores (for evaluation)
+ top_indices_local = policy_out.scores.argsort()[-k:][::-1]
+ chosen_indices = top_indices_local.tolist()
+
+ import os
+ if os.getenv("RETRIEVAL_DEBUG") == "1":
+ print(f" [Pipeline] Candidates: {len(candidates)} | Chosen Indices: {chosen_indices} | Sample: {sample}")
+
+ return candidates, candidate_vectors, base_scores, chosen_indices, policy_out.probs
+
+def retrieve_no_policy(
+ user_id: str,
+ query: str,
+ embed_model: EmbeddingModel,
+ reranker: Reranker,
+ memory_cards: List[MemoryCard],
+ memory_embeddings: np.ndarray, # shape: [M, d]
+ topk_dense: int = 64,
+ topk_rerank: int = 8,
+ only_own_memories: bool = False,
+ queries: List[str] = None,
+ dynamic_topk: bool = False,
+ dynamic_min_k: int = 3,
+ dynamic_max_k: int = 8,
+ dynamic_score_ratio: float = 0.5,
+) -> Tuple[List[MemoryCard], np.ndarray, np.ndarray, List[int], np.ndarray]:
+ """
+ Deterministic retrieval baseline (NoPersonal mode):
+ - Dense retrieval -> Rerank -> Top-K (no policy sampling, no user vector influence)
+
+ Args:
+ dynamic_topk: If True, use dynamic selection based on score distribution
+ dynamic_min_k: Minimum items to select (when dynamic_topk=True)
+ dynamic_max_k: Maximum items to select (when dynamic_topk=True)
+ dynamic_score_ratio: Threshold = top_score * ratio (when dynamic_topk=True)
+
+ Returns same structure as retrieve_with_policy for compatibility:
+ (candidates, candidate_item_vectors, base_scores, chosen_indices, rerank_scores_for_chosen)
+
+ Note: candidate_item_vectors is empty array (not used in NoPersonal mode)
+ The last return value is rerank scores instead of policy probs
+ """
+ # 0. Filter indices if needed
+ valid_indices = None
+ if only_own_memories:
+ valid_indices = [i for i, card in enumerate(memory_cards) if card.user_id == user_id]
+ if not valid_indices:
+ return [], np.array([]), np.array([]), [], np.array([])
+
+ # 1. Dense retrieval (multi-query if available)
+ if queries and len(queries) > 1:
+ dense_idx = dense_topk_indices_multi_query(
+ queries,
+ embed_model,
+ memory_embeddings,
+ valid_indices=valid_indices,
+ topk=topk_dense
+ )
+ else:
+ dense_idx = dense_topk_indices(
+ query,
+ embed_model,
+ memory_embeddings,
+ valid_indices=valid_indices,
+ topk=topk_dense
+ )
+
+ if not dense_idx:
+ return [], np.array([]), np.array([]), [], np.array([])
+
+ candidates = [memory_cards[i] for i in dense_idx]
+ candidate_docs = [c.note_text for c in candidates]
+
+ # 2. Rerank base score (P(yes|q,m)) - always use original query for reranking
+ max_k = dynamic_max_k if dynamic_topk else topk_rerank
+
+ # Skip reranking if we have fewer candidates than needed
+ if len(candidates) <= max_k:
+ # Just return all candidates without reranking
+ base_scores = np.ones(len(candidates)) # Uniform scores
+ chosen_indices = list(range(len(candidates)))
+ else:
+ base_scores = np.array(reranker.score(query, candidate_docs))
+
+ # 3. Selection: dynamic or fixed top-K
+ if dynamic_topk:
+ chosen_indices = dynamic_topk_selection(
+ base_scores,
+ min_k=dynamic_min_k,
+ max_k=dynamic_max_k,
+ score_ratio=dynamic_score_ratio,
+ )
+ else:
+ k = min(topk_rerank, len(base_scores))
+ top_indices_local = base_scores.argsort()[-k:][::-1]
+ chosen_indices = top_indices_local.tolist()
+
+ # Get scores for chosen items (for logging compatibility)
+ chosen_scores = base_scores[chosen_indices]
+
+ # Return empty item vectors (not used in NoPersonal mode)
+ # Return rerank scores as the "probs" field for logging compatibility
+ return candidates, np.array([]), base_scores, chosen_indices, chosen_scores
+
+
+def retrieve_with_rerank(
+ user_id: str,
+ query: str,
+ embed_model: EmbeddingModel,
+ reranker: Reranker,
+ memory_cards: List[MemoryCard],
+ memory_embeddings: np.ndarray, # shape: [M, d]
+ user_store: UserTensorStore,
+ item_vectors: np.ndarray, # shape: [M, k], v_m
+ topk_dense: int = 64,
+ topk_rerank: int = 8,
+ beta_long: float = 0.0,
+ beta_short: float = 0.0,
+ only_own_memories: bool = False,
+) -> List[MemoryCard]:
+ """
+ Wrapper around retrieve_with_policy for standard inference.
+ """
+ candidates, _, _, chosen_indices, _ = retrieve_with_policy(
+ user_id=user_id,
+ query=query,
+ embed_model=embed_model,
+ reranker=reranker,
+ memory_cards=memory_cards,
+ memory_embeddings=memory_embeddings,
+ user_store=user_store,
+ item_vectors=item_vectors,
+ topk_dense=topk_dense,
+ topk_rerank=topk_rerank,
+ beta_long=beta_long,
+ beta_short=beta_short,
+ tau=1.0, # Default tau
+ only_own_memories=only_own_memories
+ )
+
+ return [candidates[i] for i in chosen_indices]
+
+
diff --git a/src/personalization/retrieval/preference_store/__init__.py b/src/personalization/retrieval/preference_store/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/retrieval/preference_store/__init__.py
diff --git a/src/personalization/retrieval/preference_store/base.py b/src/personalization/retrieval/preference_store/base.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/retrieval/preference_store/base.py
diff --git a/src/personalization/retrieval/preference_store/schemas.py b/src/personalization/retrieval/preference_store/schemas.py
new file mode 100644
index 0000000..5245025
--- /dev/null
+++ b/src/personalization/retrieval/preference_store/schemas.py
@@ -0,0 +1,48 @@
+from __future__ import annotations
+
+from typing import List, Literal, Optional, Dict, Any
+
+from pydantic import BaseModel, Field, confloat
+
+
+class Preference(BaseModel):
+ condition: str = Field(
+ ..., min_length=1, max_length=128, description="When the rule applies"
+ )
+ action: str = Field(
+ ..., min_length=1, max_length=256, description="What to do in that case"
+ )
+ confidence: confloat(ge=0.0, le=1.0) = Field(
+ ..., description="Confidence the rule is correct"
+ )
+
+
+class PreferenceList(BaseModel):
+ preferences: List[Preference] = Field(default_factory=list)
+
+
+def preference_list_json_schema() -> dict:
+ return PreferenceList.model_json_schema()
+
+
+class ChatTurn(BaseModel):
+ user_id: str
+ session_id: str
+ turn_id: int
+ role: Literal["user", "assistant"]
+ text: str
+ timestamp: Optional[float] = None
+ meta: Dict[str, Any] = Field(default_factory=dict)
+
+
+class MemoryCard(BaseModel):
+ card_id: str
+ user_id: str
+ source_session_id: str
+ source_turn_ids: List[int]
+ raw_queries: List[str] # The original user utterances
+ preference_list: PreferenceList
+ note_text: str # Summarized "condition: action" text
+ embedding_e: List[float] # The embedding vector
+ kind: Literal["pref", "fact"] = "pref"
+ is_global: bool = False # True = always include in prompt, bypass retrieval
diff --git a/src/personalization/retrieval/preference_store/vector_kv.py b/src/personalization/retrieval/preference_store/vector_kv.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/retrieval/preference_store/vector_kv.py
diff --git a/src/personalization/retrieval/rerank.py b/src/personalization/retrieval/rerank.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/retrieval/rerank.py
diff --git a/src/personalization/retrieval/store/__init__.py b/src/personalization/retrieval/store/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/retrieval/store/__init__.py
diff --git a/src/personalization/retrieval/store/base.py b/src/personalization/retrieval/store/base.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/retrieval/store/base.py
diff --git a/src/personalization/retrieval/store/faiss_store.py b/src/personalization/retrieval/store/faiss_store.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/retrieval/store/faiss_store.py
diff --git a/src/personalization/serving/__init__.py b/src/personalization/serving/__init__.py
new file mode 100644
index 0000000..11adcf8
--- /dev/null
+++ b/src/personalization/serving/__init__.py
@@ -0,0 +1,22 @@
+# Personalization Serving Module
+#
+# This module provides the interface layer for the personalization system.
+
+from personalization.serving.personalized_llm import (
+ PersonalizedLLM,
+ AssistantResponse,
+ UsageStats,
+ DebugInfo,
+ Feedback,
+ create_personalized_llm,
+)
+
+__all__ = [
+ "PersonalizedLLM",
+ "AssistantResponse",
+ "UsageStats",
+ "DebugInfo",
+ "Feedback",
+ "create_personalized_llm",
+]
+
diff --git a/src/personalization/serving/personalized_llm.py b/src/personalization/serving/personalized_llm.py
new file mode 100644
index 0000000..8032e6b
--- /dev/null
+++ b/src/personalization/serving/personalized_llm.py
@@ -0,0 +1,1835 @@
+#!/usr/bin/env python3
+"""
+Personalized LLM Interface for Evaluation.
+
+This module provides the `PersonalizedLLM` class that wraps the entire
+personalization system into a clean interface for evaluation frameworks
+and user simulators.
+
+Interface contract:
+- chat(user_id, query) -> AssistantResponse: Main online interface
+- reset_session(user_id): Clear session history and short-term state
+- reset_user(user_id): Completely reset user (long-term, short-term, memories)
+- apply_feedback(feedback): Apply external feedback for RL updates
+"""
+
+from __future__ import annotations
+
+import os
+import sys
+import uuid
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional
+
+import numpy as np
+import yaml
+
+# Ensure src is in path for standalone usage
+_src_path = os.path.join(os.path.dirname(__file__), "../../..")
+if _src_path not in sys.path:
+ sys.path.insert(0, _src_path)
+
+from personalization.config.settings import load_local_models_config
+from personalization.config.registry import get_preference_extractor, get_chat_model
+from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B
+from personalization.models.reranker.qwen3_reranker import Qwen3Reranker
+from personalization.models.reranker.bge_reranker import BGEReranker
+from personalization.user_model.tensor_store import UserTensorStore, UserState
+from personalization.user_model.session_state import OnlineSessionState
+from personalization.user_model.features import ItemProjection
+from personalization.retrieval.preference_store.schemas import (
+ MemoryCard, ChatTurn, PreferenceList, Preference
+)
+from personalization.retrieval.pipeline import retrieve_with_policy, retrieve_no_policy
+from personalization.feedback.handlers import eval_step, eval_step_llm
+from personalization.feedback.llm_reward import LLMRewardClient, LLMRewardConfig
+from personalization.user_model.policy.reinforce import reinforce_update_user_state
+
+
+# =============================================================================
+# Data Classes for Interface
+# =============================================================================
+
+@dataclass
+class UsageStats:
+ """Token usage statistics from a chat completion."""
+ prompt_tokens: int
+ completion_tokens: int
+ total_tokens: int
+ model: str
+
+
+@dataclass
+class DebugInfo:
+ """
+ Debug information for analysis and ablation studies.
+ All fields are optional - fill what you have, leave empty what you don't.
+ """
+ selected_memory_ids: List[str] = field(default_factory=list)
+ selected_memory_notes: List[str] = field(default_factory=list)
+ selected_memory_scores: List[float] = field(default_factory=list)
+ user_vector_before: Optional[List[float]] = None
+ user_vector_after: Optional[List[float]] = None
+ extracted_preferences: List[Dict[str, Any]] = field(default_factory=list)
+ extra: Dict[str, Any] = field(default_factory=dict)
+
+
+@dataclass
+class AssistantResponse:
+ """Response from the personalized LLM chat interface."""
+ answer: str
+ usage: UsageStats
+ debug: Optional[DebugInfo] = None
+
+
+@dataclass
+class Feedback:
+ """
+ Feedback data structure for RL updates from user simulator or judge.
+
+ Attributes:
+ user_id: The user this feedback is for.
+ turn_id: The turn this feedback refers to (from the previous turn).
+ reward: Reward scalar computed by user simulator / judge.
+ gating: Gating flag (1=valid learning signal, 0=skip update).
+ meta: Additional metadata for training/analysis.
+ """
+ user_id: str
+ turn_id: int
+ reward: float
+ gating: float # Can be 0.0 or 1.0, or continuous
+ meta: Dict[str, Any] = field(default_factory=dict)
+
+
+# =============================================================================
+# Internal Session State Extended
+# =============================================================================
+
+@dataclass
+class _SessionContext:
+ """Extended session context for evaluation tracking."""
+ session_state: OnlineSessionState
+ turn_counter: int = 0
+ # Store info needed for apply_feedback
+ pending_rl_update: Optional[Dict[str, Any]] = None
+
+
+# =============================================================================
+# Shared Model Singletons for Multi-threaded Efficiency
+# =============================================================================
+
+_shared_embed_model = None
+_shared_reranker = None
+_shared_extractor = None
+_shared_models_lock = None # Will be initialized on first use
+
+
+def _get_shared_models_lock():
+ """Get or create the threading lock for shared models."""
+ global _shared_models_lock
+ if _shared_models_lock is None:
+ import threading
+ _shared_models_lock = threading.Lock()
+ return _shared_models_lock
+
+
+def get_shared_embedding_model(model_path: str, device_map: str = "auto"):
+ """Get or create shared embedding model (thread-safe singleton)."""
+ global _shared_embed_model
+ import torch
+
+ lock = _get_shared_models_lock()
+ with lock:
+ if _shared_embed_model is None:
+ print(f"[SharedModels] Loading shared embedding model on {device_map}...")
+ _shared_embed_model = Qwen3Embedding8B(
+ model_path=model_path,
+ dtype=torch.bfloat16,
+ device_map=device_map,
+ )
+ print("[SharedModels] Shared embedding model loaded.")
+ return _shared_embed_model
+
+
+def get_shared_reranker(model_path: str, device_map: str = "auto", reranker_type: str = "qwen3"):
+ """Get or create shared reranker model (thread-safe singleton)."""
+ global _shared_reranker
+ import torch
+
+ lock = _get_shared_models_lock()
+ with lock:
+ if _shared_reranker is None:
+ print(f"[SharedModels] Loading shared reranker ({reranker_type}) on {device_map}...")
+ if reranker_type == "bge":
+ _shared_reranker = BGEReranker(
+ model_path=model_path,
+ device_map=device_map,
+ dtype=torch.float16,
+ )
+ else:
+ _shared_reranker = Qwen3Reranker(
+ model_path=model_path,
+ device_map=device_map,
+ dtype=torch.bfloat16,
+ )
+ print("[SharedModels] Shared reranker model loaded.")
+ return _shared_reranker
+
+
+def get_shared_extractor(model_path: str, device_map: str = "auto"):
+ """Get or create shared preference extractor model (thread-safe singleton)."""
+ global _shared_extractor
+ import torch
+ from personalization.models.preference_extractor.rule_extractor import QwenRuleExtractor
+
+ lock = _get_shared_models_lock()
+ with lock:
+ if _shared_extractor is None:
+ print(f"[SharedModels] Loading shared preference extractor on {device_map}...")
+ _shared_extractor = QwenRuleExtractor(
+ model_path=model_path,
+ dtype=torch.bfloat16,
+ device_map=device_map,
+ )
+ print("[SharedModels] Shared preference extractor loaded.")
+ return _shared_extractor
+
+
+def clear_shared_models():
+ """Free all shared singleton models to reclaim GPU memory between methods."""
+ global _shared_embed_model, _shared_reranker, _shared_extractor
+ import gc
+
+ lock = _get_shared_models_lock()
+ with lock:
+ freed = []
+ if _shared_embed_model is not None:
+ freed.append("embedding")
+ del _shared_embed_model
+ _shared_embed_model = None
+ if _shared_reranker is not None:
+ freed.append("reranker")
+ del _shared_reranker
+ _shared_reranker = None
+ if _shared_extractor is not None:
+ freed.append("extractor")
+ del _shared_extractor
+ _shared_extractor = None
+
+ if freed:
+ gc.collect()
+ try:
+ import torch
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ except ImportError:
+ pass
+ print(f"[SharedModels] Cleared: {', '.join(freed)}")
+
+
+# =============================================================================
+# PersonalizedLLM Class
+# =============================================================================
+
+class PersonalizedLLM:
+ """
+ Personalized LLM wrapper for evaluation frameworks.
+
+ This class provides a clean interface that accepts only (user_id, query)
+ for the main chat function, while internally managing:
+ - User state vectors (z_long, z_short)
+ - Session history
+ - Memory retrieval and policy
+ - Preference extraction and storage
+ - RL updates
+
+ Example usage:
+ llm = PersonalizedLLM()
+
+ # Reset user for fresh experiment
+ llm.reset_user("user_123")
+
+ # Start a session
+ llm.reset_session("user_123")
+
+ # Chat
+ response = llm.chat("user_123", "What's a good recipe for dinner?")
+ print(response.answer)
+
+ # Apply feedback from previous turn (from turn 2 onwards)
+ llm.apply_feedback(Feedback(
+ user_id="user_123",
+ turn_id=0,
+ reward=0.8,
+ gating=1.0
+ ))
+ """
+
+ def __init__(
+ self,
+ config_path: Optional[str] = None,
+ user_store_path: str = "data/users/user_store_eval.npz",
+ memory_cards_path: str = "data/corpora/memory_cards.jsonl",
+ memory_embeddings_path: str = "data/corpora/memory_embeddings.npy",
+ item_projection_path: str = "data/corpora/item_projection.npz",
+ only_own_memories: bool = True,
+ enable_preference_extraction: bool = True,
+ enable_rl_updates: bool = True,
+ mode: str = "full", # "full", "nopersonal", or "vanilla"
+ eval_mode: bool = True, # True = greedy selection, False = stochastic sampling
+ device_assignment: Optional[Dict[str, str]] = None, # Multi-GPU support
+ llm_name: Optional[str] = None, # Override LLM name (e.g., "llama_8b_vllm" for vLLM)
+ use_shared_models: bool = False, # Use shared singleton models for multi-threaded efficiency
+ reranker_type: str = "qwen3", # "qwen3" (8B) or "bge" (278M)
+ best_of_n: int = 1, # Generate N responses and pick best (for RAG methods)
+ reward_mode: str = "keyword", # "keyword", "llm" (GPT-4o-mini), or "llm_local" (local vLLM)
+ llm_reward_config: Optional["LLMRewardConfig"] = None, # Config for LLM judge
+ reward_vllm_url: Optional[str] = None, # vLLM URL for local reward model (when reward_mode="llm_local")
+ enable_query_transform: bool = False, # Transform queries for better retrieval matching
+ enable_global_preferences: bool = False, # Separate global prefs that bypass retrieval
+ dynamic_topk: bool = False, # Use dynamic topk based on rerank scores
+ dynamic_min_k: int = 3, # Min preferences for dynamic topk
+ dynamic_max_k: int = 8, # Max preferences for dynamic topk
+ dynamic_score_ratio: float = 0.5, # Threshold = top_score * ratio
+ eta_long: float = None, # Override RL learning rate for z_long
+ eta_short: float = None, # Override RL learning rate for z_short
+ enable_preference_consolidation: bool = False, # Consolidate preferences at session end
+ consolidation_threshold: int = 5, # Min preferences before consolidation
+ enable_preference_rewrite: bool = False, # Use LLM to rewrite/merge retrieved preferences
+ ):
+ """
+ Initialize the PersonalizedLLM.
+
+ Args:
+ config_path: Path to config file. If None, uses default locations.
+ user_store_path: Path to persist user state vectors.
+ memory_cards_path: Path to memory cards JSONL file.
+ memory_embeddings_path: Path to memory embeddings numpy file.
+ item_projection_path: Path to item projection (PCA) file.
+ only_own_memories: If True, only retrieve user's own memories (strict privacy).
+ enable_preference_extraction: If True, extract preferences from user turns.
+ enable_rl_updates: If True, apply RL updates via apply_feedback.
+ mode: "full" for full personalization, "nopersonal" for baseline (no user vector influence),
+ "vanilla" for pure LLM without any memory retrieval or preference extraction.
+ eval_mode: If True, use greedy/deterministic selection (for evaluation).
+ If False, use stochastic sampling (for training/exploration).
+ device_assignment: Optional dict to assign models to specific GPUs.
+ Example: {"embed": "cuda:0", "reranker": "cuda:1", "chat": "cuda:2", "extractor": "cuda:3"}
+ If None, uses "auto" for all models.
+ use_shared_models: If True, use shared singleton models for embedding and reranker.
+ This is essential for multi-threaded/parallel profile processing to avoid
+ loading duplicate models. When enabled, the first thread loads the models,
+ and subsequent threads reuse the shared instances.
+ """
+ self.only_own_memories = only_own_memories
+ self.use_shared_models = use_shared_models
+ self.enable_preference_extraction = enable_preference_extraction
+ self.enable_rl_updates = enable_rl_updates
+ self.mode = mode # "full" or "nopersonal"
+ self.eval_mode = eval_mode # True = greedy, False = sample
+ self.reranker_type = reranker_type # "qwen3" or "bge"
+ self.best_of_n = best_of_n # Generate N responses and pick best
+ self.reward_mode = reward_mode # "keyword", "llm", or "llm_local"
+ self.enable_query_transform = enable_query_transform
+ self.enable_global_preferences = enable_global_preferences
+ self.enable_preference_consolidation = enable_preference_consolidation
+ self.consolidation_threshold = consolidation_threshold
+ self.enable_preference_rewrite = enable_preference_rewrite
+
+ # Initialize LLM reward client if using LLM judge
+ self._llm_reward_client = None # Can be LLMRewardClient or LocalLLMRewardClient
+ if reward_mode == "llm":
+ self._llm_reward_client = LLMRewardClient(llm_reward_config or LLMRewardConfig())
+ elif reward_mode == "llm_local":
+ from personalization.feedback.local_llm_reward import (
+ LocalLLMRewardClient,
+ LocalLLMRewardConfig,
+ )
+ local_config = LocalLLMRewardConfig(
+ vllm_url=reward_vllm_url or "http://localhost:8005/v1",
+ )
+ self._llm_reward_client = LocalLLMRewardClient(local_config)
+
+ # Multi-GPU device assignment
+ self._device_assignment = device_assignment or {
+ "embed": "auto",
+ "reranker": "auto",
+ "chat": "auto",
+ "extractor": "auto",
+ }
+
+ # Paths
+ self._memory_cards_path = memory_cards_path
+ self._memory_embeddings_path = memory_embeddings_path
+ self._item_projection_path = item_projection_path
+
+ # RL Configuration
+ # Note: beta/eta increased for more significant z_u updates
+ self._rl_cfg = {
+ "item_dim": 256,
+ "beta_long": 2.0, # Increased from 0.1 for stronger personalization
+ "beta_short": 5.0, # Increased from 0.3
+ "tau": 1.0,
+ "eta_long": eta_long if eta_long is not None else 0.01,
+ "eta_short": eta_short if eta_short is not None else 0.05,
+ "ema_alpha": 0.05,
+ "short_decay": 0.1,
+ "dense_topk": 64,
+ "rerank_topk": 5,
+ "max_new_tokens": 512,
+ # Dynamic topk settings
+ "dynamic_topk": dynamic_topk,
+ "dynamic_min_k": dynamic_min_k,
+ "dynamic_max_k": dynamic_max_k,
+ "dynamic_score_ratio": dynamic_score_ratio,
+ }
+
+ # Store llm_name before loading config (needed in _load_config)
+ self._llm_name_override = llm_name
+
+ # Load config and override RL params if available
+ self._load_config(config_path)
+
+ # Load models
+ print("[PersonalizedLLM] Loading models...")
+ self._load_models()
+
+ # Load memory store
+ print("[PersonalizedLLM] Loading memory store...")
+ self._load_memory_store()
+
+ # Initialize user store
+ self._user_store = UserTensorStore(
+ k=self._rl_cfg["item_dim"],
+ path=user_store_path,
+ )
+
+ # Session contexts per user (in-memory)
+ self._sessions: Dict[str, _SessionContext] = {}
+
+ print("[PersonalizedLLM] Initialization complete.")
+
+ def _load_config(self, config_path: Optional[str]):
+ """Load configuration from yaml files."""
+ self._cfg = load_local_models_config()
+
+ # Try to load user_model.yaml for RL params
+ if config_path is None:
+ config_path = "configs/user_model.yaml"
+
+ self._llm_name = self._llm_name_override or "qwen_1_5b" # Default, can be overridden
+
+ try:
+ if os.path.exists(config_path):
+ with open(config_path, "r") as f:
+ user_cfg = yaml.safe_load(f)
+ if user_cfg:
+ # Override RL params if present
+ for key in self._rl_cfg:
+ if key in user_cfg:
+ self._rl_cfg[key] = user_cfg[key]
+ # LLM name (only from config if not already set via parameter)
+ if self._llm_name_override is None and "llm_name" in user_cfg:
+ self._llm_name = user_cfg["llm_name"]
+ except Exception as e:
+ print(f"[PersonalizedLLM] Warning: Failed to load config: {e}")
+
+ def _load_models(self):
+ """Load all ML models with optional multi-GPU assignment."""
+ import torch
+
+ # Report GPU availability (only once, not for shared model instances)
+ if not self.use_shared_models:
+ num_gpus = torch.cuda.device_count()
+ print(f"[PersonalizedLLM] Available GPUs: {num_gpus}")
+ for i in range(num_gpus):
+ mem = torch.cuda.get_device_properties(i).total_memory / 1e9
+ print(f" GPU {i}: {torch.cuda.get_device_name(i)} ({mem:.1f}GB)")
+
+ embed_device = self._device_assignment.get("embed", "auto")
+ reranker_device = self._device_assignment.get("reranker", "auto")
+ chat_device = self._device_assignment.get("chat", "auto")
+ extractor_device = self._device_assignment.get("extractor", "auto")
+
+ # Embedding model - only load for modes that use RAG retrieval
+ # Vanilla and contextual modes don't need embedding/reranker
+ needs_retrieval = self.mode not in ("vanilla", "contextual")
+
+ if needs_retrieval:
+ if self.use_shared_models:
+ print(f"[PersonalizedLLM] Using shared embedding model...")
+ self._embed_model = get_shared_embedding_model(
+ model_path=self._cfg.embedding.qwen3.local_path,
+ device_map=embed_device,
+ )
+ else:
+ print(f"[PersonalizedLLM] Loading Embedding model on {embed_device}...")
+ self._embed_model = Qwen3Embedding8B(
+ model_path=self._cfg.embedding.qwen3.local_path,
+ dtype=torch.bfloat16,
+ device_map=embed_device,
+ )
+ else:
+ print(f"[PersonalizedLLM] Skipping embedding model (not needed for {self.mode} mode)")
+ self._embed_model = None
+
+ # Reranker - only load for modes that use RAG retrieval
+ # Support both qwen3 (8B) and bge (278M) rerankers
+ if needs_retrieval:
+ if self.reranker_type == "bge":
+ reranker_path = getattr(self._cfg.reranker, "bge_base", None)
+ reranker_path = reranker_path.local_path if reranker_path else "BAAI/bge-reranker-base"
+ else:
+ reranker_path = self._cfg.reranker.qwen3_8b.local_path
+
+ if self.use_shared_models:
+ print(f"[PersonalizedLLM] Using shared reranker model ({self.reranker_type})...")
+ self._reranker = get_shared_reranker(
+ model_path=reranker_path,
+ device_map=reranker_device,
+ reranker_type=self.reranker_type,
+ )
+ else:
+ print(f"[PersonalizedLLM] Loading Reranker ({self.reranker_type}) on {reranker_device}...")
+ if self.reranker_type == "bge":
+ self._reranker = BGEReranker(
+ model_path=reranker_path,
+ device_map=reranker_device,
+ dtype=torch.float16,
+ )
+ else:
+ self._reranker = Qwen3Reranker(
+ model_path=reranker_path,
+ device_map=reranker_device,
+ dtype=torch.bfloat16,
+ )
+ else:
+ print(f"[PersonalizedLLM] Skipping reranker (not needed for {self.mode} mode)")
+ self._reranker = None
+
+ # Chat model (via registry for backend switching)
+ print(f"[PersonalizedLLM] Loading ChatModel: {self._llm_name} on {chat_device}...")
+ # Pass device override if specified (not "auto")
+ device_for_chat = chat_device if chat_device != "auto" else None
+ self._chat_model = get_chat_model(self._llm_name, device_override=device_for_chat)
+
+ # Preference extractor - use shared singleton if enabled
+ if self.enable_preference_extraction:
+ extractor_name = "qwen3_0_6b_sft"
+ if self.use_shared_models:
+ print(f"[PersonalizedLLM] Using shared preference extractor...")
+ try:
+ extractor_path = self._cfg.preference_extractor.get("qwen3_0_6b_sft", {}).get("path", None)
+ if extractor_path:
+ self._extractor = get_shared_extractor(
+ model_path=extractor_path,
+ device_map=extractor_device,
+ )
+ else:
+ print(f"[PersonalizedLLM] Extractor path not found, using rule-based.")
+ self._extractor = get_preference_extractor("rule")
+ except Exception as e:
+ print(f"[PersonalizedLLM] Warning: Failed to load shared extractor: {e}. Trying fallbacks...")
+ try:
+ self._extractor = get_preference_extractor("rule")
+ except Exception as e2:
+ print(f"[PersonalizedLLM] Rule extractor also failed: {e2}. Using GPT-5-mini extractor.")
+ self._extractor = get_preference_extractor("gpt5_mini")
+ else:
+ print(f"[PersonalizedLLM] Loading extractor: {extractor_name} on {extractor_device}...")
+ try:
+ self._extractor = get_preference_extractor(extractor_name)
+ except Exception as e:
+ print(f"[PersonalizedLLM] Warning: Failed to load {extractor_name}: {e}. Trying fallbacks...")
+ try:
+ self._extractor = get_preference_extractor("rule")
+ except Exception as e2:
+ print(f"[PersonalizedLLM] Rule extractor also failed: {e2}. Using GPT-5-mini extractor.")
+ self._extractor = get_preference_extractor("gpt5_mini")
+ else:
+ print("[PersonalizedLLM] Preference extraction disabled, skipping extractor.")
+ self._extractor = None
+
+ def _load_memory_store(self):
+ """Load memory cards and embeddings."""
+ if not os.path.exists(self._memory_cards_path):
+ print(f"[PersonalizedLLM] Warning: Memory cards not found at {self._memory_cards_path}")
+ self._memory_cards: List[MemoryCard] = []
+ self._memory_embeddings = np.zeros((0, 4096), dtype=np.float32)
+ self._item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32)
+ # Create default projection (truncation to first k dims) so preferences can be added
+ k = self._rl_cfg["item_dim"]
+ d = 4096
+ P = np.zeros((k, d), dtype=np.float32)
+ P[:, :k] = np.eye(k, dtype=np.float32)
+ self._projection = ItemProjection(P=P, mean=np.zeros(d, dtype=np.float32))
+ print(f"[PersonalizedLLM] Created default projection (truncation, k={k})")
+ return
+
+ # Load cards
+ self._memory_cards = []
+ with open(self._memory_cards_path, "r") as f:
+ for line in f:
+ line = line.strip()
+ if line:
+ self._memory_cards.append(MemoryCard.model_validate_json(line))
+
+ # Load embeddings
+ if os.path.exists(self._memory_embeddings_path):
+ self._memory_embeddings = np.load(self._memory_embeddings_path)
+ else:
+ self._memory_embeddings = np.zeros((len(self._memory_cards), 4096), dtype=np.float32)
+
+ # Load projection
+ if os.path.exists(self._item_projection_path):
+ proj_data = np.load(self._item_projection_path)
+ self._projection = ItemProjection(P=proj_data["P"], mean=proj_data["mean"])
+ self._item_vectors = proj_data["V"]
+ else:
+ # Create default projection so preferences can still be added
+ k = self._rl_cfg["item_dim"]
+ d = 4096
+ P = np.zeros((k, d), dtype=np.float32)
+ P[:, :k] = np.eye(k, dtype=np.float32)
+ self._projection = ItemProjection(P=P, mean=np.zeros(d, dtype=np.float32))
+ self._item_vectors = np.zeros((len(self._memory_cards), self._rl_cfg["item_dim"]), dtype=np.float32)
+ print(f"[PersonalizedLLM] Created default projection (truncation, k={k})")
+
+ print(f"[PersonalizedLLM] Loaded {len(self._memory_cards)} memory cards.")
+
+ def _get_or_create_session(self, user_id: str) -> _SessionContext:
+ """Get or create session context for a user."""
+ if user_id not in self._sessions:
+ self._sessions[user_id] = _SessionContext(
+ session_state=OnlineSessionState(user_id=user_id),
+ turn_counter=0,
+ )
+ return self._sessions[user_id]
+
+ def _build_chat_turn(self, user_id: str, text: str, role: str, turn_id: int) -> ChatTurn:
+ """Build a ChatTurn object."""
+ return ChatTurn(
+ user_id=user_id,
+ session_id=f"eval_session_{user_id}",
+ turn_id=turn_id,
+ role=role,
+ text=text,
+ meta={"source": "eval"}
+ )
+
+ def _count_tokens(self, text: str) -> int:
+ """Estimate token count using the tokenizer."""
+ try:
+ # Use the chat model's tokenizer if available
+ if hasattr(self._chat_model, 'tokenizer'):
+ return len(self._chat_model.tokenizer.encode(text))
+ else:
+ # Rough estimate: ~4 chars per token
+ return len(text) // 4
+ except Exception:
+ return len(text) // 4
+
+ # Task type keywords for query transformation
+ _TASK_KEYWORDS = {
+ "math": ["solve", "calculate", "integral", "equation", "proof", "derivative",
+ "math", "algebra", "geometry", "trigonometry", "calculus", "arithmetic",
+ "formula", "compute", "evaluate", "simplify", "factor", "graph"],
+ "coding": ["code", "program", "function", "implement", "debug", "python", "java",
+ "javascript", "algorithm", "class", "method", "bug", "error", "compile",
+ "script", "html", "css", "sql", "api", "library", "framework"],
+ "writing": ["write", "essay", "paragraph", "summarize", "draft", "compose",
+ "article", "story", "letter", "email", "report", "review", "edit",
+ "rewrite", "paraphrase", "outline"],
+ "explanation": ["explain", "what is", "how does", "why", "describe", "define",
+ "meaning", "concept", "difference between", "compare", "contrast"],
+ }
+
+ def _transform_query_for_retrieval(self, query: str) -> List[str]:
+ """
+ Transform raw user query into multiple retrieval queries to bridge
+ the semantic gap between task queries and preference descriptions.
+
+ Returns [original_query, transformed_query] or [original_query] if
+ no task type detected.
+ """
+ import re
+ query_lower = query.lower()
+ detected_types = []
+ for task_type, keywords in self._TASK_KEYWORDS.items():
+ for kw in keywords:
+ # Use word boundary matching to avoid false positives
+ # e.g., "api" should not match "capital"
+ if re.search(r'\b' + re.escape(kw) + r'\b', query_lower):
+ detected_types.append(task_type)
+ break
+
+ if not detected_types:
+ return [query]
+
+ # Use first detected type (most specific match)
+ task_type = detected_types[0]
+ transformed = f"user preferences for {task_type} tasks: {query}"
+ return [query, transformed]
+
+ # Patterns indicating a global/universal preference condition
+ _GLOBAL_PATTERNS = ["general", "any", "always", "all ", "every", "regardless",
+ "any task", "any topic", "any question", "all tasks", "all topics"]
+
+ # Domain-specific terms that indicate a conditional preference
+ _DOMAIN_TERMS = ["math", "code", "coding", "program", "writing", "essay", "science",
+ "history", "language", "physics", "chemistry", "biology", "literature",
+ "creative", "technical", "formal", "informal", "academic", "casual"]
+
+ def _classify_preference_scope(self, condition: str) -> bool:
+ """
+ Classify whether a preference condition is global (always applicable)
+ or conditional (task-specific).
+
+ Returns True if global, False if conditional.
+ """
+ cond_lower = condition.lower().strip()
+
+ # Check for explicit global patterns
+ for pattern in self._GLOBAL_PATTERNS:
+ if pattern in cond_lower:
+ return True
+
+ # Very short/vague conditions with no domain terms are likely global
+ words = cond_lower.split()
+ if len(words) <= 2:
+ has_domain = any(term in cond_lower for term in self._DOMAIN_TERMS)
+ if not has_domain:
+ return True
+
+ return False
+
+ # Rewrite prompt for merging retrieved preferences
+ _REWRITE_PROMPT = """You are helping to prepare user preferences for an AI assistant.
+
+The user is asking: {query}
+
+Retrieved preferences about this user:
+{preferences}
+
+Task: Create a concise preference summary that the assistant MUST follow.
+
+Rules:
+1. PRESERVE all specific formatting requirements exactly (e.g., "type hints", "snake_case", "code fence with language")
+2. PRESERVE all structural requirements (e.g., "numbered steps", "bullet points", "answer first then explanation")
+3. Only MERGE preferences that are truly redundant (saying the same thing differently)
+4. Output as a short bulleted list if there are multiple distinct requirements
+5. Keep each point actionable and specific - NO vague generalizations like "follow best practices"
+
+Example input:
+- Include type hints in Python code
+- Use snake_case for variable names
+- When explaining, use numbered steps
+
+Example output:
+- Include type hints
+- Use snake_case for variables
+- Use numbered steps for explanations
+
+If no preferences are relevant to this query type, output: "No specific preferences apply."
+
+Preference summary:"""
+
+ def _rewrite_preferences(self, memory_notes: List[str], query: str) -> List[str]:
+ """
+ Use LLM to rewrite/merge multiple retrieved preferences into concise instructions.
+
+ This is similar to Reflection's proper_scaffolding but focuses on merging
+ rather than just filtering.
+
+ Args:
+ memory_notes: List of retrieved preference notes
+ query: Current user query
+
+ Returns:
+ List with single rewritten instruction (or original if rewrite fails/disabled)
+ """
+ if not memory_notes or len(memory_notes) <= 1:
+ return memory_notes
+
+ try:
+ import requests
+
+ # Format preferences for prompt
+ prefs_text = "\n".join(f"- {note}" for note in memory_notes)
+ prompt = self._REWRITE_PROMPT.format(query=query[:200], preferences=prefs_text)
+
+ # Direct vLLM API call (simpler than going through chat model)
+ messages = [{"role": "user", "content": prompt}]
+ payload = {
+ "model": self._chat_model.model_name,
+ "messages": messages,
+ "max_tokens": 150,
+ "temperature": 0.3, # Lower temperature for more consistent output
+ }
+
+ response = requests.post(
+ f"{self._chat_model.vllm_url}/chat/completions",
+ json=payload,
+ timeout=30
+ )
+
+ if response.status_code != 200:
+ print(f"[REWRITE] API error {response.status_code}, keeping original notes")
+ return memory_notes
+
+ result = response.json()
+ rewritten = result["choices"][0]["message"]["content"].strip().strip('"')
+
+ # Validate response
+ if rewritten and len(rewritten) > 10 and "No specific preferences" not in rewritten:
+ print(f"[REWRITE] {len(memory_notes)} notes → 1 merged instruction")
+ return [rewritten]
+ else:
+ print(f"[REWRITE] Kept original {len(memory_notes)} notes (no valid merge)")
+ return memory_notes
+
+ except Exception as e:
+ print(f"[REWRITE] Failed: {e}, keeping original notes")
+ return memory_notes
+
+ # Consolidation prompt for session-end preference merging
+ _CONSOLIDATION_PROMPT = """You are analyzing user preferences extracted from conversations.
+
+Current preferences for this user:
+{preferences}
+
+Task: Consolidate these preferences into a cleaner, more organized set by:
+1. MERGE similar preferences (e.g., "use bullet points" + "format with bullets" → single preference)
+2. REMOVE redundant or contradictory preferences (keep the more specific one)
+3. PRESERVE all unique, meaningful preferences
+4. Keep the same "When [condition], [action]." format
+
+Output ONLY the consolidated preferences, one per line, in this exact format:
+When [condition], [action].
+
+Do not add explanations or commentary. Just output the preference lines."""
+
+ def consolidate_user_preferences(self, user_id: str) -> int:
+ """
+ Consolidate user preferences at session end using LLM.
+
+ Merges similar preferences, removes redundancy, and creates cleaner
+ preference descriptions. Only runs if user has enough preferences.
+
+ Args:
+ user_id: The user whose preferences to consolidate.
+
+ Returns:
+ Number of preferences after consolidation (0 if skipped).
+ """
+ if not self.enable_preference_consolidation:
+ return 0
+
+ # Get user's memory cards
+ user_cards = [c for c in self._memory_cards if c.user_id == user_id]
+
+ if len(user_cards) < self.consolidation_threshold:
+ return len(user_cards)
+
+ # Build preference list for prompt
+ pref_lines = [card.note_text for card in user_cards]
+ preferences_text = "\n".join(f"- {p}" for p in pref_lines)
+
+ # Call LLM for consolidation
+ prompt = self._CONSOLIDATION_PROMPT.format(preferences=preferences_text)
+ messages = [{"role": "user", "content": prompt}]
+
+ try:
+ result = self._chat_model.answer(messages, max_new_tokens=512)
+ consolidated_text = result.get("content", "").strip()
+
+ if not consolidated_text:
+ return len(user_cards)
+
+ # Parse consolidated preferences
+ new_prefs = []
+ for line in consolidated_text.split("\n"):
+ line = line.strip()
+ if not line or not line.startswith("When "):
+ continue
+ # Parse "When [condition], [action]."
+ if ", " in line:
+ parts = line.split(", ", 1)
+ condition = parts[0].replace("When ", "").strip()
+ action = parts[1].rstrip(".").strip()
+ if condition and action:
+ new_prefs.append({
+ "condition": condition,
+ "action": action,
+ "is_global": self._classify_preference_scope(condition) if self.enable_global_preferences else False,
+ })
+
+ if not new_prefs:
+ return len(user_cards)
+
+ # Remove old cards for this user
+ keep_indices = [i for i, c in enumerate(self._memory_cards) if c.user_id != user_id]
+ self._memory_cards = [self._memory_cards[i] for i in keep_indices]
+ if len(keep_indices) > 0 and len(self._memory_embeddings) > 0:
+ self._memory_embeddings = self._memory_embeddings[keep_indices]
+ self._item_vectors = self._item_vectors[keep_indices]
+ else:
+ embed_dim = self._memory_embeddings.shape[1] if len(self._memory_embeddings) > 0 else 4096
+ self._memory_embeddings = np.zeros((0, embed_dim), dtype=np.float32)
+ self._item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32)
+
+ # Add consolidated preferences
+ for pref in new_prefs:
+ note_text = f"When {pref['condition']}, {pref['action']}."
+
+ # Compute embedding
+ e_note = self._embed_model.encode([note_text], normalize=True, return_tensor=False)[0]
+ v_note = self._projection.transform_vector(np.array(e_note))
+
+ # Create card
+ card = MemoryCard(
+ card_id=str(uuid.uuid4()),
+ user_id=user_id,
+ source_session_id=f"consolidated_{user_id}",
+ source_turn_ids=[],
+ raw_queries=[],
+ preference_list=PreferenceList(preferences=[
+ Preference(condition=pref["condition"], action=pref["action"], confidence=1.0)
+ ]),
+ note_text=note_text,
+ embedding_e=list(e_note),
+ kind="pref",
+ is_global=pref["is_global"],
+ )
+
+ self._memory_cards.append(card)
+ self._memory_embeddings = np.vstack([self._memory_embeddings, np.array([e_note])])
+ self._item_vectors = np.vstack([self._item_vectors, np.array([v_note])])
+
+ print(f"[PersonalizedLLM] Consolidated {len(user_cards)} → {len(new_prefs)} preferences for user {user_id}")
+ return len(new_prefs)
+
+ except Exception as e:
+ print(f"[PersonalizedLLM] Consolidation failed for user {user_id}: {e}")
+ return len(user_cards)
+
+ def _add_preferences_as_memory(
+ self,
+ prefs: PreferenceList,
+ query: str,
+ user_id: str,
+ turn_id: int,
+ ) -> List[Dict[str, Any]]:
+ """
+ Add extracted preferences as new memory cards.
+ Returns list of preference dicts for debug info.
+ """
+ extracted = []
+
+ if not prefs.preferences or self._projection is None:
+ return extracted
+
+ for pref in prefs.preferences:
+ note_text = f"When {pref.condition}, {pref.action}."
+
+ # Record for debug
+ extracted.append({
+ "condition": pref.condition,
+ "action": pref.action,
+ "confidence": pref.confidence,
+ })
+
+ # Deduplication check
+ is_duplicate = any(
+ card.user_id == user_id and card.note_text == note_text
+ for card in self._memory_cards
+ )
+
+ if is_duplicate:
+ continue
+
+ # Compute embedding from note_text (NOT query) for proper semantic retrieval
+ # This ensures retrieval query "solve math problem" matches stored "When math problems..."
+ e_note = self._embed_model.encode([note_text], normalize=True, return_tensor=False)[0]
+ v_note = self._projection.transform_vector(np.array(e_note))
+
+ # Classify as global or conditional
+ is_global = self._classify_preference_scope(pref.condition) if self.enable_global_preferences else False
+
+ # Create new memory card
+ card = MemoryCard(
+ card_id=str(uuid.uuid4()),
+ user_id=user_id,
+ source_session_id=f"eval_session_{user_id}",
+ source_turn_ids=[turn_id],
+ raw_queries=[query],
+ preference_list=PreferenceList(preferences=[pref]),
+ note_text=note_text,
+ embedding_e=list(e_note),
+ kind="pref",
+ is_global=is_global,
+ )
+
+ # Add to memory store
+ self._memory_cards.append(card)
+ self._memory_embeddings = np.vstack([self._memory_embeddings, np.array([e_note])])
+ self._item_vectors = np.vstack([self._item_vectors, np.array([v_note])])
+
+ return extracted
+
+ def _score_response(self, response: str) -> float:
+ """
+ Score a response for best-of-N selection.
+
+ Higher score = better response. Scoring heuristics:
+ 1. Length: Longer responses typically have more substance
+ 2. Solution indicators: Contains formulas, steps, answers
+ 3. Proactivity: Doesn't end with just a question
+
+ Returns:
+ Float score (higher is better)
+ """
+ score = 0.0
+ response_lower = response.lower()
+
+ # Length score (normalized, cap at 1000 chars)
+ score += min(len(response), 1000) / 1000 * 3.0
+
+ # Solution indicators (+1 each, max 5)
+ solution_indicators = ['=', 'step', 'answer', 'formula', 'result', 'therefore', 'solution']
+ indicator_count = sum(1 for ind in solution_indicators if ind in response_lower)
+ score += min(indicator_count, 5) * 0.5
+
+ # Structured content (+1 for numbered/bulleted lists)
+ if any(marker in response for marker in ['1.', '2.', '- ', '* ', '##']):
+ score += 1.0
+
+ # Penalty for ending with question (passive behavior)
+ # Check last 100 chars for question marks
+ if '?' in response[-100:]:
+ score -= 1.5
+
+ # Bonus for providing concrete values/numbers
+ import re
+ numbers = re.findall(r'\d+\.?\d*', response)
+ if len(numbers) >= 3:
+ score += 1.0
+
+ return score
+
+ # =========================================================================
+ # Public Interface
+ # =========================================================================
+
+ def chat(self, user_id: str, query: str) -> AssistantResponse:
+ """
+ Main online chat interface.
+
+ Args:
+ user_id: Unique identifier for the user.
+ query: Current user query/message.
+
+ Returns:
+ AssistantResponse containing the answer, usage stats, and debug info.
+
+ Notes:
+ - Internally manages user state, session history, memory retrieval
+ - After this call, you can call apply_feedback() with the turn's feedback
+ """
+ ctx = self._get_or_create_session(user_id)
+ session = ctx.session_state
+ user_state = self._user_store.get_state(user_id)
+
+ # Record user vector before for debug
+ z_long_before = user_state.z_long.copy().tolist()
+ z_short_before = user_state.z_short.copy().tolist()
+
+ # Add user turn to history
+ user_turn = self._build_chat_turn(user_id, query, "user", ctx.turn_counter)
+ session.history.append(user_turn)
+
+ # Vanilla mode: pure LLM without any memory or preference extraction
+ if self.mode == "vanilla":
+ # Skip embedding, preference extraction, and memory retrieval entirely
+ e_q_t = np.zeros(4096, dtype=np.float32) # Placeholder for vanilla mode
+ extracted_prefs = []
+ candidates = []
+ cand_item_vecs = np.array([])
+ base_scores = np.array([])
+ chosen_indices = []
+ probs = np.array([])
+ memories_t = []
+ memory_notes = []
+ else:
+ # Compute query embedding (only needed for non-vanilla modes)
+ # Explicitly normalize for consistent cosine similarity with stored embeddings
+ embed_result = self._embed_model.encode([query], normalize=True, return_tensor=False)
+ if embed_result is None or len(embed_result) == 0:
+ raise RuntimeError(f"Embedding model returned empty result for query: {query[:100]}")
+ e_q_t = np.array(embed_result[0])
+
+ # Store pending RL update info from last turn (for apply_feedback)
+ if session.last_query is not None and self.enable_rl_updates:
+ ctx.pending_rl_update = {
+ "last_query": session.last_query,
+ "last_answer": session.last_answer,
+ "last_memories": session.last_memories,
+ "last_query_embedding": session.last_query_embedding,
+ "current_query_embedding": e_q_t,
+ "last_candidate_item_vectors": session.last_candidate_item_vectors,
+ "last_policy_probs": session.last_policy_probs,
+ "last_chosen_indices": session.last_chosen_indices,
+ }
+
+ # Auto-compute reward via LLM judge if enabled
+ if self._llm_reward_client is not None:
+ import asyncio
+ try:
+ reward, gating = asyncio.run(eval_step_llm(
+ q_t=session.last_query,
+ answer_t=session.last_answer,
+ q_t1=query,
+ memories_t=session.last_memories or [],
+ client=self._llm_reward_client,
+ ))
+ if gating > 0.0:
+ self.apply_feedback(Feedback(
+ user_id=user_id,
+ turn_id=ctx.turn_counter - 1,
+ reward=reward,
+ gating=gating,
+ ))
+ except Exception as e:
+ # Graceful fallback: skip RL update if judge fails
+ print(f"[LLM-Reward] Judge call failed, skipping update: {e}")
+
+ # Extract preferences from conversation (if enabled)
+ # extract_turn processes only the last user turn - efficient since called each turn
+ # Preferences accumulate in _memory_cards across turns (dedup prevents duplicates)
+ extracted_prefs = []
+ if self.enable_preference_extraction:
+ prefs = self._extractor.extract_turn(session.history)
+ if prefs.preferences:
+ print(f"[DEBUG] Extracted {len(prefs.preferences)} prefs from history (len={len(session.history)})")
+ extracted_prefs = self._add_preferences_as_memory(
+ prefs, query, user_id, ctx.turn_counter
+ )
+ if extracted_prefs:
+ print(f"[DEBUG] Added {len(extracted_prefs)} to memory. Total cards: {len(self._memory_cards)}")
+
+ # Separate global preferences (bypass retrieval) from conditional ones
+ global_notes = []
+ retrieval_cards = self._memory_cards
+ retrieval_embeddings = self._memory_embeddings
+ retrieval_item_vectors = self._item_vectors
+ if self.enable_global_preferences:
+ global_cards = [c for c in self._memory_cards if c.is_global and c.user_id == user_id]
+ global_notes = [c.note_text for c in global_cards[:10]] # Cap at 10
+ # Filter out global cards for retrieval
+ cond_indices = [i for i, c in enumerate(self._memory_cards) if not c.is_global]
+ if cond_indices:
+ retrieval_cards = [self._memory_cards[i] for i in cond_indices]
+ retrieval_embeddings = self._memory_embeddings[cond_indices]
+ if len(self._item_vectors) > 0:
+ retrieval_item_vectors = self._item_vectors[cond_indices]
+ else:
+ retrieval_cards = []
+ retrieval_embeddings = np.zeros((0, self._memory_embeddings.shape[1]), dtype=np.float32) if len(self._memory_embeddings) > 0 else self._memory_embeddings
+ retrieval_item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32)
+
+ # Query transformation for better retrieval matching
+ retrieval_queries = None
+ if self.enable_query_transform:
+ retrieval_queries = self._transform_query_for_retrieval(query)
+
+ # Retrieve memories
+ if self.mode == "nopersonal":
+ candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_no_policy(
+ user_id=user_id,
+ query=query,
+ embed_model=self._embed_model,
+ reranker=self._reranker,
+ memory_cards=retrieval_cards,
+ memory_embeddings=retrieval_embeddings,
+ topk_dense=self._rl_cfg["dense_topk"],
+ topk_rerank=self._rl_cfg["rerank_topk"],
+ only_own_memories=self.only_own_memories,
+ queries=retrieval_queries,
+ dynamic_topk=self._rl_cfg["dynamic_topk"],
+ dynamic_min_k=self._rl_cfg["dynamic_min_k"],
+ dynamic_max_k=self._rl_cfg["dynamic_max_k"],
+ dynamic_score_ratio=self._rl_cfg["dynamic_score_ratio"],
+ )
+ else:
+ beta_long = self._rl_cfg["beta_long"]
+ beta_short = self._rl_cfg["beta_short"]
+ candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_with_policy(
+ user_id=user_id,
+ query=query,
+ embed_model=self._embed_model,
+ reranker=self._reranker,
+ memory_cards=retrieval_cards,
+ memory_embeddings=retrieval_embeddings,
+ user_store=self._user_store,
+ item_vectors=retrieval_item_vectors,
+ topk_dense=self._rl_cfg["dense_topk"],
+ topk_rerank=self._rl_cfg["rerank_topk"],
+ beta_long=beta_long,
+ beta_short=beta_short,
+ tau=self._rl_cfg["tau"],
+ only_own_memories=self.only_own_memories,
+ sample=not self.eval_mode,
+ queries=retrieval_queries,
+ )
+
+ # Get selected memories
+ memories_t = [candidates[int(i)] for i in chosen_indices] if chosen_indices else []
+ memory_notes = [m.note_text for m in memories_t]
+
+ # Apply preference rewrite if enabled
+ if self.enable_preference_rewrite and memory_notes:
+ memory_notes = self._rewrite_preferences(memory_notes, query)
+
+ # Debug: show retrieval info
+ if memories_t or global_notes:
+ print(f"[DEBUG-RETRIEVAL] User={user_id}, Query={query[:50]}...")
+ print(f"[DEBUG-RETRIEVAL] Global={len(global_notes)}, Candidates={len(candidates)}, Retrieved={len(memories_t)}")
+ for i, m in enumerate(memories_t[:3]): # Show top 3
+ score = probs[chosen_indices[i]] if i < len(chosen_indices) and chosen_indices[i] < len(probs) else 0
+ print(f"[DEBUG-RETRIEVAL] [{i+1}] score={score:.3f}: {m.note_text[:80]}...")
+
+ # Combine all notes for prompt (global + retrieved)
+ # For chat(), we combine all notes; chat_prepare() handles them separately
+ if self.mode != "vanilla":
+ all_memory_notes = (global_notes if global_notes else []) + memory_notes
+ else:
+ all_memory_notes = memory_notes
+
+ # Build prompt and count tokens
+ prompt_tokens = self._count_tokens(query)
+ for turn in session.history:
+ prompt_tokens += self._count_tokens(turn.text)
+ for note in all_memory_notes:
+ prompt_tokens += self._count_tokens(note)
+
+ # Generate answer (with best-of-N if enabled)
+ if self.best_of_n > 1:
+ # Generate N responses and pick the best one
+ candidates_responses = []
+ for i in range(self.best_of_n):
+ resp = self._chat_model.answer(
+ history=session.history,
+ memory_notes=all_memory_notes,
+ max_new_tokens=self._rl_cfg["max_new_tokens"],
+ temperature=0.8, # Slightly higher temp for diversity
+ )
+ score = self._score_response(resp)
+ candidates_responses.append((resp, score))
+
+ # Sort by score (descending) and pick best
+ candidates_responses.sort(key=lambda x: x[1], reverse=True)
+ answer_t = candidates_responses[0][0]
+ best_score = candidates_responses[0][1]
+
+ if len(candidates_responses) > 1:
+ print(f"[BEST-OF-{self.best_of_n}] Scores: {[f'{s:.2f}' for _, s in candidates_responses]}, picked score={best_score:.2f}")
+ else:
+ answer_t = self._chat_model.answer(
+ history=session.history,
+ memory_notes=all_memory_notes,
+ max_new_tokens=self._rl_cfg["max_new_tokens"],
+ )
+
+ completion_tokens = self._count_tokens(answer_t)
+
+ # Add assistant turn to history
+ assist_turn = self._build_chat_turn(user_id, answer_t, "assistant", ctx.turn_counter)
+ session.history.append(assist_turn)
+
+ # Update session state for next turn
+ session.last_query = query
+ session.last_answer = answer_t
+ session.last_memories = memories_t
+ session.last_query_embedding = e_q_t
+ session.last_candidate_item_vectors = cand_item_vecs
+ session.last_policy_probs = probs
+ session.last_chosen_indices = list(chosen_indices) if len(chosen_indices) > 0 else []
+
+ ctx.turn_counter += 1
+
+ # Build debug info
+ debug = DebugInfo(
+ selected_memory_ids=[m.card_id for m in memories_t],
+ selected_memory_notes=[m.note_text for m in memories_t],
+ selected_memory_scores=[float(probs[i]) if i < len(probs) else 0.0 for i in chosen_indices] if len(chosen_indices) > 0 else [],
+ user_vector_before=z_long_before + z_short_before, # Concatenated for simplicity
+ user_vector_after=user_state.z_long.tolist() + user_state.z_short.tolist(),
+ extracted_preferences=extracted_prefs,
+ extra={
+ "num_candidates": len(candidates),
+ "num_total_memories": len(self._memory_cards),
+ "z_long_norm": float(np.linalg.norm(user_state.z_long)),
+ "z_short_norm": float(np.linalg.norm(user_state.z_short)),
+ }
+ )
+
+ # Build usage stats
+ usage = UsageStats(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=prompt_tokens + completion_tokens,
+ model=self._llm_name,
+ )
+
+ return AssistantResponse(
+ answer=answer_t,
+ usage=usage,
+ debug=debug,
+ )
+
+ def chat_prepare(self, user_id: str, query: str, skip_extraction: bool = False, skip_auto_reward: bool = False) -> dict:
+ """
+ Prepare for chat without calling the LLM.
+
+ This does all the preparation work (embedding, memory retrieval, etc.)
+ and returns the messages to send to the LLM along with context needed
+ for post-processing.
+
+ Used for batch processing where messages are collected first, then
+ sent in batch to vLLM for concurrent processing.
+
+ Args:
+ user_id: Unique identifier for the user.
+ query: Current user query/message.
+
+ Returns:
+ Dict containing:
+ - messages: List of messages to send to LLM
+ - context: Dict with all state needed for chat_complete()
+ """
+ ctx = self._get_or_create_session(user_id)
+ session = ctx.session_state
+ user_state = self._user_store.get_state(user_id)
+
+ # Record user vector before for debug
+ z_long_before = user_state.z_long.copy().tolist()
+ z_short_before = user_state.z_short.copy().tolist()
+
+ # Add user turn to history
+ user_turn = self._build_chat_turn(user_id, query, "user", ctx.turn_counter)
+ session.history.append(user_turn)
+
+ # Vanilla mode: pure LLM without any memory or preference extraction
+ if self.mode == "vanilla":
+ e_q_t = np.zeros(4096, dtype=np.float32)
+ extracted_prefs = []
+ candidates = []
+ cand_item_vecs = np.array([])
+ base_scores = np.array([])
+ chosen_indices = []
+ probs = np.array([])
+ memories_t = []
+ memory_notes = []
+ else:
+ # Compute query embedding
+ embed_result = self._embed_model.encode([query], normalize=True, return_tensor=False)
+ if embed_result is None or len(embed_result) == 0:
+ raise RuntimeError(f"Embedding model returned empty result for query: {query[:100]}")
+ e_q_t = np.array(embed_result[0])
+
+ # Store pending RL update info from last turn
+ if session.last_query is not None and self.enable_rl_updates:
+ ctx.pending_rl_update = {
+ "last_query": session.last_query,
+ "last_answer": session.last_answer,
+ "last_memories": session.last_memories,
+ "last_query_embedding": session.last_query_embedding,
+ "current_query_embedding": e_q_t,
+ "last_candidate_item_vectors": session.last_candidate_item_vectors,
+ "last_policy_probs": session.last_policy_probs,
+ "last_chosen_indices": session.last_chosen_indices,
+ }
+
+ # Auto-compute reward via LLM judge if enabled
+ # skip_auto_reward=True when batch framework handles rewards externally
+ if self._llm_reward_client is not None and not skip_auto_reward:
+ import asyncio
+ try:
+ reward, gating = asyncio.run(eval_step_llm(
+ q_t=session.last_query,
+ answer_t=session.last_answer,
+ q_t1=query,
+ memories_t=session.last_memories or [],
+ client=self._llm_reward_client,
+ ))
+ if gating > 0.0:
+ self.apply_feedback(Feedback(
+ user_id=user_id,
+ turn_id=ctx.turn_counter - 1,
+ reward=reward,
+ gating=gating,
+ ))
+ except Exception as e:
+ print(f"[LLM-Reward] Judge call failed, skipping update: {e}")
+
+ # Extract preferences from conversation
+ extracted_prefs = []
+ if self.enable_preference_extraction and not skip_extraction:
+ prefs = self._extractor.extract_turn(session.history)
+ if prefs.preferences:
+ print(f"[DEBUG] Extracted {len(prefs.preferences)} prefs from history (len={len(session.history)})")
+ extracted_prefs = self._add_preferences_as_memory(
+ prefs, query, user_id, ctx.turn_counter
+ )
+ if extracted_prefs:
+ print(f"[DEBUG] Added {len(extracted_prefs)} to memory. Total cards: {len(self._memory_cards)}")
+
+ # Separate global preferences (bypass retrieval) from conditional ones
+ global_notes = []
+ retrieval_cards = self._memory_cards
+ retrieval_embeddings = self._memory_embeddings
+ retrieval_item_vectors = self._item_vectors
+ if self.enable_global_preferences:
+ global_cards = [c for c in self._memory_cards if c.is_global and c.user_id == user_id]
+ global_notes = [c.note_text for c in global_cards[:10]] # Cap at 10
+ cond_indices = [i for i, c in enumerate(self._memory_cards) if not c.is_global]
+ if cond_indices:
+ retrieval_cards = [self._memory_cards[i] for i in cond_indices]
+ retrieval_embeddings = self._memory_embeddings[cond_indices]
+ if len(self._item_vectors) > 0:
+ retrieval_item_vectors = self._item_vectors[cond_indices]
+ else:
+ retrieval_cards = []
+ retrieval_embeddings = np.zeros((0, self._memory_embeddings.shape[1]), dtype=np.float32) if len(self._memory_embeddings) > 0 else self._memory_embeddings
+ retrieval_item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32)
+
+ # Query transformation for better retrieval matching
+ retrieval_queries = None
+ if self.enable_query_transform:
+ retrieval_queries = self._transform_query_for_retrieval(query)
+
+ # Retrieve memories
+ if self.mode == "nopersonal":
+ candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_no_policy(
+ user_id=user_id,
+ query=query,
+ embed_model=self._embed_model,
+ reranker=self._reranker,
+ memory_cards=retrieval_cards,
+ memory_embeddings=retrieval_embeddings,
+ topk_dense=self._rl_cfg["dense_topk"],
+ topk_rerank=self._rl_cfg["rerank_topk"],
+ only_own_memories=self.only_own_memories,
+ queries=retrieval_queries,
+ dynamic_topk=self._rl_cfg["dynamic_topk"],
+ dynamic_min_k=self._rl_cfg["dynamic_min_k"],
+ dynamic_max_k=self._rl_cfg["dynamic_max_k"],
+ dynamic_score_ratio=self._rl_cfg["dynamic_score_ratio"],
+ )
+ else:
+ beta_long = self._rl_cfg["beta_long"]
+ beta_short = self._rl_cfg["beta_short"]
+ candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_with_policy(
+ user_id=user_id,
+ query=query,
+ embed_model=self._embed_model,
+ reranker=self._reranker,
+ memory_cards=retrieval_cards,
+ memory_embeddings=retrieval_embeddings,
+ user_store=self._user_store,
+ item_vectors=retrieval_item_vectors,
+ topk_dense=self._rl_cfg["dense_topk"],
+ topk_rerank=self._rl_cfg["rerank_topk"],
+ beta_long=beta_long,
+ beta_short=beta_short,
+ tau=self._rl_cfg["tau"],
+ only_own_memories=self.only_own_memories,
+ sample=not self.eval_mode,
+ queries=retrieval_queries,
+ )
+
+ memories_t = [candidates[int(i)] for i in chosen_indices] if chosen_indices else []
+ memory_notes = [m.note_text for m in memories_t]
+
+ # Apply preference rewrite if enabled
+ if self.enable_preference_rewrite and memory_notes:
+ memory_notes = self._rewrite_preferences(memory_notes, query)
+
+ if memories_t or global_notes:
+ print(f"[DEBUG-RETRIEVAL] User={user_id}, Query={query[:50]}...")
+ print(f"[DEBUG-RETRIEVAL] Global={len(global_notes)}, Candidates={len(candidates)}, Retrieved={len(memories_t)}")
+ for i, m in enumerate(memories_t[:3]):
+ score = probs[chosen_indices[i]] if i < len(chosen_indices) and chosen_indices[i] < len(probs) else 0
+ print(f"[DEBUG-RETRIEVAL] [{i+1}] score={score:.3f}: {m.note_text[:80]}...")
+
+ # Build prompt token count
+ prompt_tokens = self._count_tokens(query)
+ for turn in session.history:
+ prompt_tokens += self._count_tokens(turn.text)
+ all_notes = memory_notes + (global_notes if self.mode != "vanilla" else [])
+ for note in all_notes:
+ prompt_tokens += self._count_tokens(note)
+
+ # Build messages for LLM (pass global_notes separately for distinct prompt sections)
+ effective_global = global_notes if (self.enable_global_preferences and self.mode != "vanilla") else None
+ messages = self._chat_model.build_messages(
+ history=session.history,
+ memory_notes=memory_notes,
+ max_new_tokens=self._rl_cfg["max_new_tokens"],
+ global_notes=effective_global,
+ )
+
+ # Return messages and context for chat_complete
+ return {
+ "messages": messages,
+ "context": {
+ "user_id": user_id,
+ "query": query,
+ "ctx": ctx,
+ "session": session,
+ "user_state": user_state,
+ "z_long_before": z_long_before,
+ "z_short_before": z_short_before,
+ "e_q_t": e_q_t,
+ "extracted_prefs": extracted_prefs,
+ "candidates": candidates,
+ "cand_item_vecs": cand_item_vecs,
+ "base_scores": base_scores,
+ "chosen_indices": chosen_indices,
+ "probs": probs,
+ "memories_t": memories_t,
+ "memory_notes": memory_notes,
+ "prompt_tokens": prompt_tokens,
+ }
+ }
+
+ def chat_complete(self, answer_t: str, context: dict) -> AssistantResponse:
+ """
+ Complete chat with LLM response.
+
+ This takes the LLM response and context from chat_prepare(), and
+ does all post-processing (add to history, debug info, etc.).
+
+ Args:
+ answer_t: The LLM response text.
+ context: Context dict from chat_prepare().
+
+ Returns:
+ AssistantResponse containing the answer, usage stats, and debug info.
+ """
+ # Unpack context
+ user_id = context["user_id"]
+ query = context["query"]
+ ctx = context["ctx"]
+ session = context["session"]
+ user_state = context["user_state"]
+ z_long_before = context["z_long_before"]
+ z_short_before = context["z_short_before"]
+ e_q_t = context["e_q_t"]
+ extracted_prefs = context["extracted_prefs"]
+ candidates = context["candidates"]
+ cand_item_vecs = context["cand_item_vecs"]
+ chosen_indices = context["chosen_indices"]
+ probs = context["probs"]
+ memories_t = context["memories_t"]
+ memory_notes = context["memory_notes"]
+ prompt_tokens = context["prompt_tokens"]
+
+ completion_tokens = self._count_tokens(answer_t)
+
+ # Add assistant turn to history
+ assist_turn = self._build_chat_turn(user_id, answer_t, "assistant", ctx.turn_counter)
+ session.history.append(assist_turn)
+
+ # Update session state for next turn
+ session.last_query = query
+ session.last_answer = answer_t
+ session.last_memories = memories_t
+ session.last_query_embedding = e_q_t
+ session.last_candidate_item_vectors = cand_item_vecs
+ session.last_policy_probs = probs
+ session.last_chosen_indices = list(chosen_indices) if len(chosen_indices) > 0 else []
+
+ ctx.turn_counter += 1
+
+ # Build debug info
+ debug = DebugInfo(
+ selected_memory_ids=[m.card_id for m in memories_t],
+ selected_memory_notes=[m.note_text for m in memories_t],
+ selected_memory_scores=[float(probs[i]) if i < len(probs) else 0.0 for i in chosen_indices] if len(chosen_indices) > 0 else [],
+ user_vector_before=z_long_before + z_short_before,
+ user_vector_after=user_state.z_long.tolist() + user_state.z_short.tolist(),
+ extracted_preferences=extracted_prefs,
+ extra={
+ "num_candidates": len(candidates),
+ "num_total_memories": len(self._memory_cards),
+ "z_long_norm": float(np.linalg.norm(user_state.z_long)),
+ "z_short_norm": float(np.linalg.norm(user_state.z_short)),
+ }
+ )
+
+ # Build usage stats
+ usage = UsageStats(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=prompt_tokens + completion_tokens,
+ model=self._llm_name,
+ )
+
+ return AssistantResponse(
+ answer=answer_t,
+ usage=usage,
+ debug=debug,
+ )
+
+ def apply_extracted_preferences(self, user_id: str, pref_dict: dict) -> list:
+ """Apply pre-computed extraction results (from batch extraction) to memory."""
+ prefs = PreferenceList.model_validate(pref_dict)
+ if not prefs.preferences:
+ return []
+ ctx = self._get_or_create_session(user_id)
+ query = ctx.session_state.history[-1].text if ctx.session_state.history else ""
+ extracted = self._add_preferences_as_memory(prefs, query, user_id, ctx.turn_counter)
+ if extracted:
+ print(f"[DEBUG] Batch-added {len(extracted)} to memory. Total cards: {len(self._memory_cards)}")
+ return extracted
+
+ def get_last_user_query(self, user_id: str) -> str:
+ """Get the last user message text for this user's session."""
+ ctx = self._sessions.get(user_id)
+ if ctx and ctx.session_state.history:
+ for t in reversed(ctx.session_state.history):
+ if t.role == "user":
+ return t.text
+ return ""
+
+ def reset_session(self, user_id: str) -> None:
+ """
+ Reset session for a user (new chat window).
+
+ This clears:
+ - Session conversation history
+ - Short-term user vector (z_short)
+ - Pending RL update info
+
+ This preserves:
+ - Long-term user vector (z_long)
+ - User's memory cards (may be consolidated if enabled)
+
+ Args:
+ user_id: The user whose session to reset.
+ """
+ # Consolidate preferences at session end (before clearing session)
+ if self.enable_preference_consolidation:
+ self.consolidate_user_preferences(user_id)
+
+ # Clear session context
+ if user_id in self._sessions:
+ del self._sessions[user_id]
+
+ # Create fresh session
+ self._sessions[user_id] = _SessionContext(
+ session_state=OnlineSessionState(user_id=user_id),
+ turn_counter=0,
+ )
+
+ # Reset short-term vector but keep long-term
+ user_state = self._user_store.get_state(user_id)
+ user_state.z_short = np.zeros(self._rl_cfg["item_dim"], dtype=np.float32)
+ self._user_store.save_state(user_state)
+
+ def reset_user(self, user_id: str) -> None:
+ """
+ Completely reset a user (new "life").
+
+ This clears:
+ - Long-term user vector (z_long)
+ - Short-term user vector (z_short)
+ - User's memory cards
+ - Session history
+ - All cached state
+
+ Args:
+ user_id: The user to reset.
+ """
+ # Clear session
+ if user_id in self._sessions:
+ del self._sessions[user_id]
+
+ # Reset user state vectors
+ user_state = self._user_store.get_state(user_id)
+ user_state.z_long = self._user_store.global_init_z.copy()
+ user_state.z_short = np.zeros(self._rl_cfg["item_dim"], dtype=np.float32)
+ user_state.reward_ma = 0.0
+ self._user_store.save_state(user_state)
+
+ # Find indices to KEEP (cards NOT belonging to this user)
+ # Must do this BEFORE modifying _memory_cards
+ keep_indices = [
+ i for i, card in enumerate(self._memory_cards)
+ if card.user_id != user_id
+ ]
+
+ # Filter memory cards
+ self._memory_cards = [self._memory_cards[i] for i in keep_indices]
+
+ # Filter embeddings and item vectors to match
+ if len(keep_indices) > 0 and len(self._memory_embeddings) > 0:
+ self._memory_embeddings = self._memory_embeddings[keep_indices]
+ self._item_vectors = self._item_vectors[keep_indices]
+ else:
+ # No cards left or no embeddings
+ embed_dim = self._memory_embeddings.shape[1] if len(self._memory_embeddings) > 0 else 4096
+ self._memory_embeddings = np.zeros((0, embed_dim), dtype=np.float32)
+ self._item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32)
+
+ def apply_feedback(self, feedback: Feedback) -> None:
+ """
+ Apply feedback from user simulator or judge.
+
+ This performs the REINFORCE update to user vectors based on
+ the reward signal from the previous turn.
+
+ Args:
+ feedback: Feedback object containing reward, gating, and metadata.
+
+ Notes:
+ - Should be called AFTER chat() but BEFORE the next chat() call
+ - Uses the stored context from the previous turn
+ - If enable_rl_updates is False, this is a no-op (logging only)
+ - If mode is "nopersonal", this is a no-op (baseline comparison)
+ """
+ if not self.enable_rl_updates:
+ return
+
+ # In "nopersonal" or "vanilla" mode, skip RL updates entirely (baseline)
+ if self.mode in ("nopersonal", "vanilla"):
+ return
+
+ user_id = feedback.user_id
+ ctx = self._sessions.get(user_id)
+
+ if ctx is None or ctx.pending_rl_update is None:
+ return
+
+ pending = ctx.pending_rl_update
+ user_state = self._user_store.get_state(user_id)
+
+ # Check if we have the necessary data for RL update
+ if (pending.get("last_candidate_item_vectors") is not None and
+ pending.get("last_policy_probs") is not None and
+ pending.get("last_chosen_indices") is not None and
+ len(pending["last_chosen_indices"]) > 0):
+
+ # Extract chosen vectors
+ chosen_indices = pending["last_chosen_indices"]
+ candidate_vectors = pending["last_candidate_item_vectors"]
+
+ if len(candidate_vectors) > 0:
+ print(f"[DEBUG-REINFORCE] User={user_id} reward={feedback.reward:.2f} "
+ f"n_candidates={len(candidate_vectors)} chosen={chosen_indices} "
+ f"probs_shape={pending['last_policy_probs'].shape if hasattr(pending['last_policy_probs'], 'shape') else 'N/A'}")
+ # REINFORCE expects:
+ # - item_vectors: ALL candidate vectors [K, k]
+ # - chosen_indices: indices into those candidates
+ # - policy_probs: probabilities over all K candidates [K]
+ updated = reinforce_update_user_state(
+ user_state=user_state,
+ item_vectors=candidate_vectors, # All candidates, not just chosen
+ chosen_indices=chosen_indices, # Original indices into candidates
+ policy_probs=pending["last_policy_probs"],
+ reward_hat=feedback.reward,
+ gating=feedback.gating,
+ tau=self._rl_cfg["tau"],
+ eta_long=self._rl_cfg["eta_long"],
+ eta_short=self._rl_cfg["eta_short"],
+ ema_alpha=self._rl_cfg["ema_alpha"],
+ short_decay=self._rl_cfg["short_decay"],
+ )
+
+ print(f"[DEBUG-REINFORCE] updated={updated} z_long_norm={np.linalg.norm(user_state.z_long):.15e}")
+ if updated:
+ self._user_store.save_state(user_state)
+
+ # Clear pending update
+ ctx.pending_rl_update = None
+
+ def get_user_state_summary(self, user_id: str) -> Dict[str, Any]:
+ """
+ Get a summary of the user's current state (for debugging/analysis).
+
+ Args:
+ user_id: The user to query.
+
+ Returns:
+ Dictionary with user state information.
+ """
+ user_state = self._user_store.get_state(user_id)
+ ctx = self._sessions.get(user_id)
+
+ user_memory_count = sum(
+ 1 for card in self._memory_cards if card.user_id == user_id
+ )
+
+ return {
+ "user_id": user_id,
+ "z_long_norm": float(np.linalg.norm(user_state.z_long)),
+ "z_short_norm": float(np.linalg.norm(user_state.z_short)),
+ "reward_ma": user_state.reward_ma,
+ "session_history_length": len(ctx.session_state.history) if ctx else 0,
+ "turn_counter": ctx.turn_counter if ctx else 0,
+ "user_memory_count": user_memory_count,
+ "total_memory_count": len(self._memory_cards),
+ }
+
+ def persist(self) -> None:
+ """
+ Persist all state to disk.
+
+ Call this at the end of an evaluation run to save:
+ - User state vectors
+ - Memory cards
+ """
+ # Save user store
+ self._user_store.persist()
+
+ # Save memory cards
+ with open(self._memory_cards_path, "w", encoding="utf-8") as f:
+ for card in self._memory_cards:
+ f.write(card.model_dump_json() + "\n")
+
+ # Save embeddings
+ np.save(self._memory_embeddings_path, self._memory_embeddings)
+
+ # Save item projection with updated vectors
+ if self._projection is not None:
+ np.savez(
+ self._item_projection_path,
+ P=self._projection.P,
+ mean=self._projection.mean,
+ V=self._item_vectors,
+ )
+
+ print("[PersonalizedLLM] State persisted to disk.")
+
+
+# =============================================================================
+# Convenience Factory
+# =============================================================================
+
+def create_personalized_llm(
+ config_path: Optional[str] = None,
+ **kwargs
+) -> PersonalizedLLM:
+ """
+ Factory function to create a PersonalizedLLM instance.
+
+ Args:
+ config_path: Optional path to configuration file.
+ **kwargs: Additional arguments passed to PersonalizedLLM constructor.
+
+ Returns:
+ Configured PersonalizedLLM instance.
+ """
+ return PersonalizedLLM(config_path=config_path, **kwargs)
+
diff --git a/src/personalization/types.py b/src/personalization/types.py
new file mode 100644
index 0000000..a25b560
--- /dev/null
+++ b/src/personalization/types.py
@@ -0,0 +1,4 @@
+from personalization.retrieval.preference_store.schemas import ChatTurn
+
+__all__ = ["ChatTurn"]
+
diff --git a/src/personalization/user_model/__init__.py b/src/personalization/user_model/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/user_model/__init__.py
diff --git a/src/personalization/user_model/features.py b/src/personalization/user_model/features.py
new file mode 100644
index 0000000..a4508b4
--- /dev/null
+++ b/src/personalization/user_model/features.py
@@ -0,0 +1,49 @@
+import numpy as np
+from dataclasses import dataclass
+from sklearn.decomposition import PCA
+
+@dataclass
+class ItemProjection:
+ P: np.ndarray # [k, d]
+ mean: np.ndarray # [d]
+
+ @classmethod
+ def from_pca(cls, embeddings: np.ndarray, k: int) -> "ItemProjection":
+ """
+ embeddings: [M, d]
+ """
+ mean = embeddings.mean(axis=0)
+ centered = embeddings - mean
+
+ # Ensure k is not larger than min(n_samples, n_features)
+ n_samples, n_features = embeddings.shape
+ actual_k = min(k, n_samples, n_features)
+
+ pca = PCA(n_components=actual_k)
+ pca.fit(centered)
+
+ # pca.components_: [k, d]
+ P = pca.components_ # Each row is a principal component vector
+
+ # If we had to reduce k, we might want to pad P or handle it?
+ # For now, let's assume we get what we asked for or less if data is small.
+ # But for the system we want fixed k.
+ # If actual_k < k, we should pad with zeros to match expected dimension.
+ if actual_k < k:
+ padding = np.zeros((k - actual_k, n_features), dtype=P.dtype)
+ P = np.vstack([P, padding])
+
+ return cls(P=P, mean=mean)
+
+ def transform_embeddings(self, E: np.ndarray) -> np.ndarray:
+ """
+ E: [N, d] -> [N, k]
+ """
+ return (E - self.mean) @ self.P.T
+
+ def transform_vector(self, e: np.ndarray) -> np.ndarray:
+ """
+ e: [d] -> [k]
+ """
+ return self.P @ (e - self.mean)
+
diff --git a/src/personalization/user_model/policy/__init__.py b/src/personalization/user_model/policy/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/user_model/policy/__init__.py
diff --git a/src/personalization/user_model/policy/optimizer.py b/src/personalization/user_model/policy/optimizer.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/user_model/policy/optimizer.py
diff --git a/src/personalization/user_model/policy/reinforce.py b/src/personalization/user_model/policy/reinforce.py
new file mode 100644
index 0000000..adfaef7
--- /dev/null
+++ b/src/personalization/user_model/policy/reinforce.py
@@ -0,0 +1,104 @@
+from typing import Sequence, List
+from dataclasses import dataclass
+import numpy as np
+
+from personalization.user_model.tensor_store import UserState
+
+@dataclass
+class PolicyScores:
+ scores: np.ndarray # [K] s(q_t, m; u)
+ probs: np.ndarray # [K] π_z(m|q_t)
+
+def compute_policy_scores(
+ base_scores: np.ndarray, # [K], from reranker
+ user_state: UserState,
+ item_vectors: np.ndarray, # [K, k], v_m for the K candidates
+ beta_long: float,
+ beta_short: float,
+ tau: float,
+) -> PolicyScores:
+ """
+ Compute personalized scores and softmax probabilities.
+ s(q_t, m; u) = s_0(q_t,m) + z_t^{(eff)}.T @ v_m
+ z_t^{(eff)} = beta_long * z_long + beta_short * z_short
+ """
+ if len(item_vectors) == 0:
+ return PolicyScores(scores=np.array([]), probs=np.array([]))
+
+ z_eff = beta_long * user_state.z_long + beta_short * user_state.z_short
+
+ # Calculate personalized term
+ # item_vectors: [K, k]
+ # z_eff: [k]
+ # term: [K]
+ personalization_term = np.dot(item_vectors, z_eff)
+
+ # Total scores
+ scores = base_scores + personalization_term
+
+ # Softmax
+ # Use exp(score/tau)
+ # Subtract max for stability
+ scaled_scores = scores / tau
+ exp_scores = np.exp(scaled_scores - np.max(scaled_scores))
+ probs = exp_scores / np.sum(exp_scores)
+
+ return PolicyScores(scores=scores, probs=probs)
+
+def reinforce_update_user_state(
+ user_state: UserState,
+ item_vectors: np.ndarray, # [K, k] for candidates
+ chosen_indices: Sequence[int], # indices of A_t in 0..K-1
+ policy_probs: np.ndarray, # [K] π_z(m|q_t)
+ reward_hat: float, # \hat r_t
+ gating: float, # g_t
+ tau: float,
+ eta_long: float,
+ eta_short: float,
+ ema_alpha: float,
+ short_decay: float,
+) -> bool:
+ """
+ In-place update user_state.z_long / z_short / reward_ma via REINFORCE.
+ Returns True if update occurred, False otherwise.
+ """
+ if len(chosen_indices) == 0:
+ return False
+
+ # 1. Baseline Advantage
+ advantage = gating * (reward_hat - user_state.reward_ma)
+
+ # Optimization: skip if advantage is negligible
+ if abs(advantage) < 1e-6:
+ return False
+
+ # 2. Chosen Vector Average (v_{chosen,t})
+ chosen_mask = np.zeros(len(item_vectors), dtype=np.float32)
+ for idx in chosen_indices:
+ idx_int = int(idx)
+ if 0 <= idx_int < len(item_vectors):
+ chosen_mask[idx_int] = 1.0
+
+ if chosen_mask.sum() == 0:
+ return False
+
+ chosen_mask /= chosen_mask.sum() # Normalize to average
+ v_chosen = np.dot(chosen_mask, item_vectors) # [k]
+
+ # 3. Expected Vector (\mu_t(z))
+ # policy_probs: [K]
+ # item_vectors: [K, k]
+ v_expect = np.dot(policy_probs, item_vectors) # [k]
+
+ # 4. Gradient Direction
+ grad = (advantage / tau) * (v_chosen - v_expect)
+
+ # 5. Update Vectors
+ user_state.z_long += eta_long * grad
+ user_state.z_short = (1.0 - short_decay) * user_state.z_short + eta_short * grad
+
+ # 6. Update Reward Baseline (EMA)
+ user_state.reward_ma = (1.0 - ema_alpha) * user_state.reward_ma + ema_alpha * reward_hat
+
+ return True
+
diff --git a/src/personalization/user_model/scoring.py b/src/personalization/user_model/scoring.py
new file mode 100644
index 0000000..75ffc84
--- /dev/null
+++ b/src/personalization/user_model/scoring.py
@@ -0,0 +1,25 @@
+import numpy as np
+from .tensor_store import UserState
+
+def score_with_user(
+ base_score: float,
+ user_state: UserState,
+ v_m: np.ndarray, # [k]
+ beta_long: float,
+ beta_short: float,
+) -> float:
+ """
+ Personalized scoring:
+ s = base_score + (beta_long * z_long + beta_short * z_short) . v_m
+ Day2: beta_long = beta_short = 0 -> s == base_score
+ """
+ z_eff = beta_long * user_state.z_long + beta_short * user_state.z_short
+ # dot product
+ # Ensure shapes match
+ if v_m.shape != z_eff.shape:
+ # Just in case of dimension mismatch
+ return float(base_score)
+
+ term = np.dot(z_eff, v_m)
+ return float(base_score + term)
+
diff --git a/src/personalization/user_model/session_state.py b/src/personalization/user_model/session_state.py
new file mode 100644
index 0000000..5cd2243
--- /dev/null
+++ b/src/personalization/user_model/session_state.py
@@ -0,0 +1,19 @@
+from dataclasses import dataclass, field
+from typing import List, Optional
+import numpy as np
+
+from personalization.retrieval.preference_store.schemas import ChatTurn, MemoryCard
+
+@dataclass
+class OnlineSessionState:
+ user_id: str
+ history: List[ChatTurn] = field(default_factory=list)
+ last_query: Optional[str] = None
+ last_answer: Optional[str] = None
+ last_memories: List[MemoryCard] = field(default_factory=list)
+ last_query_embedding: Optional[np.ndarray] = None
+ last_candidate_item_vectors: Optional[np.ndarray] = None # [K, k]
+ last_policy_probs: Optional[np.ndarray] = None # [K]
+ last_chosen_indices: List[int] = field(default_factory=list)
+
+
diff --git a/src/personalization/user_model/tensor_store.py b/src/personalization/user_model/tensor_store.py
new file mode 100644
index 0000000..42dbf4e
--- /dev/null
+++ b/src/personalization/user_model/tensor_store.py
@@ -0,0 +1,80 @@
+import numpy as np
+from dataclasses import dataclass
+from typing import Dict, Optional
+import os
+
+@dataclass
+class UserState:
+ user_id: str
+ z_long: np.ndarray # [k]
+ z_short: np.ndarray # [k]
+ reward_ma: float # baseline for reward, init 0.0
+
+class UserTensorStore:
+ def __init__(self, k: int, path: str):
+ self.k = k
+ self.path = path
+ self._states: Dict[str, UserState] = {}
+ self._load()
+
+ # Calculate global mean for initialization
+ if self._states:
+ z_all = np.stack([st.z_long for st in self._states.values()])
+ self.global_init_z = np.mean(z_all, axis=0)
+ else:
+ self.global_init_z = np.zeros(self.k, dtype=np.float32)
+
+ def _load(self):
+ if os.path.exists(self.path):
+ try:
+ data = np.load(self.path, allow_pickle=True)
+ # Assume saved as dict of user_id -> dict/object
+ # For simplicity, let's say we save a single dict in a .npy or .npz
+ # But np.save/load with pickle is tricky for complex objects.
+ # Let's save as .npz where each key is user_id and value is a structured array or just use z_long for now?
+ # A robust way for prototype:
+ # save multiple arrays: "u1_long", "u1_short", "u1_meta"
+ pass
+ # For Day 2 prototype, we might just re-init from init script or rely on memory if not persisting strictly.
+ # But let's try to load if we can.
+
+ # Let's implement a simple npz schema:
+ # keys: "{uid}_long", "{uid}_short", "{uid}_meta" (meta=[reward_ma])
+ for key in data.files:
+ if key.endswith("_long"):
+ uid = key[:-5]
+ z_long = data[key]
+ z_short = data.get(f"{uid}_short", np.zeros(self.k))
+ meta = data.get(f"{uid}_meta", np.array([0.0]))
+ self._states[uid] = UserState(uid, z_long, z_short, float(meta[0]))
+ except Exception as e:
+ print(f"Warning: Failed to load UserStore from {self.path}: {e}")
+
+ def _save(self):
+ # Save to npz
+ save_dict = {}
+ for uid, state in self._states.items():
+ save_dict[f"{uid}_long"] = state.z_long
+ save_dict[f"{uid}_short"] = state.z_short
+ save_dict[f"{uid}_meta"] = np.array([state.reward_ma])
+ np.savez(self.path, **save_dict)
+
+ def get_state(self, user_id: str) -> UserState:
+ if user_id not in self._states:
+ # Lazy init with global mean for new users
+ state = UserState(
+ user_id=user_id,
+ z_long=self.global_init_z.copy(),
+ z_short=np.zeros(self.k, dtype=np.float32),
+ reward_ma=0.0,
+ )
+ self._states[user_id] = state
+ return self._states[user_id]
+
+ def save_state(self, state: UserState) -> None:
+ self._states[state.user_id] = state
+
+ def persist(self):
+ """Public method to force save to disk."""
+ self._save()
+
diff --git a/src/personalization/utils/__init__.py b/src/personalization/utils/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/utils/__init__.py
diff --git a/src/personalization/utils/ids.py b/src/personalization/utils/ids.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/utils/ids.py
diff --git a/src/personalization/utils/io.py b/src/personalization/utils/io.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/utils/io.py
diff --git a/src/personalization/utils/logging.py b/src/personalization/utils/logging.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/utils/logging.py
diff --git a/src/personalization/utils/timing.py b/src/personalization/utils/timing.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/utils/timing.py