diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-03-18 18:25:09 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-03-18 18:25:09 -0500 |
| commit | b6c3e4e51eeab703b40284459c6e9fff2151216c (patch) | |
| tree | 221410886f23214575f93b9ef44fa8431c9a6dfc /src/personalization | |
Initial release: VARS - personalized LLM with RAG and user vector learning
Diffstat (limited to 'src/personalization')
57 files changed, 4844 insertions, 0 deletions
diff --git a/src/personalization/__init__.py b/src/personalization/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/__init__.py diff --git a/src/personalization/config/__init__.py b/src/personalization/config/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/config/__init__.py diff --git a/src/personalization/config/registry.py b/src/personalization/config/registry.py new file mode 100644 index 0000000..6badeae --- /dev/null +++ b/src/personalization/config/registry.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any, Dict, Optional +import torch +import yaml + +from personalization.config import settings + +# Project root for resolving config paths +_PROJECT_ROOT = Path(__file__).parent.parent.parent.parent + +# Avoid circular imports by NOT importing extractors here at top level +# from personalization.models.preference_extractor.base import PreferenceExtractorBase +# from personalization.models.preference_extractor.rule_extractor import QwenRuleExtractor +# from personalization.models.preference_extractor.gpt4o_extractor import GPT4OExtractor +# from personalization.models.preference_extractor.llm_extractor import PreferenceExtractorLLM + +_DTYPE_MAP: Dict[str, torch.dtype] = { + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "float32": torch.float32, +} + +def choose_dtype(preferred: Optional[str] = None) -> torch.dtype: + if preferred and preferred.lower() in _DTYPE_MAP: + dt = _DTYPE_MAP[preferred.lower()] + else: + dt = torch.bfloat16 if torch.cuda.is_available() else torch.float32 + if dt is torch.bfloat16 and not torch.cuda.is_available(): + return torch.float32 + return dt + +def choose_device_map(spec: Optional[str] = "auto") -> Any: + return spec or "auto" + +def ensure_local_path(path_str: str) -> str: + path = Path(path_str) + if not path.exists(): + path.mkdir(parents=True, exist_ok=True) + return str(path) + +# --- Chat Model Factory --- +def get_chat_model(name: str, device_override: Optional[str] = None): + """ + Get a chat model by name. + + Args: + name: Model name (e.g., "qwen_1_5b", "llama_8b") + device_override: Optional device override (e.g., "cuda:2"). If None, uses config default. + """ + from personalization.models.llm.base import ChatModel + from personalization.models.llm.qwen_instruct import QwenInstruct + from personalization.models.llm.llama_instruct import LlamaChatModel + from personalization.models.llm.vllm_chat import VLLMChatModel + + cfg = settings.load_local_models_config() + + # Try to load raw config to support multi-backend map + with open(_PROJECT_ROOT / "configs/local_models.yaml", "r") as f: + raw_cfg = yaml.safe_load(f) + + models = raw_cfg.get("models", {}).get("llm", {}) + + # If models['llm'] is a dict of configs (new style) + if isinstance(models, dict) and "backend" in models.get(name, {}): + spec = models[name] + backend = spec.get("backend", "qwen") + path = spec["path"] + device = device_override or spec.get("device", "cuda") # Use override if provided + dtype = spec.get("dtype", "bfloat16") + max_len = spec.get("max_context_length", 4096) + + if backend == "qwen": + return QwenInstruct( + model_path=path, + device=device, + dtype=choose_dtype(dtype), # Converts string to torch.dtype + max_context_length=max_len + ) + elif backend == "llama": + return LlamaChatModel( + model_path=path, + device=device, + dtype=choose_dtype(dtype), # Converts string to torch.dtype + max_context_length=max_len + ) + elif backend == "vllm": + # Use vLLM HTTP API for high-throughput inference + vllm_url = spec.get("vllm_url", "http://localhost:8003/v1") + return VLLMChatModel( + vllm_url=vllm_url, + model_name=spec.get("model_name"), + max_context_length=max_len + ) + + # Fallback to legacy single config + return QwenInstruct.from_config(cfg) + +def get_preference_extractor(name: Optional[str] = None): + # Deferred imports to break circular dependency + from personalization.models.preference_extractor.rule_extractor import QwenRuleExtractor + from personalization.models.preference_extractor.gpt4o_extractor import GPT4OExtractor + from personalization.models.preference_extractor.llm_extractor import PreferenceExtractorLLM + + cfg = settings.load_local_models_config() + pref_cfg = cfg.preference_extractor + + if name is None: + if isinstance(pref_cfg, dict) and "qwen3_0_6b_sft" in pref_cfg: + name = "qwen3_0_6b_sft" + else: + name = "rule" + + if isinstance(pref_cfg, dict) and name in pref_cfg: + spec = pref_cfg[name] + if name == "qwen3_0_6b_sft": + # Use QwenRuleExtractor which we have updated for SFT End-to-End logic + return QwenRuleExtractor( + model_path=spec["path"], + device_map=spec.get("device", "auto"), + dtype=choose_dtype(spec.get("dtype", "bfloat16")), + ) + # Add 'default' handling if mapped to rule/gpt + if name == "default": + pass + + if name == "gpt4o": + return GPT4OExtractor.from_config(cfg) + elif name == "gpt5_mini": + import os + return GPT4OExtractor(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-5-mini") + elif name == "rule": + if isinstance(pref_cfg, dict): + if "default" in pref_cfg: + # Manually construct to bypass ModelSpec mismatch if needed + spec_dict = pref_cfg["default"] + return QwenRuleExtractor( + model_path=spec_dict["local_path"], + dtype=choose_dtype(spec_dict.get("dtype")), + device_map=choose_device_map(spec_dict.get("device_map")) + ) + else: + return QwenRuleExtractor.from_config(cfg) + + raise ValueError(f"Could not load preference extractor: {name}") diff --git a/src/personalization/config/settings.py b/src/personalization/config/settings.py new file mode 100644 index 0000000..8f0cc8a --- /dev/null +++ b/src/personalization/config/settings.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import os +from pathlib import Path +from typing import Optional, Any, Dict + +import yaml +from pydantic import BaseModel, Field + + +class ModelSpec(BaseModel): + hf_id: str = Field(..., description="Hugging Face repository id") + local_path: str = Field(..., description="Local directory for model weights") + dtype: Optional[str] = Field( + default="bfloat16", description="Preferred torch dtype: bfloat16|float16|float32" + ) + device_map: Optional[str] = Field(default="auto", description="Device map policy") + + +class EmbeddingModelsConfig(BaseModel): + qwen3: Optional[ModelSpec] = None + nemotron: Optional[ModelSpec] = None + + +class RerankerModelsConfig(BaseModel): + qwen3_8b: Optional[ModelSpec] = None + + +class LocalModelsConfig(BaseModel): + llm: ModelSpec + preference_extractor: Any # Allow flexible dict or ModelSpec for now to support map + embedding: Optional[EmbeddingModelsConfig] = None + reranker: Optional[RerankerModelsConfig] = None + + +def _resolve_config_path(env_key: str, default_rel: str) -> Path: + value = os.getenv(env_key) + if value: + return Path(value).expanduser().resolve() + # Use project root (parent of src/personalization/config) instead of cwd + project_root = Path(__file__).parent.parent.parent.parent + return (project_root / default_rel).resolve() + + +def load_local_models_config(path: Optional[str] = None) -> LocalModelsConfig: + config_path = Path(path) if path else _resolve_config_path( + "LOCAL_MODELS_CONFIG", "configs/local_models.yaml" + ) + with open(config_path, "r", encoding="utf-8") as f: + raw = yaml.safe_load(f) or {} + models = raw.get("models", {}) + embedding_cfg = None + if "embedding" in models: + emb = models["embedding"] or {} + # dtype/device_map are not necessary for embedders; ModelSpec still accepts them + embedding_cfg = EmbeddingModelsConfig( + qwen3=ModelSpec(**emb["qwen3"]) if "qwen3" in emb else None, + nemotron=ModelSpec(**emb["nemotron"]) if "nemotron" in emb else None, + ) + + reranker_cfg = None + if "reranker" in models: + rer = models["reranker"] or {} + reranker_cfg = RerankerModelsConfig( + qwen3_8b=ModelSpec(**rer["qwen3_8b"]) if "qwen3_8b" in rer else None + ) + + return LocalModelsConfig( + llm=ModelSpec(**models["llm"]), + preference_extractor=models["preference_extractor"], # Pass raw dict/value + embedding=embedding_cfg, + reranker=reranker_cfg, + ) + + diff --git a/src/personalization/feedback/__init__.py b/src/personalization/feedback/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/feedback/__init__.py diff --git a/src/personalization/feedback/gating.py b/src/personalization/feedback/gating.py new file mode 100644 index 0000000..d741874 --- /dev/null +++ b/src/personalization/feedback/gating.py @@ -0,0 +1,72 @@ +import numpy as np +from personalization.feedback.schemas import TurnSample + +def cosine_sim_batch(matrix: np.ndarray, vector: np.ndarray) -> np.ndarray: + # matrix: [N, d], vector: [d] + # return: [N] + norm_m = np.linalg.norm(matrix, axis=1) + norm_v = np.linalg.norm(vector) + + # Avoid div by zero + den = norm_m * norm_v + den[den == 0] = 1e-9 + + return np.dot(matrix, vector) / den + +def estimate_retrieval_gating(sample: TurnSample, reward_hat: float) -> float: + """ + Return g_t in [0,1], representing how much the reward is due to retrieval. + """ + e_q = sample.query_embedding_t + e_q1 = sample.query_embedding_t1 + + if e_q is None or e_q1 is None or not sample.memories: + return 0.5 # Neutral + + # We need embeddings of the memories. + # In a real pipeline, we might pass them in sample.memory_embeddings. + # If missing, we can't compute sim. + if sample.memory_embeddings is None: + # Try to use embedding_e from memory cards if available + # But MemoryCard.embedding_e is List[float] + try: + mem_embs = np.array([m.embedding_e for m in sample.memories]) + if mem_embs.shape[1] == 0: # Empty embeddings + return 0.5 + except: + return 0.5 + else: + mem_embs = sample.memory_embeddings + + # Compute similarities + # shape: [K] + sims_q = cosine_sim_batch(mem_embs, e_q) + sims_q1 = cosine_sim_batch(mem_embs, e_q1) + + s_q_max = sims_q.max() if len(sims_q) > 0 else 0 + s_q1_max = sims_q1.max() if len(sims_q1) > 0 else 0 + + g = 0.5 + + # Heuristics + + # Case A: Retrieval clearly irrelevant + bad reward + # q_t / q_{t+1} have low similarity to memories -> likely retrieval failure (or no relevant memories) + if reward_hat < -0.5 and s_q_max < 0.2 and s_q1_max < 0.2: + g = 0.9 # Blame retrieval (for failing to find anything, or nothing exists) + + # Case B: Retrieval looks good but reward is bad + # Memories are relevant to query, but user still unhappy -> LLM didn't use them well? + elif reward_hat < -0.5 and s_q_max > 0.5: + g = 0.2 # Likely LLM fault + + # Case C: Good reward + # If reward is high, we assume both did okay. + elif reward_hat > 0.5: + if s_q_max > 0.4: + g = 0.6 # Retrieval helped + else: + g = 0.3 # LLM handled it without strong retrieval help + + return float(g) + diff --git a/src/personalization/feedback/handlers.py b/src/personalization/feedback/handlers.py new file mode 100644 index 0000000..f0468b6 --- /dev/null +++ b/src/personalization/feedback/handlers.py @@ -0,0 +1,87 @@ +from typing import Tuple, List, Optional +import numpy as np + +from personalization.retrieval.preference_store.schemas import MemoryCard +from personalization.feedback.schemas import TurnSample +from personalization.feedback.reward_model import estimate_reward +from personalization.feedback.gating import estimate_retrieval_gating +from personalization.feedback.llm_reward import ( + LLMRewardClient, LLMRewardConfig, RewardResult +) + + +def eval_step( + q_t: str, + answer_t: str, + q_t1: str, + memories_t: List[MemoryCard], + query_embedding_t: Optional[np.ndarray] = None, + query_embedding_t1: Optional[np.ndarray] = None, +) -> Tuple[float, float]: + """ + Keyword-based evaluation (legacy). + Given (q_t, a_t, q_{t+1}, memories), returns (reward_hat, gating_hat). + """ + mem_embs = None + if memories_t and memories_t[0].embedding_e: + try: + mem_embs = np.array([m.embedding_e for m in memories_t]) + except: + pass + + sample = TurnSample( + user_id="", + session_id="", + turn_id=0, + query_t=q_t, + answer_t=answer_t, + query_t1=q_t1, + memories=memories_t, + query_embedding_t=query_embedding_t, + query_embedding_t1=query_embedding_t1, + memory_embeddings=mem_embs, + ) + + r_hat = estimate_reward(sample) + g_hat = estimate_retrieval_gating(sample, r_hat) + + return r_hat, g_hat + + +async def eval_step_llm( + q_t: str, + answer_t: str, + q_t1: str, + memories_t: List[MemoryCard], + client: LLMRewardClient, + query_embedding_t: Optional[np.ndarray] = None, + query_embedding_t1: Optional[np.ndarray] = None, +) -> Tuple[float, float]: + """ + LLM-as-judge evaluation (async). + Returns (reward, gating) where gating=0.0 if update should be skipped. + + The gating signal is derived from the judge's confidence and label: + - If confidence < tau_c or label == topic_shift: gating = 0.0 + - Otherwise: gating = confidence (continuous, in [tau_c, 1.0]) + + This replaces the old heuristic gating with the judge's own confidence. + """ + sample = TurnSample( + user_id="", + session_id="", + turn_id=0, + query_t=q_t, + answer_t=answer_t, + query_t1=q_t1, + memories=memories_t, + query_embedding_t=query_embedding_t, + query_embedding_t1=query_embedding_t1, + ) + + result: RewardResult = await client.judge(sample) + + if result.should_update: + return result.reward, result.confidence + else: + return 0.0, 0.0 diff --git a/src/personalization/feedback/llm_reward.py b/src/personalization/feedback/llm_reward.py new file mode 100644 index 0000000..6adcf98 --- /dev/null +++ b/src/personalization/feedback/llm_reward.py @@ -0,0 +1,253 @@ +""" +LLM-as-Judge reward model using OpenAI GPT-5-nano (async for parallelism). + +Replaces keyword-based heuristic reward with structured LLM judgement. +Judge receives only (q_t, a_t, q_{t+1}) — no oracle preference cards, no history. + +Label taxonomy → scalar reward mapping: + neg_constraint_restate → -1.0 + neg_correction → -0.8 + neg_confusion → -0.6 + pos_praise → +0.8 + pos_progress → +0.1 + neutral → 0.0 + topic_shift → 0.0 (update skipped) + +Confidence gating: if confidence < tau_c, reward is set to 0 and update is skipped. +""" +from __future__ import annotations + +import asyncio +import hashlib +import json +import os +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple + +from openai import AsyncOpenAI, RateLimitError, APITimeoutError, APIConnectionError + +from personalization.feedback.schemas import TurnSample + + +# --- Label → Reward Mapping --- + +REWARD_MAP: Dict[str, float] = { + "neg_constraint_restate": -1.0, + "neg_correction": -0.8, + "neg_confusion": -0.6, + "pos_praise": +0.8, + "pos_progress": +0.1, + "neutral": 0.0, + "topic_shift": 0.0, +} + +VALID_LABELS = set(REWARD_MAP.keys()) + + +# --- Configuration --- + +@dataclass +class LLMRewardConfig: + model: str = "gpt-5-nano" + api_key: Optional[str] = None # Falls back to OPENAI_API_KEY env var + base_url: Optional[str] = None # For custom endpoints + max_concurrent: int = 32 # Semaphore limit for parallel requests + max_retries: int = 3 + retry_base_delay: float = 1.0 # Exponential backoff base (seconds) + timeout: float = 60.0 # Per-request timeout (reasoning models are slower) + max_completion_tokens: int = 2048 # Must be high — reasoning models use internal tokens + confidence_threshold: float = 0.6 # tau_c: skip update if confidence < this + enable_cache: bool = True # Cache by hash of (q_t, a_t, q_{t+1}) + + +# --- Prompt --- + +JUDGE_SYSTEM_PROMPT = """\ +You are a feedback classifier. Given a user query (q_t), the assistant's response (a_t), \ +and the user's next message (q_{t+1}), classify the user's follow-up into exactly one label. + +Labels (mutually exclusive): +- neg_constraint_restate: User reasserts constraints/preferences as correction (e.g., "as I said…", "remember…", "按我说的…"). +- neg_correction: User indicates the content is wrong or the assistant failed to answer. +- neg_confusion: User indicates confusion or requests re-explanation. +- pos_praise: Explicit praise or satisfaction with the response. +- pos_progress: Constructive continuation (examples, extensions, what-if, next steps) without complaint. +- neutral: Ambiguous or minimal feedback, not clearly positive or negative. +- topic_shift: User switches to a new unrelated task/topic. + +Output a JSON object with fields: label, confidence (0-1), rationale (one short sentence).""" + +JUDGE_USER_TEMPLATE = """\ +q_t: {query_t} + +a_t: {answer_t} + +q_{{t+1}}: {query_t1}""" + + +# --- Result Dataclass --- + +@dataclass +class RewardResult: + label: str + confidence: float + rationale: str + reward: float + should_update: bool # False if gated by confidence or topic_shift + + +# --- Async Client --- + +class LLMRewardClient: + """Async OpenAI client for LLM-as-judge reward estimation.""" + + def __init__(self, config: Optional[LLMRewardConfig] = None): + self.config = config or LLMRewardConfig() + self._client = AsyncOpenAI( + api_key=self.config.api_key or os.getenv("OPENAI_API_KEY"), + base_url=self.config.base_url, + timeout=self.config.timeout, + ) + self._semaphore = asyncio.Semaphore(self.config.max_concurrent) + self._cache: Dict[str, RewardResult] = {} + + def _cache_key(self, query_t: str, answer_t: str, query_t1: str) -> str: + """Deterministic hash of the judge input triple.""" + content = f"{query_t}\x00{answer_t}\x00{query_t1}" + return hashlib.sha256(content.encode("utf-8")).hexdigest() + + async def _call_with_retry(self, messages: List[dict]) -> str: + """Single LLM call with exponential backoff retry.""" + for attempt in range(self.config.max_retries): + try: + async with self._semaphore: + response = await self._client.chat.completions.create( + model=self.config.model, + messages=messages, + max_completion_tokens=self.config.max_completion_tokens, + response_format={"type": "json_object"}, + ) + content = response.choices[0].message.content + if content: + return content.strip() + # Reasoning model may exhaust tokens on thinking — retry + if response.choices[0].finish_reason == "length": + continue + return "" + except (RateLimitError, APITimeoutError, APIConnectionError) as e: + if attempt == self.config.max_retries - 1: + raise + delay = self.config.retry_base_delay * (2 ** attempt) + await asyncio.sleep(delay) + return "" + + def _build_messages(self, sample: TurnSample) -> List[dict]: + """Construct the judge prompt from (q_t, a_t, q_{t+1}) only.""" + user_content = JUDGE_USER_TEMPLATE.format( + query_t=sample.query_t, + answer_t=sample.answer_t, + query_t1=sample.query_t1, + ) + return [ + {"role": "system", "content": JUDGE_SYSTEM_PROMPT}, + {"role": "user", "content": user_content}, + ] + + def _parse_result(self, raw: str) -> RewardResult: + """Parse structured JSON output into RewardResult.""" + try: + parsed = json.loads(raw) + label = parsed["label"] + confidence = float(parsed["confidence"]) + rationale = parsed.get("rationale", "") + + if label not in VALID_LABELS: + label = "neutral" + confidence = 0.0 + + reward = REWARD_MAP[label] + + # Confidence gating and topic_shift skip + should_update = ( + confidence >= self.config.confidence_threshold + and label != "topic_shift" + ) + if not should_update: + reward = 0.0 + + return RewardResult( + label=label, + confidence=confidence, + rationale=rationale, + reward=reward, + should_update=should_update, + ) + except (json.JSONDecodeError, KeyError, TypeError, ValueError): + return RewardResult( + label="neutral", + confidence=0.0, + rationale="parse_failure", + reward=0.0, + should_update=False, + ) + + async def judge(self, sample: TurnSample) -> RewardResult: + """Judge a single turn (async). Returns RewardResult with gating applied.""" + # Cache lookup + if self.config.enable_cache: + key = self._cache_key(sample.query_t, sample.answer_t, sample.query_t1) + if key in self._cache: + return self._cache[key] + + messages = self._build_messages(sample) + raw = await self._call_with_retry(messages) + result = self._parse_result(raw) + + # Cache store + if self.config.enable_cache: + self._cache[key] = result + + return result + + async def judge_batch(self, samples: List[TurnSample]) -> List[RewardResult]: + """Judge a batch of turns in parallel. Returns list of RewardResult.""" + tasks = [self.judge(s) for s in samples] + return await asyncio.gather(*tasks) + + async def close(self): + """Close the underlying HTTP client.""" + await self._client.close() + + +# --- Synchronous Wrappers --- + +def estimate_reward_llm( + sample: TurnSample, + config: Optional[LLMRewardConfig] = None, +) -> Tuple[float, bool]: + """ + Synchronous single-sample reward estimation. + Returns (reward, should_update). + """ + client = LLMRewardClient(config) + try: + result = asyncio.run(client.judge(sample)) + return result.reward, result.should_update + finally: + asyncio.run(client.close()) + + +def estimate_rewards_batch( + samples: List[TurnSample], + config: Optional[LLMRewardConfig] = None, +) -> List[Tuple[float, bool]]: + """ + Synchronous batch reward estimation (runs async internally). + Returns list of (reward, should_update) tuples. + """ + client = LLMRewardClient(config) + try: + results = asyncio.run(client.judge_batch(samples)) + return [(r.reward, r.should_update) for r in results] + finally: + asyncio.run(client.close()) diff --git a/src/personalization/feedback/local_llm_reward.py b/src/personalization/feedback/local_llm_reward.py new file mode 100644 index 0000000..70bbeb8 --- /dev/null +++ b/src/personalization/feedback/local_llm_reward.py @@ -0,0 +1,370 @@ +""" +Local LLM reward model using vLLM server for batch inference. + +Drop-in replacement for LLMRewardClient when you want to use a local model +(e.g., Llama-3.1-8B-Instruct) instead of OpenAI API. + +Uses BatchVLLMClient for efficient concurrent requests - vLLM's continuous +batching will process them together for high throughput. +""" +from __future__ import annotations + +import asyncio +import hashlib +import json +from dataclasses import dataclass +from typing import Dict, List, Optional + +import aiohttp + +from personalization.feedback.schemas import TurnSample +from personalization.feedback.llm_reward import ( + REWARD_MAP, + VALID_LABELS, + JUDGE_SYSTEM_PROMPT, + JUDGE_USER_TEMPLATE, + RewardResult, +) + + +@dataclass +class LocalLLMRewardConfig: + """Configuration for local LLM reward model.""" + vllm_url: str = "http://localhost:8005/v1" # vLLM server URL + model_name: Optional[str] = None # Auto-discovered if None + max_tokens: int = 256 + temperature: float = 0.1 + max_concurrent: int = 100 # High concurrency for vLLM batching + timeout: Optional[int] = 60 # Per-request timeout in seconds + confidence_threshold: float = 0.6 # tau_c: skip update if confidence < this + enable_cache: bool = True # Cache by hash of (q_t, a_t, q_{t+1}) + + +class LocalLLMRewardClient: + """ + Local LLM reward client using vLLM server. + + Designed for batch processing - uses async HTTP requests that vLLM + batches together via continuous batching for high throughput. + """ + + def __init__(self, config: Optional[LocalLLMRewardConfig] = None): + self.config = config or LocalLLMRewardConfig() + self._model_name = self.config.model_name + self._cache: Dict[str, RewardResult] = {} + + # Discover model name if not provided + if self._model_name is None: + self._discover_model_sync() + + def _discover_model_sync(self): + """Synchronously discover model name from vLLM server.""" + import requests + try: + response = requests.get( + f"{self.config.vllm_url}/models", + timeout=10 + ) + response.raise_for_status() + models = response.json() + if models.get("data") and len(models["data"]) > 0: + self._model_name = models["data"][0]["id"] + else: + self._model_name = "default" + except Exception as e: + print(f"[LocalLLMReward] Warning: Could not discover model ({e})") + self._model_name = "default" + + def _cache_key(self, query_t: str, answer_t: str, query_t1: str) -> str: + """Deterministic hash of the judge input triple.""" + content = f"{query_t}\x00{answer_t}\x00{query_t1}" + return hashlib.sha256(content.encode("utf-8")).hexdigest() + + def _build_messages(self, sample: TurnSample) -> List[dict]: + """Construct the judge prompt from (q_t, a_t, q_{t+1}).""" + user_content = JUDGE_USER_TEMPLATE.format( + query_t=sample.query_t, + answer_t=sample.answer_t, + query_t1=sample.query_t1, + ) + return [ + {"role": "system", "content": JUDGE_SYSTEM_PROMPT}, + {"role": "user", "content": user_content}, + ] + + def _parse_result(self, raw: Optional[str]) -> RewardResult: + """Parse structured JSON output into RewardResult.""" + if raw is None: + return RewardResult( + label="neutral", + confidence=0.0, + rationale="request_failed", + reward=0.0, + should_update=False, + ) + + try: + # Handle markdown code blocks + text = raw.strip() + if text.startswith("```"): + lines = text.split("\n") + text = "\n".join( + lines[1:-1] if lines[-1].strip() == "```" else lines[1:] + ) + + parsed = json.loads(text) + label = parsed.get("label", "neutral") + confidence = float(parsed.get("confidence", 0.0)) + rationale = parsed.get("rationale", "") + + if label not in VALID_LABELS: + label = "neutral" + confidence = 0.0 + + reward = REWARD_MAP[label] + + # Confidence gating and topic_shift skip + should_update = ( + confidence >= self.config.confidence_threshold + and label != "topic_shift" + ) + if not should_update: + reward = 0.0 + + return RewardResult( + label=label, + confidence=confidence, + rationale=rationale, + reward=reward, + should_update=should_update, + ) + except (json.JSONDecodeError, KeyError, TypeError, ValueError): + # Try to extract JSON from text + import re + match = re.search(r'\{[^}]+\}', raw, re.DOTALL) + if match: + try: + parsed = json.loads(match.group()) + label = parsed.get("label", "neutral") + confidence = float(parsed.get("confidence", 0.0)) + rationale = parsed.get("rationale", "") + + if label not in VALID_LABELS: + label = "neutral" + confidence = 0.0 + + reward = REWARD_MAP[label] + should_update = ( + confidence >= self.config.confidence_threshold + and label != "topic_shift" + ) + if not should_update: + reward = 0.0 + + return RewardResult( + label=label, + confidence=confidence, + rationale=rationale, + reward=reward, + should_update=should_update, + ) + except: + pass + + return RewardResult( + label="neutral", + confidence=0.0, + rationale="parse_failure", + reward=0.0, + should_update=False, + ) + + async def _single_request( + self, + session: aiohttp.ClientSession, + messages: List[dict], + idx: int, + ) -> tuple: + """Make a single async request to vLLM server.""" + payload = { + "model": self._model_name, + "messages": messages, + "max_tokens": self.config.max_tokens, + "temperature": self.config.temperature, + "response_format": {"type": "json_object"}, + } + + for attempt in range(3): + try: + timeout_config = ( + aiohttp.ClientTimeout(total=self.config.timeout) + if self.config.timeout else None + ) + async with session.post( + f"{self.config.vllm_url}/chat/completions", + json=payload, + timeout=timeout_config, + ) as response: + if response.status == 200: + result = await response.json() + content = result["choices"][0]["message"]["content"] + return (idx, content, None) + elif response.status == 429: + # Rate limit - wait and retry + await asyncio.sleep(2 ** attempt) + continue + else: + error_text = await response.text() + return (idx, None, f"HTTP {response.status}: {error_text[:200]}") + except asyncio.TimeoutError: + if attempt < 2: + continue + return (idx, None, "Timeout") + except Exception as e: + return (idx, None, str(e)) + + return (idx, None, "Max retries exceeded") + + async def judge_batch_async( + self, + samples: List[TurnSample], + show_progress: bool = False, + ) -> List[RewardResult]: + """ + Judge a batch of turns using concurrent vLLM requests. + + vLLM's continuous batching will process these together for + high throughput. + """ + n_samples = len(samples) + results = [None] * n_samples + + # Check cache and build request list + to_request = [] # (original_idx, messages) + for i, sample in enumerate(samples): + if self.config.enable_cache: + key = self._cache_key(sample.query_t, sample.answer_t, sample.query_t1) + if key in self._cache: + results[i] = self._cache[key] + continue + + messages = self._build_messages(sample) + to_request.append((i, messages)) + + if not to_request: + return results + + # Make concurrent requests + semaphore = asyncio.Semaphore(self.config.max_concurrent) + + async def limited_request(session, messages, idx): + async with semaphore: + return await self._single_request(session, messages, idx) + + connector = aiohttp.TCPConnector(limit=self.config.max_concurrent) + headers = {"Content-Type": "application/json"} + + async with aiohttp.ClientSession( + connector=connector, headers=headers + ) as session: + tasks = [ + limited_request(session, messages, orig_idx) + for orig_idx, messages in to_request + ] + + completed = 0 + for coro in asyncio.as_completed(tasks): + orig_idx, content, error = await coro + completed += 1 + + if error: + print(f"[LocalLLMReward] Request {orig_idx} failed: {error}") + + result = self._parse_result(content) + results[orig_idx] = result + + # Cache the result + if self.config.enable_cache: + sample = samples[orig_idx] + key = self._cache_key( + sample.query_t, sample.answer_t, sample.query_t1 + ) + self._cache[key] = result + + if show_progress and completed % 10 == 0: + print(f" [LocalLLMReward {completed}/{len(to_request)}] completed") + + return results + + async def judge_async(self, sample: TurnSample) -> RewardResult: + """Judge a single turn (async).""" + results = await self.judge_batch_async([sample]) + return results[0] + + def judge_batch(self, samples: List[TurnSample]) -> List[RewardResult]: + """ + Judge a batch of turns (sync wrapper). + + This is the main entry point for batch reward estimation. + """ + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop is not None: + # Already in an event loop - create a new thread to run the coroutine + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, self.judge_batch_async(samples)) + return future.result() + else: + return asyncio.run(self.judge_batch_async(samples)) + + async def judge(self, sample: TurnSample) -> RewardResult: + """Judge a single turn (async interface for compatibility with LLMRewardClient).""" + return await self.judge_async(sample) + + def judge_sync(self, sample: TurnSample) -> RewardResult: + """Judge a single turn (sync wrapper).""" + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop is not None: + # Already in an event loop - create a new thread to run the coroutine + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, self.judge_async(sample)) + return future.result() + else: + return asyncio.run(self.judge_async(sample)) + + +# --- Convenience Functions --- + +def estimate_reward_local( + sample: TurnSample, + config: Optional[LocalLLMRewardConfig] = None, +) -> tuple: + """ + Synchronous single-sample reward estimation using local LLM. + Returns (reward, should_update). + """ + client = LocalLLMRewardClient(config) + result = client.judge(sample) + return result.reward, result.should_update + + +def estimate_rewards_batch_local( + samples: List[TurnSample], + config: Optional[LocalLLMRewardConfig] = None, +) -> List[tuple]: + """ + Synchronous batch reward estimation using local LLM. + Returns list of (reward, should_update) tuples. + """ + client = LocalLLMRewardClient(config) + results = client.judge_batch(samples) + return [(r.reward, r.should_update) for r in results] diff --git a/src/personalization/feedback/online_update.py b/src/personalization/feedback/online_update.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/feedback/online_update.py diff --git a/src/personalization/feedback/reward_model.py b/src/personalization/feedback/reward_model.py new file mode 100644 index 0000000..3584b43 --- /dev/null +++ b/src/personalization/feedback/reward_model.py @@ -0,0 +1,64 @@ +import numpy as np +from personalization.feedback.schemas import TurnSample + +def cosine_sim(a: np.ndarray, b: np.ndarray) -> float: + norm_a = np.linalg.norm(a) + norm_b = np.linalg.norm(b) + if norm_a == 0 or norm_b == 0: + return 0.0 + return float(np.dot(a, b) / (norm_a * norm_b)) + +def estimate_reward(sample: TurnSample) -> float: + """ + Return a scalar reward_hat, indicating if the previous answer was helpful. + Range: [-1.0, 1.0] (approx) + """ + + # 1. Language/Topic Coherence + if sample.query_embedding_t is None or sample.query_embedding_t1 is None: + topic_sim = 0.5 + else: + topic_sim = cosine_sim(sample.query_embedding_t, sample.query_embedding_t1) + + # 2. Negative Keywords (Complaint/Correction) + negative_keywords = [ + "you didn't", "that's not", "incorrect", "redo", "again", "explain more", + "doesn't help", "wrong", "no", "not what i asked", + "你没", "不是", "这不是", "重来", "重新", "不对", "错了", "没说清楚" + ] + + # 3. Positive Keywords (Follow-up/Elaboration) + positive_keywords = [ + "can you elaborate", "give an example", "continue", "what if", "based on that", + "thanks", "good", "great", "cool", + "能不能详细一点", "举个例子", "再继续", "那如果", "接下来", "在这个基础上", "谢谢", "不错" + ] + + q1_lower = sample.query_t1.lower() + + has_negative = any(kw in q1_lower for kw in negative_keywords) + has_positive = any(kw in q1_lower for kw in positive_keywords) + + reward = 0.0 + + if has_negative: + reward -= 1.0 + + if has_positive: + # Only reward if topic similarity is decent, otherwise might be "thanks, bye" (end of session) + # But "thanks" is good. + reward += 0.5 + if topic_sim > 0.3: + reward += 0.5 + + if topic_sim < 0.2: + # Topic shift -> previous interaction likely finished or failed. + # If no explicit positive/negative, assume neutral/slightly decayed. + # If user changes topic, it often means the previous task is done (neutral/positive) + # OR they gave up (negative). Hard to tell. + # Let's dampen the reward towards 0. + reward *= 0.5 + + # Clip + return max(-1.0, min(1.0, reward)) + diff --git a/src/personalization/feedback/sampler.py b/src/personalization/feedback/sampler.py new file mode 100644 index 0000000..9e26912 --- /dev/null +++ b/src/personalization/feedback/sampler.py @@ -0,0 +1,109 @@ +from typing import Iterable, List, Optional +import numpy as np +from tqdm import tqdm + +from personalization.retrieval.preference_store.schemas import ChatTurn, MemoryCard +from personalization.feedback.schemas import TurnSample +from personalization.retrieval.pipeline import retrieve_with_rerank +from personalization.models.llm.qwen_instruct import QwenInstruct +from personalization.models.embedding.base import EmbeddingModel +from personalization.models.reranker.base import Reranker +from personalization.user_model.tensor_store import UserTensorStore + +def build_turn_samples_from_sessions( + sessions: Iterable[List[ChatTurn]], + embed_model: EmbeddingModel, + llm: QwenInstruct, + reranker: Reranker, + memory_cards: List[MemoryCard], + memory_embeddings: np.ndarray, + user_store: UserTensorStore, + item_vectors: np.ndarray, + max_samples: Optional[int] = None, + topk_dense: int = 64, + topk_rerank: int = 3, +) -> List[TurnSample]: + samples = [] + + for turns in tqdm(sessions, desc="Building TurnSamples"): + if max_samples and len(samples) >= max_samples: + break + + # Ensure sorted by turn_id + sorted_turns = sorted(turns, key=lambda x: x.turn_id) + + # Iterate to find (q_t, a_t, q_{t+1}) + for i in range(len(sorted_turns)): + if max_samples and len(samples) >= max_samples: + break + + q_t = sorted_turns[i] + if q_t.role != "user": + continue + + # Find next user turn + # Also try to find assistant response in between + a_t_text = "" + q_t1 = None + + # Look ahead + for j in range(i + 1, len(sorted_turns)): + next_turn = sorted_turns[j] + if next_turn.role == "assistant" and not a_t_text: + a_t_text = next_turn.text + elif next_turn.role == "user": + q_t1 = next_turn + break + + if not q_t1: + # End of session or no subsequent user query + continue + + # We have q_t, a_t (optional but preferred), q_t1 + # If a_t is missing, we might skip or use empty string. + # For RL, we usually need the answer to evaluate quality. + # If dataset doesn't have assistant turns, we might need to generate one? + # For now, let's proceed even if a_t is empty, or maybe require it. + if not a_t_text: + # Try to use LLM to generate if needed, but for offline sampling + # from existing chats, we prefer existing answers. + # If using OASST1, it should have assistant turns. + pass + + # 3. Retrieve memories for q_t + memories_t = retrieve_with_rerank( + user_id=q_t.user_id, + query=q_t.text, + embed_model=embed_model, + reranker=reranker, + memory_cards=memory_cards, + memory_embeddings=memory_embeddings, + user_store=user_store, + item_vectors=item_vectors, + topk_dense=topk_dense, + topk_rerank=topk_rerank, + beta_long=0.0, + beta_short=0.0, + only_own_memories=True # Assume we want user specific memories + ) + + # 4. Precompute embeddings + # We can do this efficiently later or batch, but here per sample + e_q_t = embed_model.encode([q_t.text], return_tensor=False)[0] + e_q_t1 = embed_model.encode([q_t1.text], return_tensor=False)[0] + + sample = TurnSample( + user_id=q_t.user_id, + session_id=q_t.session_id, + turn_id=q_t.turn_id, + query_t=q_t.text, + answer_t=a_t_text, + query_t1=q_t1.text, + memories=memories_t, + query_embedding_t=np.array(e_q_t), + query_embedding_t1=np.array(e_q_t1) + ) + samples.append(sample) + + return samples + diff --git a/src/personalization/feedback/schemas.py b/src/personalization/feedback/schemas.py new file mode 100644 index 0000000..b15db80 --- /dev/null +++ b/src/personalization/feedback/schemas.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Optional, Any +import numpy as np + +from personalization.retrieval.preference_store.schemas import MemoryCard + +@dataclass +class TurnSample: + user_id: str + session_id: str + turn_id: int # index of q_t within the session + query_t: str # q_t + answer_t: str # a_t + query_t1: str # q_{t+1} + memories: List[MemoryCard] # A_t + + # Optional pre-computed vectors and features + query_embedding_t: Optional[np.ndarray] = None + query_embedding_t1: Optional[np.ndarray] = None + memory_embeddings: Optional[np.ndarray] = None # corresponding e_m or v_m for memories + diff --git a/src/personalization/models/__init__.py b/src/personalization/models/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/models/__init__.py diff --git a/src/personalization/models/embedding/__init__.py b/src/personalization/models/embedding/__init__.py new file mode 100644 index 0000000..05221aa --- /dev/null +++ b/src/personalization/models/embedding/__init__.py @@ -0,0 +1,11 @@ +from .base import EmbeddingModel +from .qwen3_8b import Qwen3Embedding8B +from .nemotron_8b import LlamaEmbedNemotron8B + +__all__ = [ + "EmbeddingModel", + "Qwen3Embedding8B", + "LlamaEmbedNemotron8B", +] + + diff --git a/src/personalization/models/embedding/base.py b/src/personalization/models/embedding/base.py new file mode 100644 index 0000000..9f9d4d1 --- /dev/null +++ b/src/personalization/models/embedding/base.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Iterable, List, Sequence + +import torch + + +class EmbeddingModel(ABC): + @abstractmethod + def encode( + self, + texts: Sequence[str], + batch_size: int = 8, + max_length: int = 512, + normalize: bool = True, + return_tensor: bool = False, + ) -> List[List[float]] | torch.Tensor: + """Encode a batch of texts into dense embeddings.""" + raise NotImplementedError + + +def _mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + # last_hidden_state: [batch, seq_len, hidden] + # attention_mask: [batch, seq_len] + mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state) # [b, s, 1] + summed = (last_hidden_state * mask).sum(dim=1) + counts = mask.sum(dim=1).clamp_min(1e-6) + return summed / counts + + +def _maybe_normalize(x: torch.Tensor, normalize: bool) -> torch.Tensor: + if not normalize: + return x + return torch.nn.functional.normalize(x, p=2, dim=-1) + + diff --git a/src/personalization/models/embedding/qwen3_8b.py b/src/personalization/models/embedding/qwen3_8b.py new file mode 100644 index 0000000..fb02e67 --- /dev/null +++ b/src/personalization/models/embedding/qwen3_8b.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from typing import List, Sequence + +import torch +from transformers import AutoModel, AutoTokenizer + +from personalization.config.registry import choose_dtype, choose_device_map +from personalization.config.settings import LocalModelsConfig +from .base import EmbeddingModel, _mean_pool, _maybe_normalize + + +class Qwen3Embedding8B(EmbeddingModel): + def __init__( + self, + model_path: str, + dtype: torch.dtype, + device_map: str = "auto", + trust_remote_code: bool = True, + ) -> None: + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=True, trust_remote_code=trust_remote_code + ) + + # Handle specific device assignment (e.g., "cuda:0", "cuda:1") + if device_map and device_map.startswith("cuda:"): + # Load to CPU first, then move to specific GPU + self.model = AutoModel.from_pretrained( + model_path, + torch_dtype=dtype, + device_map=None, # Don't use accelerate's device_map + trust_remote_code=trust_remote_code, + low_cpu_mem_usage=True, + ) + self.model = self.model.to(device_map) + else: + # Use accelerate's auto device mapping + self.model = AutoModel.from_pretrained( + model_path, + torch_dtype=dtype, + device_map=device_map, + trust_remote_code=trust_remote_code, + low_cpu_mem_usage=True, + ) + + @classmethod + def from_config(cls, cfg: LocalModelsConfig) -> "Qwen3Embedding8B": + if not cfg.embedding or not cfg.embedding.qwen3: + raise ValueError("Embedding config for qwen3 is missing") + spec = cfg.embedding.qwen3 + dtype = choose_dtype(spec.dtype) + device_map = choose_device_map(spec.device_map) + return cls( + spec.local_path, + dtype=dtype, + device_map=device_map, + trust_remote_code=True, + ) + + @torch.inference_mode() + def encode( + self, + texts: Sequence[str], + batch_size: int = 8, + max_length: int = 512, + normalize: bool = True, + return_tensor: bool = False, + ) -> List[List[float]] | torch.Tensor: + device = next(self.model.parameters()).device + outputs: List[torch.Tensor] = [] + for i in range(0, len(texts), batch_size): + batch = list(texts[i : i + batch_size]) + enc = self.tokenizer( + batch, + padding=True, + truncation=True, + max_length=max_length, + return_tensors="pt", + ).to(device) + model_out = self.model(**enc, output_hidden_states=False, return_dict=True) + pooled = _mean_pool(model_out.last_hidden_state, enc["attention_mask"]) # type: ignore[attr-defined] + pooled = _maybe_normalize(pooled, normalize) + outputs.append(pooled) + emb = torch.cat(outputs, dim=0) + if return_tensor: + return emb + return emb.cpu().to(torch.float32).tolist() + + diff --git a/src/personalization/models/llm/__init__.py b/src/personalization/models/llm/__init__.py new file mode 100644 index 0000000..3f1af81 --- /dev/null +++ b/src/personalization/models/llm/__init__.py @@ -0,0 +1,4 @@ +from .qwen_instruct import QwenInstruct + +__all__ = ["QwenInstruct"] + diff --git a/src/personalization/models/llm/base.py b/src/personalization/models/llm/base.py new file mode 100644 index 0000000..72b6ca8 --- /dev/null +++ b/src/personalization/models/llm/base.py @@ -0,0 +1,29 @@ +from typing import List, Protocol, Optional +from personalization.types import ChatTurn + +class ChatModel(Protocol): + def answer( + self, + history: List[ChatTurn], + memory_notes: List[str], + max_new_tokens: int = 512, + temperature: float = 0.7, + top_p: float = 0.9, + top_k: Optional[int] = None, + ) -> str: + """ + Generate an assistant response given conversation history and memory notes. + + Args: + history: The conversation history ending with the current user turn. + memory_notes: List of retrieved memory content strings. + max_new_tokens: Max tokens to generate. + temperature: Sampling temperature. + top_p: Top-p sampling. + top_k: Top-k sampling. + + Returns: + The generated assistant response text. + """ + ... + diff --git a/src/personalization/models/llm/prompt_builder.py b/src/personalization/models/llm/prompt_builder.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/models/llm/prompt_builder.py diff --git a/src/personalization/models/llm/vllm_chat.py b/src/personalization/models/llm/vllm_chat.py new file mode 100644 index 0000000..d577a30 --- /dev/null +++ b/src/personalization/models/llm/vllm_chat.py @@ -0,0 +1,244 @@ +""" +vLLM-based ChatModel implementation for high-throughput inference. + +This provides the same interface as LlamaChatModel but uses vLLM HTTP API +for much faster inference (3000+ sessions/hr vs 20 sessions/hr). +""" + +from typing import List, Optional +import time +import requests + +from personalization.models.llm.base import ChatModel +from personalization.types import ChatTurn + + +class VLLMChatModel(ChatModel): + """ + ChatModel implementation using vLLM HTTP API. + + This is a drop-in replacement for LlamaChatModel that uses vLLM + for much faster inference. + """ + + def __init__( + self, + vllm_url: str = "http://localhost:8003/v1", + model_name: str = None, + max_context_length: int = 8192, + timeout: int = 120, + ): + self.vllm_url = vllm_url.rstrip('/') + self.model_name = model_name + self.max_context_length = max_context_length + self.timeout = timeout + + # Discover model name if not provided + if self.model_name is None: + self._discover_model() + + def _discover_model(self): + """Discover the model name from the vLLM server.""" + max_retries = 30 + for attempt in range(max_retries): + try: + response = requests.get(f"{self.vllm_url}/models", timeout=10) + response.raise_for_status() + models = response.json() + if models.get("data") and len(models["data"]) > 0: + self.model_name = models["data"][0]["id"] + return + except Exception as e: + if attempt < max_retries - 1: + wait_time = min(2 ** attempt * 0.5, 10) + time.sleep(wait_time) + + # Fallback + self.model_name = "default" + print(f"[VLLMChatModel] Warning: Could not discover model, using '{self.model_name}'") + + def health_check(self) -> bool: + """Check if the vLLM server is healthy.""" + try: + response = requests.get(f"{self.vllm_url.replace('/v1', '')}/health", timeout=5) + return response.status_code == 200 + except: + return False + + def _estimate_tokens(self, text: str) -> int: + """Estimate token count using character-based heuristic. + + For Llama models, ~4 characters per token is a reasonable estimate. + We use 3.5 to be conservative (slightly overestimate tokens). + """ + return int(len(text) / 3.5) + + def _build_messages( + self, + history: List[ChatTurn], + memory_notes: List[str], + max_new_tokens: int = 512, + global_notes: List[str] = None, + ) -> List[dict]: + """Build messages list for chat completion API with auto-truncation. + + If the context exceeds max_context_length, older conversation turns + are removed to keep only the most recent context that fits. + + Args: + global_notes: If provided, these are always-applicable preferences + displayed in a separate section from task-specific retrieved notes. + """ + # Use CollaborativeAgents-style system prompt + has_any_notes = memory_notes or global_notes + if has_any_notes: + # Build preference sections + pref_sections = "" + if global_notes: + global_bullet = "\n".join(f"- {n}" for n in global_notes) + pref_sections += f"## General Preferences (always apply)\n{global_bullet}\n\n" + if memory_notes: + task_bullet = "\n".join(f"- {n}" for n in memory_notes) + if global_notes: + pref_sections += f"## Task-Specific Preferences\n{task_bullet}\n" + else: + pref_sections += f"{task_bullet}\n" + + system_content = ( + "You are a collaborative AI agent helping users solve writing, question answering, math, and coding problems.\n\n" + "# User Preferences\n" + "The user has a set of preferences for how you should behave. If you do not follow these preferences, " + "the user will be unable to learn from your response and you will need to adjust your response to adhere " + "to these preferences (so it is best to follow them initially).\n\n" + "**IMPORTANT**: If the user explicitly requests something in THIS conversation (e.g., asks you to change " + "your format, style, or approach), that request takes PRIORITY over the remembered preferences below. " + "Always adapt to the user's direct feedback first.\n\n" + "Based on your past interactions with the user, you have maintained a set of notes about the user's preferences:\n" + f"{pref_sections}\n" + "# Before Responding\n" + "Before writing your response, briefly consider:\n" + "1. Which preferences above are relevant to this specific request?\n" + "2. How will you satisfy each relevant preference in your response?\n\n" + "# Conversation Guidelines:\n" + "- If the user asks you to adjust your response (e.g., 'be more concise', 'focus on intuition'), you MUST change your approach accordingly. Do NOT repeat the same response.\n" + "- If the user's message is unclear, lacks details, or is ambiguous (e.g. length of an essay, format requirements, " + "specific constraints), do not make assumptions. Ask for clarification and ensure you have enough information before providing an answer.\n" + "- Your goal is to help the user solve their problem. Adhere to their preferences and do your best to help them solve their problem.\n" + "- **Verify**: Before finalizing, check that your response satisfies the relevant preferences listed above.\n" + ) + else: + # Vanilla mode - no preferences + system_content = ( + "You are a collaborative AI agent helping users solve writing, question answering, math, and coding problems.\n\n" + "# Conversation Guidelines:\n" + "- If the user's message is unclear, lacks details, or is ambiguous (e.g. length of an essay, format requirements, " + "specific constraints), do not make assumptions. Ask for clarification and ensure you have enough information before providing an answer.\n" + "- Your goal is to help the user solve their problem. Do your best to help them.\n" + ) + system_message = {"role": "system", "content": system_content} + + # Calculate available tokens for conversation history + # Reserve space for: system prompt + max_new_tokens + safety margin + system_tokens = self._estimate_tokens(system_content) + available_tokens = self.max_context_length - system_tokens - max_new_tokens - 100 # 100 token safety margin + + # Build conversation messages from history + conversation_messages = [] + for turn in history: + conversation_messages.append({"role": turn.role, "content": turn.text}) + + # Check if truncation is needed + total_conv_tokens = sum(self._estimate_tokens(m["content"]) for m in conversation_messages) + + if total_conv_tokens > available_tokens: + # Truncate from the beginning (keep recent messages) + truncated_messages = [] + current_tokens = 0 + + # Iterate from most recent to oldest + for msg in reversed(conversation_messages): + msg_tokens = self._estimate_tokens(msg["content"]) + if current_tokens + msg_tokens <= available_tokens: + truncated_messages.insert(0, msg) + current_tokens += msg_tokens + else: + # Stop adding older messages + break + + conversation_messages = truncated_messages + if len(truncated_messages) < len(history): + print(f"[VLLMChatModel] Truncated context: kept {len(truncated_messages)}/{len(history)} turns " + f"({current_tokens}/{total_conv_tokens} estimated tokens)") + + messages = [system_message] + conversation_messages + return messages + + def build_messages( + self, + history: List[ChatTurn], + memory_notes: List[str], + max_new_tokens: int = 512, + global_notes: List[str] = None, + ) -> List[dict]: + """Public method to build messages without calling the API. + + Used for batch processing where messages are collected first, + then sent in batch to vLLM for concurrent processing. + """ + return self._build_messages(history, memory_notes, max_new_tokens, global_notes=global_notes) + + def answer( + self, + history: List[ChatTurn], + memory_notes: List[str], + max_new_tokens: int = 512, + temperature: float = 0.7, + top_p: float = 0.9, + top_k: Optional[int] = None, + ) -> str: + """Generate a response using vLLM HTTP API.""" + messages = self._build_messages(history, memory_notes, max_new_tokens) + + payload = { + "model": self.model_name, + "messages": messages, + "max_tokens": max_new_tokens, + "temperature": temperature, + "top_p": top_p, + } + + # Retry with exponential backoff + max_retries = 5 + for attempt in range(max_retries): + try: + response = requests.post( + f"{self.vllm_url}/chat/completions", + json=payload, + timeout=self.timeout + ) + + if response.status_code == 200: + result = response.json() + return result["choices"][0]["message"]["content"] + elif response.status_code == 400: + error_text = response.text + # Handle context length error + if "max_tokens" in error_text and max_new_tokens > 64: + payload["max_tokens"] = max(64, max_new_tokens // 2) + continue + raise RuntimeError(f"vLLM error: {error_text[:200]}") + else: + raise RuntimeError(f"vLLM HTTP {response.status_code}: {response.text[:200]}") + + except requests.exceptions.Timeout: + if attempt < max_retries - 1: + time.sleep(2 ** attempt) + continue + raise RuntimeError("vLLM request timeout") + except requests.exceptions.ConnectionError as e: + if attempt < max_retries - 1: + time.sleep(2 ** attempt) + continue + raise RuntimeError(f"vLLM connection error: {e}") + + raise RuntimeError("Max retries exceeded for vLLM request") diff --git a/src/personalization/models/preference_extractor/__init__.py b/src/personalization/models/preference_extractor/__init__.py new file mode 100644 index 0000000..65e2595 --- /dev/null +++ b/src/personalization/models/preference_extractor/__init__.py @@ -0,0 +1,5 @@ +from .rule_extractor import QwenRuleExtractor +from .gpt4o_extractor import GPT4OExtractor +from .base import PreferenceExtractor + +__all__ = ["QwenRuleExtractor", "GPT4OExtractor", "PreferenceExtractor"] diff --git a/src/personalization/models/preference_extractor/base.py b/src/personalization/models/preference_extractor/base.py new file mode 100644 index 0000000..850292f --- /dev/null +++ b/src/personalization/models/preference_extractor/base.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict, List +from personalization.retrieval.preference_store.schemas import ChatTurn, PreferenceList + +class PreferenceExtractorBase(ABC): + @abstractmethod + def extract_turn(self, turns: List[ChatTurn]) -> PreferenceList: + """ + Extract preferences from a window of chat turns (history + current query). + """ + raise NotImplementedError + +# Alias for backward compatibility if needed, +# though specific extractors should inherit from PreferenceExtractorBase now. +PreferenceExtractor = PreferenceExtractorBase diff --git a/src/personalization/models/preference_extractor/gpt4o_extractor.py b/src/personalization/models/preference_extractor/gpt4o_extractor.py new file mode 100644 index 0000000..0f70522 --- /dev/null +++ b/src/personalization/models/preference_extractor/gpt4o_extractor.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +import json +import os +from typing import Any, Dict, List + +from openai import OpenAI +from personalization.config.settings import LocalModelsConfig +from personalization.models.preference_extractor.base import PreferenceExtractorBase as PreferenceExtractor +from personalization.retrieval.preference_store.schemas import ( + ChatTurn, + PreferenceList, + preference_list_json_schema, +) + + +class GPT4OExtractor(PreferenceExtractor): + def __init__(self, api_key: str, model: str = "gpt-4o") -> None: + self.client = OpenAI(api_key=api_key) + self.model = model + + # Load system prompt template + template_path = "fine_tuning_prompt_template.txt" + if os.path.exists(template_path): + with open(template_path, "r", encoding="utf-8") as f: + self.system_prompt = f.read() + else: + # Structured prompt that enforces the PreferenceList schema + self.system_prompt = ( + "You are a preference extraction assistant. " + "Given a user message, extract any user preferences as condition-action rules.\n\n" + "Return a JSON object with exactly this structure:\n" + '{"preferences": [{"condition": "<when this applies>", "action": "<what to do>", "confidence": <0.0-1.0>}]}\n\n' + "Examples of preferences:\n" + '- {"condition": "general", "action": "respond in Chinese", "confidence": 0.9}\n' + '- {"condition": "when writing code", "action": "use Python with type hints", "confidence": 0.8}\n' + '- {"condition": "when explaining math", "action": "show step-by-step derivation", "confidence": 0.7}\n\n' + "If no preferences are found, return {\"preferences\": []}.\n" + "IMPORTANT: The output MUST be a JSON object with a \"preferences\" key containing a list." + ) + + @classmethod + def from_config(cls, cfg: LocalModelsConfig) -> "GPT4OExtractor": + # We rely on env var for API key, config for other potential settings if needed + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + raise ValueError("OPENAI_API_KEY environment variable not set") + return cls(api_key=api_key) + + def build_preference_prompt(self, query: str) -> str: + # GPT4OExtractor uses the system prompt loaded in __init__ + return self.system_prompt + + def _call_kwargs(self, messages): + """Build kwargs for chat completion, skipping temperature for models that don't support it.""" + kwargs = { + "model": self.model, + "messages": messages, + "response_format": {"type": "json_object"}, + } + # GPT-5 series doesn't support temperature=0 + if not self.model.startswith("gpt-5"): + kwargs["temperature"] = 0.0 + return kwargs + + def extract_preferences(self, query: str) -> Dict[str, Any]: + # Reuse logic but return raw dict + try: + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": query}, + ] + response = self.client.chat.completions.create(**self._call_kwargs(messages)) + content = response.choices[0].message.content + if content: + return json.loads(content) + except Exception as e: + print(f"Error calling GPT-4o: {e}") + return {"preferences": []} + + def extract_turn(self, turns) -> PreferenceList: + # Accept both a single ChatTurn and a list of ChatTurns (history) + if isinstance(turns, list): + # Find the last user message in history + last_user_msg = None + for t in reversed(turns): + if hasattr(t, 'role') and t.role == "user": + last_user_msg = t.text + break + if not last_user_msg: + return PreferenceList(preferences=[]) + else: + # Single ChatTurn + if turns.role != "user": + return PreferenceList(preferences=[]) + last_user_msg = turns.text + + try: + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": last_user_msg}, + ] + response = self.client.chat.completions.create(**self._call_kwargs(messages)) + + content = response.choices[0].message.content + if not content: + return PreferenceList(preferences=[]) + + data = json.loads(content) + return self._parse_to_preference_list(data) + + except Exception as e: + print(f"Error calling GPT-4o: {e}") + return PreferenceList(preferences=[]) + + @staticmethod + def _parse_to_preference_list(data: dict) -> PreferenceList: + """Robustly convert GPT output to PreferenceList, handling non-standard formats.""" + # Best case: already matches schema + if "preferences" in data and isinstance(data["preferences"], list): + prefs = [] + for item in data["preferences"]: + if isinstance(item, dict) and "condition" in item and "action" in item: + prefs.append({ + "condition": str(item["condition"])[:128], + "action": str(item["action"])[:256], + "confidence": float(item.get("confidence", 0.7)), + }) + return PreferenceList.model_validate({"preferences": prefs}) + + # GPT returned a flat dict of preferences - convert to condition/action pairs + prefs = [] + for key, value in data.items(): + if isinstance(value, str) and len(value) > 2: + prefs.append({ + "condition": str(key)[:128] if len(str(key)) > 1 else "general", + "action": str(value)[:256], + "confidence": 0.7, + }) + elif isinstance(value, dict): + # Nested dict: try to extract meaningful pairs + for sub_key, sub_val in value.items(): + if isinstance(sub_val, str) and len(sub_val) > 2: + prefs.append({ + "condition": str(sub_key)[:128], + "action": str(sub_val)[:256], + "confidence": 0.7, + }) + elif isinstance(value, list): + for item in value: + if isinstance(item, str) and len(item) > 2: + prefs.append({ + "condition": str(key)[:128], + "action": str(item)[:256], + "confidence": 0.7, + }) + + return PreferenceList.model_validate({"preferences": prefs[:20]}) + + def extract_session(self, turns: List[ChatTurn]) -> List[PreferenceList]: + results = [] + for turn in turns: + results.append(self.extract_turn(turn)) + return results + diff --git a/src/personalization/models/preference_extractor/llm_extractor.py b/src/personalization/models/preference_extractor/llm_extractor.py new file mode 100644 index 0000000..8f7a6cb --- /dev/null +++ b/src/personalization/models/preference_extractor/llm_extractor.py @@ -0,0 +1,153 @@ +from typing import List, Dict, Any +import torch +import json +import os +from transformers import AutoModelForCausalLM, AutoTokenizer + +from personalization.models.preference_extractor.base import PreferenceExtractorBase +from personalization.retrieval.preference_store.schemas import ChatTurn, PreferenceList +from personalization.config.settings import LocalModelsConfig +from personalization.config.registry import choose_dtype, choose_device_map + +class PreferenceExtractorLLM(PreferenceExtractorBase): + def __init__( + self, + model_path: str, + prompt_template_path: str = "fine_tuning_prompt_template.txt", + device_map: str = "auto", + dtype: torch.dtype = torch.bfloat16, + max_new_tokens: int = 512, + ) -> None: + self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=dtype, + device_map=device_map, + trust_remote_code=True, + ) + self.max_new_tokens = max_new_tokens + + if os.path.exists(prompt_template_path): + with open(prompt_template_path, "r", encoding="utf-8") as f: + self.prompt_template = f.read() + else: + print(f"Warning: Prompt template not found at {prompt_template_path}. Using fallback.") + self.prompt_template = "Extract user preferences from the following conversation." + + @classmethod + def from_config(cls, cfg: LocalModelsConfig, name: str = "qwen3_0_6b_sft") -> "PreferenceExtractorLLM": + # We need to access the specific extractor config by name + # Assuming cfg has a way to access extra configs or we update LocalModelsConfig to support multiple extractors + # For now, let's look for it in the 'preference_extractor' dict if it was a Dict, but it is a ModelSpec. + # We need to update LocalModelsConfig to support a dictionary of extractors or a specific one. + # Based on user design doc: + # preference_extractor: + # qwen3_0_6b_sft: ... + + # We might need to manually parse the raw config or update settings.py + # Let's assume settings.py will be updated to hold a map or specific fields. + # For now, if we use the existing ModelSpec for preference_extractor in cfg, we assume it points to this model. + + # BUT the design doc says "preference_extractor" in local_models.yaml will have "qwen3_0_6b_sft" key. + # The current settings.py defines preference_extractor as a single ModelSpec. + # We will need to update settings.py first to support multiple extractors or a dict. + # I will proceed implementing this class assuming arguments are passed, and update settings/registry later. + + # This from_config might change depending on how settings.py is refactored. + # For now I will implement it assuming a direct ModelSpec is passed, or we handle it in registry. + pass + return None + + def _build_prompt(self, turns: List[ChatTurn]) -> str: + # Construct messages list for chat template + messages = [{"role": "system", "content": self.prompt_template}] + + # Window size 6 + window = turns[-6:] + + # Add conversation history + # We need to format the conversation as input context. + # Since the task is to extract preferences from the *whole* context (or latest turn?), + # usually we provide the conversation and ask for extraction. + # But LLaMA-Factory SFT usually expects: + # System: <template> + # User: <input> + # Assistant: <output> + + # We should pack the conversation history into the User message? + # Or if we trained with multi-turn chat format? + # Assuming "Input" column in dataset was the conversation history. + + history_texts = [] + for t in window: + role = "User" if t.role == "user" else "Assistant" + history_texts.append(f"{role}: {t.text}") + + conversation_text = "\n".join(history_texts) + + # Construct the User input + # We append a trigger instruction if it wasn't part of the training input implicitly. + # But based on your template, the User Input Example was just the query "I am a Python developer..." + # So likely we should just feed the conversation text as the user message. + + messages.append({"role": "user", "content": conversation_text}) + + # Apply chat template + prompt = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + + return prompt + + def _generate(self, prompt: str) -> str: + inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) + with torch.no_grad(): + outputs = self.model.generate( + **inputs, + max_new_tokens=self.max_new_tokens, + do_sample=False, + temperature=0.0, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id, + ) + full_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) + return full_text[len(prompt):] + + def _parse_preferences(self, raw_output: str) -> PreferenceList: + start = raw_output.find("{") + end = raw_output.rfind("}") + + if start == -1 or end == -1 or end <= start: + return PreferenceList(preferences=[]) + + json_str = raw_output[start:end+1] + try: + data = json.loads(json_str) + return PreferenceList.model_validate(data) + except Exception: + return PreferenceList(preferences=[]) + + def extract_turn(self, turns: List[ChatTurn]) -> PreferenceList: + prompt = self._build_prompt(turns) + raw_output = self._generate(prompt) + return self._parse_preferences(raw_output) + + # Legacy support + def build_preference_prompt(self, query: str) -> str: + # Wrap query in a dummy turn + turn = ChatTurn( + user_id="dummy", session_id="dummy", turn_id=0, + role="user", text=query + ) + return self._build_prompt([turn]) + + def extract_preferences(self, query: str) -> Dict[str, Any]: + turn = ChatTurn( + user_id="dummy", session_id="dummy", turn_id=0, + role="user", text=query + ) + prefs = self.extract_turn([turn]) + return prefs.model_dump() + diff --git a/src/personalization/models/preference_extractor/rule_extractor.py b/src/personalization/models/preference_extractor/rule_extractor.py new file mode 100644 index 0000000..42f43ed --- /dev/null +++ b/src/personalization/models/preference_extractor/rule_extractor.py @@ -0,0 +1,205 @@ +from __future__ import annotations + +import json +import re +import os +from typing import Any, Dict, List + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from personalization.config.registry import choose_dtype, choose_device_map +from personalization.config.settings import LocalModelsConfig +from .base import PreferenceExtractor +from personalization.retrieval.preference_store.schemas import ( + PreferenceList, + preference_list_json_schema, + ChatTurn, +) + +# Hardcoded System Prompt to match SFT training +# This MUST match what was used in training (scripts/split_train_test.py) +SFT_SYSTEM_PROMPT = ( + "Extract user preferences from the query into JSON format based on the PreferenceList schema. " + "If no preferences are found, return {\"preferences\": []}." +) + +class QwenRuleExtractor(PreferenceExtractor): + """ + Extractor using a Fine-Tuned (SFT) Qwen model. + Despite the name 'RuleExtractor' (legacy), this now performs direct End-to-End extraction. + """ + def __init__(self, model_path: str, dtype: torch.dtype, device_map: str = "auto") -> None: + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=True, trust_remote_code=True + ) + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + dtype=dtype, + device_map=device_map, + trust_remote_code=True, + ) + if self.tokenizer.pad_token_id is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + @classmethod + def from_config(cls, cfg: LocalModelsConfig) -> "QwenRuleExtractor": + spec = cfg.preference_extractor + dtype = choose_dtype(spec.dtype) + device_map = choose_device_map(spec.device_map) + return cls(spec.local_path, dtype=dtype, device_map=device_map) + + def build_preference_prompt(self, query: str) -> str: + """ + Construct the prompt string using the tokenizer's chat template. + Matches the format seen during SFT training. + """ + messages = [ + {"role": "system", "content": SFT_SYSTEM_PROMPT}, + {"role": "user", "content": query} + ] + prompt = self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + return prompt + + @torch.inference_mode() + def extract_preferences(self, query: str) -> Dict[str, Any]: + """ + Directly extract preferences from query using the SFT model. + Returns a dict compatible with PreferenceList model (key: 'preferences'). + """ + prompt = self.build_preference_prompt(query) + inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) + + outputs = self.model.generate( + **inputs, + do_sample=False, # Deterministic greedy decoding + max_new_tokens=512, # Allow enough space for JSON + pad_token_id=self.tokenizer.pad_token_id, + eos_token_id=self.tokenizer.eos_token_id, + ) + + input_len = inputs["input_ids"].shape[1] + gen_ids = outputs[0][input_len:] + text = self.tokenizer.decode(gen_ids, skip_special_tokens=True) + + if os.getenv("PREF_DEBUG") == "1": + print(f"[debug][extractor] Raw output: {text}") + + # Try parsing JSON + try: + # 1. Direct parse + data = json.loads(text) + + # 2. Validate against schema structure + validated = PreferenceList.model_validate(data) + return validated.model_dump() + + except Exception: + # Fallback: Try to find JSON blob if model outputted extra text (rare for SFT but possible) + extracted_json = self._extract_json_substring(text) + if extracted_json: + try: + data = json.loads(extracted_json) + validated = PreferenceList.model_validate(data) + return validated.model_dump() + except: + pass + + # If all fails, return empty + return {"preferences": []} + + def _extract_json_substring(self, text: str) -> str | None: + """Helper to find { ... } block in text.""" + # Find first '{' and last '}' + start = text.find('{') + end = text.rfind('}') + if start != -1 and end != -1 and end > start: + return text[start : end + 1] + return None + + @torch.inference_mode() + def batch_extract_preferences(self, queries: List[str], batch_size: int = 64) -> List[Dict[str, Any]]: + """ + Batch extract preferences from multiple queries using left-padded batching. + """ + if not queries: + return [] + + # Save and set padding side for decoder-only batched generation + orig_padding_side = self.tokenizer.padding_side + self.tokenizer.padding_side = "left" + + all_results = [] + prompts = [self.build_preference_prompt(q) for q in queries] + + for start in range(0, len(prompts), batch_size): + batch_prompts = prompts[start:start + batch_size] + inputs = self.tokenizer( + batch_prompts, return_tensors="pt", padding=True, truncation=True + ).to(self.model.device) + + outputs = self.model.generate( + **inputs, + do_sample=False, + max_new_tokens=512, + pad_token_id=self.tokenizer.pad_token_id, + eos_token_id=self.tokenizer.eos_token_id, + ) + + for i in range(len(batch_prompts)): + input_len = (inputs["attention_mask"][i] == 1).sum().item() + gen_ids = outputs[i][input_len:] + text = self.tokenizer.decode(gen_ids, skip_special_tokens=True) + + try: + data = json.loads(text) + validated = PreferenceList.model_validate(data) + all_results.append(validated.model_dump()) + except Exception: + extracted_json = self._extract_json_substring(text) + if extracted_json: + try: + data = json.loads(extracted_json) + validated = PreferenceList.model_validate(data) + all_results.append(validated.model_dump()) + continue + except Exception: + pass + all_results.append({"preferences": []}) + + self.tokenizer.padding_side = orig_padding_side + return all_results + + def extract_turn(self, turns: List[ChatTurn]) -> PreferenceList: + """ + Extract preferences from the LAST user turn in the history. + We don't concat history because our SFT model was trained on single-turn extraction. + Using context might confuse it unless we trained it that way. + """ + # Find the last user message + last_user_msg = None + for t in reversed(turns): + if t.role == "user": + last_user_msg = t.text + break + + if not last_user_msg: + return PreferenceList(preferences=[]) + + result_dict = self.extract_preferences(last_user_msg) + return PreferenceList.model_validate(result_dict) + + def extract_session(self, turns: List[ChatTurn]) -> List[PreferenceList]: + """ + Extract preferences from ALL user turns individually. + """ + results = [] + for turn in turns: + if turn.role == "user": + res = self.extract_preferences(turn.text) + results.append(PreferenceList.model_validate(res)) + else: + results.append(PreferenceList(preferences=[])) + return results diff --git a/src/personalization/models/reranker/__init__.py b/src/personalization/models/reranker/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/models/reranker/__init__.py diff --git a/src/personalization/models/reranker/base.py b/src/personalization/models/reranker/base.py new file mode 100644 index 0000000..34cf6ce --- /dev/null +++ b/src/personalization/models/reranker/base.py @@ -0,0 +1,16 @@ +from typing import List, Protocol + +class Reranker(Protocol): + def score( + self, + query: str, + docs: List[str], + **kwargs, + ) -> List[float]: + """ + Score multiple candidate documents for the same query. + Higher score indicates higher relevance. + Returns a list of floats with length equal to len(docs). + """ + ... + diff --git a/src/personalization/models/reranker/qwen3_reranker.py b/src/personalization/models/reranker/qwen3_reranker.py new file mode 100644 index 0000000..b648421 --- /dev/null +++ b/src/personalization/models/reranker/qwen3_reranker.py @@ -0,0 +1,96 @@ +from typing import List +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from .base import Reranker +from personalization.config.settings import LocalModelsConfig +from personalization.config.registry import choose_dtype, choose_device_map + +class Qwen3Reranker(Reranker): + def __init__(self, model_path: str, device_map: str = "auto", dtype: torch.dtype = torch.bfloat16): + # Ensure we pass trust_remote_code=True for Qwen models + self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + # Handle specific device assignment (e.g., "cuda:0", "cuda:1") + if device_map and device_map.startswith("cuda:"): + # Load to CPU first, then move to specific GPU + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=dtype, + device_map=None, + trust_remote_code=True, + low_cpu_mem_usage=True, + ) + self.model = self.model.to(device_map) + else: + # Use accelerate's auto device mapping + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=dtype, + device_map=device_map, + trust_remote_code=True, + ) + + self.yes_token_id = self.tokenizer("yes", add_special_tokens=False).input_ids[0] + + @classmethod + def from_config(cls, cfg: LocalModelsConfig) -> "Qwen3Reranker": + if not cfg.reranker or not cfg.reranker.qwen3_8b: + raise ValueError("Reranker config for qwen3_8b is missing") + spec = cfg.reranker.qwen3_8b + dtype = choose_dtype(spec.dtype) + device_map = choose_device_map(spec.device_map) + return cls(spec.local_path, device_map=device_map, dtype=dtype) + + def _build_prompt(self, query: str, doc: str) -> str: + return ( + "You are a reranker. " + "Given a user query and a memory note, answer 'yes' if the note is helpful " + "for answering the query, otherwise answer 'no'.\n\n" + f"Query: {query}\n" + f"Note: {doc}\n" + "Answer with a single token: yes or no." + ) + + @torch.inference_mode() + def score(self, query: str, docs: List[str], batch_size: int = 8, **kwargs) -> List[float]: + scores = [] + for i in range(0, len(docs), batch_size): + batch_docs = docs[i : i + batch_size] + prompts = [self._build_prompt(query, d) for d in batch_docs] + + inputs = self.tokenizer( + prompts, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512 + ).to(self.model.device) + + outputs = self.model(**inputs) + # Take logits of the last token + # shape: [batch, seq_len, vocab_size] + logits = outputs.logits + + # We want the logits for the token position immediately after the prompt ends. + # But since we generated inputs directly from tokenizer(prompts), + # we look at the last position of the input. + # For causal LM, we usually look at the logits of the last token + # to predict the *next* token (which we hope is 'yes' or 'no'). + + # Get logits for the next token prediction (last position) + # For each sequence in batch, select the last token's logits + # inputs['input_ids'] shape: [B, L] + # logits shape: [B, L, V] + # We want logits[:, -1, :] + + last_token_logits = logits[:, -1, :] + + # Calculate log prob of 'yes' + # We can use log_softmax over the vocab dimension + log_probs = torch.log_softmax(last_token_logits, dim=-1) + yes_log_probs = log_probs[:, self.yes_token_id] + + scores.extend(yes_log_probs.float().cpu().numpy().tolist()) + + return scores + diff --git a/src/personalization/retrieval/__init__.py b/src/personalization/retrieval/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/retrieval/__init__.py diff --git a/src/personalization/retrieval/chunking/__init__.py b/src/personalization/retrieval/chunking/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/retrieval/chunking/__init__.py diff --git a/src/personalization/retrieval/chunking/rules.py b/src/personalization/retrieval/chunking/rules.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/retrieval/chunking/rules.py diff --git a/src/personalization/retrieval/pipeline.py b/src/personalization/retrieval/pipeline.py new file mode 100644 index 0000000..6cc7f3e --- /dev/null +++ b/src/personalization/retrieval/pipeline.py @@ -0,0 +1,388 @@ +from typing import List, Tuple +import numpy as np + +from personalization.models.embedding.base import EmbeddingModel +from personalization.models.reranker.base import Reranker +from personalization.retrieval.preference_store.schemas import MemoryCard +from personalization.user_model.tensor_store import UserTensorStore, UserState +from personalization.user_model.scoring import score_with_user +from personalization.user_model.policy.reinforce import compute_policy_scores + +def cosine_similarity_matrix(E: np.ndarray, e_q: np.ndarray) -> np.ndarray: + # E: [M, d], e_q: [d] + return np.dot(E, e_q) + + +def dynamic_topk_selection( + scores: np.ndarray, + min_k: int = 3, + max_k: int = 8, + score_ratio: float = 0.5, +) -> List[int]: + """ + Dynamically select top-k indices based on score distribution. + + Strategy: + 1. Sort by score descending + 2. Compute threshold = top_score * score_ratio + 3. Select all indices with score > threshold + 4. Clamp to [min_k, max_k] range + + Args: + scores: Array of scores (higher = better) + min_k: Minimum number of items to select + max_k: Maximum number of items to select + score_ratio: Threshold ratio relative to top score + + Returns: + List of selected indices (in descending score order) + """ + if len(scores) == 0: + return [] + + # Sort indices by score descending + sorted_indices = np.argsort(scores)[::-1] + sorted_scores = scores[sorted_indices] + + # Compute threshold + top_score = sorted_scores[0] + threshold = top_score * score_ratio + + # Find how many pass threshold + n_above_threshold = np.sum(sorted_scores > threshold) + + # Clamp to [min_k, max_k] + n_select = max(min_k, min(max_k, n_above_threshold)) + n_select = min(n_select, len(scores)) # Don't exceed available + + return sorted_indices[:n_select].tolist() + +def dense_topk_indices( + query: str, + embed_model: EmbeddingModel, + memory_embeddings: np.ndarray, + valid_indices: List[int] = None, + topk: int = 64 +) -> List[int]: + """ + Return indices of topk memories based on dense embedding similarity. + If valid_indices is provided, only search within that subset. + """ + if valid_indices is not None and len(valid_indices) == 0: + return [] + + e_q_list = embed_model.encode([query], normalize=True, return_tensor=False) + e_q = np.array(e_q_list[0], dtype=np.float32) + + # Select subset of embeddings if restricted + if valid_indices is not None: + # subset_embeddings = memory_embeddings[valid_indices] + # But valid_indices might be arbitrary. + # Efficient way: only dot product with subset + # E_sub: [M_sub, d] + E_sub = memory_embeddings[valid_indices] + sims_sub = np.dot(E_sub, e_q) + + # Topk within subset + k = min(topk, len(sims_sub)) + if k == 0: + return [] + + # argsort gives indices relative to E_sub (0..M_sub-1) + # We need to map back to original indices + idx_sub = np.argsort(sims_sub)[-k:][::-1] + + return [valid_indices[i] for i in idx_sub] + + # Global search + sims = np.dot(memory_embeddings, e_q) + k = min(topk, len(memory_embeddings)) + if k == 0: + return [] + + idx = np.argsort(sims)[-k:][::-1] + return idx.tolist() + +def dense_topk_indices_multi_query( + queries: List[str], + embed_model: EmbeddingModel, + memory_embeddings: np.ndarray, + valid_indices: List[int] = None, + topk: int = 64 +) -> List[int]: + """ + Multi-query dense retrieval: embed all queries, take max similarity per memory, + return top-k by max similarity (union effect). + """ + if len(memory_embeddings) == 0: + return [] + + # Embed all queries at once + e_qs = embed_model.encode(queries, normalize=True, return_tensor=False) + e_qs = np.array(e_qs, dtype=np.float32) # [Q, d] + + if valid_indices is not None: + if len(valid_indices) == 0: + return [] + E_sub = memory_embeddings[valid_indices] + # sims: [Q, M_sub] + sims = np.dot(e_qs, E_sub.T) + # max across queries per memory + max_sims = sims.max(axis=0) # [M_sub] + k = min(topk, len(max_sims)) + if k == 0: + return [] + idx_sub = np.argsort(max_sims)[-k:][::-1] + return [valid_indices[i] for i in idx_sub] + + # Global search + # sims: [Q, M] + sims = np.dot(e_qs, memory_embeddings.T) + max_sims = sims.max(axis=0) # [M] + k = min(topk, len(max_sims)) + if k == 0: + return [] + idx = np.argsort(max_sims)[-k:][::-1] + return idx.tolist() + + +def retrieve_with_policy( + user_id: str, + query: str, + embed_model: EmbeddingModel, + reranker: Reranker, + memory_cards: List[MemoryCard], + memory_embeddings: np.ndarray, # shape: [M, d] + user_store: UserTensorStore, + item_vectors: np.ndarray, # shape: [M, k], v_m + topk_dense: int = 64, + topk_rerank: int = 8, + beta_long: float = 0.0, + beta_short: float = 0.0, + tau: float = 1.0, + only_own_memories: bool = False, + sample: bool = False, + queries: List[str] = None, +) -> Tuple[List[MemoryCard], np.ndarray, np.ndarray, List[int], np.ndarray]: + """ + Returns extended info for policy update: + (candidates, candidate_item_vectors, base_scores, chosen_indices, policy_probs) + + Args: + sample: If True, use stochastic sampling from policy distribution (for training/exploration). + If False, use deterministic top-k by policy scores (for evaluation). + """ + # 0. Filter indices if needed + valid_indices = None + if only_own_memories: + valid_indices = [i for i, card in enumerate(memory_cards) if card.user_id == user_id] + if not valid_indices: + return [], np.array([]), np.array([]), [], np.array([]) + + # 1. Dense retrieval (multi-query if available) + if queries and len(queries) > 1: + dense_idx = dense_topk_indices_multi_query( + queries, + embed_model, + memory_embeddings, + valid_indices=valid_indices, + topk=topk_dense + ) + else: + dense_idx = dense_topk_indices( + query, + embed_model, + memory_embeddings, + valid_indices=valid_indices, + topk=topk_dense + ) + # DEBUG: Check for duplicates or out of bounds + if len(dense_idx) > 0: + import os + if os.getenv("RETRIEVAL_DEBUG") == "1": + print(f" [Pipeline] Dense Indices (Top {len(dense_idx)}): {dense_idx[:10]}...") + print(f" [Pipeline] Max Index: {max(dense_idx)} | Memory Size: {len(memory_cards)}") + + if not dense_idx: + return [], np.array([]), np.array([]), [], np.array([]) + + candidates = [memory_cards[i] for i in dense_idx] + candidate_docs = [c.note_text for c in candidates] + + # 2. Rerank base score (P(yes|q,m)) - always use original query for reranking + # Skip reranking if we have fewer candidates than topk_rerank (saves GPU memory) + if len(candidates) <= topk_rerank: + base_scores = np.ones(len(candidates)) # Uniform scores + else: + base_scores = np.array(reranker.score(query, candidate_docs)) + + # 3. Policy Scoring (Softmax) + user_state: UserState = user_store.get_state(user_id) + candidate_vectors = item_vectors[dense_idx] # [K, k] + + policy_out = compute_policy_scores( + base_scores=base_scores, + user_state=user_state, + item_vectors=candidate_vectors, + beta_long=beta_long, + beta_short=beta_short, + tau=tau + ) + + # 4. Selection: Greedy (eval) or Stochastic (training) + k = min(topk_rerank, len(policy_out.scores)) + + if sample: + # Stochastic sampling from policy distribution (for training/exploration) + # Sample k indices without replacement, weighted by policy probs + probs = policy_out.probs + # Normalize to ensure sum to 1 (handle numerical issues) + probs = probs / (probs.sum() + 1e-10) + # Sample without replacement + chosen_indices = np.random.choice( + len(probs), size=k, replace=False, p=probs + ).tolist() + else: + # Deterministic top-k by policy scores (for evaluation) + top_indices_local = policy_out.scores.argsort()[-k:][::-1] + chosen_indices = top_indices_local.tolist() + + import os + if os.getenv("RETRIEVAL_DEBUG") == "1": + print(f" [Pipeline] Candidates: {len(candidates)} | Chosen Indices: {chosen_indices} | Sample: {sample}") + + return candidates, candidate_vectors, base_scores, chosen_indices, policy_out.probs + +def retrieve_no_policy( + user_id: str, + query: str, + embed_model: EmbeddingModel, + reranker: Reranker, + memory_cards: List[MemoryCard], + memory_embeddings: np.ndarray, # shape: [M, d] + topk_dense: int = 64, + topk_rerank: int = 8, + only_own_memories: bool = False, + queries: List[str] = None, + dynamic_topk: bool = False, + dynamic_min_k: int = 3, + dynamic_max_k: int = 8, + dynamic_score_ratio: float = 0.5, +) -> Tuple[List[MemoryCard], np.ndarray, np.ndarray, List[int], np.ndarray]: + """ + Deterministic retrieval baseline (NoPersonal mode): + - Dense retrieval -> Rerank -> Top-K (no policy sampling, no user vector influence) + + Args: + dynamic_topk: If True, use dynamic selection based on score distribution + dynamic_min_k: Minimum items to select (when dynamic_topk=True) + dynamic_max_k: Maximum items to select (when dynamic_topk=True) + dynamic_score_ratio: Threshold = top_score * ratio (when dynamic_topk=True) + + Returns same structure as retrieve_with_policy for compatibility: + (candidates, candidate_item_vectors, base_scores, chosen_indices, rerank_scores_for_chosen) + + Note: candidate_item_vectors is empty array (not used in NoPersonal mode) + The last return value is rerank scores instead of policy probs + """ + # 0. Filter indices if needed + valid_indices = None + if only_own_memories: + valid_indices = [i for i, card in enumerate(memory_cards) if card.user_id == user_id] + if not valid_indices: + return [], np.array([]), np.array([]), [], np.array([]) + + # 1. Dense retrieval (multi-query if available) + if queries and len(queries) > 1: + dense_idx = dense_topk_indices_multi_query( + queries, + embed_model, + memory_embeddings, + valid_indices=valid_indices, + topk=topk_dense + ) + else: + dense_idx = dense_topk_indices( + query, + embed_model, + memory_embeddings, + valid_indices=valid_indices, + topk=topk_dense + ) + + if not dense_idx: + return [], np.array([]), np.array([]), [], np.array([]) + + candidates = [memory_cards[i] for i in dense_idx] + candidate_docs = [c.note_text for c in candidates] + + # 2. Rerank base score (P(yes|q,m)) - always use original query for reranking + max_k = dynamic_max_k if dynamic_topk else topk_rerank + + # Skip reranking if we have fewer candidates than needed + if len(candidates) <= max_k: + # Just return all candidates without reranking + base_scores = np.ones(len(candidates)) # Uniform scores + chosen_indices = list(range(len(candidates))) + else: + base_scores = np.array(reranker.score(query, candidate_docs)) + + # 3. Selection: dynamic or fixed top-K + if dynamic_topk: + chosen_indices = dynamic_topk_selection( + base_scores, + min_k=dynamic_min_k, + max_k=dynamic_max_k, + score_ratio=dynamic_score_ratio, + ) + else: + k = min(topk_rerank, len(base_scores)) + top_indices_local = base_scores.argsort()[-k:][::-1] + chosen_indices = top_indices_local.tolist() + + # Get scores for chosen items (for logging compatibility) + chosen_scores = base_scores[chosen_indices] + + # Return empty item vectors (not used in NoPersonal mode) + # Return rerank scores as the "probs" field for logging compatibility + return candidates, np.array([]), base_scores, chosen_indices, chosen_scores + + +def retrieve_with_rerank( + user_id: str, + query: str, + embed_model: EmbeddingModel, + reranker: Reranker, + memory_cards: List[MemoryCard], + memory_embeddings: np.ndarray, # shape: [M, d] + user_store: UserTensorStore, + item_vectors: np.ndarray, # shape: [M, k], v_m + topk_dense: int = 64, + topk_rerank: int = 8, + beta_long: float = 0.0, + beta_short: float = 0.0, + only_own_memories: bool = False, +) -> List[MemoryCard]: + """ + Wrapper around retrieve_with_policy for standard inference. + """ + candidates, _, _, chosen_indices, _ = retrieve_with_policy( + user_id=user_id, + query=query, + embed_model=embed_model, + reranker=reranker, + memory_cards=memory_cards, + memory_embeddings=memory_embeddings, + user_store=user_store, + item_vectors=item_vectors, + topk_dense=topk_dense, + topk_rerank=topk_rerank, + beta_long=beta_long, + beta_short=beta_short, + tau=1.0, # Default tau + only_own_memories=only_own_memories + ) + + return [candidates[i] for i in chosen_indices] + + diff --git a/src/personalization/retrieval/preference_store/__init__.py b/src/personalization/retrieval/preference_store/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/retrieval/preference_store/__init__.py diff --git a/src/personalization/retrieval/preference_store/base.py b/src/personalization/retrieval/preference_store/base.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/retrieval/preference_store/base.py diff --git a/src/personalization/retrieval/preference_store/schemas.py b/src/personalization/retrieval/preference_store/schemas.py new file mode 100644 index 0000000..5245025 --- /dev/null +++ b/src/personalization/retrieval/preference_store/schemas.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from typing import List, Literal, Optional, Dict, Any + +from pydantic import BaseModel, Field, confloat + + +class Preference(BaseModel): + condition: str = Field( + ..., min_length=1, max_length=128, description="When the rule applies" + ) + action: str = Field( + ..., min_length=1, max_length=256, description="What to do in that case" + ) + confidence: confloat(ge=0.0, le=1.0) = Field( + ..., description="Confidence the rule is correct" + ) + + +class PreferenceList(BaseModel): + preferences: List[Preference] = Field(default_factory=list) + + +def preference_list_json_schema() -> dict: + return PreferenceList.model_json_schema() + + +class ChatTurn(BaseModel): + user_id: str + session_id: str + turn_id: int + role: Literal["user", "assistant"] + text: str + timestamp: Optional[float] = None + meta: Dict[str, Any] = Field(default_factory=dict) + + +class MemoryCard(BaseModel): + card_id: str + user_id: str + source_session_id: str + source_turn_ids: List[int] + raw_queries: List[str] # The original user utterances + preference_list: PreferenceList + note_text: str # Summarized "condition: action" text + embedding_e: List[float] # The embedding vector + kind: Literal["pref", "fact"] = "pref" + is_global: bool = False # True = always include in prompt, bypass retrieval diff --git a/src/personalization/retrieval/preference_store/vector_kv.py b/src/personalization/retrieval/preference_store/vector_kv.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/retrieval/preference_store/vector_kv.py diff --git a/src/personalization/retrieval/rerank.py b/src/personalization/retrieval/rerank.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/retrieval/rerank.py diff --git a/src/personalization/retrieval/store/__init__.py b/src/personalization/retrieval/store/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/retrieval/store/__init__.py diff --git a/src/personalization/retrieval/store/base.py b/src/personalization/retrieval/store/base.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/retrieval/store/base.py diff --git a/src/personalization/retrieval/store/faiss_store.py b/src/personalization/retrieval/store/faiss_store.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/retrieval/store/faiss_store.py diff --git a/src/personalization/serving/__init__.py b/src/personalization/serving/__init__.py new file mode 100644 index 0000000..11adcf8 --- /dev/null +++ b/src/personalization/serving/__init__.py @@ -0,0 +1,22 @@ +# Personalization Serving Module +# +# This module provides the interface layer for the personalization system. + +from personalization.serving.personalized_llm import ( + PersonalizedLLM, + AssistantResponse, + UsageStats, + DebugInfo, + Feedback, + create_personalized_llm, +) + +__all__ = [ + "PersonalizedLLM", + "AssistantResponse", + "UsageStats", + "DebugInfo", + "Feedback", + "create_personalized_llm", +] + diff --git a/src/personalization/serving/personalized_llm.py b/src/personalization/serving/personalized_llm.py new file mode 100644 index 0000000..8032e6b --- /dev/null +++ b/src/personalization/serving/personalized_llm.py @@ -0,0 +1,1835 @@ +#!/usr/bin/env python3 +""" +Personalized LLM Interface for Evaluation. + +This module provides the `PersonalizedLLM` class that wraps the entire +personalization system into a clean interface for evaluation frameworks +and user simulators. + +Interface contract: +- chat(user_id, query) -> AssistantResponse: Main online interface +- reset_session(user_id): Clear session history and short-term state +- reset_user(user_id): Completely reset user (long-term, short-term, memories) +- apply_feedback(feedback): Apply external feedback for RL updates +""" + +from __future__ import annotations + +import os +import sys +import uuid +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +import numpy as np +import yaml + +# Ensure src is in path for standalone usage +_src_path = os.path.join(os.path.dirname(__file__), "../../..") +if _src_path not in sys.path: + sys.path.insert(0, _src_path) + +from personalization.config.settings import load_local_models_config +from personalization.config.registry import get_preference_extractor, get_chat_model +from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B +from personalization.models.reranker.qwen3_reranker import Qwen3Reranker +from personalization.models.reranker.bge_reranker import BGEReranker +from personalization.user_model.tensor_store import UserTensorStore, UserState +from personalization.user_model.session_state import OnlineSessionState +from personalization.user_model.features import ItemProjection +from personalization.retrieval.preference_store.schemas import ( + MemoryCard, ChatTurn, PreferenceList, Preference +) +from personalization.retrieval.pipeline import retrieve_with_policy, retrieve_no_policy +from personalization.feedback.handlers import eval_step, eval_step_llm +from personalization.feedback.llm_reward import LLMRewardClient, LLMRewardConfig +from personalization.user_model.policy.reinforce import reinforce_update_user_state + + +# ============================================================================= +# Data Classes for Interface +# ============================================================================= + +@dataclass +class UsageStats: + """Token usage statistics from a chat completion.""" + prompt_tokens: int + completion_tokens: int + total_tokens: int + model: str + + +@dataclass +class DebugInfo: + """ + Debug information for analysis and ablation studies. + All fields are optional - fill what you have, leave empty what you don't. + """ + selected_memory_ids: List[str] = field(default_factory=list) + selected_memory_notes: List[str] = field(default_factory=list) + selected_memory_scores: List[float] = field(default_factory=list) + user_vector_before: Optional[List[float]] = None + user_vector_after: Optional[List[float]] = None + extracted_preferences: List[Dict[str, Any]] = field(default_factory=list) + extra: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class AssistantResponse: + """Response from the personalized LLM chat interface.""" + answer: str + usage: UsageStats + debug: Optional[DebugInfo] = None + + +@dataclass +class Feedback: + """ + Feedback data structure for RL updates from user simulator or judge. + + Attributes: + user_id: The user this feedback is for. + turn_id: The turn this feedback refers to (from the previous turn). + reward: Reward scalar computed by user simulator / judge. + gating: Gating flag (1=valid learning signal, 0=skip update). + meta: Additional metadata for training/analysis. + """ + user_id: str + turn_id: int + reward: float + gating: float # Can be 0.0 or 1.0, or continuous + meta: Dict[str, Any] = field(default_factory=dict) + + +# ============================================================================= +# Internal Session State Extended +# ============================================================================= + +@dataclass +class _SessionContext: + """Extended session context for evaluation tracking.""" + session_state: OnlineSessionState + turn_counter: int = 0 + # Store info needed for apply_feedback + pending_rl_update: Optional[Dict[str, Any]] = None + + +# ============================================================================= +# Shared Model Singletons for Multi-threaded Efficiency +# ============================================================================= + +_shared_embed_model = None +_shared_reranker = None +_shared_extractor = None +_shared_models_lock = None # Will be initialized on first use + + +def _get_shared_models_lock(): + """Get or create the threading lock for shared models.""" + global _shared_models_lock + if _shared_models_lock is None: + import threading + _shared_models_lock = threading.Lock() + return _shared_models_lock + + +def get_shared_embedding_model(model_path: str, device_map: str = "auto"): + """Get or create shared embedding model (thread-safe singleton).""" + global _shared_embed_model + import torch + + lock = _get_shared_models_lock() + with lock: + if _shared_embed_model is None: + print(f"[SharedModels] Loading shared embedding model on {device_map}...") + _shared_embed_model = Qwen3Embedding8B( + model_path=model_path, + dtype=torch.bfloat16, + device_map=device_map, + ) + print("[SharedModels] Shared embedding model loaded.") + return _shared_embed_model + + +def get_shared_reranker(model_path: str, device_map: str = "auto", reranker_type: str = "qwen3"): + """Get or create shared reranker model (thread-safe singleton).""" + global _shared_reranker + import torch + + lock = _get_shared_models_lock() + with lock: + if _shared_reranker is None: + print(f"[SharedModels] Loading shared reranker ({reranker_type}) on {device_map}...") + if reranker_type == "bge": + _shared_reranker = BGEReranker( + model_path=model_path, + device_map=device_map, + dtype=torch.float16, + ) + else: + _shared_reranker = Qwen3Reranker( + model_path=model_path, + device_map=device_map, + dtype=torch.bfloat16, + ) + print("[SharedModels] Shared reranker model loaded.") + return _shared_reranker + + +def get_shared_extractor(model_path: str, device_map: str = "auto"): + """Get or create shared preference extractor model (thread-safe singleton).""" + global _shared_extractor + import torch + from personalization.models.preference_extractor.rule_extractor import QwenRuleExtractor + + lock = _get_shared_models_lock() + with lock: + if _shared_extractor is None: + print(f"[SharedModels] Loading shared preference extractor on {device_map}...") + _shared_extractor = QwenRuleExtractor( + model_path=model_path, + dtype=torch.bfloat16, + device_map=device_map, + ) + print("[SharedModels] Shared preference extractor loaded.") + return _shared_extractor + + +def clear_shared_models(): + """Free all shared singleton models to reclaim GPU memory between methods.""" + global _shared_embed_model, _shared_reranker, _shared_extractor + import gc + + lock = _get_shared_models_lock() + with lock: + freed = [] + if _shared_embed_model is not None: + freed.append("embedding") + del _shared_embed_model + _shared_embed_model = None + if _shared_reranker is not None: + freed.append("reranker") + del _shared_reranker + _shared_reranker = None + if _shared_extractor is not None: + freed.append("extractor") + del _shared_extractor + _shared_extractor = None + + if freed: + gc.collect() + try: + import torch + if torch.cuda.is_available(): + torch.cuda.empty_cache() + except ImportError: + pass + print(f"[SharedModels] Cleared: {', '.join(freed)}") + + +# ============================================================================= +# PersonalizedLLM Class +# ============================================================================= + +class PersonalizedLLM: + """ + Personalized LLM wrapper for evaluation frameworks. + + This class provides a clean interface that accepts only (user_id, query) + for the main chat function, while internally managing: + - User state vectors (z_long, z_short) + - Session history + - Memory retrieval and policy + - Preference extraction and storage + - RL updates + + Example usage: + llm = PersonalizedLLM() + + # Reset user for fresh experiment + llm.reset_user("user_123") + + # Start a session + llm.reset_session("user_123") + + # Chat + response = llm.chat("user_123", "What's a good recipe for dinner?") + print(response.answer) + + # Apply feedback from previous turn (from turn 2 onwards) + llm.apply_feedback(Feedback( + user_id="user_123", + turn_id=0, + reward=0.8, + gating=1.0 + )) + """ + + def __init__( + self, + config_path: Optional[str] = None, + user_store_path: str = "data/users/user_store_eval.npz", + memory_cards_path: str = "data/corpora/memory_cards.jsonl", + memory_embeddings_path: str = "data/corpora/memory_embeddings.npy", + item_projection_path: str = "data/corpora/item_projection.npz", + only_own_memories: bool = True, + enable_preference_extraction: bool = True, + enable_rl_updates: bool = True, + mode: str = "full", # "full", "nopersonal", or "vanilla" + eval_mode: bool = True, # True = greedy selection, False = stochastic sampling + device_assignment: Optional[Dict[str, str]] = None, # Multi-GPU support + llm_name: Optional[str] = None, # Override LLM name (e.g., "llama_8b_vllm" for vLLM) + use_shared_models: bool = False, # Use shared singleton models for multi-threaded efficiency + reranker_type: str = "qwen3", # "qwen3" (8B) or "bge" (278M) + best_of_n: int = 1, # Generate N responses and pick best (for RAG methods) + reward_mode: str = "keyword", # "keyword", "llm" (GPT-4o-mini), or "llm_local" (local vLLM) + llm_reward_config: Optional["LLMRewardConfig"] = None, # Config for LLM judge + reward_vllm_url: Optional[str] = None, # vLLM URL for local reward model (when reward_mode="llm_local") + enable_query_transform: bool = False, # Transform queries for better retrieval matching + enable_global_preferences: bool = False, # Separate global prefs that bypass retrieval + dynamic_topk: bool = False, # Use dynamic topk based on rerank scores + dynamic_min_k: int = 3, # Min preferences for dynamic topk + dynamic_max_k: int = 8, # Max preferences for dynamic topk + dynamic_score_ratio: float = 0.5, # Threshold = top_score * ratio + eta_long: float = None, # Override RL learning rate for z_long + eta_short: float = None, # Override RL learning rate for z_short + enable_preference_consolidation: bool = False, # Consolidate preferences at session end + consolidation_threshold: int = 5, # Min preferences before consolidation + enable_preference_rewrite: bool = False, # Use LLM to rewrite/merge retrieved preferences + ): + """ + Initialize the PersonalizedLLM. + + Args: + config_path: Path to config file. If None, uses default locations. + user_store_path: Path to persist user state vectors. + memory_cards_path: Path to memory cards JSONL file. + memory_embeddings_path: Path to memory embeddings numpy file. + item_projection_path: Path to item projection (PCA) file. + only_own_memories: If True, only retrieve user's own memories (strict privacy). + enable_preference_extraction: If True, extract preferences from user turns. + enable_rl_updates: If True, apply RL updates via apply_feedback. + mode: "full" for full personalization, "nopersonal" for baseline (no user vector influence), + "vanilla" for pure LLM without any memory retrieval or preference extraction. + eval_mode: If True, use greedy/deterministic selection (for evaluation). + If False, use stochastic sampling (for training/exploration). + device_assignment: Optional dict to assign models to specific GPUs. + Example: {"embed": "cuda:0", "reranker": "cuda:1", "chat": "cuda:2", "extractor": "cuda:3"} + If None, uses "auto" for all models. + use_shared_models: If True, use shared singleton models for embedding and reranker. + This is essential for multi-threaded/parallel profile processing to avoid + loading duplicate models. When enabled, the first thread loads the models, + and subsequent threads reuse the shared instances. + """ + self.only_own_memories = only_own_memories + self.use_shared_models = use_shared_models + self.enable_preference_extraction = enable_preference_extraction + self.enable_rl_updates = enable_rl_updates + self.mode = mode # "full" or "nopersonal" + self.eval_mode = eval_mode # True = greedy, False = sample + self.reranker_type = reranker_type # "qwen3" or "bge" + self.best_of_n = best_of_n # Generate N responses and pick best + self.reward_mode = reward_mode # "keyword", "llm", or "llm_local" + self.enable_query_transform = enable_query_transform + self.enable_global_preferences = enable_global_preferences + self.enable_preference_consolidation = enable_preference_consolidation + self.consolidation_threshold = consolidation_threshold + self.enable_preference_rewrite = enable_preference_rewrite + + # Initialize LLM reward client if using LLM judge + self._llm_reward_client = None # Can be LLMRewardClient or LocalLLMRewardClient + if reward_mode == "llm": + self._llm_reward_client = LLMRewardClient(llm_reward_config or LLMRewardConfig()) + elif reward_mode == "llm_local": + from personalization.feedback.local_llm_reward import ( + LocalLLMRewardClient, + LocalLLMRewardConfig, + ) + local_config = LocalLLMRewardConfig( + vllm_url=reward_vllm_url or "http://localhost:8005/v1", + ) + self._llm_reward_client = LocalLLMRewardClient(local_config) + + # Multi-GPU device assignment + self._device_assignment = device_assignment or { + "embed": "auto", + "reranker": "auto", + "chat": "auto", + "extractor": "auto", + } + + # Paths + self._memory_cards_path = memory_cards_path + self._memory_embeddings_path = memory_embeddings_path + self._item_projection_path = item_projection_path + + # RL Configuration + # Note: beta/eta increased for more significant z_u updates + self._rl_cfg = { + "item_dim": 256, + "beta_long": 2.0, # Increased from 0.1 for stronger personalization + "beta_short": 5.0, # Increased from 0.3 + "tau": 1.0, + "eta_long": eta_long if eta_long is not None else 0.01, + "eta_short": eta_short if eta_short is not None else 0.05, + "ema_alpha": 0.05, + "short_decay": 0.1, + "dense_topk": 64, + "rerank_topk": 5, + "max_new_tokens": 512, + # Dynamic topk settings + "dynamic_topk": dynamic_topk, + "dynamic_min_k": dynamic_min_k, + "dynamic_max_k": dynamic_max_k, + "dynamic_score_ratio": dynamic_score_ratio, + } + + # Store llm_name before loading config (needed in _load_config) + self._llm_name_override = llm_name + + # Load config and override RL params if available + self._load_config(config_path) + + # Load models + print("[PersonalizedLLM] Loading models...") + self._load_models() + + # Load memory store + print("[PersonalizedLLM] Loading memory store...") + self._load_memory_store() + + # Initialize user store + self._user_store = UserTensorStore( + k=self._rl_cfg["item_dim"], + path=user_store_path, + ) + + # Session contexts per user (in-memory) + self._sessions: Dict[str, _SessionContext] = {} + + print("[PersonalizedLLM] Initialization complete.") + + def _load_config(self, config_path: Optional[str]): + """Load configuration from yaml files.""" + self._cfg = load_local_models_config() + + # Try to load user_model.yaml for RL params + if config_path is None: + config_path = "configs/user_model.yaml" + + self._llm_name = self._llm_name_override or "qwen_1_5b" # Default, can be overridden + + try: + if os.path.exists(config_path): + with open(config_path, "r") as f: + user_cfg = yaml.safe_load(f) + if user_cfg: + # Override RL params if present + for key in self._rl_cfg: + if key in user_cfg: + self._rl_cfg[key] = user_cfg[key] + # LLM name (only from config if not already set via parameter) + if self._llm_name_override is None and "llm_name" in user_cfg: + self._llm_name = user_cfg["llm_name"] + except Exception as e: + print(f"[PersonalizedLLM] Warning: Failed to load config: {e}") + + def _load_models(self): + """Load all ML models with optional multi-GPU assignment.""" + import torch + + # Report GPU availability (only once, not for shared model instances) + if not self.use_shared_models: + num_gpus = torch.cuda.device_count() + print(f"[PersonalizedLLM] Available GPUs: {num_gpus}") + for i in range(num_gpus): + mem = torch.cuda.get_device_properties(i).total_memory / 1e9 + print(f" GPU {i}: {torch.cuda.get_device_name(i)} ({mem:.1f}GB)") + + embed_device = self._device_assignment.get("embed", "auto") + reranker_device = self._device_assignment.get("reranker", "auto") + chat_device = self._device_assignment.get("chat", "auto") + extractor_device = self._device_assignment.get("extractor", "auto") + + # Embedding model - only load for modes that use RAG retrieval + # Vanilla and contextual modes don't need embedding/reranker + needs_retrieval = self.mode not in ("vanilla", "contextual") + + if needs_retrieval: + if self.use_shared_models: + print(f"[PersonalizedLLM] Using shared embedding model...") + self._embed_model = get_shared_embedding_model( + model_path=self._cfg.embedding.qwen3.local_path, + device_map=embed_device, + ) + else: + print(f"[PersonalizedLLM] Loading Embedding model on {embed_device}...") + self._embed_model = Qwen3Embedding8B( + model_path=self._cfg.embedding.qwen3.local_path, + dtype=torch.bfloat16, + device_map=embed_device, + ) + else: + print(f"[PersonalizedLLM] Skipping embedding model (not needed for {self.mode} mode)") + self._embed_model = None + + # Reranker - only load for modes that use RAG retrieval + # Support both qwen3 (8B) and bge (278M) rerankers + if needs_retrieval: + if self.reranker_type == "bge": + reranker_path = getattr(self._cfg.reranker, "bge_base", None) + reranker_path = reranker_path.local_path if reranker_path else "BAAI/bge-reranker-base" + else: + reranker_path = self._cfg.reranker.qwen3_8b.local_path + + if self.use_shared_models: + print(f"[PersonalizedLLM] Using shared reranker model ({self.reranker_type})...") + self._reranker = get_shared_reranker( + model_path=reranker_path, + device_map=reranker_device, + reranker_type=self.reranker_type, + ) + else: + print(f"[PersonalizedLLM] Loading Reranker ({self.reranker_type}) on {reranker_device}...") + if self.reranker_type == "bge": + self._reranker = BGEReranker( + model_path=reranker_path, + device_map=reranker_device, + dtype=torch.float16, + ) + else: + self._reranker = Qwen3Reranker( + model_path=reranker_path, + device_map=reranker_device, + dtype=torch.bfloat16, + ) + else: + print(f"[PersonalizedLLM] Skipping reranker (not needed for {self.mode} mode)") + self._reranker = None + + # Chat model (via registry for backend switching) + print(f"[PersonalizedLLM] Loading ChatModel: {self._llm_name} on {chat_device}...") + # Pass device override if specified (not "auto") + device_for_chat = chat_device if chat_device != "auto" else None + self._chat_model = get_chat_model(self._llm_name, device_override=device_for_chat) + + # Preference extractor - use shared singleton if enabled + if self.enable_preference_extraction: + extractor_name = "qwen3_0_6b_sft" + if self.use_shared_models: + print(f"[PersonalizedLLM] Using shared preference extractor...") + try: + extractor_path = self._cfg.preference_extractor.get("qwen3_0_6b_sft", {}).get("path", None) + if extractor_path: + self._extractor = get_shared_extractor( + model_path=extractor_path, + device_map=extractor_device, + ) + else: + print(f"[PersonalizedLLM] Extractor path not found, using rule-based.") + self._extractor = get_preference_extractor("rule") + except Exception as e: + print(f"[PersonalizedLLM] Warning: Failed to load shared extractor: {e}. Trying fallbacks...") + try: + self._extractor = get_preference_extractor("rule") + except Exception as e2: + print(f"[PersonalizedLLM] Rule extractor also failed: {e2}. Using GPT-5-mini extractor.") + self._extractor = get_preference_extractor("gpt5_mini") + else: + print(f"[PersonalizedLLM] Loading extractor: {extractor_name} on {extractor_device}...") + try: + self._extractor = get_preference_extractor(extractor_name) + except Exception as e: + print(f"[PersonalizedLLM] Warning: Failed to load {extractor_name}: {e}. Trying fallbacks...") + try: + self._extractor = get_preference_extractor("rule") + except Exception as e2: + print(f"[PersonalizedLLM] Rule extractor also failed: {e2}. Using GPT-5-mini extractor.") + self._extractor = get_preference_extractor("gpt5_mini") + else: + print("[PersonalizedLLM] Preference extraction disabled, skipping extractor.") + self._extractor = None + + def _load_memory_store(self): + """Load memory cards and embeddings.""" + if not os.path.exists(self._memory_cards_path): + print(f"[PersonalizedLLM] Warning: Memory cards not found at {self._memory_cards_path}") + self._memory_cards: List[MemoryCard] = [] + self._memory_embeddings = np.zeros((0, 4096), dtype=np.float32) + self._item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32) + # Create default projection (truncation to first k dims) so preferences can be added + k = self._rl_cfg["item_dim"] + d = 4096 + P = np.zeros((k, d), dtype=np.float32) + P[:, :k] = np.eye(k, dtype=np.float32) + self._projection = ItemProjection(P=P, mean=np.zeros(d, dtype=np.float32)) + print(f"[PersonalizedLLM] Created default projection (truncation, k={k})") + return + + # Load cards + self._memory_cards = [] + with open(self._memory_cards_path, "r") as f: + for line in f: + line = line.strip() + if line: + self._memory_cards.append(MemoryCard.model_validate_json(line)) + + # Load embeddings + if os.path.exists(self._memory_embeddings_path): + self._memory_embeddings = np.load(self._memory_embeddings_path) + else: + self._memory_embeddings = np.zeros((len(self._memory_cards), 4096), dtype=np.float32) + + # Load projection + if os.path.exists(self._item_projection_path): + proj_data = np.load(self._item_projection_path) + self._projection = ItemProjection(P=proj_data["P"], mean=proj_data["mean"]) + self._item_vectors = proj_data["V"] + else: + # Create default projection so preferences can still be added + k = self._rl_cfg["item_dim"] + d = 4096 + P = np.zeros((k, d), dtype=np.float32) + P[:, :k] = np.eye(k, dtype=np.float32) + self._projection = ItemProjection(P=P, mean=np.zeros(d, dtype=np.float32)) + self._item_vectors = np.zeros((len(self._memory_cards), self._rl_cfg["item_dim"]), dtype=np.float32) + print(f"[PersonalizedLLM] Created default projection (truncation, k={k})") + + print(f"[PersonalizedLLM] Loaded {len(self._memory_cards)} memory cards.") + + def _get_or_create_session(self, user_id: str) -> _SessionContext: + """Get or create session context for a user.""" + if user_id not in self._sessions: + self._sessions[user_id] = _SessionContext( + session_state=OnlineSessionState(user_id=user_id), + turn_counter=0, + ) + return self._sessions[user_id] + + def _build_chat_turn(self, user_id: str, text: str, role: str, turn_id: int) -> ChatTurn: + """Build a ChatTurn object.""" + return ChatTurn( + user_id=user_id, + session_id=f"eval_session_{user_id}", + turn_id=turn_id, + role=role, + text=text, + meta={"source": "eval"} + ) + + def _count_tokens(self, text: str) -> int: + """Estimate token count using the tokenizer.""" + try: + # Use the chat model's tokenizer if available + if hasattr(self._chat_model, 'tokenizer'): + return len(self._chat_model.tokenizer.encode(text)) + else: + # Rough estimate: ~4 chars per token + return len(text) // 4 + except Exception: + return len(text) // 4 + + # Task type keywords for query transformation + _TASK_KEYWORDS = { + "math": ["solve", "calculate", "integral", "equation", "proof", "derivative", + "math", "algebra", "geometry", "trigonometry", "calculus", "arithmetic", + "formula", "compute", "evaluate", "simplify", "factor", "graph"], + "coding": ["code", "program", "function", "implement", "debug", "python", "java", + "javascript", "algorithm", "class", "method", "bug", "error", "compile", + "script", "html", "css", "sql", "api", "library", "framework"], + "writing": ["write", "essay", "paragraph", "summarize", "draft", "compose", + "article", "story", "letter", "email", "report", "review", "edit", + "rewrite", "paraphrase", "outline"], + "explanation": ["explain", "what is", "how does", "why", "describe", "define", + "meaning", "concept", "difference between", "compare", "contrast"], + } + + def _transform_query_for_retrieval(self, query: str) -> List[str]: + """ + Transform raw user query into multiple retrieval queries to bridge + the semantic gap between task queries and preference descriptions. + + Returns [original_query, transformed_query] or [original_query] if + no task type detected. + """ + import re + query_lower = query.lower() + detected_types = [] + for task_type, keywords in self._TASK_KEYWORDS.items(): + for kw in keywords: + # Use word boundary matching to avoid false positives + # e.g., "api" should not match "capital" + if re.search(r'\b' + re.escape(kw) + r'\b', query_lower): + detected_types.append(task_type) + break + + if not detected_types: + return [query] + + # Use first detected type (most specific match) + task_type = detected_types[0] + transformed = f"user preferences for {task_type} tasks: {query}" + return [query, transformed] + + # Patterns indicating a global/universal preference condition + _GLOBAL_PATTERNS = ["general", "any", "always", "all ", "every", "regardless", + "any task", "any topic", "any question", "all tasks", "all topics"] + + # Domain-specific terms that indicate a conditional preference + _DOMAIN_TERMS = ["math", "code", "coding", "program", "writing", "essay", "science", + "history", "language", "physics", "chemistry", "biology", "literature", + "creative", "technical", "formal", "informal", "academic", "casual"] + + def _classify_preference_scope(self, condition: str) -> bool: + """ + Classify whether a preference condition is global (always applicable) + or conditional (task-specific). + + Returns True if global, False if conditional. + """ + cond_lower = condition.lower().strip() + + # Check for explicit global patterns + for pattern in self._GLOBAL_PATTERNS: + if pattern in cond_lower: + return True + + # Very short/vague conditions with no domain terms are likely global + words = cond_lower.split() + if len(words) <= 2: + has_domain = any(term in cond_lower for term in self._DOMAIN_TERMS) + if not has_domain: + return True + + return False + + # Rewrite prompt for merging retrieved preferences + _REWRITE_PROMPT = """You are helping to prepare user preferences for an AI assistant. + +The user is asking: {query} + +Retrieved preferences about this user: +{preferences} + +Task: Create a concise preference summary that the assistant MUST follow. + +Rules: +1. PRESERVE all specific formatting requirements exactly (e.g., "type hints", "snake_case", "code fence with language") +2. PRESERVE all structural requirements (e.g., "numbered steps", "bullet points", "answer first then explanation") +3. Only MERGE preferences that are truly redundant (saying the same thing differently) +4. Output as a short bulleted list if there are multiple distinct requirements +5. Keep each point actionable and specific - NO vague generalizations like "follow best practices" + +Example input: +- Include type hints in Python code +- Use snake_case for variable names +- When explaining, use numbered steps + +Example output: +- Include type hints +- Use snake_case for variables +- Use numbered steps for explanations + +If no preferences are relevant to this query type, output: "No specific preferences apply." + +Preference summary:""" + + def _rewrite_preferences(self, memory_notes: List[str], query: str) -> List[str]: + """ + Use LLM to rewrite/merge multiple retrieved preferences into concise instructions. + + This is similar to Reflection's proper_scaffolding but focuses on merging + rather than just filtering. + + Args: + memory_notes: List of retrieved preference notes + query: Current user query + + Returns: + List with single rewritten instruction (or original if rewrite fails/disabled) + """ + if not memory_notes or len(memory_notes) <= 1: + return memory_notes + + try: + import requests + + # Format preferences for prompt + prefs_text = "\n".join(f"- {note}" for note in memory_notes) + prompt = self._REWRITE_PROMPT.format(query=query[:200], preferences=prefs_text) + + # Direct vLLM API call (simpler than going through chat model) + messages = [{"role": "user", "content": prompt}] + payload = { + "model": self._chat_model.model_name, + "messages": messages, + "max_tokens": 150, + "temperature": 0.3, # Lower temperature for more consistent output + } + + response = requests.post( + f"{self._chat_model.vllm_url}/chat/completions", + json=payload, + timeout=30 + ) + + if response.status_code != 200: + print(f"[REWRITE] API error {response.status_code}, keeping original notes") + return memory_notes + + result = response.json() + rewritten = result["choices"][0]["message"]["content"].strip().strip('"') + + # Validate response + if rewritten and len(rewritten) > 10 and "No specific preferences" not in rewritten: + print(f"[REWRITE] {len(memory_notes)} notes → 1 merged instruction") + return [rewritten] + else: + print(f"[REWRITE] Kept original {len(memory_notes)} notes (no valid merge)") + return memory_notes + + except Exception as e: + print(f"[REWRITE] Failed: {e}, keeping original notes") + return memory_notes + + # Consolidation prompt for session-end preference merging + _CONSOLIDATION_PROMPT = """You are analyzing user preferences extracted from conversations. + +Current preferences for this user: +{preferences} + +Task: Consolidate these preferences into a cleaner, more organized set by: +1. MERGE similar preferences (e.g., "use bullet points" + "format with bullets" → single preference) +2. REMOVE redundant or contradictory preferences (keep the more specific one) +3. PRESERVE all unique, meaningful preferences +4. Keep the same "When [condition], [action]." format + +Output ONLY the consolidated preferences, one per line, in this exact format: +When [condition], [action]. + +Do not add explanations or commentary. Just output the preference lines.""" + + def consolidate_user_preferences(self, user_id: str) -> int: + """ + Consolidate user preferences at session end using LLM. + + Merges similar preferences, removes redundancy, and creates cleaner + preference descriptions. Only runs if user has enough preferences. + + Args: + user_id: The user whose preferences to consolidate. + + Returns: + Number of preferences after consolidation (0 if skipped). + """ + if not self.enable_preference_consolidation: + return 0 + + # Get user's memory cards + user_cards = [c for c in self._memory_cards if c.user_id == user_id] + + if len(user_cards) < self.consolidation_threshold: + return len(user_cards) + + # Build preference list for prompt + pref_lines = [card.note_text for card in user_cards] + preferences_text = "\n".join(f"- {p}" for p in pref_lines) + + # Call LLM for consolidation + prompt = self._CONSOLIDATION_PROMPT.format(preferences=preferences_text) + messages = [{"role": "user", "content": prompt}] + + try: + result = self._chat_model.answer(messages, max_new_tokens=512) + consolidated_text = result.get("content", "").strip() + + if not consolidated_text: + return len(user_cards) + + # Parse consolidated preferences + new_prefs = [] + for line in consolidated_text.split("\n"): + line = line.strip() + if not line or not line.startswith("When "): + continue + # Parse "When [condition], [action]." + if ", " in line: + parts = line.split(", ", 1) + condition = parts[0].replace("When ", "").strip() + action = parts[1].rstrip(".").strip() + if condition and action: + new_prefs.append({ + "condition": condition, + "action": action, + "is_global": self._classify_preference_scope(condition) if self.enable_global_preferences else False, + }) + + if not new_prefs: + return len(user_cards) + + # Remove old cards for this user + keep_indices = [i for i, c in enumerate(self._memory_cards) if c.user_id != user_id] + self._memory_cards = [self._memory_cards[i] for i in keep_indices] + if len(keep_indices) > 0 and len(self._memory_embeddings) > 0: + self._memory_embeddings = self._memory_embeddings[keep_indices] + self._item_vectors = self._item_vectors[keep_indices] + else: + embed_dim = self._memory_embeddings.shape[1] if len(self._memory_embeddings) > 0 else 4096 + self._memory_embeddings = np.zeros((0, embed_dim), dtype=np.float32) + self._item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32) + + # Add consolidated preferences + for pref in new_prefs: + note_text = f"When {pref['condition']}, {pref['action']}." + + # Compute embedding + e_note = self._embed_model.encode([note_text], normalize=True, return_tensor=False)[0] + v_note = self._projection.transform_vector(np.array(e_note)) + + # Create card + card = MemoryCard( + card_id=str(uuid.uuid4()), + user_id=user_id, + source_session_id=f"consolidated_{user_id}", + source_turn_ids=[], + raw_queries=[], + preference_list=PreferenceList(preferences=[ + Preference(condition=pref["condition"], action=pref["action"], confidence=1.0) + ]), + note_text=note_text, + embedding_e=list(e_note), + kind="pref", + is_global=pref["is_global"], + ) + + self._memory_cards.append(card) + self._memory_embeddings = np.vstack([self._memory_embeddings, np.array([e_note])]) + self._item_vectors = np.vstack([self._item_vectors, np.array([v_note])]) + + print(f"[PersonalizedLLM] Consolidated {len(user_cards)} → {len(new_prefs)} preferences for user {user_id}") + return len(new_prefs) + + except Exception as e: + print(f"[PersonalizedLLM] Consolidation failed for user {user_id}: {e}") + return len(user_cards) + + def _add_preferences_as_memory( + self, + prefs: PreferenceList, + query: str, + user_id: str, + turn_id: int, + ) -> List[Dict[str, Any]]: + """ + Add extracted preferences as new memory cards. + Returns list of preference dicts for debug info. + """ + extracted = [] + + if not prefs.preferences or self._projection is None: + return extracted + + for pref in prefs.preferences: + note_text = f"When {pref.condition}, {pref.action}." + + # Record for debug + extracted.append({ + "condition": pref.condition, + "action": pref.action, + "confidence": pref.confidence, + }) + + # Deduplication check + is_duplicate = any( + card.user_id == user_id and card.note_text == note_text + for card in self._memory_cards + ) + + if is_duplicate: + continue + + # Compute embedding from note_text (NOT query) for proper semantic retrieval + # This ensures retrieval query "solve math problem" matches stored "When math problems..." + e_note = self._embed_model.encode([note_text], normalize=True, return_tensor=False)[0] + v_note = self._projection.transform_vector(np.array(e_note)) + + # Classify as global or conditional + is_global = self._classify_preference_scope(pref.condition) if self.enable_global_preferences else False + + # Create new memory card + card = MemoryCard( + card_id=str(uuid.uuid4()), + user_id=user_id, + source_session_id=f"eval_session_{user_id}", + source_turn_ids=[turn_id], + raw_queries=[query], + preference_list=PreferenceList(preferences=[pref]), + note_text=note_text, + embedding_e=list(e_note), + kind="pref", + is_global=is_global, + ) + + # Add to memory store + self._memory_cards.append(card) + self._memory_embeddings = np.vstack([self._memory_embeddings, np.array([e_note])]) + self._item_vectors = np.vstack([self._item_vectors, np.array([v_note])]) + + return extracted + + def _score_response(self, response: str) -> float: + """ + Score a response for best-of-N selection. + + Higher score = better response. Scoring heuristics: + 1. Length: Longer responses typically have more substance + 2. Solution indicators: Contains formulas, steps, answers + 3. Proactivity: Doesn't end with just a question + + Returns: + Float score (higher is better) + """ + score = 0.0 + response_lower = response.lower() + + # Length score (normalized, cap at 1000 chars) + score += min(len(response), 1000) / 1000 * 3.0 + + # Solution indicators (+1 each, max 5) + solution_indicators = ['=', 'step', 'answer', 'formula', 'result', 'therefore', 'solution'] + indicator_count = sum(1 for ind in solution_indicators if ind in response_lower) + score += min(indicator_count, 5) * 0.5 + + # Structured content (+1 for numbered/bulleted lists) + if any(marker in response for marker in ['1.', '2.', '- ', '* ', '##']): + score += 1.0 + + # Penalty for ending with question (passive behavior) + # Check last 100 chars for question marks + if '?' in response[-100:]: + score -= 1.5 + + # Bonus for providing concrete values/numbers + import re + numbers = re.findall(r'\d+\.?\d*', response) + if len(numbers) >= 3: + score += 1.0 + + return score + + # ========================================================================= + # Public Interface + # ========================================================================= + + def chat(self, user_id: str, query: str) -> AssistantResponse: + """ + Main online chat interface. + + Args: + user_id: Unique identifier for the user. + query: Current user query/message. + + Returns: + AssistantResponse containing the answer, usage stats, and debug info. + + Notes: + - Internally manages user state, session history, memory retrieval + - After this call, you can call apply_feedback() with the turn's feedback + """ + ctx = self._get_or_create_session(user_id) + session = ctx.session_state + user_state = self._user_store.get_state(user_id) + + # Record user vector before for debug + z_long_before = user_state.z_long.copy().tolist() + z_short_before = user_state.z_short.copy().tolist() + + # Add user turn to history + user_turn = self._build_chat_turn(user_id, query, "user", ctx.turn_counter) + session.history.append(user_turn) + + # Vanilla mode: pure LLM without any memory or preference extraction + if self.mode == "vanilla": + # Skip embedding, preference extraction, and memory retrieval entirely + e_q_t = np.zeros(4096, dtype=np.float32) # Placeholder for vanilla mode + extracted_prefs = [] + candidates = [] + cand_item_vecs = np.array([]) + base_scores = np.array([]) + chosen_indices = [] + probs = np.array([]) + memories_t = [] + memory_notes = [] + else: + # Compute query embedding (only needed for non-vanilla modes) + # Explicitly normalize for consistent cosine similarity with stored embeddings + embed_result = self._embed_model.encode([query], normalize=True, return_tensor=False) + if embed_result is None or len(embed_result) == 0: + raise RuntimeError(f"Embedding model returned empty result for query: {query[:100]}") + e_q_t = np.array(embed_result[0]) + + # Store pending RL update info from last turn (for apply_feedback) + if session.last_query is not None and self.enable_rl_updates: + ctx.pending_rl_update = { + "last_query": session.last_query, + "last_answer": session.last_answer, + "last_memories": session.last_memories, + "last_query_embedding": session.last_query_embedding, + "current_query_embedding": e_q_t, + "last_candidate_item_vectors": session.last_candidate_item_vectors, + "last_policy_probs": session.last_policy_probs, + "last_chosen_indices": session.last_chosen_indices, + } + + # Auto-compute reward via LLM judge if enabled + if self._llm_reward_client is not None: + import asyncio + try: + reward, gating = asyncio.run(eval_step_llm( + q_t=session.last_query, + answer_t=session.last_answer, + q_t1=query, + memories_t=session.last_memories or [], + client=self._llm_reward_client, + )) + if gating > 0.0: + self.apply_feedback(Feedback( + user_id=user_id, + turn_id=ctx.turn_counter - 1, + reward=reward, + gating=gating, + )) + except Exception as e: + # Graceful fallback: skip RL update if judge fails + print(f"[LLM-Reward] Judge call failed, skipping update: {e}") + + # Extract preferences from conversation (if enabled) + # extract_turn processes only the last user turn - efficient since called each turn + # Preferences accumulate in _memory_cards across turns (dedup prevents duplicates) + extracted_prefs = [] + if self.enable_preference_extraction: + prefs = self._extractor.extract_turn(session.history) + if prefs.preferences: + print(f"[DEBUG] Extracted {len(prefs.preferences)} prefs from history (len={len(session.history)})") + extracted_prefs = self._add_preferences_as_memory( + prefs, query, user_id, ctx.turn_counter + ) + if extracted_prefs: + print(f"[DEBUG] Added {len(extracted_prefs)} to memory. Total cards: {len(self._memory_cards)}") + + # Separate global preferences (bypass retrieval) from conditional ones + global_notes = [] + retrieval_cards = self._memory_cards + retrieval_embeddings = self._memory_embeddings + retrieval_item_vectors = self._item_vectors + if self.enable_global_preferences: + global_cards = [c for c in self._memory_cards if c.is_global and c.user_id == user_id] + global_notes = [c.note_text for c in global_cards[:10]] # Cap at 10 + # Filter out global cards for retrieval + cond_indices = [i for i, c in enumerate(self._memory_cards) if not c.is_global] + if cond_indices: + retrieval_cards = [self._memory_cards[i] for i in cond_indices] + retrieval_embeddings = self._memory_embeddings[cond_indices] + if len(self._item_vectors) > 0: + retrieval_item_vectors = self._item_vectors[cond_indices] + else: + retrieval_cards = [] + retrieval_embeddings = np.zeros((0, self._memory_embeddings.shape[1]), dtype=np.float32) if len(self._memory_embeddings) > 0 else self._memory_embeddings + retrieval_item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32) + + # Query transformation for better retrieval matching + retrieval_queries = None + if self.enable_query_transform: + retrieval_queries = self._transform_query_for_retrieval(query) + + # Retrieve memories + if self.mode == "nopersonal": + candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_no_policy( + user_id=user_id, + query=query, + embed_model=self._embed_model, + reranker=self._reranker, + memory_cards=retrieval_cards, + memory_embeddings=retrieval_embeddings, + topk_dense=self._rl_cfg["dense_topk"], + topk_rerank=self._rl_cfg["rerank_topk"], + only_own_memories=self.only_own_memories, + queries=retrieval_queries, + dynamic_topk=self._rl_cfg["dynamic_topk"], + dynamic_min_k=self._rl_cfg["dynamic_min_k"], + dynamic_max_k=self._rl_cfg["dynamic_max_k"], + dynamic_score_ratio=self._rl_cfg["dynamic_score_ratio"], + ) + else: + beta_long = self._rl_cfg["beta_long"] + beta_short = self._rl_cfg["beta_short"] + candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_with_policy( + user_id=user_id, + query=query, + embed_model=self._embed_model, + reranker=self._reranker, + memory_cards=retrieval_cards, + memory_embeddings=retrieval_embeddings, + user_store=self._user_store, + item_vectors=retrieval_item_vectors, + topk_dense=self._rl_cfg["dense_topk"], + topk_rerank=self._rl_cfg["rerank_topk"], + beta_long=beta_long, + beta_short=beta_short, + tau=self._rl_cfg["tau"], + only_own_memories=self.only_own_memories, + sample=not self.eval_mode, + queries=retrieval_queries, + ) + + # Get selected memories + memories_t = [candidates[int(i)] for i in chosen_indices] if chosen_indices else [] + memory_notes = [m.note_text for m in memories_t] + + # Apply preference rewrite if enabled + if self.enable_preference_rewrite and memory_notes: + memory_notes = self._rewrite_preferences(memory_notes, query) + + # Debug: show retrieval info + if memories_t or global_notes: + print(f"[DEBUG-RETRIEVAL] User={user_id}, Query={query[:50]}...") + print(f"[DEBUG-RETRIEVAL] Global={len(global_notes)}, Candidates={len(candidates)}, Retrieved={len(memories_t)}") + for i, m in enumerate(memories_t[:3]): # Show top 3 + score = probs[chosen_indices[i]] if i < len(chosen_indices) and chosen_indices[i] < len(probs) else 0 + print(f"[DEBUG-RETRIEVAL] [{i+1}] score={score:.3f}: {m.note_text[:80]}...") + + # Combine all notes for prompt (global + retrieved) + # For chat(), we combine all notes; chat_prepare() handles them separately + if self.mode != "vanilla": + all_memory_notes = (global_notes if global_notes else []) + memory_notes + else: + all_memory_notes = memory_notes + + # Build prompt and count tokens + prompt_tokens = self._count_tokens(query) + for turn in session.history: + prompt_tokens += self._count_tokens(turn.text) + for note in all_memory_notes: + prompt_tokens += self._count_tokens(note) + + # Generate answer (with best-of-N if enabled) + if self.best_of_n > 1: + # Generate N responses and pick the best one + candidates_responses = [] + for i in range(self.best_of_n): + resp = self._chat_model.answer( + history=session.history, + memory_notes=all_memory_notes, + max_new_tokens=self._rl_cfg["max_new_tokens"], + temperature=0.8, # Slightly higher temp for diversity + ) + score = self._score_response(resp) + candidates_responses.append((resp, score)) + + # Sort by score (descending) and pick best + candidates_responses.sort(key=lambda x: x[1], reverse=True) + answer_t = candidates_responses[0][0] + best_score = candidates_responses[0][1] + + if len(candidates_responses) > 1: + print(f"[BEST-OF-{self.best_of_n}] Scores: {[f'{s:.2f}' for _, s in candidates_responses]}, picked score={best_score:.2f}") + else: + answer_t = self._chat_model.answer( + history=session.history, + memory_notes=all_memory_notes, + max_new_tokens=self._rl_cfg["max_new_tokens"], + ) + + completion_tokens = self._count_tokens(answer_t) + + # Add assistant turn to history + assist_turn = self._build_chat_turn(user_id, answer_t, "assistant", ctx.turn_counter) + session.history.append(assist_turn) + + # Update session state for next turn + session.last_query = query + session.last_answer = answer_t + session.last_memories = memories_t + session.last_query_embedding = e_q_t + session.last_candidate_item_vectors = cand_item_vecs + session.last_policy_probs = probs + session.last_chosen_indices = list(chosen_indices) if len(chosen_indices) > 0 else [] + + ctx.turn_counter += 1 + + # Build debug info + debug = DebugInfo( + selected_memory_ids=[m.card_id for m in memories_t], + selected_memory_notes=[m.note_text for m in memories_t], + selected_memory_scores=[float(probs[i]) if i < len(probs) else 0.0 for i in chosen_indices] if len(chosen_indices) > 0 else [], + user_vector_before=z_long_before + z_short_before, # Concatenated for simplicity + user_vector_after=user_state.z_long.tolist() + user_state.z_short.tolist(), + extracted_preferences=extracted_prefs, + extra={ + "num_candidates": len(candidates), + "num_total_memories": len(self._memory_cards), + "z_long_norm": float(np.linalg.norm(user_state.z_long)), + "z_short_norm": float(np.linalg.norm(user_state.z_short)), + } + ) + + # Build usage stats + usage = UsageStats( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + model=self._llm_name, + ) + + return AssistantResponse( + answer=answer_t, + usage=usage, + debug=debug, + ) + + def chat_prepare(self, user_id: str, query: str, skip_extraction: bool = False, skip_auto_reward: bool = False) -> dict: + """ + Prepare for chat without calling the LLM. + + This does all the preparation work (embedding, memory retrieval, etc.) + and returns the messages to send to the LLM along with context needed + for post-processing. + + Used for batch processing where messages are collected first, then + sent in batch to vLLM for concurrent processing. + + Args: + user_id: Unique identifier for the user. + query: Current user query/message. + + Returns: + Dict containing: + - messages: List of messages to send to LLM + - context: Dict with all state needed for chat_complete() + """ + ctx = self._get_or_create_session(user_id) + session = ctx.session_state + user_state = self._user_store.get_state(user_id) + + # Record user vector before for debug + z_long_before = user_state.z_long.copy().tolist() + z_short_before = user_state.z_short.copy().tolist() + + # Add user turn to history + user_turn = self._build_chat_turn(user_id, query, "user", ctx.turn_counter) + session.history.append(user_turn) + + # Vanilla mode: pure LLM without any memory or preference extraction + if self.mode == "vanilla": + e_q_t = np.zeros(4096, dtype=np.float32) + extracted_prefs = [] + candidates = [] + cand_item_vecs = np.array([]) + base_scores = np.array([]) + chosen_indices = [] + probs = np.array([]) + memories_t = [] + memory_notes = [] + else: + # Compute query embedding + embed_result = self._embed_model.encode([query], normalize=True, return_tensor=False) + if embed_result is None or len(embed_result) == 0: + raise RuntimeError(f"Embedding model returned empty result for query: {query[:100]}") + e_q_t = np.array(embed_result[0]) + + # Store pending RL update info from last turn + if session.last_query is not None and self.enable_rl_updates: + ctx.pending_rl_update = { + "last_query": session.last_query, + "last_answer": session.last_answer, + "last_memories": session.last_memories, + "last_query_embedding": session.last_query_embedding, + "current_query_embedding": e_q_t, + "last_candidate_item_vectors": session.last_candidate_item_vectors, + "last_policy_probs": session.last_policy_probs, + "last_chosen_indices": session.last_chosen_indices, + } + + # Auto-compute reward via LLM judge if enabled + # skip_auto_reward=True when batch framework handles rewards externally + if self._llm_reward_client is not None and not skip_auto_reward: + import asyncio + try: + reward, gating = asyncio.run(eval_step_llm( + q_t=session.last_query, + answer_t=session.last_answer, + q_t1=query, + memories_t=session.last_memories or [], + client=self._llm_reward_client, + )) + if gating > 0.0: + self.apply_feedback(Feedback( + user_id=user_id, + turn_id=ctx.turn_counter - 1, + reward=reward, + gating=gating, + )) + except Exception as e: + print(f"[LLM-Reward] Judge call failed, skipping update: {e}") + + # Extract preferences from conversation + extracted_prefs = [] + if self.enable_preference_extraction and not skip_extraction: + prefs = self._extractor.extract_turn(session.history) + if prefs.preferences: + print(f"[DEBUG] Extracted {len(prefs.preferences)} prefs from history (len={len(session.history)})") + extracted_prefs = self._add_preferences_as_memory( + prefs, query, user_id, ctx.turn_counter + ) + if extracted_prefs: + print(f"[DEBUG] Added {len(extracted_prefs)} to memory. Total cards: {len(self._memory_cards)}") + + # Separate global preferences (bypass retrieval) from conditional ones + global_notes = [] + retrieval_cards = self._memory_cards + retrieval_embeddings = self._memory_embeddings + retrieval_item_vectors = self._item_vectors + if self.enable_global_preferences: + global_cards = [c for c in self._memory_cards if c.is_global and c.user_id == user_id] + global_notes = [c.note_text for c in global_cards[:10]] # Cap at 10 + cond_indices = [i for i, c in enumerate(self._memory_cards) if not c.is_global] + if cond_indices: + retrieval_cards = [self._memory_cards[i] for i in cond_indices] + retrieval_embeddings = self._memory_embeddings[cond_indices] + if len(self._item_vectors) > 0: + retrieval_item_vectors = self._item_vectors[cond_indices] + else: + retrieval_cards = [] + retrieval_embeddings = np.zeros((0, self._memory_embeddings.shape[1]), dtype=np.float32) if len(self._memory_embeddings) > 0 else self._memory_embeddings + retrieval_item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32) + + # Query transformation for better retrieval matching + retrieval_queries = None + if self.enable_query_transform: + retrieval_queries = self._transform_query_for_retrieval(query) + + # Retrieve memories + if self.mode == "nopersonal": + candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_no_policy( + user_id=user_id, + query=query, + embed_model=self._embed_model, + reranker=self._reranker, + memory_cards=retrieval_cards, + memory_embeddings=retrieval_embeddings, + topk_dense=self._rl_cfg["dense_topk"], + topk_rerank=self._rl_cfg["rerank_topk"], + only_own_memories=self.only_own_memories, + queries=retrieval_queries, + dynamic_topk=self._rl_cfg["dynamic_topk"], + dynamic_min_k=self._rl_cfg["dynamic_min_k"], + dynamic_max_k=self._rl_cfg["dynamic_max_k"], + dynamic_score_ratio=self._rl_cfg["dynamic_score_ratio"], + ) + else: + beta_long = self._rl_cfg["beta_long"] + beta_short = self._rl_cfg["beta_short"] + candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_with_policy( + user_id=user_id, + query=query, + embed_model=self._embed_model, + reranker=self._reranker, + memory_cards=retrieval_cards, + memory_embeddings=retrieval_embeddings, + user_store=self._user_store, + item_vectors=retrieval_item_vectors, + topk_dense=self._rl_cfg["dense_topk"], + topk_rerank=self._rl_cfg["rerank_topk"], + beta_long=beta_long, + beta_short=beta_short, + tau=self._rl_cfg["tau"], + only_own_memories=self.only_own_memories, + sample=not self.eval_mode, + queries=retrieval_queries, + ) + + memories_t = [candidates[int(i)] for i in chosen_indices] if chosen_indices else [] + memory_notes = [m.note_text for m in memories_t] + + # Apply preference rewrite if enabled + if self.enable_preference_rewrite and memory_notes: + memory_notes = self._rewrite_preferences(memory_notes, query) + + if memories_t or global_notes: + print(f"[DEBUG-RETRIEVAL] User={user_id}, Query={query[:50]}...") + print(f"[DEBUG-RETRIEVAL] Global={len(global_notes)}, Candidates={len(candidates)}, Retrieved={len(memories_t)}") + for i, m in enumerate(memories_t[:3]): + score = probs[chosen_indices[i]] if i < len(chosen_indices) and chosen_indices[i] < len(probs) else 0 + print(f"[DEBUG-RETRIEVAL] [{i+1}] score={score:.3f}: {m.note_text[:80]}...") + + # Build prompt token count + prompt_tokens = self._count_tokens(query) + for turn in session.history: + prompt_tokens += self._count_tokens(turn.text) + all_notes = memory_notes + (global_notes if self.mode != "vanilla" else []) + for note in all_notes: + prompt_tokens += self._count_tokens(note) + + # Build messages for LLM (pass global_notes separately for distinct prompt sections) + effective_global = global_notes if (self.enable_global_preferences and self.mode != "vanilla") else None + messages = self._chat_model.build_messages( + history=session.history, + memory_notes=memory_notes, + max_new_tokens=self._rl_cfg["max_new_tokens"], + global_notes=effective_global, + ) + + # Return messages and context for chat_complete + return { + "messages": messages, + "context": { + "user_id": user_id, + "query": query, + "ctx": ctx, + "session": session, + "user_state": user_state, + "z_long_before": z_long_before, + "z_short_before": z_short_before, + "e_q_t": e_q_t, + "extracted_prefs": extracted_prefs, + "candidates": candidates, + "cand_item_vecs": cand_item_vecs, + "base_scores": base_scores, + "chosen_indices": chosen_indices, + "probs": probs, + "memories_t": memories_t, + "memory_notes": memory_notes, + "prompt_tokens": prompt_tokens, + } + } + + def chat_complete(self, answer_t: str, context: dict) -> AssistantResponse: + """ + Complete chat with LLM response. + + This takes the LLM response and context from chat_prepare(), and + does all post-processing (add to history, debug info, etc.). + + Args: + answer_t: The LLM response text. + context: Context dict from chat_prepare(). + + Returns: + AssistantResponse containing the answer, usage stats, and debug info. + """ + # Unpack context + user_id = context["user_id"] + query = context["query"] + ctx = context["ctx"] + session = context["session"] + user_state = context["user_state"] + z_long_before = context["z_long_before"] + z_short_before = context["z_short_before"] + e_q_t = context["e_q_t"] + extracted_prefs = context["extracted_prefs"] + candidates = context["candidates"] + cand_item_vecs = context["cand_item_vecs"] + chosen_indices = context["chosen_indices"] + probs = context["probs"] + memories_t = context["memories_t"] + memory_notes = context["memory_notes"] + prompt_tokens = context["prompt_tokens"] + + completion_tokens = self._count_tokens(answer_t) + + # Add assistant turn to history + assist_turn = self._build_chat_turn(user_id, answer_t, "assistant", ctx.turn_counter) + session.history.append(assist_turn) + + # Update session state for next turn + session.last_query = query + session.last_answer = answer_t + session.last_memories = memories_t + session.last_query_embedding = e_q_t + session.last_candidate_item_vectors = cand_item_vecs + session.last_policy_probs = probs + session.last_chosen_indices = list(chosen_indices) if len(chosen_indices) > 0 else [] + + ctx.turn_counter += 1 + + # Build debug info + debug = DebugInfo( + selected_memory_ids=[m.card_id for m in memories_t], + selected_memory_notes=[m.note_text for m in memories_t], + selected_memory_scores=[float(probs[i]) if i < len(probs) else 0.0 for i in chosen_indices] if len(chosen_indices) > 0 else [], + user_vector_before=z_long_before + z_short_before, + user_vector_after=user_state.z_long.tolist() + user_state.z_short.tolist(), + extracted_preferences=extracted_prefs, + extra={ + "num_candidates": len(candidates), + "num_total_memories": len(self._memory_cards), + "z_long_norm": float(np.linalg.norm(user_state.z_long)), + "z_short_norm": float(np.linalg.norm(user_state.z_short)), + } + ) + + # Build usage stats + usage = UsageStats( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + model=self._llm_name, + ) + + return AssistantResponse( + answer=answer_t, + usage=usage, + debug=debug, + ) + + def apply_extracted_preferences(self, user_id: str, pref_dict: dict) -> list: + """Apply pre-computed extraction results (from batch extraction) to memory.""" + prefs = PreferenceList.model_validate(pref_dict) + if not prefs.preferences: + return [] + ctx = self._get_or_create_session(user_id) + query = ctx.session_state.history[-1].text if ctx.session_state.history else "" + extracted = self._add_preferences_as_memory(prefs, query, user_id, ctx.turn_counter) + if extracted: + print(f"[DEBUG] Batch-added {len(extracted)} to memory. Total cards: {len(self._memory_cards)}") + return extracted + + def get_last_user_query(self, user_id: str) -> str: + """Get the last user message text for this user's session.""" + ctx = self._sessions.get(user_id) + if ctx and ctx.session_state.history: + for t in reversed(ctx.session_state.history): + if t.role == "user": + return t.text + return "" + + def reset_session(self, user_id: str) -> None: + """ + Reset session for a user (new chat window). + + This clears: + - Session conversation history + - Short-term user vector (z_short) + - Pending RL update info + + This preserves: + - Long-term user vector (z_long) + - User's memory cards (may be consolidated if enabled) + + Args: + user_id: The user whose session to reset. + """ + # Consolidate preferences at session end (before clearing session) + if self.enable_preference_consolidation: + self.consolidate_user_preferences(user_id) + + # Clear session context + if user_id in self._sessions: + del self._sessions[user_id] + + # Create fresh session + self._sessions[user_id] = _SessionContext( + session_state=OnlineSessionState(user_id=user_id), + turn_counter=0, + ) + + # Reset short-term vector but keep long-term + user_state = self._user_store.get_state(user_id) + user_state.z_short = np.zeros(self._rl_cfg["item_dim"], dtype=np.float32) + self._user_store.save_state(user_state) + + def reset_user(self, user_id: str) -> None: + """ + Completely reset a user (new "life"). + + This clears: + - Long-term user vector (z_long) + - Short-term user vector (z_short) + - User's memory cards + - Session history + - All cached state + + Args: + user_id: The user to reset. + """ + # Clear session + if user_id in self._sessions: + del self._sessions[user_id] + + # Reset user state vectors + user_state = self._user_store.get_state(user_id) + user_state.z_long = self._user_store.global_init_z.copy() + user_state.z_short = np.zeros(self._rl_cfg["item_dim"], dtype=np.float32) + user_state.reward_ma = 0.0 + self._user_store.save_state(user_state) + + # Find indices to KEEP (cards NOT belonging to this user) + # Must do this BEFORE modifying _memory_cards + keep_indices = [ + i for i, card in enumerate(self._memory_cards) + if card.user_id != user_id + ] + + # Filter memory cards + self._memory_cards = [self._memory_cards[i] for i in keep_indices] + + # Filter embeddings and item vectors to match + if len(keep_indices) > 0 and len(self._memory_embeddings) > 0: + self._memory_embeddings = self._memory_embeddings[keep_indices] + self._item_vectors = self._item_vectors[keep_indices] + else: + # No cards left or no embeddings + embed_dim = self._memory_embeddings.shape[1] if len(self._memory_embeddings) > 0 else 4096 + self._memory_embeddings = np.zeros((0, embed_dim), dtype=np.float32) + self._item_vectors = np.zeros((0, self._rl_cfg["item_dim"]), dtype=np.float32) + + def apply_feedback(self, feedback: Feedback) -> None: + """ + Apply feedback from user simulator or judge. + + This performs the REINFORCE update to user vectors based on + the reward signal from the previous turn. + + Args: + feedback: Feedback object containing reward, gating, and metadata. + + Notes: + - Should be called AFTER chat() but BEFORE the next chat() call + - Uses the stored context from the previous turn + - If enable_rl_updates is False, this is a no-op (logging only) + - If mode is "nopersonal", this is a no-op (baseline comparison) + """ + if not self.enable_rl_updates: + return + + # In "nopersonal" or "vanilla" mode, skip RL updates entirely (baseline) + if self.mode in ("nopersonal", "vanilla"): + return + + user_id = feedback.user_id + ctx = self._sessions.get(user_id) + + if ctx is None or ctx.pending_rl_update is None: + return + + pending = ctx.pending_rl_update + user_state = self._user_store.get_state(user_id) + + # Check if we have the necessary data for RL update + if (pending.get("last_candidate_item_vectors") is not None and + pending.get("last_policy_probs") is not None and + pending.get("last_chosen_indices") is not None and + len(pending["last_chosen_indices"]) > 0): + + # Extract chosen vectors + chosen_indices = pending["last_chosen_indices"] + candidate_vectors = pending["last_candidate_item_vectors"] + + if len(candidate_vectors) > 0: + print(f"[DEBUG-REINFORCE] User={user_id} reward={feedback.reward:.2f} " + f"n_candidates={len(candidate_vectors)} chosen={chosen_indices} " + f"probs_shape={pending['last_policy_probs'].shape if hasattr(pending['last_policy_probs'], 'shape') else 'N/A'}") + # REINFORCE expects: + # - item_vectors: ALL candidate vectors [K, k] + # - chosen_indices: indices into those candidates + # - policy_probs: probabilities over all K candidates [K] + updated = reinforce_update_user_state( + user_state=user_state, + item_vectors=candidate_vectors, # All candidates, not just chosen + chosen_indices=chosen_indices, # Original indices into candidates + policy_probs=pending["last_policy_probs"], + reward_hat=feedback.reward, + gating=feedback.gating, + tau=self._rl_cfg["tau"], + eta_long=self._rl_cfg["eta_long"], + eta_short=self._rl_cfg["eta_short"], + ema_alpha=self._rl_cfg["ema_alpha"], + short_decay=self._rl_cfg["short_decay"], + ) + + print(f"[DEBUG-REINFORCE] updated={updated} z_long_norm={np.linalg.norm(user_state.z_long):.15e}") + if updated: + self._user_store.save_state(user_state) + + # Clear pending update + ctx.pending_rl_update = None + + def get_user_state_summary(self, user_id: str) -> Dict[str, Any]: + """ + Get a summary of the user's current state (for debugging/analysis). + + Args: + user_id: The user to query. + + Returns: + Dictionary with user state information. + """ + user_state = self._user_store.get_state(user_id) + ctx = self._sessions.get(user_id) + + user_memory_count = sum( + 1 for card in self._memory_cards if card.user_id == user_id + ) + + return { + "user_id": user_id, + "z_long_norm": float(np.linalg.norm(user_state.z_long)), + "z_short_norm": float(np.linalg.norm(user_state.z_short)), + "reward_ma": user_state.reward_ma, + "session_history_length": len(ctx.session_state.history) if ctx else 0, + "turn_counter": ctx.turn_counter if ctx else 0, + "user_memory_count": user_memory_count, + "total_memory_count": len(self._memory_cards), + } + + def persist(self) -> None: + """ + Persist all state to disk. + + Call this at the end of an evaluation run to save: + - User state vectors + - Memory cards + """ + # Save user store + self._user_store.persist() + + # Save memory cards + with open(self._memory_cards_path, "w", encoding="utf-8") as f: + for card in self._memory_cards: + f.write(card.model_dump_json() + "\n") + + # Save embeddings + np.save(self._memory_embeddings_path, self._memory_embeddings) + + # Save item projection with updated vectors + if self._projection is not None: + np.savez( + self._item_projection_path, + P=self._projection.P, + mean=self._projection.mean, + V=self._item_vectors, + ) + + print("[PersonalizedLLM] State persisted to disk.") + + +# ============================================================================= +# Convenience Factory +# ============================================================================= + +def create_personalized_llm( + config_path: Optional[str] = None, + **kwargs +) -> PersonalizedLLM: + """ + Factory function to create a PersonalizedLLM instance. + + Args: + config_path: Optional path to configuration file. + **kwargs: Additional arguments passed to PersonalizedLLM constructor. + + Returns: + Configured PersonalizedLLM instance. + """ + return PersonalizedLLM(config_path=config_path, **kwargs) + diff --git a/src/personalization/types.py b/src/personalization/types.py new file mode 100644 index 0000000..a25b560 --- /dev/null +++ b/src/personalization/types.py @@ -0,0 +1,4 @@ +from personalization.retrieval.preference_store.schemas import ChatTurn + +__all__ = ["ChatTurn"] + diff --git a/src/personalization/user_model/__init__.py b/src/personalization/user_model/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/user_model/__init__.py diff --git a/src/personalization/user_model/features.py b/src/personalization/user_model/features.py new file mode 100644 index 0000000..a4508b4 --- /dev/null +++ b/src/personalization/user_model/features.py @@ -0,0 +1,49 @@ +import numpy as np +from dataclasses import dataclass +from sklearn.decomposition import PCA + +@dataclass +class ItemProjection: + P: np.ndarray # [k, d] + mean: np.ndarray # [d] + + @classmethod + def from_pca(cls, embeddings: np.ndarray, k: int) -> "ItemProjection": + """ + embeddings: [M, d] + """ + mean = embeddings.mean(axis=0) + centered = embeddings - mean + + # Ensure k is not larger than min(n_samples, n_features) + n_samples, n_features = embeddings.shape + actual_k = min(k, n_samples, n_features) + + pca = PCA(n_components=actual_k) + pca.fit(centered) + + # pca.components_: [k, d] + P = pca.components_ # Each row is a principal component vector + + # If we had to reduce k, we might want to pad P or handle it? + # For now, let's assume we get what we asked for or less if data is small. + # But for the system we want fixed k. + # If actual_k < k, we should pad with zeros to match expected dimension. + if actual_k < k: + padding = np.zeros((k - actual_k, n_features), dtype=P.dtype) + P = np.vstack([P, padding]) + + return cls(P=P, mean=mean) + + def transform_embeddings(self, E: np.ndarray) -> np.ndarray: + """ + E: [N, d] -> [N, k] + """ + return (E - self.mean) @ self.P.T + + def transform_vector(self, e: np.ndarray) -> np.ndarray: + """ + e: [d] -> [k] + """ + return self.P @ (e - self.mean) + diff --git a/src/personalization/user_model/policy/__init__.py b/src/personalization/user_model/policy/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/user_model/policy/__init__.py diff --git a/src/personalization/user_model/policy/optimizer.py b/src/personalization/user_model/policy/optimizer.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/user_model/policy/optimizer.py diff --git a/src/personalization/user_model/policy/reinforce.py b/src/personalization/user_model/policy/reinforce.py new file mode 100644 index 0000000..adfaef7 --- /dev/null +++ b/src/personalization/user_model/policy/reinforce.py @@ -0,0 +1,104 @@ +from typing import Sequence, List +from dataclasses import dataclass +import numpy as np + +from personalization.user_model.tensor_store import UserState + +@dataclass +class PolicyScores: + scores: np.ndarray # [K] s(q_t, m; u) + probs: np.ndarray # [K] π_z(m|q_t) + +def compute_policy_scores( + base_scores: np.ndarray, # [K], from reranker + user_state: UserState, + item_vectors: np.ndarray, # [K, k], v_m for the K candidates + beta_long: float, + beta_short: float, + tau: float, +) -> PolicyScores: + """ + Compute personalized scores and softmax probabilities. + s(q_t, m; u) = s_0(q_t,m) + z_t^{(eff)}.T @ v_m + z_t^{(eff)} = beta_long * z_long + beta_short * z_short + """ + if len(item_vectors) == 0: + return PolicyScores(scores=np.array([]), probs=np.array([])) + + z_eff = beta_long * user_state.z_long + beta_short * user_state.z_short + + # Calculate personalized term + # item_vectors: [K, k] + # z_eff: [k] + # term: [K] + personalization_term = np.dot(item_vectors, z_eff) + + # Total scores + scores = base_scores + personalization_term + + # Softmax + # Use exp(score/tau) + # Subtract max for stability + scaled_scores = scores / tau + exp_scores = np.exp(scaled_scores - np.max(scaled_scores)) + probs = exp_scores / np.sum(exp_scores) + + return PolicyScores(scores=scores, probs=probs) + +def reinforce_update_user_state( + user_state: UserState, + item_vectors: np.ndarray, # [K, k] for candidates + chosen_indices: Sequence[int], # indices of A_t in 0..K-1 + policy_probs: np.ndarray, # [K] π_z(m|q_t) + reward_hat: float, # \hat r_t + gating: float, # g_t + tau: float, + eta_long: float, + eta_short: float, + ema_alpha: float, + short_decay: float, +) -> bool: + """ + In-place update user_state.z_long / z_short / reward_ma via REINFORCE. + Returns True if update occurred, False otherwise. + """ + if len(chosen_indices) == 0: + return False + + # 1. Baseline Advantage + advantage = gating * (reward_hat - user_state.reward_ma) + + # Optimization: skip if advantage is negligible + if abs(advantage) < 1e-6: + return False + + # 2. Chosen Vector Average (v_{chosen,t}) + chosen_mask = np.zeros(len(item_vectors), dtype=np.float32) + for idx in chosen_indices: + idx_int = int(idx) + if 0 <= idx_int < len(item_vectors): + chosen_mask[idx_int] = 1.0 + + if chosen_mask.sum() == 0: + return False + + chosen_mask /= chosen_mask.sum() # Normalize to average + v_chosen = np.dot(chosen_mask, item_vectors) # [k] + + # 3. Expected Vector (\mu_t(z)) + # policy_probs: [K] + # item_vectors: [K, k] + v_expect = np.dot(policy_probs, item_vectors) # [k] + + # 4. Gradient Direction + grad = (advantage / tau) * (v_chosen - v_expect) + + # 5. Update Vectors + user_state.z_long += eta_long * grad + user_state.z_short = (1.0 - short_decay) * user_state.z_short + eta_short * grad + + # 6. Update Reward Baseline (EMA) + user_state.reward_ma = (1.0 - ema_alpha) * user_state.reward_ma + ema_alpha * reward_hat + + return True + diff --git a/src/personalization/user_model/scoring.py b/src/personalization/user_model/scoring.py new file mode 100644 index 0000000..75ffc84 --- /dev/null +++ b/src/personalization/user_model/scoring.py @@ -0,0 +1,25 @@ +import numpy as np +from .tensor_store import UserState + +def score_with_user( + base_score: float, + user_state: UserState, + v_m: np.ndarray, # [k] + beta_long: float, + beta_short: float, +) -> float: + """ + Personalized scoring: + s = base_score + (beta_long * z_long + beta_short * z_short) . v_m + Day2: beta_long = beta_short = 0 -> s == base_score + """ + z_eff = beta_long * user_state.z_long + beta_short * user_state.z_short + # dot product + # Ensure shapes match + if v_m.shape != z_eff.shape: + # Just in case of dimension mismatch + return float(base_score) + + term = np.dot(z_eff, v_m) + return float(base_score + term) + diff --git a/src/personalization/user_model/session_state.py b/src/personalization/user_model/session_state.py new file mode 100644 index 0000000..5cd2243 --- /dev/null +++ b/src/personalization/user_model/session_state.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass, field +from typing import List, Optional +import numpy as np + +from personalization.retrieval.preference_store.schemas import ChatTurn, MemoryCard + +@dataclass +class OnlineSessionState: + user_id: str + history: List[ChatTurn] = field(default_factory=list) + last_query: Optional[str] = None + last_answer: Optional[str] = None + last_memories: List[MemoryCard] = field(default_factory=list) + last_query_embedding: Optional[np.ndarray] = None + last_candidate_item_vectors: Optional[np.ndarray] = None # [K, k] + last_policy_probs: Optional[np.ndarray] = None # [K] + last_chosen_indices: List[int] = field(default_factory=list) + + diff --git a/src/personalization/user_model/tensor_store.py b/src/personalization/user_model/tensor_store.py new file mode 100644 index 0000000..42dbf4e --- /dev/null +++ b/src/personalization/user_model/tensor_store.py @@ -0,0 +1,80 @@ +import numpy as np +from dataclasses import dataclass +from typing import Dict, Optional +import os + +@dataclass +class UserState: + user_id: str + z_long: np.ndarray # [k] + z_short: np.ndarray # [k] + reward_ma: float # baseline for reward, init 0.0 + +class UserTensorStore: + def __init__(self, k: int, path: str): + self.k = k + self.path = path + self._states: Dict[str, UserState] = {} + self._load() + + # Calculate global mean for initialization + if self._states: + z_all = np.stack([st.z_long for st in self._states.values()]) + self.global_init_z = np.mean(z_all, axis=0) + else: + self.global_init_z = np.zeros(self.k, dtype=np.float32) + + def _load(self): + if os.path.exists(self.path): + try: + data = np.load(self.path, allow_pickle=True) + # Assume saved as dict of user_id -> dict/object + # For simplicity, let's say we save a single dict in a .npy or .npz + # But np.save/load with pickle is tricky for complex objects. + # Let's save as .npz where each key is user_id and value is a structured array or just use z_long for now? + # A robust way for prototype: + # save multiple arrays: "u1_long", "u1_short", "u1_meta" + pass + # For Day 2 prototype, we might just re-init from init script or rely on memory if not persisting strictly. + # But let's try to load if we can. + + # Let's implement a simple npz schema: + # keys: "{uid}_long", "{uid}_short", "{uid}_meta" (meta=[reward_ma]) + for key in data.files: + if key.endswith("_long"): + uid = key[:-5] + z_long = data[key] + z_short = data.get(f"{uid}_short", np.zeros(self.k)) + meta = data.get(f"{uid}_meta", np.array([0.0])) + self._states[uid] = UserState(uid, z_long, z_short, float(meta[0])) + except Exception as e: + print(f"Warning: Failed to load UserStore from {self.path}: {e}") + + def _save(self): + # Save to npz + save_dict = {} + for uid, state in self._states.items(): + save_dict[f"{uid}_long"] = state.z_long + save_dict[f"{uid}_short"] = state.z_short + save_dict[f"{uid}_meta"] = np.array([state.reward_ma]) + np.savez(self.path, **save_dict) + + def get_state(self, user_id: str) -> UserState: + if user_id not in self._states: + # Lazy init with global mean for new users + state = UserState( + user_id=user_id, + z_long=self.global_init_z.copy(), + z_short=np.zeros(self.k, dtype=np.float32), + reward_ma=0.0, + ) + self._states[user_id] = state + return self._states[user_id] + + def save_state(self, state: UserState) -> None: + self._states[state.user_id] = state + + def persist(self): + """Public method to force save to disk.""" + self._save() + diff --git a/src/personalization/utils/__init__.py b/src/personalization/utils/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/utils/__init__.py diff --git a/src/personalization/utils/ids.py b/src/personalization/utils/ids.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/utils/ids.py diff --git a/src/personalization/utils/io.py b/src/personalization/utils/io.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/utils/io.py diff --git a/src/personalization/utils/logging.py b/src/personalization/utils/logging.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/utils/logging.py diff --git a/src/personalization/utils/timing.py b/src/personalization/utils/timing.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/utils/timing.py |
