diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-03-18 18:25:09 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-03-18 18:25:09 -0500 |
| commit | b6c3e4e51eeab703b40284459c6e9fff2151216c (patch) | |
| tree | 221410886f23214575f93b9ef44fa8431c9a6dfc /src/personalization/feedback | |
Initial release: VARS - personalized LLM with RAG and user vector learning
Diffstat (limited to 'src/personalization/feedback')
| -rw-r--r-- | src/personalization/feedback/__init__.py | 0 | ||||
| -rw-r--r-- | src/personalization/feedback/gating.py | 72 | ||||
| -rw-r--r-- | src/personalization/feedback/handlers.py | 87 | ||||
| -rw-r--r-- | src/personalization/feedback/llm_reward.py | 253 | ||||
| -rw-r--r-- | src/personalization/feedback/local_llm_reward.py | 370 | ||||
| -rw-r--r-- | src/personalization/feedback/online_update.py | 0 | ||||
| -rw-r--r-- | src/personalization/feedback/reward_model.py | 64 | ||||
| -rw-r--r-- | src/personalization/feedback/sampler.py | 109 | ||||
| -rw-r--r-- | src/personalization/feedback/schemas.py | 23 |
9 files changed, 978 insertions, 0 deletions
diff --git a/src/personalization/feedback/__init__.py b/src/personalization/feedback/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/feedback/__init__.py diff --git a/src/personalization/feedback/gating.py b/src/personalization/feedback/gating.py new file mode 100644 index 0000000..d741874 --- /dev/null +++ b/src/personalization/feedback/gating.py @@ -0,0 +1,72 @@ +import numpy as np +from personalization.feedback.schemas import TurnSample + +def cosine_sim_batch(matrix: np.ndarray, vector: np.ndarray) -> np.ndarray: + # matrix: [N, d], vector: [d] + # return: [N] + norm_m = np.linalg.norm(matrix, axis=1) + norm_v = np.linalg.norm(vector) + + # Avoid div by zero + den = norm_m * norm_v + den[den == 0] = 1e-9 + + return np.dot(matrix, vector) / den + +def estimate_retrieval_gating(sample: TurnSample, reward_hat: float) -> float: + """ + Return g_t in [0,1], representing how much the reward is due to retrieval. + """ + e_q = sample.query_embedding_t + e_q1 = sample.query_embedding_t1 + + if e_q is None or e_q1 is None or not sample.memories: + return 0.5 # Neutral + + # We need embeddings of the memories. + # In a real pipeline, we might pass them in sample.memory_embeddings. + # If missing, we can't compute sim. + if sample.memory_embeddings is None: + # Try to use embedding_e from memory cards if available + # But MemoryCard.embedding_e is List[float] + try: + mem_embs = np.array([m.embedding_e for m in sample.memories]) + if mem_embs.shape[1] == 0: # Empty embeddings + return 0.5 + except: + return 0.5 + else: + mem_embs = sample.memory_embeddings + + # Compute similarities + # shape: [K] + sims_q = cosine_sim_batch(mem_embs, e_q) + sims_q1 = cosine_sim_batch(mem_embs, e_q1) + + s_q_max = sims_q.max() if len(sims_q) > 0 else 0 + s_q1_max = sims_q1.max() if len(sims_q1) > 0 else 0 + + g = 0.5 + + # Heuristics + + # Case A: Retrieval clearly irrelevant + bad reward + # q_t / q_{t+1} have low similarity to memories -> likely retrieval failure (or no relevant memories) + if reward_hat < -0.5 and s_q_max < 0.2 and s_q1_max < 0.2: + g = 0.9 # Blame retrieval (for failing to find anything, or nothing exists) + + # Case B: Retrieval looks good but reward is bad + # Memories are relevant to query, but user still unhappy -> LLM didn't use them well? + elif reward_hat < -0.5 and s_q_max > 0.5: + g = 0.2 # Likely LLM fault + + # Case C: Good reward + # If reward is high, we assume both did okay. + elif reward_hat > 0.5: + if s_q_max > 0.4: + g = 0.6 # Retrieval helped + else: + g = 0.3 # LLM handled it without strong retrieval help + + return float(g) + diff --git a/src/personalization/feedback/handlers.py b/src/personalization/feedback/handlers.py new file mode 100644 index 0000000..f0468b6 --- /dev/null +++ b/src/personalization/feedback/handlers.py @@ -0,0 +1,87 @@ +from typing import Tuple, List, Optional +import numpy as np + +from personalization.retrieval.preference_store.schemas import MemoryCard +from personalization.feedback.schemas import TurnSample +from personalization.feedback.reward_model import estimate_reward +from personalization.feedback.gating import estimate_retrieval_gating +from personalization.feedback.llm_reward import ( + LLMRewardClient, LLMRewardConfig, RewardResult +) + + +def eval_step( + q_t: str, + answer_t: str, + q_t1: str, + memories_t: List[MemoryCard], + query_embedding_t: Optional[np.ndarray] = None, + query_embedding_t1: Optional[np.ndarray] = None, +) -> Tuple[float, float]: + """ + Keyword-based evaluation (legacy). + Given (q_t, a_t, q_{t+1}, memories), returns (reward_hat, gating_hat). + """ + mem_embs = None + if memories_t and memories_t[0].embedding_e: + try: + mem_embs = np.array([m.embedding_e for m in memories_t]) + except: + pass + + sample = TurnSample( + user_id="", + session_id="", + turn_id=0, + query_t=q_t, + answer_t=answer_t, + query_t1=q_t1, + memories=memories_t, + query_embedding_t=query_embedding_t, + query_embedding_t1=query_embedding_t1, + memory_embeddings=mem_embs, + ) + + r_hat = estimate_reward(sample) + g_hat = estimate_retrieval_gating(sample, r_hat) + + return r_hat, g_hat + + +async def eval_step_llm( + q_t: str, + answer_t: str, + q_t1: str, + memories_t: List[MemoryCard], + client: LLMRewardClient, + query_embedding_t: Optional[np.ndarray] = None, + query_embedding_t1: Optional[np.ndarray] = None, +) -> Tuple[float, float]: + """ + LLM-as-judge evaluation (async). + Returns (reward, gating) where gating=0.0 if update should be skipped. + + The gating signal is derived from the judge's confidence and label: + - If confidence < tau_c or label == topic_shift: gating = 0.0 + - Otherwise: gating = confidence (continuous, in [tau_c, 1.0]) + + This replaces the old heuristic gating with the judge's own confidence. + """ + sample = TurnSample( + user_id="", + session_id="", + turn_id=0, + query_t=q_t, + answer_t=answer_t, + query_t1=q_t1, + memories=memories_t, + query_embedding_t=query_embedding_t, + query_embedding_t1=query_embedding_t1, + ) + + result: RewardResult = await client.judge(sample) + + if result.should_update: + return result.reward, result.confidence + else: + return 0.0, 0.0 diff --git a/src/personalization/feedback/llm_reward.py b/src/personalization/feedback/llm_reward.py new file mode 100644 index 0000000..6adcf98 --- /dev/null +++ b/src/personalization/feedback/llm_reward.py @@ -0,0 +1,253 @@ +""" +LLM-as-Judge reward model using OpenAI GPT-5-nano (async for parallelism). + +Replaces keyword-based heuristic reward with structured LLM judgement. +Judge receives only (q_t, a_t, q_{t+1}) — no oracle preference cards, no history. + +Label taxonomy → scalar reward mapping: + neg_constraint_restate → -1.0 + neg_correction → -0.8 + neg_confusion → -0.6 + pos_praise → +0.8 + pos_progress → +0.1 + neutral → 0.0 + topic_shift → 0.0 (update skipped) + +Confidence gating: if confidence < tau_c, reward is set to 0 and update is skipped. +""" +from __future__ import annotations + +import asyncio +import hashlib +import json +import os +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple + +from openai import AsyncOpenAI, RateLimitError, APITimeoutError, APIConnectionError + +from personalization.feedback.schemas import TurnSample + + +# --- Label → Reward Mapping --- + +REWARD_MAP: Dict[str, float] = { + "neg_constraint_restate": -1.0, + "neg_correction": -0.8, + "neg_confusion": -0.6, + "pos_praise": +0.8, + "pos_progress": +0.1, + "neutral": 0.0, + "topic_shift": 0.0, +} + +VALID_LABELS = set(REWARD_MAP.keys()) + + +# --- Configuration --- + +@dataclass +class LLMRewardConfig: + model: str = "gpt-5-nano" + api_key: Optional[str] = None # Falls back to OPENAI_API_KEY env var + base_url: Optional[str] = None # For custom endpoints + max_concurrent: int = 32 # Semaphore limit for parallel requests + max_retries: int = 3 + retry_base_delay: float = 1.0 # Exponential backoff base (seconds) + timeout: float = 60.0 # Per-request timeout (reasoning models are slower) + max_completion_tokens: int = 2048 # Must be high — reasoning models use internal tokens + confidence_threshold: float = 0.6 # tau_c: skip update if confidence < this + enable_cache: bool = True # Cache by hash of (q_t, a_t, q_{t+1}) + + +# --- Prompt --- + +JUDGE_SYSTEM_PROMPT = """\ +You are a feedback classifier. Given a user query (q_t), the assistant's response (a_t), \ +and the user's next message (q_{t+1}), classify the user's follow-up into exactly one label. + +Labels (mutually exclusive): +- neg_constraint_restate: User reasserts constraints/preferences as correction (e.g., "as I said…", "remember…", "按我说的…"). +- neg_correction: User indicates the content is wrong or the assistant failed to answer. +- neg_confusion: User indicates confusion or requests re-explanation. +- pos_praise: Explicit praise or satisfaction with the response. +- pos_progress: Constructive continuation (examples, extensions, what-if, next steps) without complaint. +- neutral: Ambiguous or minimal feedback, not clearly positive or negative. +- topic_shift: User switches to a new unrelated task/topic. + +Output a JSON object with fields: label, confidence (0-1), rationale (one short sentence).""" + +JUDGE_USER_TEMPLATE = """\ +q_t: {query_t} + +a_t: {answer_t} + +q_{{t+1}}: {query_t1}""" + + +# --- Result Dataclass --- + +@dataclass +class RewardResult: + label: str + confidence: float + rationale: str + reward: float + should_update: bool # False if gated by confidence or topic_shift + + +# --- Async Client --- + +class LLMRewardClient: + """Async OpenAI client for LLM-as-judge reward estimation.""" + + def __init__(self, config: Optional[LLMRewardConfig] = None): + self.config = config or LLMRewardConfig() + self._client = AsyncOpenAI( + api_key=self.config.api_key or os.getenv("OPENAI_API_KEY"), + base_url=self.config.base_url, + timeout=self.config.timeout, + ) + self._semaphore = asyncio.Semaphore(self.config.max_concurrent) + self._cache: Dict[str, RewardResult] = {} + + def _cache_key(self, query_t: str, answer_t: str, query_t1: str) -> str: + """Deterministic hash of the judge input triple.""" + content = f"{query_t}\x00{answer_t}\x00{query_t1}" + return hashlib.sha256(content.encode("utf-8")).hexdigest() + + async def _call_with_retry(self, messages: List[dict]) -> str: + """Single LLM call with exponential backoff retry.""" + for attempt in range(self.config.max_retries): + try: + async with self._semaphore: + response = await self._client.chat.completions.create( + model=self.config.model, + messages=messages, + max_completion_tokens=self.config.max_completion_tokens, + response_format={"type": "json_object"}, + ) + content = response.choices[0].message.content + if content: + return content.strip() + # Reasoning model may exhaust tokens on thinking — retry + if response.choices[0].finish_reason == "length": + continue + return "" + except (RateLimitError, APITimeoutError, APIConnectionError) as e: + if attempt == self.config.max_retries - 1: + raise + delay = self.config.retry_base_delay * (2 ** attempt) + await asyncio.sleep(delay) + return "" + + def _build_messages(self, sample: TurnSample) -> List[dict]: + """Construct the judge prompt from (q_t, a_t, q_{t+1}) only.""" + user_content = JUDGE_USER_TEMPLATE.format( + query_t=sample.query_t, + answer_t=sample.answer_t, + query_t1=sample.query_t1, + ) + return [ + {"role": "system", "content": JUDGE_SYSTEM_PROMPT}, + {"role": "user", "content": user_content}, + ] + + def _parse_result(self, raw: str) -> RewardResult: + """Parse structured JSON output into RewardResult.""" + try: + parsed = json.loads(raw) + label = parsed["label"] + confidence = float(parsed["confidence"]) + rationale = parsed.get("rationale", "") + + if label not in VALID_LABELS: + label = "neutral" + confidence = 0.0 + + reward = REWARD_MAP[label] + + # Confidence gating and topic_shift skip + should_update = ( + confidence >= self.config.confidence_threshold + and label != "topic_shift" + ) + if not should_update: + reward = 0.0 + + return RewardResult( + label=label, + confidence=confidence, + rationale=rationale, + reward=reward, + should_update=should_update, + ) + except (json.JSONDecodeError, KeyError, TypeError, ValueError): + return RewardResult( + label="neutral", + confidence=0.0, + rationale="parse_failure", + reward=0.0, + should_update=False, + ) + + async def judge(self, sample: TurnSample) -> RewardResult: + """Judge a single turn (async). Returns RewardResult with gating applied.""" + # Cache lookup + if self.config.enable_cache: + key = self._cache_key(sample.query_t, sample.answer_t, sample.query_t1) + if key in self._cache: + return self._cache[key] + + messages = self._build_messages(sample) + raw = await self._call_with_retry(messages) + result = self._parse_result(raw) + + # Cache store + if self.config.enable_cache: + self._cache[key] = result + + return result + + async def judge_batch(self, samples: List[TurnSample]) -> List[RewardResult]: + """Judge a batch of turns in parallel. Returns list of RewardResult.""" + tasks = [self.judge(s) for s in samples] + return await asyncio.gather(*tasks) + + async def close(self): + """Close the underlying HTTP client.""" + await self._client.close() + + +# --- Synchronous Wrappers --- + +def estimate_reward_llm( + sample: TurnSample, + config: Optional[LLMRewardConfig] = None, +) -> Tuple[float, bool]: + """ + Synchronous single-sample reward estimation. + Returns (reward, should_update). + """ + client = LLMRewardClient(config) + try: + result = asyncio.run(client.judge(sample)) + return result.reward, result.should_update + finally: + asyncio.run(client.close()) + + +def estimate_rewards_batch( + samples: List[TurnSample], + config: Optional[LLMRewardConfig] = None, +) -> List[Tuple[float, bool]]: + """ + Synchronous batch reward estimation (runs async internally). + Returns list of (reward, should_update) tuples. + """ + client = LLMRewardClient(config) + try: + results = asyncio.run(client.judge_batch(samples)) + return [(r.reward, r.should_update) for r in results] + finally: + asyncio.run(client.close()) diff --git a/src/personalization/feedback/local_llm_reward.py b/src/personalization/feedback/local_llm_reward.py new file mode 100644 index 0000000..70bbeb8 --- /dev/null +++ b/src/personalization/feedback/local_llm_reward.py @@ -0,0 +1,370 @@ +""" +Local LLM reward model using vLLM server for batch inference. + +Drop-in replacement for LLMRewardClient when you want to use a local model +(e.g., Llama-3.1-8B-Instruct) instead of OpenAI API. + +Uses BatchVLLMClient for efficient concurrent requests - vLLM's continuous +batching will process them together for high throughput. +""" +from __future__ import annotations + +import asyncio +import hashlib +import json +from dataclasses import dataclass +from typing import Dict, List, Optional + +import aiohttp + +from personalization.feedback.schemas import TurnSample +from personalization.feedback.llm_reward import ( + REWARD_MAP, + VALID_LABELS, + JUDGE_SYSTEM_PROMPT, + JUDGE_USER_TEMPLATE, + RewardResult, +) + + +@dataclass +class LocalLLMRewardConfig: + """Configuration for local LLM reward model.""" + vllm_url: str = "http://localhost:8005/v1" # vLLM server URL + model_name: Optional[str] = None # Auto-discovered if None + max_tokens: int = 256 + temperature: float = 0.1 + max_concurrent: int = 100 # High concurrency for vLLM batching + timeout: Optional[int] = 60 # Per-request timeout in seconds + confidence_threshold: float = 0.6 # tau_c: skip update if confidence < this + enable_cache: bool = True # Cache by hash of (q_t, a_t, q_{t+1}) + + +class LocalLLMRewardClient: + """ + Local LLM reward client using vLLM server. + + Designed for batch processing - uses async HTTP requests that vLLM + batches together via continuous batching for high throughput. + """ + + def __init__(self, config: Optional[LocalLLMRewardConfig] = None): + self.config = config or LocalLLMRewardConfig() + self._model_name = self.config.model_name + self._cache: Dict[str, RewardResult] = {} + + # Discover model name if not provided + if self._model_name is None: + self._discover_model_sync() + + def _discover_model_sync(self): + """Synchronously discover model name from vLLM server.""" + import requests + try: + response = requests.get( + f"{self.config.vllm_url}/models", + timeout=10 + ) + response.raise_for_status() + models = response.json() + if models.get("data") and len(models["data"]) > 0: + self._model_name = models["data"][0]["id"] + else: + self._model_name = "default" + except Exception as e: + print(f"[LocalLLMReward] Warning: Could not discover model ({e})") + self._model_name = "default" + + def _cache_key(self, query_t: str, answer_t: str, query_t1: str) -> str: + """Deterministic hash of the judge input triple.""" + content = f"{query_t}\x00{answer_t}\x00{query_t1}" + return hashlib.sha256(content.encode("utf-8")).hexdigest() + + def _build_messages(self, sample: TurnSample) -> List[dict]: + """Construct the judge prompt from (q_t, a_t, q_{t+1}).""" + user_content = JUDGE_USER_TEMPLATE.format( + query_t=sample.query_t, + answer_t=sample.answer_t, + query_t1=sample.query_t1, + ) + return [ + {"role": "system", "content": JUDGE_SYSTEM_PROMPT}, + {"role": "user", "content": user_content}, + ] + + def _parse_result(self, raw: Optional[str]) -> RewardResult: + """Parse structured JSON output into RewardResult.""" + if raw is None: + return RewardResult( + label="neutral", + confidence=0.0, + rationale="request_failed", + reward=0.0, + should_update=False, + ) + + try: + # Handle markdown code blocks + text = raw.strip() + if text.startswith("```"): + lines = text.split("\n") + text = "\n".join( + lines[1:-1] if lines[-1].strip() == "```" else lines[1:] + ) + + parsed = json.loads(text) + label = parsed.get("label", "neutral") + confidence = float(parsed.get("confidence", 0.0)) + rationale = parsed.get("rationale", "") + + if label not in VALID_LABELS: + label = "neutral" + confidence = 0.0 + + reward = REWARD_MAP[label] + + # Confidence gating and topic_shift skip + should_update = ( + confidence >= self.config.confidence_threshold + and label != "topic_shift" + ) + if not should_update: + reward = 0.0 + + return RewardResult( + label=label, + confidence=confidence, + rationale=rationale, + reward=reward, + should_update=should_update, + ) + except (json.JSONDecodeError, KeyError, TypeError, ValueError): + # Try to extract JSON from text + import re + match = re.search(r'\{[^}]+\}', raw, re.DOTALL) + if match: + try: + parsed = json.loads(match.group()) + label = parsed.get("label", "neutral") + confidence = float(parsed.get("confidence", 0.0)) + rationale = parsed.get("rationale", "") + + if label not in VALID_LABELS: + label = "neutral" + confidence = 0.0 + + reward = REWARD_MAP[label] + should_update = ( + confidence >= self.config.confidence_threshold + and label != "topic_shift" + ) + if not should_update: + reward = 0.0 + + return RewardResult( + label=label, + confidence=confidence, + rationale=rationale, + reward=reward, + should_update=should_update, + ) + except: + pass + + return RewardResult( + label="neutral", + confidence=0.0, + rationale="parse_failure", + reward=0.0, + should_update=False, + ) + + async def _single_request( + self, + session: aiohttp.ClientSession, + messages: List[dict], + idx: int, + ) -> tuple: + """Make a single async request to vLLM server.""" + payload = { + "model": self._model_name, + "messages": messages, + "max_tokens": self.config.max_tokens, + "temperature": self.config.temperature, + "response_format": {"type": "json_object"}, + } + + for attempt in range(3): + try: + timeout_config = ( + aiohttp.ClientTimeout(total=self.config.timeout) + if self.config.timeout else None + ) + async with session.post( + f"{self.config.vllm_url}/chat/completions", + json=payload, + timeout=timeout_config, + ) as response: + if response.status == 200: + result = await response.json() + content = result["choices"][0]["message"]["content"] + return (idx, content, None) + elif response.status == 429: + # Rate limit - wait and retry + await asyncio.sleep(2 ** attempt) + continue + else: + error_text = await response.text() + return (idx, None, f"HTTP {response.status}: {error_text[:200]}") + except asyncio.TimeoutError: + if attempt < 2: + continue + return (idx, None, "Timeout") + except Exception as e: + return (idx, None, str(e)) + + return (idx, None, "Max retries exceeded") + + async def judge_batch_async( + self, + samples: List[TurnSample], + show_progress: bool = False, + ) -> List[RewardResult]: + """ + Judge a batch of turns using concurrent vLLM requests. + + vLLM's continuous batching will process these together for + high throughput. + """ + n_samples = len(samples) + results = [None] * n_samples + + # Check cache and build request list + to_request = [] # (original_idx, messages) + for i, sample in enumerate(samples): + if self.config.enable_cache: + key = self._cache_key(sample.query_t, sample.answer_t, sample.query_t1) + if key in self._cache: + results[i] = self._cache[key] + continue + + messages = self._build_messages(sample) + to_request.append((i, messages)) + + if not to_request: + return results + + # Make concurrent requests + semaphore = asyncio.Semaphore(self.config.max_concurrent) + + async def limited_request(session, messages, idx): + async with semaphore: + return await self._single_request(session, messages, idx) + + connector = aiohttp.TCPConnector(limit=self.config.max_concurrent) + headers = {"Content-Type": "application/json"} + + async with aiohttp.ClientSession( + connector=connector, headers=headers + ) as session: + tasks = [ + limited_request(session, messages, orig_idx) + for orig_idx, messages in to_request + ] + + completed = 0 + for coro in asyncio.as_completed(tasks): + orig_idx, content, error = await coro + completed += 1 + + if error: + print(f"[LocalLLMReward] Request {orig_idx} failed: {error}") + + result = self._parse_result(content) + results[orig_idx] = result + + # Cache the result + if self.config.enable_cache: + sample = samples[orig_idx] + key = self._cache_key( + sample.query_t, sample.answer_t, sample.query_t1 + ) + self._cache[key] = result + + if show_progress and completed % 10 == 0: + print(f" [LocalLLMReward {completed}/{len(to_request)}] completed") + + return results + + async def judge_async(self, sample: TurnSample) -> RewardResult: + """Judge a single turn (async).""" + results = await self.judge_batch_async([sample]) + return results[0] + + def judge_batch(self, samples: List[TurnSample]) -> List[RewardResult]: + """ + Judge a batch of turns (sync wrapper). + + This is the main entry point for batch reward estimation. + """ + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop is not None: + # Already in an event loop - create a new thread to run the coroutine + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, self.judge_batch_async(samples)) + return future.result() + else: + return asyncio.run(self.judge_batch_async(samples)) + + async def judge(self, sample: TurnSample) -> RewardResult: + """Judge a single turn (async interface for compatibility with LLMRewardClient).""" + return await self.judge_async(sample) + + def judge_sync(self, sample: TurnSample) -> RewardResult: + """Judge a single turn (sync wrapper).""" + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop is not None: + # Already in an event loop - create a new thread to run the coroutine + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, self.judge_async(sample)) + return future.result() + else: + return asyncio.run(self.judge_async(sample)) + + +# --- Convenience Functions --- + +def estimate_reward_local( + sample: TurnSample, + config: Optional[LocalLLMRewardConfig] = None, +) -> tuple: + """ + Synchronous single-sample reward estimation using local LLM. + Returns (reward, should_update). + """ + client = LocalLLMRewardClient(config) + result = client.judge(sample) + return result.reward, result.should_update + + +def estimate_rewards_batch_local( + samples: List[TurnSample], + config: Optional[LocalLLMRewardConfig] = None, +) -> List[tuple]: + """ + Synchronous batch reward estimation using local LLM. + Returns list of (reward, should_update) tuples. + """ + client = LocalLLMRewardClient(config) + results = client.judge_batch(samples) + return [(r.reward, r.should_update) for r in results] diff --git a/src/personalization/feedback/online_update.py b/src/personalization/feedback/online_update.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/feedback/online_update.py diff --git a/src/personalization/feedback/reward_model.py b/src/personalization/feedback/reward_model.py new file mode 100644 index 0000000..3584b43 --- /dev/null +++ b/src/personalization/feedback/reward_model.py @@ -0,0 +1,64 @@ +import numpy as np +from personalization.feedback.schemas import TurnSample + +def cosine_sim(a: np.ndarray, b: np.ndarray) -> float: + norm_a = np.linalg.norm(a) + norm_b = np.linalg.norm(b) + if norm_a == 0 or norm_b == 0: + return 0.0 + return float(np.dot(a, b) / (norm_a * norm_b)) + +def estimate_reward(sample: TurnSample) -> float: + """ + Return a scalar reward_hat, indicating if the previous answer was helpful. + Range: [-1.0, 1.0] (approx) + """ + + # 1. Language/Topic Coherence + if sample.query_embedding_t is None or sample.query_embedding_t1 is None: + topic_sim = 0.5 + else: + topic_sim = cosine_sim(sample.query_embedding_t, sample.query_embedding_t1) + + # 2. Negative Keywords (Complaint/Correction) + negative_keywords = [ + "you didn't", "that's not", "incorrect", "redo", "again", "explain more", + "doesn't help", "wrong", "no", "not what i asked", + "你没", "不是", "这不是", "重来", "重新", "不对", "错了", "没说清楚" + ] + + # 3. Positive Keywords (Follow-up/Elaboration) + positive_keywords = [ + "can you elaborate", "give an example", "continue", "what if", "based on that", + "thanks", "good", "great", "cool", + "能不能详细一点", "举个例子", "再继续", "那如果", "接下来", "在这个基础上", "谢谢", "不错" + ] + + q1_lower = sample.query_t1.lower() + + has_negative = any(kw in q1_lower for kw in negative_keywords) + has_positive = any(kw in q1_lower for kw in positive_keywords) + + reward = 0.0 + + if has_negative: + reward -= 1.0 + + if has_positive: + # Only reward if topic similarity is decent, otherwise might be "thanks, bye" (end of session) + # But "thanks" is good. + reward += 0.5 + if topic_sim > 0.3: + reward += 0.5 + + if topic_sim < 0.2: + # Topic shift -> previous interaction likely finished or failed. + # If no explicit positive/negative, assume neutral/slightly decayed. + # If user changes topic, it often means the previous task is done (neutral/positive) + # OR they gave up (negative). Hard to tell. + # Let's dampen the reward towards 0. + reward *= 0.5 + + # Clip + return max(-1.0, min(1.0, reward)) + diff --git a/src/personalization/feedback/sampler.py b/src/personalization/feedback/sampler.py new file mode 100644 index 0000000..9e26912 --- /dev/null +++ b/src/personalization/feedback/sampler.py @@ -0,0 +1,109 @@ +from typing import Iterable, List, Optional +import numpy as np +from tqdm import tqdm + +from personalization.retrieval.preference_store.schemas import ChatTurn, MemoryCard +from personalization.feedback.schemas import TurnSample +from personalization.retrieval.pipeline import retrieve_with_rerank +from personalization.models.llm.qwen_instruct import QwenInstruct +from personalization.models.embedding.base import EmbeddingModel +from personalization.models.reranker.base import Reranker +from personalization.user_model.tensor_store import UserTensorStore + +def build_turn_samples_from_sessions( + sessions: Iterable[List[ChatTurn]], + embed_model: EmbeddingModel, + llm: QwenInstruct, + reranker: Reranker, + memory_cards: List[MemoryCard], + memory_embeddings: np.ndarray, + user_store: UserTensorStore, + item_vectors: np.ndarray, + max_samples: Optional[int] = None, + topk_dense: int = 64, + topk_rerank: int = 3, +) -> List[TurnSample]: + samples = [] + + for turns in tqdm(sessions, desc="Building TurnSamples"): + if max_samples and len(samples) >= max_samples: + break + + # Ensure sorted by turn_id + sorted_turns = sorted(turns, key=lambda x: x.turn_id) + + # Iterate to find (q_t, a_t, q_{t+1}) + for i in range(len(sorted_turns)): + if max_samples and len(samples) >= max_samples: + break + + q_t = sorted_turns[i] + if q_t.role != "user": + continue + + # Find next user turn + # Also try to find assistant response in between + a_t_text = "" + q_t1 = None + + # Look ahead + for j in range(i + 1, len(sorted_turns)): + next_turn = sorted_turns[j] + if next_turn.role == "assistant" and not a_t_text: + a_t_text = next_turn.text + elif next_turn.role == "user": + q_t1 = next_turn + break + + if not q_t1: + # End of session or no subsequent user query + continue + + # We have q_t, a_t (optional but preferred), q_t1 + # If a_t is missing, we might skip or use empty string. + # For RL, we usually need the answer to evaluate quality. + # If dataset doesn't have assistant turns, we might need to generate one? + # For now, let's proceed even if a_t is empty, or maybe require it. + if not a_t_text: + # Try to use LLM to generate if needed, but for offline sampling + # from existing chats, we prefer existing answers. + # If using OASST1, it should have assistant turns. + pass + + # 3. Retrieve memories for q_t + memories_t = retrieve_with_rerank( + user_id=q_t.user_id, + query=q_t.text, + embed_model=embed_model, + reranker=reranker, + memory_cards=memory_cards, + memory_embeddings=memory_embeddings, + user_store=user_store, + item_vectors=item_vectors, + topk_dense=topk_dense, + topk_rerank=topk_rerank, + beta_long=0.0, + beta_short=0.0, + only_own_memories=True # Assume we want user specific memories + ) + + # 4. Precompute embeddings + # We can do this efficiently later or batch, but here per sample + e_q_t = embed_model.encode([q_t.text], return_tensor=False)[0] + e_q_t1 = embed_model.encode([q_t1.text], return_tensor=False)[0] + + sample = TurnSample( + user_id=q_t.user_id, + session_id=q_t.session_id, + turn_id=q_t.turn_id, + query_t=q_t.text, + answer_t=a_t_text, + query_t1=q_t1.text, + memories=memories_t, + query_embedding_t=np.array(e_q_t), + query_embedding_t1=np.array(e_q_t1) + ) + samples.append(sample) + + return samples + diff --git a/src/personalization/feedback/schemas.py b/src/personalization/feedback/schemas.py new file mode 100644 index 0000000..b15db80 --- /dev/null +++ b/src/personalization/feedback/schemas.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Optional, Any +import numpy as np + +from personalization.retrieval.preference_store.schemas import MemoryCard + +@dataclass +class TurnSample: + user_id: str + session_id: str + turn_id: int # index of q_t within the session + query_t: str # q_t + answer_t: str # a_t + query_t1: str # q_{t+1} + memories: List[MemoryCard] # A_t + + # Optional pre-computed vectors and features + query_embedding_t: Optional[np.ndarray] = None + query_embedding_t1: Optional[np.ndarray] = None + memory_embeddings: Optional[np.ndarray] = None # corresponding e_m or v_m for memories + |
