From b6c3e4e51eeab703b40284459c6e9fff2151216c Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Wed, 18 Mar 2026 18:25:09 -0500 Subject: Initial release: VARS - personalized LLM with RAG and user vector learning --- configs/base.yaml | 0 configs/local_models.yaml | 66 + configs/reranker.yaml | 3 + configs/retrieval.yaml | 5 + configs/user_model.yaml | 14 + pyproject.toml | 38 + requirements.txt | 9 + src/personalization/__init__.py | 0 src/personalization/config/__init__.py | 0 src/personalization/config/registry.py | 146 ++ src/personalization/config/settings.py | 75 + src/personalization/feedback/__init__.py | 0 src/personalization/feedback/gating.py | 72 + src/personalization/feedback/handlers.py | 87 + src/personalization/feedback/llm_reward.py | 253 +++ src/personalization/feedback/local_llm_reward.py | 370 ++++ src/personalization/feedback/online_update.py | 0 src/personalization/feedback/reward_model.py | 64 + src/personalization/feedback/sampler.py | 109 ++ src/personalization/feedback/schemas.py | 23 + src/personalization/models/__init__.py | 0 src/personalization/models/embedding/__init__.py | 11 + src/personalization/models/embedding/base.py | 37 + src/personalization/models/embedding/qwen3_8b.py | 89 + src/personalization/models/llm/__init__.py | 4 + src/personalization/models/llm/base.py | 29 + src/personalization/models/llm/prompt_builder.py | 0 src/personalization/models/llm/vllm_chat.py | 244 +++ .../models/preference_extractor/__init__.py | 5 + .../models/preference_extractor/base.py | 17 + .../models/preference_extractor/gpt4o_extractor.py | 165 ++ .../models/preference_extractor/llm_extractor.py | 153 ++ .../models/preference_extractor/rule_extractor.py | 205 +++ src/personalization/models/reranker/__init__.py | 0 src/personalization/models/reranker/base.py | 16 + .../models/reranker/qwen3_reranker.py | 96 + src/personalization/retrieval/__init__.py | 0 src/personalization/retrieval/chunking/__init__.py | 0 src/personalization/retrieval/chunking/rules.py | 0 src/personalization/retrieval/pipeline.py | 388 +++++ .../retrieval/preference_store/__init__.py | 0 .../retrieval/preference_store/base.py | 0 .../retrieval/preference_store/schemas.py | 48 + .../retrieval/preference_store/vector_kv.py | 0 src/personalization/retrieval/rerank.py | 0 src/personalization/retrieval/store/__init__.py | 0 src/personalization/retrieval/store/base.py | 0 src/personalization/retrieval/store/faiss_store.py | 0 src/personalization/serving/__init__.py | 22 + src/personalization/serving/personalized_llm.py | 1835 ++++++++++++++++++++ src/personalization/types.py | 4 + src/personalization/user_model/__init__.py | 0 src/personalization/user_model/features.py | 49 + src/personalization/user_model/policy/__init__.py | 0 src/personalization/user_model/policy/optimizer.py | 0 src/personalization/user_model/policy/reinforce.py | 104 ++ src/personalization/user_model/scoring.py | 25 + src/personalization/user_model/session_state.py | 19 + src/personalization/user_model/tensor_store.py | 80 + src/personalization/utils/__init__.py | 0 src/personalization/utils/ids.py | 0 src/personalization/utils/io.py | 0 src/personalization/utils/logging.py | 0 src/personalization/utils/timing.py | 0 64 files changed, 4979 insertions(+) create mode 100644 configs/base.yaml create mode 100644 configs/local_models.yaml create mode 100644 configs/reranker.yaml create mode 100644 configs/retrieval.yaml create mode 100644 configs/user_model.yaml create mode 100644 pyproject.toml create mode 100644 requirements.txt create mode 100644 src/personalization/__init__.py create mode 100644 src/personalization/config/__init__.py create mode 100644 src/personalization/config/registry.py create mode 100644 src/personalization/config/settings.py create mode 100644 src/personalization/feedback/__init__.py create mode 100644 src/personalization/feedback/gating.py create mode 100644 src/personalization/feedback/handlers.py create mode 100644 src/personalization/feedback/llm_reward.py create mode 100644 src/personalization/feedback/local_llm_reward.py create mode 100644 src/personalization/feedback/online_update.py create mode 100644 src/personalization/feedback/reward_model.py create mode 100644 src/personalization/feedback/sampler.py create mode 100644 src/personalization/feedback/schemas.py create mode 100644 src/personalization/models/__init__.py create mode 100644 src/personalization/models/embedding/__init__.py create mode 100644 src/personalization/models/embedding/base.py create mode 100644 src/personalization/models/embedding/qwen3_8b.py create mode 100644 src/personalization/models/llm/__init__.py create mode 100644 src/personalization/models/llm/base.py create mode 100644 src/personalization/models/llm/prompt_builder.py create mode 100644 src/personalization/models/llm/vllm_chat.py create mode 100644 src/personalization/models/preference_extractor/__init__.py create mode 100644 src/personalization/models/preference_extractor/base.py create mode 100644 src/personalization/models/preference_extractor/gpt4o_extractor.py create mode 100644 src/personalization/models/preference_extractor/llm_extractor.py create mode 100644 src/personalization/models/preference_extractor/rule_extractor.py create mode 100644 src/personalization/models/reranker/__init__.py create mode 100644 src/personalization/models/reranker/base.py create mode 100644 src/personalization/models/reranker/qwen3_reranker.py create mode 100644 src/personalization/retrieval/__init__.py create mode 100644 src/personalization/retrieval/chunking/__init__.py create mode 100644 src/personalization/retrieval/chunking/rules.py create mode 100644 src/personalization/retrieval/pipeline.py create mode 100644 src/personalization/retrieval/preference_store/__init__.py create mode 100644 src/personalization/retrieval/preference_store/base.py create mode 100644 src/personalization/retrieval/preference_store/schemas.py create mode 100644 src/personalization/retrieval/preference_store/vector_kv.py create mode 100644 src/personalization/retrieval/rerank.py create mode 100644 src/personalization/retrieval/store/__init__.py create mode 100644 src/personalization/retrieval/store/base.py create mode 100644 src/personalization/retrieval/store/faiss_store.py create mode 100644 src/personalization/serving/__init__.py create mode 100644 src/personalization/serving/personalized_llm.py create mode 100644 src/personalization/types.py create mode 100644 src/personalization/user_model/__init__.py create mode 100644 src/personalization/user_model/features.py create mode 100644 src/personalization/user_model/policy/__init__.py create mode 100644 src/personalization/user_model/policy/optimizer.py create mode 100644 src/personalization/user_model/policy/reinforce.py create mode 100644 src/personalization/user_model/scoring.py create mode 100644 src/personalization/user_model/session_state.py create mode 100644 src/personalization/user_model/tensor_store.py create mode 100644 src/personalization/utils/__init__.py create mode 100644 src/personalization/utils/ids.py create mode 100644 src/personalization/utils/io.py create mode 100644 src/personalization/utils/logging.py create mode 100644 src/personalization/utils/timing.py diff --git a/configs/base.yaml b/configs/base.yaml new file mode 100644 index 0000000..e69de29 diff --git a/configs/local_models.yaml b/configs/local_models.yaml new file mode 100644 index 0000000..ea001ea --- /dev/null +++ b/configs/local_models.yaml @@ -0,0 +1,66 @@ +# Base path for all models +_base_path: &base ./ + +models: + llm: + # New Multi-Backend Config + qwen_1_5b: + backend: qwen + path: .//models/qwen2.5-1.5b-instruct + device: auto + dtype: bfloat16 + max_context_length: 4096 + + llama_8b: + backend: llama + path: .//models/llama-3.1-8b-instruct + device: auto + dtype: bfloat16 + max_context_length: 8192 + + # vLLM backend for high-throughput experiments + llama_8b_vllm: + backend: vllm + path: .//models/llama-3.1-8b-instruct + vllm_url: http://localhost:8003/v1 + model_name: meta-llama/Llama-3.1-8B-Instruct + max_context_length: 8192 + + # Legacy fallback (needed if from_config is called directly without name) + hf_id: Qwen/Qwen2.5-1.5B-Instruct + local_path: .//models/qwen2.5-1.5b-instruct + dtype: bfloat16 + device_map: auto + + preference_extractor: + # Default/Legacy + default: + hf_id: Qwen/Qwen2.5-0.5B-Instruct + local_path: .//models/qwen2.5-0.5b-instruct + dtype: bfloat16 + device_map: auto + # New SFT Extractor + qwen3_0_6b_sft: + path: .//models/pref-extractor-qwen3-0.6b-sft + prompt_template_path: fine_tuning_prompt_template.txt + device: auto + dtype: bfloat16 + max_new_tokens: 512 + embedding: + qwen3: + hf_id: Qwen/Qwen3-Embedding-8B + local_path: .//models/qwen3-embedding-8b + nemotron: + hf_id: nvidia/llama-embed-nemotron-8b + local_path: .//models/llama-embed-nemotron-8b + reranker: + qwen3_8b: + hf_id: Qwen/Qwen3-Reranker-8B + local_path: .//models/rerankers/qwen3-reranker-8b + dtype: bfloat16 + device_map: auto + bge_base: + hf_id: BAAI/bge-reranker-base + local_path: .//models/rerankers/bge-reranker-base + dtype: float16 + device_map: auto diff --git a/configs/reranker.yaml b/configs/reranker.yaml new file mode 100644 index 0000000..c376fc7 --- /dev/null +++ b/configs/reranker.yaml @@ -0,0 +1,3 @@ +reranker: + default: qwen3_8b + diff --git a/configs/retrieval.yaml b/configs/retrieval.yaml new file mode 100644 index 0000000..d2e100e --- /dev/null +++ b/configs/retrieval.yaml @@ -0,0 +1,5 @@ +retrieval: + dense_topk: 64 # Initial recall count + rerank_topk: 8 # Count fed to LLM after rerank + pca_dim: 256 + diff --git a/configs/user_model.yaml b/configs/user_model.yaml new file mode 100644 index 0000000..7b8e230 --- /dev/null +++ b/configs/user_model.yaml @@ -0,0 +1,14 @@ +user_model: + item_dim: 256 + user_dim: 256 + beta_long: 0.1 # Enable personalization for Day 4 + beta_short: 0.3 + tau: 1.0 + preference_extractor_name: qwen3_0_6b_sft # Switch to new extractor + rl: + eta_long: 1.0e-3 + eta_short: 5.0e-3 + ema_alpha: 0.05 + short_decay: 0.1 + +llm_name: llama_8b # Switch backend to Llama 3.1 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..a76f9bb --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,38 @@ +[ +tool.black +] + +[build-system] +requires = ["hatchling>=1.18"] +build-backend = "hatchling.build" + +[project] +name = "personalization-user-model" +version = "0.1.0" +description = "Personalized memory RAG system with online user modeling" +readme = "README.md" +requires-python = ">=3.10" +license = { text = "Apache-2.0" } +authors = [ + { name = "Anonymous" } +] +dependencies = [ + "torch>=2.3.0", + "transformers>=4.44.0", + "accelerate>=0.33.0", + "huggingface_hub>=0.24.0", + "pydantic>=2.7.0", + "pyyaml>=6.0.0", + "safetensors>=0.4.2" +] + +[project.urls] +homepage = "https://example.com" + +[tool.hatch.build.targets.wheel] +packages = ["src/personalization"] + +[tool.hatch.metadata] +allow-direct-references = true + + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..1e227de --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +torch>=2.3.0 +transformers>=4.44.0 +accelerate>=0.33.0 +huggingface_hub>=0.24.0 +pydantic>=2.7.0 +PyYAML>=6.0.0 +safetensors>=0.4.2 + + diff --git a/src/personalization/__init__.py b/src/personalization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/personalization/config/__init__.py b/src/personalization/config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/personalization/config/registry.py b/src/personalization/config/registry.py new file mode 100644 index 0000000..6badeae --- /dev/null +++ b/src/personalization/config/registry.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any, Dict, Optional +import torch +import yaml + +from personalization.config import settings + +# Project root for resolving config paths +_PROJECT_ROOT = Path(__file__).parent.parent.parent.parent + +# Avoid circular imports by NOT importing extractors here at top level +# from personalization.models.preference_extractor.base import PreferenceExtractorBase +# from personalization.models.preference_extractor.rule_extractor import QwenRuleExtractor +# from personalization.models.preference_extractor.gpt4o_extractor import GPT4OExtractor +# from personalization.models.preference_extractor.llm_extractor import PreferenceExtractorLLM + +_DTYPE_MAP: Dict[str, torch.dtype] = { + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "float32": torch.float32, +} + +def choose_dtype(preferred: Optional[str] = None) -> torch.dtype: + if preferred and preferred.lower() in _DTYPE_MAP: + dt = _DTYPE_MAP[preferred.lower()] + else: + dt = torch.bfloat16 if torch.cuda.is_available() else torch.float32 + if dt is torch.bfloat16 and not torch.cuda.is_available(): + return torch.float32 + return dt + +def choose_device_map(spec: Optional[str] = "auto") -> Any: + return spec or "auto" + +def ensure_local_path(path_str: str) -> str: + path = Path(path_str) + if not path.exists(): + path.mkdir(parents=True, exist_ok=True) + return str(path) + +# --- Chat Model Factory --- +def get_chat_model(name: str, device_override: Optional[str] = None): + """ + Get a chat model by name. + + Args: + name: Model name (e.g., "qwen_1_5b", "llama_8b") + device_override: Optional device override (e.g., "cuda:2"). If None, uses config default. + """ + from personalization.models.llm.base import ChatModel + from personalization.models.llm.qwen_instruct import QwenInstruct + from personalization.models.llm.llama_instruct import LlamaChatModel + from personalization.models.llm.vllm_chat import VLLMChatModel + + cfg = settings.load_local_models_config() + + # Try to load raw config to support multi-backend map + with open(_PROJECT_ROOT / "configs/local_models.yaml", "r") as f: + raw_cfg = yaml.safe_load(f) + + models = raw_cfg.get("models", {}).get("llm", {}) + + # If models['llm'] is a dict of configs (new style) + if isinstance(models, dict) and "backend" in models.get(name, {}): + spec = models[name] + backend = spec.get("backend", "qwen") + path = spec["path"] + device = device_override or spec.get("device", "cuda") # Use override if provided + dtype = spec.get("dtype", "bfloat16") + max_len = spec.get("max_context_length", 4096) + + if backend == "qwen": + return QwenInstruct( + model_path=path, + device=device, + dtype=choose_dtype(dtype), # Converts string to torch.dtype + max_context_length=max_len + ) + elif backend == "llama": + return LlamaChatModel( + model_path=path, + device=device, + dtype=choose_dtype(dtype), # Converts string to torch.dtype + max_context_length=max_len + ) + elif backend == "vllm": + # Use vLLM HTTP API for high-throughput inference + vllm_url = spec.get("vllm_url", "http://localhost:8003/v1") + return VLLMChatModel( + vllm_url=vllm_url, + model_name=spec.get("model_name"), + max_context_length=max_len + ) + + # Fallback to legacy single config + return QwenInstruct.from_config(cfg) + +def get_preference_extractor(name: Optional[str] = None): + # Deferred imports to break circular dependency + from personalization.models.preference_extractor.rule_extractor import QwenRuleExtractor + from personalization.models.preference_extractor.gpt4o_extractor import GPT4OExtractor + from personalization.models.preference_extractor.llm_extractor import PreferenceExtractorLLM + + cfg = settings.load_local_models_config() + pref_cfg = cfg.preference_extractor + + if name is None: + if isinstance(pref_cfg, dict) and "qwen3_0_6b_sft" in pref_cfg: + name = "qwen3_0_6b_sft" + else: + name = "rule" + + if isinstance(pref_cfg, dict) and name in pref_cfg: + spec = pref_cfg[name] + if name == "qwen3_0_6b_sft": + # Use QwenRuleExtractor which we have updated for SFT End-to-End logic + return QwenRuleExtractor( + model_path=spec["path"], + device_map=spec.get("device", "auto"), + dtype=choose_dtype(spec.get("dtype", "bfloat16")), + ) + # Add 'default' handling if mapped to rule/gpt + if name == "default": + pass + + if name == "gpt4o": + return GPT4OExtractor.from_config(cfg) + elif name == "gpt5_mini": + import os + return GPT4OExtractor(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-5-mini") + elif name == "rule": + if isinstance(pref_cfg, dict): + if "default" in pref_cfg: + # Manually construct to bypass ModelSpec mismatch if needed + spec_dict = pref_cfg["default"] + return QwenRuleExtractor( + model_path=spec_dict["local_path"], + dtype=choose_dtype(spec_dict.get("dtype")), + device_map=choose_device_map(spec_dict.get("device_map")) + ) + else: + return QwenRuleExtractor.from_config(cfg) + + raise ValueError(f"Could not load preference extractor: {name}") diff --git a/src/personalization/config/settings.py b/src/personalization/config/settings.py new file mode 100644 index 0000000..8f0cc8a --- /dev/null +++ b/src/personalization/config/settings.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import os +from pathlib import Path +from typing import Optional, Any, Dict + +import yaml +from pydantic import BaseModel, Field + + +class ModelSpec(BaseModel): + hf_id: str = Field(..., description="Hugging Face repository id") + local_path: str = Field(..., description="Local directory for model weights") + dtype: Optional[str] = Field( + default="bfloat16", description="Preferred torch dtype: bfloat16|float16|float32" + ) + device_map: Optional[str] = Field(default="auto", description="Device map policy") + + +class EmbeddingModelsConfig(BaseModel): + qwen3: Optional[ModelSpec] = None + nemotron: Optional[ModelSpec] = None + + +class RerankerModelsConfig(BaseModel): + qwen3_8b: Optional[ModelSpec] = None + + +class LocalModelsConfig(BaseModel): + llm: ModelSpec + preference_extractor: Any # Allow flexible dict or ModelSpec for now to support map + embedding: Optional[EmbeddingModelsConfig] = None + reranker: Optional[RerankerModelsConfig] = None + + +def _resolve_config_path(env_key: str, default_rel: str) -> Path: + value = os.getenv(env_key) + if value: + return Path(value).expanduser().resolve() + # Use project root (parent of src/personalization/config) instead of cwd + project_root = Path(__file__).parent.parent.parent.parent + return (project_root / default_rel).resolve() + + +def load_local_models_config(path: Optional[str] = None) -> LocalModelsConfig: + config_path = Path(path) if path else _resolve_config_path( + "LOCAL_MODELS_CONFIG", "configs/local_models.yaml" + ) + with open(config_path, "r", encoding="utf-8") as f: + raw = yaml.safe_load(f) or {} + models = raw.get("models", {}) + embedding_cfg = None + if "embedding" in models: + emb = models["embedding"] or {} + # dtype/device_map are not necessary for embedders; ModelSpec still accepts them + embedding_cfg = EmbeddingModelsConfig( + qwen3=ModelSpec(**emb["qwen3"]) if "qwen3" in emb else None, + nemotron=ModelSpec(**emb["nemotron"]) if "nemotron" in emb else None, + ) + + reranker_cfg = None + if "reranker" in models: + rer = models["reranker"] or {} + reranker_cfg = RerankerModelsConfig( + qwen3_8b=ModelSpec(**rer["qwen3_8b"]) if "qwen3_8b" in rer else None + ) + + return LocalModelsConfig( + llm=ModelSpec(**models["llm"]), + preference_extractor=models["preference_extractor"], # Pass raw dict/value + embedding=embedding_cfg, + reranker=reranker_cfg, + ) + + diff --git a/src/personalization/feedback/__init__.py b/src/personalization/feedback/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/personalization/feedback/gating.py b/src/personalization/feedback/gating.py new file mode 100644 index 0000000..d741874 --- /dev/null +++ b/src/personalization/feedback/gating.py @@ -0,0 +1,72 @@ +import numpy as np +from personalization.feedback.schemas import TurnSample + +def cosine_sim_batch(matrix: np.ndarray, vector: np.ndarray) -> np.ndarray: + # matrix: [N, d], vector: [d] + # return: [N] + norm_m = np.linalg.norm(matrix, axis=1) + norm_v = np.linalg.norm(vector) + + # Avoid div by zero + den = norm_m * norm_v + den[den == 0] = 1e-9 + + return np.dot(matrix, vector) / den + +def estimate_retrieval_gating(sample: TurnSample, reward_hat: float) -> float: + """ + Return g_t in [0,1], representing how much the reward is due to retrieval. + """ + e_q = sample.query_embedding_t + e_q1 = sample.query_embedding_t1 + + if e_q is None or e_q1 is None or not sample.memories: + return 0.5 # Neutral + + # We need embeddings of the memories. + # In a real pipeline, we might pass them in sample.memory_embeddings. + # If missing, we can't compute sim. + if sample.memory_embeddings is None: + # Try to use embedding_e from memory cards if available + # But MemoryCard.embedding_e is List[float] + try: + mem_embs = np.array([m.embedding_e for m in sample.memories]) + if mem_embs.shape[1] == 0: # Empty embeddings + return 0.5 + except: + return 0.5 + else: + mem_embs = sample.memory_embeddings + + # Compute similarities + # shape: [K] + sims_q = cosine_sim_batch(mem_embs, e_q) + sims_q1 = cosine_sim_batch(mem_embs, e_q1) + + s_q_max = sims_q.max() if len(sims_q) > 0 else 0 + s_q1_max = sims_q1.max() if len(sims_q1) > 0 else 0 + + g = 0.5 + + # Heuristics + + # Case A: Retrieval clearly irrelevant + bad reward + # q_t / q_{t+1} have low similarity to memories -> likely retrieval failure (or no relevant memories) + if reward_hat < -0.5 and s_q_max < 0.2 and s_q1_max < 0.2: + g = 0.9 # Blame retrieval (for failing to find anything, or nothing exists) + + # Case B: Retrieval looks good but reward is bad + # Memories are relevant to query, but user still unhappy -> LLM didn't use them well? + elif reward_hat < -0.5 and s_q_max > 0.5: + g = 0.2 # Likely LLM fault + + # Case C: Good reward + # If reward is high, we assume both did okay. + elif reward_hat > 0.5: + if s_q_max > 0.4: + g = 0.6 # Retrieval helped + else: + g = 0.3 # LLM handled it without strong retrieval help + + return float(g) + diff --git a/src/personalization/feedback/handlers.py b/src/personalization/feedback/handlers.py new file mode 100644 index 0000000..f0468b6 --- /dev/null +++ b/src/personalization/feedback/handlers.py @@ -0,0 +1,87 @@ +from typing import Tuple, List, Optional +import numpy as np + +from personalization.retrieval.preference_store.schemas import MemoryCard +from personalization.feedback.schemas import TurnSample +from personalization.feedback.reward_model import estimate_reward +from personalization.feedback.gating import estimate_retrieval_gating +from personalization.feedback.llm_reward import ( + LLMRewardClient, LLMRewardConfig, RewardResult +) + + +def eval_step( + q_t: str, + answer_t: str, + q_t1: str, + memories_t: List[MemoryCard], + query_embedding_t: Optional[np.ndarray] = None, + query_embedding_t1: Optional[np.ndarray] = None, +) -> Tuple[float, float]: + """ + Keyword-based evaluation (legacy). + Given (q_t, a_t, q_{t+1}, memories), returns (reward_hat, gating_hat). + """ + mem_embs = None + if memories_t and memories_t[0].embedding_e: + try: + mem_embs = np.array([m.embedding_e for m in memories_t]) + except: + pass + + sample = TurnSample( + user_id="", + session_id="", + turn_id=0, + query_t=q_t, + answer_t=answer_t, + query_t1=q_t1, + memories=memories_t, + query_embedding_t=query_embedding_t, + query_embedding_t1=query_embedding_t1, + memory_embeddings=mem_embs, + ) + + r_hat = estimate_reward(sample) + g_hat = estimate_retrieval_gating(sample, r_hat) + + return r_hat, g_hat + + +async def eval_step_llm( + q_t: str, + answer_t: str, + q_t1: str, + memories_t: List[MemoryCard], + client: LLMRewardClient, + query_embedding_t: Optional[np.ndarray] = None, + query_embedding_t1: Optional[np.ndarray] = None, +) -> Tuple[float, float]: + """ + LLM-as-judge evaluation (async). + Returns (reward, gating) where gating=0.0 if update should be skipped. + + The gating signal is derived from the judge's confidence and label: + - If confidence < tau_c or label == topic_shift: gating = 0.0 + - Otherwise: gating = confidence (continuous, in [tau_c, 1.0]) + + This replaces the old heuristic gating with the judge's own confidence. + """ + sample = TurnSample( + user_id="", + session_id="", + turn_id=0, + query_t=q_t, + answer_t=answer_t, + query_t1=q_t1, + memories=memories_t, + query_embedding_t=query_embedding_t, + query_embedding_t1=query_embedding_t1, + ) + + result: RewardResult = await client.judge(sample) + + if result.should_update: + return result.reward, result.confidence + else: + return 0.0, 0.0 diff --git a/src/personalization/feedback/llm_reward.py b/src/personalization/feedback/llm_reward.py new file mode 100644 index 0000000..6adcf98 --- /dev/null +++ b/src/personalization/feedback/llm_reward.py @@ -0,0 +1,253 @@ +""" +LLM-as-Judge reward model using OpenAI GPT-5-nano (async for parallelism). + +Replaces keyword-based heuristic reward with structured LLM judgement. +Judge receives only (q_t, a_t, q_{t+1}) — no oracle preference cards, no history. + +Label taxonomy → scalar reward mapping: + neg_constraint_restate → -1.0 + neg_correction → -0.8 + neg_confusion → -0.6 + pos_praise → +0.8 + pos_progress → +0.1 + neutral → 0.0 + topic_shift → 0.0 (update skipped) + +Confidence gating: if confidence < tau_c, reward is set to 0 and update is skipped. +""" +from __future__ import annotations + +import asyncio +import hashlib +import json +import os +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple + +from openai import AsyncOpenAI, RateLimitError, APITimeoutError, APIConnectionError + +from personalization.feedback.schemas import TurnSample + + +# --- Label → Reward Mapping --- + +REWARD_MAP: Dict[str, float] = { + "neg_constraint_restate": -1.0, + "neg_correction": -0.8, + "neg_confusion": -0.6, + "pos_praise": +0.8, + "pos_progress": +0.1, + "neutral": 0.0, + "topic_shift": 0.0, +} + +VALID_LABELS = set(REWARD_MAP.keys()) + + +# --- Configuration --- + +@dataclass +class LLMRewardConfig: + model: str = "gpt-5-nano" + api_key: Optional[str] = None # Falls back to OPENAI_API_KEY env var + base_url: Optional[str] = None # For custom endpoints + max_concurrent: int = 32 # Semaphore limit for parallel requests + max_retries: int = 3 + retry_base_delay: float = 1.0 # Exponential backoff base (seconds) + timeout: float = 60.0 # Per-request timeout (reasoning models are slower) + max_completion_tokens: int = 2048 # Must be high — reasoning models use internal tokens + confidence_threshold: float = 0.6 # tau_c: skip update if confidence < this + enable_cache: bool = True # Cache by hash of (q_t, a_t, q_{t+1}) + + +# --- Prompt --- + +JUDGE_SYSTEM_PROMPT = """\ +You are a feedback classifier. Given a user query (q_t), the assistant's response (a_t), \ +and the user's next message (q_{t+1}), classify the user's follow-up into exactly one label. + +Labels (mutually exclusive): +- neg_constraint_restate: User reasserts constraints/preferences as correction (e.g., "as I said…", "remember…", "按我说的…"). +- neg_correction: User indicates the content is wrong or the assistant failed to answer. +- neg_confusion: User indicates confusion or requests re-explanation. +- pos_praise: Explicit praise or satisfaction with the response. +- pos_progress: Constructive continuation (examples, extensions, what-if, next steps) without complaint. +- neutral: Ambiguous or minimal feedback, not clearly positive or negative. +- topic_shift: User switches to a new unrelated task/topic. + +Output a JSON object with fields: label, confidence (0-1), rationale (one short sentence).""" + +JUDGE_USER_TEMPLATE = """\ +q_t: {query_t} + +a_t: {answer_t} + +q_{{t+1}}: {query_t1}""" + + +# --- Result Dataclass --- + +@dataclass +class RewardResult: + label: str + confidence: float + rationale: str + reward: float + should_update: bool # False if gated by confidence or topic_shift + + +# --- Async Client --- + +class LLMRewardClient: + """Async OpenAI client for LLM-as-judge reward estimation.""" + + def __init__(self, config: Optional[LLMRewardConfig] = None): + self.config = config or LLMRewardConfig() + self._client = AsyncOpenAI( + api_key=self.config.api_key or os.getenv("OPENAI_API_KEY"), + base_url=self.config.base_url, + timeout=self.config.timeout, + ) + self._semaphore = asyncio.Semaphore(self.config.max_concurrent) + self._cache: Dict[str, RewardResult] = {} + + def _cache_key(self, query_t: str, answer_t: str, query_t1: str) -> str: + """Deterministic hash of the judge input triple.""" + content = f"{query_t}\x00{answer_t}\x00{query_t1}" + return hashlib.sha256(content.encode("utf-8")).hexdigest() + + async def _call_with_retry(self, messages: List[dict]) -> str: + """Single LLM call with exponential backoff retry.""" + for attempt in range(self.config.max_retries): + try: + async with self._semaphore: + response = await self._client.chat.completions.create( + model=self.config.model, + messages=messages, + max_completion_tokens=self.config.max_completion_tokens, + response_format={"type": "json_object"}, + ) + content = response.choices[0].message.content + if content: + return content.strip() + # Reasoning model may exhaust tokens on thinking — retry + if response.choices[0].finish_reason == "length": + continue + return "" + except (RateLimitError, APITimeoutError, APIConnectionError) as e: + if attempt == self.config.max_retries - 1: + raise + delay = self.config.retry_base_delay * (2 ** attempt) + await asyncio.sleep(delay) + return "" + + def _build_messages(self, sample: TurnSample) -> List[dict]: + """Construct the judge prompt from (q_t, a_t, q_{t+1}) only.""" + user_content = JUDGE_USER_TEMPLATE.format( + query_t=sample.query_t, + answer_t=sample.answer_t, + query_t1=sample.query_t1, + ) + return [ + {"role": "system", "content": JUDGE_SYSTEM_PROMPT}, + {"role": "user", "content": user_content}, + ] + + def _parse_result(self, raw: str) -> RewardResult: + """Parse structured JSON output into RewardResult.""" + try: + parsed = json.loads(raw) + label = parsed["label"] + confidence = float(parsed["confidence"]) + rationale = parsed.get("rationale", "") + + if label not in VALID_LABELS: + label = "neutral" + confidence = 0.0 + + reward = REWARD_MAP[label] + + # Confidence gating and topic_shift skip + should_update = ( + confidence >= self.config.confidence_threshold + and label != "topic_shift" + ) + if not should_update: + reward = 0.0 + + return RewardResult( + label=label, + confidence=confidence, + rationale=rationale, + reward=reward, + should_update=should_update, + ) + except (json.JSONDecodeError, KeyError, TypeError, ValueError): + return RewardResult( + label="neutral", + confidence=0.0, + rationale="parse_failure", + reward=0.0, + should_update=False, + ) + + async def judge(self, sample: TurnSample) -> RewardResult: + """Judge a single turn (async). Returns RewardResult with gating applied.""" + # Cache lookup + if self.config.enable_cache: + key = self._cache_key(sample.query_t, sample.answer_t, sample.query_t1) + if key in self._cache: + return self._cache[key] + + messages = self._build_messages(sample) + raw = await self._call_with_retry(messages) + result = self._parse_result(raw) + + # Cache store + if self.config.enable_cache: + self._cache[key] = result + + return result + + async def judge_batch(self, samples: List[TurnSample]) -> List[RewardResult]: + """Judge a batch of turns in parallel. Returns list of RewardResult.""" + tasks = [self.judge(s) for s in samples] + return await asyncio.gather(*tasks) + + async def close(self): + """Close the underlying HTTP client.""" + await self._client.close() + + +# --- Synchronous Wrappers --- + +def estimate_reward_llm( + sample: TurnSample, + config: Optional[LLMRewardConfig] = None, +) -> Tuple[float, bool]: + """ + Synchronous single-sample reward estimation. + Returns (reward, should_update). + """ + client = LLMRewardClient(config) + try: + result = asyncio.run(client.judge(sample)) + return result.reward, result.should_update + finally: + asyncio.run(client.close()) + + +def estimate_rewards_batch( + samples: List[TurnSample], + config: Optional[LLMRewardConfig] = None, +) -> List[Tuple[float, bool]]: + """ + Synchronous batch reward estimation (runs async internally). + Returns list of (reward, should_update) tuples. + """ + client = LLMRewardClient(config) + try: + results = asyncio.run(client.judge_batch(samples)) + return [(r.reward, r.should_update) for r in results] + finally: + asyncio.run(client.close()) diff --git a/src/personalization/feedback/local_llm_reward.py b/src/personalization/feedback/local_llm_reward.py new file mode 100644 index 0000000..70bbeb8 --- /dev/null +++ b/src/personalization/feedback/local_llm_reward.py @@ -0,0 +1,370 @@ +""" +Local LLM reward model using vLLM server for batch inference. + +Drop-in replacement for LLMRewardClient when you want to use a local model +(e.g., Llama-3.1-8B-Instruct) instead of OpenAI API. + +Uses BatchVLLMClient for efficient concurrent requests - vLLM's continuous +batching will process them together for high throughput. +""" +from __future__ import annotations + +import asyncio +import hashlib +import json +from dataclasses import dataclass +from typing import Dict, List, Optional + +import aiohttp + +from personalization.feedback.schemas import TurnSample +from personalization.feedback.llm_reward import ( + REWARD_MAP, + VALID_LABELS, + JUDGE_SYSTEM_PROMPT, + JUDGE_USER_TEMPLATE, + RewardResult, +) + + +@dataclass +class LocalLLMRewardConfig: + """Configuration for local LLM reward model.""" + vllm_url: str = "http://localhost:8005/v1" # vLLM server URL + model_name: Optional[str] = None # Auto-discovered if None + max_tokens: int = 256 + temperature: float = 0.1 + max_concurrent: int = 100 # High concurrency for vLLM batching + timeout: Optional[int] = 60 # Per-request timeout in seconds + confidence_threshold: float = 0.6 # tau_c: skip update if confidence < this + enable_cache: bool = True # Cache by hash of (q_t, a_t, q_{t+1}) + + +class LocalLLMRewardClient: + """ + Local LLM reward client using vLLM server. + + Designed for batch processing - uses async HTTP requests that vLLM + batches together via continuous batching for high throughput. + """ + + def __init__(self, config: Optional[LocalLLMRewardConfig] = None): + self.config = config or LocalLLMRewardConfig() + self._model_name = self.config.model_name + self._cache: Dict[str, RewardResult] = {} + + # Discover model name if not provided + if self._model_name is None: + self._discover_model_sync() + + def _discover_model_sync(self): + """Synchronously discover model name from vLLM server.""" + import requests + try: + response = requests.get( + f"{self.config.vllm_url}/models", + timeout=10 + ) + response.raise_for_status() + models = response.json() + if models.get("data") and len(models["data"]) > 0: + self._model_name = models["data"][0]["id"] + else: + self._model_name = "default" + except Exception as e: + print(f"[LocalLLMReward] Warning: Could not discover model ({e})") + self._model_name = "default" + + def _cache_key(self, query_t: str, answer_t: str, query_t1: str) -> str: + """Deterministic hash of the judge input triple.""" + content = f"{query_t}\x00{answer_t}\x00{query_t1}" + return hashlib.sha256(content.encode("utf-8")).hexdigest() + + def _build_messages(self, sample: TurnSample) -> List[dict]: + """Construct the judge prompt from (q_t, a_t, q_{t+1}).""" + user_content = JUDGE_USER_TEMPLATE.format( + query_t=sample.query_t, + answer_t=sample.answer_t, + query_t1=sample.query_t1, + ) + return [ + {"role": "system", "content": JUDGE_SYSTEM_PROMPT}, + {"role": "user", "content": user_content}, + ] + + def _parse_result(self, raw: Optional[str]) -> RewardResult: + """Parse structured JSON output into RewardResult.""" + if raw is None: + return RewardResult( + label="neutral", + confidence=0.0, + rationale="request_failed", + reward=0.0, + should_update=False, + ) + + try: + # Handle markdown code blocks + text = raw.strip() + if text.startswith("```"): + lines = text.split("\n") + text = "\n".join( + lines[1:-1] if lines[-1].strip() == "```" else lines[1:] + ) + + parsed = json.loads(text) + label = parsed.get("label", "neutral") + confidence = float(parsed.get("confidence", 0.0)) + rationale = parsed.get("rationale", "") + + if label not in VALID_LABELS: + label = "neutral" + confidence = 0.0 + + reward = REWARD_MAP[label] + + # Confidence gating and topic_shift skip + should_update = ( + confidence >= self.config.confidence_threshold + and label != "topic_shift" + ) + if not should_update: + reward = 0.0 + + return RewardResult( + label=label, + confidence=confidence, + rationale=rationale, + reward=reward, + should_update=should_update, + ) + except (json.JSONDecodeError, KeyError, TypeError, ValueError): + # Try to extract JSON from text + import re + match = re.search(r'\{[^}]+\}', raw, re.DOTALL) + if match: + try: + parsed = json.loads(match.group()) + label = parsed.get("label", "neutral") + confidence = float(parsed.get("confidence", 0.0)) + rationale = parsed.get("rationale", "") + + if label not in VALID_LABELS: + label = "neutral" + confidence = 0.0 + + reward = REWARD_MAP[label] + should_update = ( + confidence >= self.config.confidence_threshold + and label != "topic_shift" + ) + if not should_update: + reward = 0.0 + + return RewardResult( + label=label, + confidence=confidence, + rationale=rationale, + reward=reward, + should_update=should_update, + ) + except: + pass + + return RewardResult( + label="neutral", + confidence=0.0, + rationale="parse_failure", + reward=0.0, + should_update=False, + ) + + async def _single_request( + self, + session: aiohttp.ClientSession, + messages: List[dict], + idx: int, + ) -> tuple: + """Make a single async request to vLLM server.""" + payload = { + "model": self._model_name, + "messages": messages, + "max_tokens": self.config.max_tokens, + "temperature": self.config.temperature, + "response_format": {"type": "json_object"}, + } + + for attempt in range(3): + try: + timeout_config = ( + aiohttp.ClientTimeout(total=self.config.timeout) + if self.config.timeout else None + ) + async with session.post( + f"{self.config.vllm_url}/chat/completions", + json=payload, + timeout=timeout_config, + ) as response: + if response.status == 200: + result = await response.json() + content = result["choices"][0]["message"]["content"] + return (idx, content, None) + elif response.status == 429: + # Rate limit - wait and retry + await asyncio.sleep(2 ** attempt) + continue + else: + error_text = await response.text() + return (idx, None, f"HTTP {response.status}: {error_text[:200]}") + except asyncio.TimeoutError: + if attempt < 2: + continue + return (idx, None, "Timeout") + except Exception as e: + return (idx, None, str(e)) + + return (idx, None, "Max retries exceeded") + + async def judge_batch_async( + self, + samples: List[TurnSample], + show_progress: bool = False, + ) -> List[RewardResult]: + """ + Judge a batch of turns using concurrent vLLM requests. + + vLLM's continuous batching will process these together for + high throughput. + """ + n_samples = len(samples) + results = [None] * n_samples + + # Check cache and build request list + to_request = [] # (original_idx, messages) + for i, sample in enumerate(samples): + if self.config.enable_cache: + key = self._cache_key(sample.query_t, sample.answer_t, sample.query_t1) + if key in self._cache: + results[i] = self._cache[key] + continue + + messages = self._build_messages(sample) + to_request.append((i, messages)) + + if not to_request: + return results + + # Make concurrent requests + semaphore = asyncio.Semaphore(self.config.max_concurrent) + + async def limited_request(session, messages, idx): + async with semaphore: + return await self._single_request(session, messages, idx) + + connector = aiohttp.TCPConnector(limit=self.config.max_concurrent) + headers = {"Content-Type": "application/json"} + + async with aiohttp.ClientSession( + connector=connector, headers=headers + ) as session: + tasks = [ + limited_request(session, messages, orig_idx) + for orig_idx, messages in to_request + ] + + completed = 0 + for coro in asyncio.as_completed(tasks): + orig_idx, content, error = await coro + completed += 1 + + if error: + print(f"[LocalLLMReward] Request {orig_idx} failed: {error}") + + result = self._parse_result(content) + results[orig_idx] = result + + # Cache the result + if self.config.enable_cache: + sample = samples[orig_idx] + key = self._cache_key( + sample.query_t, sample.answer_t, sample.query_t1 + ) + self._cache[key] = result + + if show_progress and completed % 10 == 0: + print(f" [LocalLLMReward {completed}/{len(to_request)}] completed") + + return results + + async def judge_async(self, sample: TurnSample) -> RewardResult: + """Judge a single turn (async).""" + results = await self.judge_batch_async([sample]) + return results[0] + + def judge_batch(self, samples: List[TurnSample]) -> List[RewardResult]: + """ + Judge a batch of turns (sync wrapper). + + This is the main entry point for batch reward estimation. + """ + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop is not None: + # Already in an event loop - create a new thread to run the coroutine + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, self.judge_batch_async(samples)) + return future.result() + else: + return asyncio.run(self.judge_batch_async(samples)) + + async def judge(self, sample: TurnSample) -> RewardResult: + """Judge a single turn (async interface for compatibility with LLMRewardClient).""" + return await self.judge_async(sample) + + def judge_sync(self, sample: TurnSample) -> RewardResult: + """Judge a single turn (sync wrapper).""" + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop is not None: + # Already in an event loop - create a new thread to run the coroutine + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, self.judge_async(sample)) + return future.result() + else: + return asyncio.run(self.judge_async(sample)) + + +# --- Convenience Functions --- + +def estimate_reward_local( + sample: TurnSample, + config: Optional[LocalLLMRewardConfig] = None, +) -> tuple: + """ + Synchronous single-sample reward estimation using local LLM. + Returns (reward, should_update). + """ + client = LocalLLMRewardClient(config) + result = client.judge(sample) + return result.reward, result.should_update + + +def estimate_rewards_batch_local( + samples: List[TurnSample], + config: Optional[LocalLLMRewardConfig] = None, +) -> List[tuple]: + """ + Synchronous batch reward estimation using local LLM. + Returns list of (reward, should_update) tuples. + """ + client = LocalLLMRewardClient(config) + results = client.judge_batch(samples) + return [(r.reward, r.should_update) for r in results] diff --git a/src/personalization/feedback/online_update.py b/src/personalization/feedback/online_update.py new file mode 100644 index 0000000..e69de29 diff --git a/src/personalization/feedback/reward_model.py b/src/personalization/feedback/reward_model.py new file mode 100644 index 0000000..3584b43 --- /dev/null +++ b/src/personalization/feedback/reward_model.py @@ -0,0 +1,64 @@ +import numpy as np +from personalization.feedback.schemas import TurnSample + +def cosine_sim(a: np.ndarray, b: np.ndarray) -> float: + norm_a = np.linalg.norm(a) + norm_b = np.linalg.norm(b) + if norm_a == 0 or norm_b == 0: + return 0.0 + return float(np.dot(a, b) / (norm_a * norm_b)) + +def estimate_reward(sample: TurnSample) -> float: + """ + Return a scalar reward_hat, indicating if the previous answer was helpful. + Range: [-1.0, 1.0] (approx) + """ + + # 1. Language/Topic Coherence + if sample.query_embedding_t is None or sample.query_embedding_t1 is None: + topic_sim = 0.5 + else: + topic_sim = cosine_sim(sample.query_embedding_t, sample.query_embedding_t1) + + # 2. Negative Keywords (Complaint/Correction) + negative_keywords = [ + "you didn't", "that's not", "incorrect", "redo", "again", "explain more", + "doesn't help", "wrong", "no", "not what i asked", + "你没", "不是", "这不是", "重来", "重新", "不对", "错了", "没说清楚" + ] + + # 3. Positive Keywords (Follow-up/Elaboration) + positive_keywords = [ + "can you elaborate", "give an example", "continue", "what if", "based on that", + "thanks", "good", "great", "cool", + "能不能详细一点", "举个例子", "再继续", "那如果", "接下来", "在这个基础上", "谢谢", "不错" + ] + + q1_lower = sample.query_t1.lower() + + has_negative = any(kw in q1_lower for kw in negative_keywords) + has_positive = any(kw in q1_lower for kw in positive_keywords) + + reward = 0.0 + + if has_negative: + reward -= 1.0 + + if has_positive: + # Only reward if topic similarity is decent, otherwise might be "thanks, bye" (end of session) + # But "thanks" is good. + reward += 0.5 + if topic_sim > 0.3: + reward += 0.5 + + if topic_sim < 0.2: + # Topic shift -> previous interaction likely finished or failed. + # If no explicit positive/negative, assume neutral/slightly decayed. + # If user changes topic, it often means the previous task is done (neutral/positive) + # OR they gave up (negative). Hard to tell. + # Let's dampen the reward towards 0. + reward *= 0.5 + + # Clip + return max(-1.0, min(1.0, reward)) + diff --git a/src/personalization/feedback/sampler.py b/src/personalization/feedback/sampler.py new file mode 100644 index 0000000..9e26912 --- /dev/null +++ b/src/personalization/feedback/sampler.py @@ -0,0 +1,109 @@ +from typing import Iterable, List, Optional +import numpy as np +from tqdm import tqdm + +from personalization.retrieval.preference_store.schemas import ChatTurn, MemoryCard +from personalization.feedback.schemas import TurnSample +from personalization.retrieval.pipeline import retrieve_with_rerank +from personalization.models.llm.qwen_instruct import QwenInstruct +from personalization.models.embedding.base import EmbeddingModel +from personalization.models.reranker.base import Reranker +from personalization.user_model.tensor_store import UserTensorStore + +def build_turn_samples_from_sessions( + sessions: Iterable[List[ChatTurn]], + embed_model: EmbeddingModel, + llm: QwenInstruct, + reranker: Reranker, + memory_cards: List[MemoryCard], + memory_embeddings: np.ndarray, + user_store: UserTensorStore, + item_vectors: np.ndarray, + max_samples: Optional[int] = None, + topk_dense: int = 64, + topk_rerank: int = 3, +) -> List[TurnSample]: + samples = [] + + for turns in tqdm(sessions, desc="Building TurnSamples"): + if max_samples and len(samples) >= max_samples: + break + + # Ensure sorted by turn_id + sorted_turns = sorted(turns, key=lambda x: x.turn_id) + + # Iterate to find (q_t, a_t, q_{t+1}) + for i in range(len(sorted_turns)): + if max_samples and len(samples) >= max_samples: + break + + q_t = sorted_turns[i] + if q_t.role != "user": + continue + + # Find next user turn + # Also try to find assistant response in between + a_t_text = "" + q_t1 = None + + # Look ahead + for j in range(i + 1, len(sorted_turns)): + next_turn = sorted_turns[j] + if next_turn.role == "assistant" and not a_t_text: + a_t_text = next_turn.text + elif next_turn.role == "user": + q_t1 = next_turn + break + + if not q_t1: + # End of session or no subsequent user query + continue + + # We have q_t, a_t (optional but preferred), q_t1 + # If a_t is missing, we might skip or use empty string. + # For RL, we usually need the answer to evaluate quality. + # If dataset doesn't have assistant turns, we might need to generate one? + # For now, let's proceed even if a_t is empty, or maybe require it. + if not a_t_text: + # Try to use LLM to generate if needed, but for offline sampling + # from existing chats, we prefer existing answers. + # If using OASST1, it should have assistant turns. + pass + + # 3. Retrieve memories for q_t + memories_t = retrieve_with_rerank( + user_id=q_t.user_id, + query=q_t.text, + embed_model=embed_model, + reranker=reranker, + memory_cards=memory_cards, + memory_embeddings=memory_embeddings, + user_store=user_store, + item_vectors=item_vectors, + topk_dense=topk_dense, + topk_rerank=topk_rerank, + beta_long=0.0, + beta_short=0.0, + only_own_memories=True # Assume we want user specific memories + ) + + # 4. Precompute embeddings + # We can do this efficiently later or batch, but here per sample + e_q_t = embed_model.encode([q_t.text], return_tensor=False)[0] + e_q_t1 = embed_model.encode([q_t1.text], return_tensor=False)[0] + + sample = TurnSample( + user_id=q_t.user_id, + session_id=q_t.session_id, + turn_id=q_t.turn_id, + query_t=q_t.text, + answer_t=a_t_text, + query_t1=q_t1.text, + memories=memories_t, + query_embedding_t=np.array(e_q_t), + query_embedding_t1=np.array(e_q_t1) + ) + samples.append(sample) + + return samples + diff --git a/src/personalization/feedback/schemas.py b/src/personalization/feedback/schemas.py new file mode 100644 index 0000000..b15db80 --- /dev/null +++ b/src/personalization/feedback/schemas.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Optional, Any +import numpy as np + +from personalization.retrieval.preference_store.schemas import MemoryCard + +@dataclass +class TurnSample: + user_id: str + session_id: str + turn_id: int # index of q_t within the session + query_t: str # q_t + answer_t: str # a_t + query_t1: str # q_{t+1} + memories: List[MemoryCard] # A_t + + # Optional pre-computed vectors and features + query_embedding_t: Optional[np.ndarray] = None + query_embedding_t1: Optional[np.ndarray] = None + memory_embeddings: Optional[np.ndarray] = None # corresponding e_m or v_m for memories + diff --git a/src/personalization/models/__init__.py b/src/personalization/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/personalization/models/embedding/__init__.py b/src/personalization/models/embedding/__init__.py new file mode 100644 index 0000000..05221aa --- /dev/null +++ b/src/personalization/models/embedding/__init__.py @@ -0,0 +1,11 @@ +from .base import EmbeddingModel +from .qwen3_8b import Qwen3Embedding8B +from .nemotron_8b import LlamaEmbedNemotron8B + +__all__ = [ + "EmbeddingModel", + "Qwen3Embedding8B", + "LlamaEmbedNemotron8B", +] + + diff --git a/src/personalization/models/embedding/base.py b/src/personalization/models/embedding/base.py new file mode 100644 index 0000000..9f9d4d1 --- /dev/null +++ b/src/personalization/models/embedding/base.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Iterable, List, Sequence + +import torch + + +class EmbeddingModel(ABC): + @abstractmethod + def encode( + self, + texts: Sequence[str], + batch_size: int = 8, + max_length: int = 512, + normalize: bool = True, + return_tensor: bool = False, + ) -> List[List[float]] | torch.Tensor: + """Encode a batch of texts into dense embeddings.""" + raise NotImplementedError + + +def _mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + # last_hidden_state: [batch, seq_len, hidden] + # attention_mask: [batch, seq_len] + mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state) # [b, s, 1] + summed = (last_hidden_state * mask).sum(dim=1) + counts = mask.sum(dim=1).clamp_min(1e-6) + return summed / counts + + +def _maybe_normalize(x: torch.Tensor, normalize: bool) -> torch.Tensor: + if not normalize: + return x + return torch.nn.functional.normalize(x, p=2, dim=-1) + + diff --git a/src/personalization/models/embedding/qwen3_8b.py b/src/personalization/models/embedding/qwen3_8b.py new file mode 100644 index 0000000..fb02e67 --- /dev/null +++ b/src/personalization/models/embedding/qwen3_8b.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from typing import List, Sequence + +import torch +from transformers import AutoModel, AutoTokenizer + +from personalization.config.registry import choose_dtype, choose_device_map +from personalization.config.settings import LocalModelsConfig +from .base import EmbeddingModel, _mean_pool, _maybe_normalize + + +class Qwen3Embedding8B(EmbeddingModel): + def __init__( + self, + model_path: str, + dtype: torch.dtype, + device_map: str = "auto", + trust_remote_code: bool = True, + ) -> None: + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=True, trust_remote_code=trust_remote_code + ) + + # Handle specific device assignment (e.g., "cuda:0", "cuda:1") + if device_map and device_map.startswith("cuda:"): + # Load to CPU first, then move to specific GPU + self.model = AutoModel.from_pretrained( + model_path, + torch_dtype=dtype, + device_map=None, # Don't use accelerate's device_map + trust_remote_code=trust_remote_code, + low_cpu_mem_usage=True, + ) + self.model = self.model.to(device_map) + else: + # Use accelerate's auto device mapping + self.model = AutoModel.from_pretrained( + model_path, + torch_dtype=dtype, + device_map=device_map, + trust_remote_code=trust_remote_code, + low_cpu_mem_usage=True, + ) + + @classmethod + def from_config(cls, cfg: LocalModelsConfig) -> "Qwen3Embedding8B": + if not cfg.embedding or not cfg.embedding.qwen3: + raise ValueError("Embedding config for qwen3 is missing") + spec = cfg.embedding.qwen3 + dtype = choose_dtype(spec.dtype) + device_map = choose_device_map(spec.device_map) + return cls( + spec.local_path, + dtype=dtype, + device_map=device_map, + trust_remote_code=True, + ) + + @torch.inference_mode() + def encode( + self, + texts: Sequence[str], + batch_size: int = 8, + max_length: int = 512, + normalize: bool = True, + return_tensor: bool = False, + ) -> List[List[float]] | torch.Tensor: + device = next(self.model.parameters()).device + outputs: List[torch.Tensor] = [] + for i in range(0, len(texts), batch_size): + batch = list(texts[i : i + batch_size]) + enc = self.tokenizer( + batch, + padding=True, + truncation=True, + max_length=max_length, + return_tensors="pt", + ).to(device) + model_out = self.model(**enc, output_hidden_states=False, return_dict=True) + pooled = _mean_pool(model_out.last_hidden_state, enc["attention_mask"]) # type: ignore[attr-defined] + pooled = _maybe_normalize(pooled, normalize) + outputs.append(pooled) + emb = torch.cat(outputs, dim=0) + if return_tensor: + return emb + return emb.cpu().to(torch.float32).tolist() + + diff --git a/src/personalization/models/llm/__init__.py b/src/personalization/models/llm/__init__.py new file mode 100644 index 0000000..3f1af81 --- /dev/null +++ b/src/personalization/models/llm/__init__.py @@ -0,0 +1,4 @@ +from .qwen_instruct import QwenInstruct + +__all__ = ["QwenInstruct"] + diff --git a/src/personalization/models/llm/base.py b/src/personalization/models/llm/base.py new file mode 100644 index 0000000..72b6ca8 --- /dev/null +++ b/src/personalization/models/llm/base.py @@ -0,0 +1,29 @@ +from typing import List, Protocol, Optional +from personalization.types import ChatTurn + +class ChatModel(Protocol): + def answer( + self, + history: List[ChatTurn], + memory_notes: List[str], + max_new_tokens: int = 512, + temperature: float = 0.7, + top_p: float = 0.9, + top_k: Optional[int] = None, + ) -> str: + """ + Generate an assistant response given conversation history and memory notes. + + Args: + history: The conversation history ending with the current user turn. + memory_notes: List of retrieved memory content strings. + max_new_tokens: Max tokens to generate. + temperature: Sampling temperature. + top_p: Top-p sampling. + top_k: Top-k sampling. + + Returns: + The generated assistant response text. + """ + ... + diff --git a/src/personalization/models/llm/prompt_builder.py b/src/personalization/models/llm/prompt_builder.py new file mode 100644 index 0000000..e69de29 diff --git a/src/personalization/models/llm/vllm_chat.py b/src/personalization/models/llm/vllm_chat.py new file mode 100644 index 0000000..d577a30 --- /dev/null +++ b/src/personalization/models/llm/vllm_chat.py @@ -0,0 +1,244 @@ +""" +vLLM-based ChatModel implementation for high-throughput inference. + +This provides the same interface as LlamaChatModel but uses vLLM HTTP API +for much faster inference (3000+ sessions/hr vs 20 sessions/hr). +""" + +from typing import List, Optional +import time +import requests + +from personalization.models.llm.base import ChatModel +from personalization.types import ChatTurn + + +class VLLMChatModel(ChatModel): + """ + ChatModel implementation using vLLM HTTP API. + + This is a drop-in replacement for LlamaChatModel that uses vLLM + for much faster inference. + """ + + def __init__( + self, + vllm_url: str = "http://localhost:8003/v1", + model_name: str = None, + max_context_length: int = 8192, + timeout: int = 120, + ): + self.vllm_url = vllm_url.rstrip('/') + self.model_name = model_name + self.max_context_length = max_context_length + self.timeout = timeout + + # Discover model name if not provided + if self.model_name is None: + self._discover_model() + + def _discover_model(self): + """Discover the model name from the vLLM server.""" + max_retries = 30 + for attempt in range(max_retries): + try: + response = requests.get(f"{self.vllm_url}/models", timeout=10) + response.raise_for_status() + models = response.json() + if models.get("data") and len(models["data"]) > 0: + self.model_name = models["data"][0]["id"] + return + except Exception as e: + if attempt < max_retries - 1: + wait_time = min(2 ** attempt * 0.5, 10) + time.sleep(wait_time) + + # Fallback + self.model_name = "default" + print(f"[VLLMChatModel] Warning: Could not discover model, using '{self.model_name}'") + + def health_check(self) -> bool: + """Check if the vLLM server is healthy.""" + try: + response = requests.get(f"{self.vllm_url.replace('/v1', '')}/health", timeout=5) + return response.status_code == 200 + except: + return False + + def _estimate_tokens(self, text: str) -> int: + """Estimate token count using character-based heuristic. + + For Llama models, ~4 characters per token is a reasonable estimate. + We use 3.5 to be conservative (slightly overestimate tokens). + """ + return int(len(text) / 3.5) + + def _build_messages( + self, + history: List[ChatTurn], + memory_notes: List[str], + max_new_tokens: int = 512, + global_notes: List[str] = None, + ) -> List[dict]: + """Build messages list for chat completion API with auto-truncation. + + If the context exceeds max_context_length, older conversation turns + are removed to keep only the most recent context that fits. + + Args: + global_notes: If provided, these are always-applicable preferences + displayed in a separate section from task-specific retrieved notes. + """ + # Use CollaborativeAgents-style system prompt + has_any_notes = memory_notes or global_notes + if has_any_notes: + # Build preference sections + pref_sections = "" + if global_notes: + global_bullet = "\n".join(f"- {n}" for n in global_notes) + pref_sections += f"## General Preferences (always apply)\n{global_bullet}\n\n" + if memory_notes: + task_bullet = "\n".join(f"- {n}" for n in memory_notes) + if global_notes: + pref_sections += f"## Task-Specific Preferences\n{task_bullet}\n" + else: + pref_sections += f"{task_bullet}\n" + + system_content = ( + "You are a collaborative AI agent helping users solve writing, question answering, math, and coding problems.\n\n" + "# User Preferences\n" + "The user has a set of preferences for how you should behave. If you do not follow these preferences, " + "the user will be unable to learn from your response and you will need to adjust your response to adhere " + "to these preferences (so it is best to follow them initially).\n\n" + "**IMPORTANT**: If the user explicitly requests something in THIS conversation (e.g., asks you to change " + "your format, style, or approach), that request takes PRIORITY over the remembered preferences below. " + "Always adapt to the user's direct feedback first.\n\n" + "Based on your past interactions with the user, you have maintained a set of notes about the user's preferences:\n" + f"{pref_sections}\n" + "# Before Responding\n" + "Before writing your response, briefly consider:\n" + "1. Which preferences above are relevant to this specific request?\n" + "2. How will you satisfy each relevant preference in your response?\n\n" + "# Conversation Guidelines:\n" + "- If the user asks you to adjust your response (e.g., 'be more concise', 'focus on intuition'), you MUST change your approach accordingly. Do NOT repeat the same response.\n" + "- If the user's message is unclear, lacks details, or is ambiguous (e.g. length of an essay, format requirements, " + "specific constraints), do not make assumptions. Ask for clarification and ensure you have enough information before providing an answer.\n" + "- Your goal is to help the user solve their problem. Adhere to their preferences and do your best to help them solve their problem.\n" + "- **Verify**: Before finalizing, check that your response satisfies the relevant preferences listed above.\n" + ) + else: + # Vanilla mode - no preferences + system_content = ( + "You are a collaborative AI agent helping users solve writing, question answering, math, and coding problems.\n\n" + "# Conversation Guidelines:\n" + "- If the user's message is unclear, lacks details, or is ambiguous (e.g. length of an essay, format requirements, " + "specific constraints), do not make assumptions. Ask for clarification and ensure you have enough information before providing an answer.\n" + "- Your goal is to help the user solve their problem. Do your best to help them.\n" + ) + system_message = {"role": "system", "content": system_content} + + # Calculate available tokens for conversation history + # Reserve space for: system prompt + max_new_tokens + safety margin + system_tokens = self._estimate_tokens(system_content) + available_tokens = self.max_context_length - system_tokens - max_new_tokens - 100 # 100 token safety margin + + # Build conversation messages from history + conversation_messages = [] + for turn in history: + conversation_messages.append({"role": turn.role, "content": turn.text}) + + # Check if truncation is needed + total_conv_tokens = sum(self._estimate_tokens(m["content"]) for m in conversation_messages) + + if total_conv_tokens > available_tokens: + # Truncate from the beginning (keep recent messages) + truncated_messages = [] + current_tokens = 0 + + # Iterate from most recent to oldest + for msg in reversed(conversation_messages): + msg_tokens = self._estimate_tokens(msg["content"]) + if current_tokens + msg_tokens <= available_tokens: + truncated_messages.insert(0, msg) + current_tokens += msg_tokens + else: + # Stop adding older messages + break + + conversation_messages = truncated_messages + if len(truncated_messages) < len(history): + print(f"[VLLMChatModel] Truncated context: kept {len(truncated_messages)}/{len(history)} turns " + f"({current_tokens}/{total_conv_tokens} estimated tokens)") + + messages = [system_message] + conversation_messages + return messages + + def build_messages( + self, + history: List[ChatTurn], + memory_notes: List[str], + max_new_tokens: int = 512, + global_notes: List[str] = None, + ) -> List[dict]: + """Public method to build messages without calling the API. + + Used for batch processing where messages are collected first, + then sent in batch to vLLM for concurrent processing. + """ + return self._build_messages(history, memory_notes, max_new_tokens, global_notes=global_notes) + + def answer( + self, + history: List[ChatTurn], + memory_notes: List[str], + max_new_tokens: int = 512, + temperature: float = 0.7, + top_p: float = 0.9, + top_k: Optional[int] = None, + ) -> str: + """Generate a response using vLLM HTTP API.""" + messages = self._build_messages(history, memory_notes, max_new_tokens) + + payload = { + "model": self.model_name, + "messages": messages, + "max_tokens": max_new_tokens, + "temperature": temperature, + "top_p": top_p, + } + + # Retry with exponential backoff + max_retries = 5 + for attempt in range(max_retries): + try: + response = requests.post( + f"{self.vllm_url}/chat/completions", + json=payload, + timeout=self.timeout + ) + + if response.status_code == 200: + result = response.json() + return result["choices"][0]["message"]["content"] + elif response.status_code == 400: + error_text = response.text + # Handle context length error + if "max_tokens" in error_text and max_new_tokens > 64: + payload["max_tokens"] = max(64, max_new_tokens // 2) + continue + raise RuntimeError(f"vLLM error: {error_text[:200]}") + else: + raise RuntimeError(f"vLLM HTTP {response.status_code}: {response.text[:200]}") + + except requests.exceptions.Timeout: + if attempt < max_retries - 1: + time.sleep(2 ** attempt) + continue + raise RuntimeError("vLLM request timeout") + except requests.exceptions.ConnectionError as e: + if attempt < max_retries - 1: + time.sleep(2 ** attempt) + continue + raise RuntimeError(f"vLLM connection error: {e}") + + raise RuntimeError("Max retries exceeded for vLLM request") diff --git a/src/personalization/models/preference_extractor/__init__.py b/src/personalization/models/preference_extractor/__init__.py new file mode 100644 index 0000000..65e2595 --- /dev/null +++ b/src/personalization/models/preference_extractor/__init__.py @@ -0,0 +1,5 @@ +from .rule_extractor import QwenRuleExtractor +from .gpt4o_extractor import GPT4OExtractor +from .base import PreferenceExtractor + +__all__ = ["QwenRuleExtractor", "GPT4OExtractor", "PreferenceExtractor"] diff --git a/src/personalization/models/preference_extractor/base.py b/src/personalization/models/preference_extractor/base.py new file mode 100644 index 0000000..850292f --- /dev/null +++ b/src/personalization/models/preference_extractor/base.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict, List +from personalization.retrieval.preference_store.schemas import ChatTurn, PreferenceList + +class PreferenceExtractorBase(ABC): + @abstractmethod + def extract_turn(self, turns: List[ChatTurn]) -> PreferenceList: + """ + Extract preferences from a window of chat turns (history + current query). + """ + raise NotImplementedError + +# Alias for backward compatibility if needed, +# though specific extractors should inherit from PreferenceExtractorBase now. +PreferenceExtractor = PreferenceExtractorBase diff --git a/src/personalization/models/preference_extractor/gpt4o_extractor.py b/src/personalization/models/preference_extractor/gpt4o_extractor.py new file mode 100644 index 0000000..0f70522 --- /dev/null +++ b/src/personalization/models/preference_extractor/gpt4o_extractor.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +import json +import os +from typing import Any, Dict, List + +from openai import OpenAI +from personalization.config.settings import LocalModelsConfig +from personalization.models.preference_extractor.base import PreferenceExtractorBase as PreferenceExtractor +from personalization.retrieval.preference_store.schemas import ( + ChatTurn, + PreferenceList, + preference_list_json_schema, +) + + +class GPT4OExtractor(PreferenceExtractor): + def __init__(self, api_key: str, model: str = "gpt-4o") -> None: + self.client = OpenAI(api_key=api_key) + self.model = model + + # Load system prompt template + template_path = "fine_tuning_prompt_template.txt" + if os.path.exists(template_path): + with open(template_path, "r", encoding="utf-8") as f: + self.system_prompt = f.read() + else: + # Structured prompt that enforces the PreferenceList schema + self.system_prompt = ( + "You are a preference extraction assistant. " + "Given a user message, extract any user preferences as condition-action rules.\n\n" + "Return a JSON object with exactly this structure:\n" + '{"preferences": [{"condition": "", "action": "", "confidence": <0.0-1.0>}]}\n\n' + "Examples of preferences:\n" + '- {"condition": "general", "action": "respond in Chinese", "confidence": 0.9}\n' + '- {"condition": "when writing code", "action": "use Python with type hints", "confidence": 0.8}\n' + '- {"condition": "when explaining math", "action": "show step-by-step derivation", "confidence": 0.7}\n\n' + "If no preferences are found, return {\"preferences\": []}.\n" + "IMPORTANT: The output MUST be a JSON object with a \"preferences\" key containing a list." + ) + + @classmethod + def from_config(cls, cfg: LocalModelsConfig) -> "GPT4OExtractor": + # We rely on env var for API key, config for other potential settings if needed + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + raise ValueError("OPENAI_API_KEY environment variable not set") + return cls(api_key=api_key) + + def build_preference_prompt(self, query: str) -> str: + # GPT4OExtractor uses the system prompt loaded in __init__ + return self.system_prompt + + def _call_kwargs(self, messages): + """Build kwargs for chat completion, skipping temperature for models that don't support it.""" + kwargs = { + "model": self.model, + "messages": messages, + "response_format": {"type": "json_object"}, + } + # GPT-5 series doesn't support temperature=0 + if not self.model.startswith("gpt-5"): + kwargs["temperature"] = 0.0 + return kwargs + + def extract_preferences(self, query: str) -> Dict[str, Any]: + # Reuse logic but return raw dict + try: + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": query}, + ] + response = self.client.chat.completions.create(**self._call_kwargs(messages)) + content = response.choices[0].message.content + if content: + return json.loads(content) + except Exception as e: + print(f"Error calling GPT-4o: {e}") + return {"preferences": []} + + def extract_turn(self, turns) -> PreferenceList: + # Accept both a single ChatTurn and a list of ChatTurns (history) + if isinstance(turns, list): + # Find the last user message in history + last_user_msg = None + for t in reversed(turns): + if hasattr(t, 'role') and t.role == "user": + last_user_msg = t.text + break + if not last_user_msg: + return PreferenceList(preferences=[]) + else: + # Single ChatTurn + if turns.role != "user": + return PreferenceList(preferences=[]) + last_user_msg = turns.text + + try: + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": last_user_msg}, + ] + response = self.client.chat.completions.create(**self._call_kwargs(messages)) + + content = response.choices[0].message.content + if not content: + return PreferenceList(preferences=[]) + + data = json.loads(content) + return self._parse_to_preference_list(data) + + except Exception as e: + print(f"Error calling GPT-4o: {e}") + return PreferenceList(preferences=[]) + + @staticmethod + def _parse_to_preference_list(data: dict) -> PreferenceList: + """Robustly convert GPT output to PreferenceList, handling non-standard formats.""" + # Best case: already matches schema + if "preferences" in data and isinstance(data["preferences"], list): + prefs = [] + for item in data["preferences"]: + if isinstance(item, dict) and "condition" in item and "action" in item: + prefs.append({ + "condition": str(item["condition"])[:128], + "action": str(item["action"])[:256], + "confidence": float(item.get("confidence", 0.7)), + }) + return PreferenceList.model_validate({"preferences": prefs}) + + # GPT returned a flat dict of preferences - convert to condition/action pairs + prefs = [] + for key, value in data.items(): + if isinstance(value, str) and len(value) > 2: + prefs.append({ + "condition": str(key)[:128] if len(str(key)) > 1 else "general", + "action": str(value)[:256], + "confidence": 0.7, + }) + elif isinstance(value, dict): + # Nested dict: try to extract meaningful pairs + for sub_key, sub_val in value.items(): + if isinstance(sub_val, str) and len(sub_val) > 2: + prefs.append({ + "condition": str(sub_key)[:128], + "action": str(sub_val)[:256], + "confidence": 0.7, + }) + elif isinstance(value, list): + for item in value: + if isinstance(item, str) and len(item) > 2: + prefs.append({ + "condition": str(key)[:128], + "action": str(item)[:256], + "confidence": 0.7, + }) + + return PreferenceList.model_validate({"preferences": prefs[:20]}) + + def extract_session(self, turns: List[ChatTurn]) -> List[PreferenceList]: + results = [] + for turn in turns: + results.append(self.extract_turn(turn)) + return results + diff --git a/src/personalization/models/preference_extractor/llm_extractor.py b/src/personalization/models/preference_extractor/llm_extractor.py new file mode 100644 index 0000000..8f7a6cb --- /dev/null +++ b/src/personalization/models/preference_extractor/llm_extractor.py @@ -0,0 +1,153 @@ +from typing import List, Dict, Any +import torch +import json +import os +from transformers import AutoModelForCausalLM, AutoTokenizer + +from personalization.models.preference_extractor.base import PreferenceExtractorBase +from personalization.retrieval.preference_store.schemas import ChatTurn, PreferenceList +from personalization.config.settings import LocalModelsConfig +from personalization.config.registry import choose_dtype, choose_device_map + +class PreferenceExtractorLLM(PreferenceExtractorBase): + def __init__( + self, + model_path: str, + prompt_template_path: str = "fine_tuning_prompt_template.txt", + device_map: str = "auto", + dtype: torch.dtype = torch.bfloat16, + max_new_tokens: int = 512, + ) -> None: + self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=dtype, + device_map=device_map, + trust_remote_code=True, + ) + self.max_new_tokens = max_new_tokens + + if os.path.exists(prompt_template_path): + with open(prompt_template_path, "r", encoding="utf-8") as f: + self.prompt_template = f.read() + else: + print(f"Warning: Prompt template not found at {prompt_template_path}. Using fallback.") + self.prompt_template = "Extract user preferences from the following conversation." + + @classmethod + def from_config(cls, cfg: LocalModelsConfig, name: str = "qwen3_0_6b_sft") -> "PreferenceExtractorLLM": + # We need to access the specific extractor config by name + # Assuming cfg has a way to access extra configs or we update LocalModelsConfig to support multiple extractors + # For now, let's look for it in the 'preference_extractor' dict if it was a Dict, but it is a ModelSpec. + # We need to update LocalModelsConfig to support a dictionary of extractors or a specific one. + # Based on user design doc: + # preference_extractor: + # qwen3_0_6b_sft: ... + + # We might need to manually parse the raw config or update settings.py + # Let's assume settings.py will be updated to hold a map or specific fields. + # For now, if we use the existing ModelSpec for preference_extractor in cfg, we assume it points to this model. + + # BUT the design doc says "preference_extractor" in local_models.yaml will have "qwen3_0_6b_sft" key. + # The current settings.py defines preference_extractor as a single ModelSpec. + # We will need to update settings.py first to support multiple extractors or a dict. + # I will proceed implementing this class assuming arguments are passed, and update settings/registry later. + + # This from_config might change depending on how settings.py is refactored. + # For now I will implement it assuming a direct ModelSpec is passed, or we handle it in registry. + pass + return None + + def _build_prompt(self, turns: List[ChatTurn]) -> str: + # Construct messages list for chat template + messages = [{"role": "system", "content": self.prompt_template}] + + # Window size 6 + window = turns[-6:] + + # Add conversation history + # We need to format the conversation as input context. + # Since the task is to extract preferences from the *whole* context (or latest turn?), + # usually we provide the conversation and ask for extraction. + # But LLaMA-Factory SFT usually expects: + # System: