summaryrefslogtreecommitdiff
path: root/src/personalization/feedback
diff options
context:
space:
mode:
Diffstat (limited to 'src/personalization/feedback')
-rw-r--r--src/personalization/feedback/__init__.py0
-rw-r--r--src/personalization/feedback/gating.py72
-rw-r--r--src/personalization/feedback/handlers.py87
-rw-r--r--src/personalization/feedback/llm_reward.py253
-rw-r--r--src/personalization/feedback/local_llm_reward.py370
-rw-r--r--src/personalization/feedback/online_update.py0
-rw-r--r--src/personalization/feedback/reward_model.py64
-rw-r--r--src/personalization/feedback/sampler.py109
-rw-r--r--src/personalization/feedback/schemas.py23
9 files changed, 978 insertions, 0 deletions
diff --git a/src/personalization/feedback/__init__.py b/src/personalization/feedback/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/feedback/__init__.py
diff --git a/src/personalization/feedback/gating.py b/src/personalization/feedback/gating.py
new file mode 100644
index 0000000..d741874
--- /dev/null
+++ b/src/personalization/feedback/gating.py
@@ -0,0 +1,72 @@
+import numpy as np
+from personalization.feedback.schemas import TurnSample
+
+def cosine_sim_batch(matrix: np.ndarray, vector: np.ndarray) -> np.ndarray:
+ # matrix: [N, d], vector: [d]
+ # return: [N]
+ norm_m = np.linalg.norm(matrix, axis=1)
+ norm_v = np.linalg.norm(vector)
+
+ # Avoid div by zero
+ den = norm_m * norm_v
+ den[den == 0] = 1e-9
+
+ return np.dot(matrix, vector) / den
+
+def estimate_retrieval_gating(sample: TurnSample, reward_hat: float) -> float:
+ """
+ Return g_t in [0,1], representing how much the reward is due to retrieval.
+ """
+ e_q = sample.query_embedding_t
+ e_q1 = sample.query_embedding_t1
+
+ if e_q is None or e_q1 is None or not sample.memories:
+ return 0.5 # Neutral
+
+ # We need embeddings of the memories.
+ # In a real pipeline, we might pass them in sample.memory_embeddings.
+ # If missing, we can't compute sim.
+ if sample.memory_embeddings is None:
+ # Try to use embedding_e from memory cards if available
+ # But MemoryCard.embedding_e is List[float]
+ try:
+ mem_embs = np.array([m.embedding_e for m in sample.memories])
+ if mem_embs.shape[1] == 0: # Empty embeddings
+ return 0.5
+ except:
+ return 0.5
+ else:
+ mem_embs = sample.memory_embeddings
+
+ # Compute similarities
+ # shape: [K]
+ sims_q = cosine_sim_batch(mem_embs, e_q)
+ sims_q1 = cosine_sim_batch(mem_embs, e_q1)
+
+ s_q_max = sims_q.max() if len(sims_q) > 0 else 0
+ s_q1_max = sims_q1.max() if len(sims_q1) > 0 else 0
+
+ g = 0.5
+
+ # Heuristics
+
+ # Case A: Retrieval clearly irrelevant + bad reward
+ # q_t / q_{t+1} have low similarity to memories -> likely retrieval failure (or no relevant memories)
+ if reward_hat < -0.5 and s_q_max < 0.2 and s_q1_max < 0.2:
+ g = 0.9 # Blame retrieval (for failing to find anything, or nothing exists)
+
+ # Case B: Retrieval looks good but reward is bad
+ # Memories are relevant to query, but user still unhappy -> LLM didn't use them well?
+ elif reward_hat < -0.5 and s_q_max > 0.5:
+ g = 0.2 # Likely LLM fault
+
+ # Case C: Good reward
+ # If reward is high, we assume both did okay.
+ elif reward_hat > 0.5:
+ if s_q_max > 0.4:
+ g = 0.6 # Retrieval helped
+ else:
+ g = 0.3 # LLM handled it without strong retrieval help
+
+ return float(g)
+
diff --git a/src/personalization/feedback/handlers.py b/src/personalization/feedback/handlers.py
new file mode 100644
index 0000000..f0468b6
--- /dev/null
+++ b/src/personalization/feedback/handlers.py
@@ -0,0 +1,87 @@
+from typing import Tuple, List, Optional
+import numpy as np
+
+from personalization.retrieval.preference_store.schemas import MemoryCard
+from personalization.feedback.schemas import TurnSample
+from personalization.feedback.reward_model import estimate_reward
+from personalization.feedback.gating import estimate_retrieval_gating
+from personalization.feedback.llm_reward import (
+ LLMRewardClient, LLMRewardConfig, RewardResult
+)
+
+
+def eval_step(
+ q_t: str,
+ answer_t: str,
+ q_t1: str,
+ memories_t: List[MemoryCard],
+ query_embedding_t: Optional[np.ndarray] = None,
+ query_embedding_t1: Optional[np.ndarray] = None,
+) -> Tuple[float, float]:
+ """
+ Keyword-based evaluation (legacy).
+ Given (q_t, a_t, q_{t+1}, memories), returns (reward_hat, gating_hat).
+ """
+ mem_embs = None
+ if memories_t and memories_t[0].embedding_e:
+ try:
+ mem_embs = np.array([m.embedding_e for m in memories_t])
+ except:
+ pass
+
+ sample = TurnSample(
+ user_id="",
+ session_id="",
+ turn_id=0,
+ query_t=q_t,
+ answer_t=answer_t,
+ query_t1=q_t1,
+ memories=memories_t,
+ query_embedding_t=query_embedding_t,
+ query_embedding_t1=query_embedding_t1,
+ memory_embeddings=mem_embs,
+ )
+
+ r_hat = estimate_reward(sample)
+ g_hat = estimate_retrieval_gating(sample, r_hat)
+
+ return r_hat, g_hat
+
+
+async def eval_step_llm(
+ q_t: str,
+ answer_t: str,
+ q_t1: str,
+ memories_t: List[MemoryCard],
+ client: LLMRewardClient,
+ query_embedding_t: Optional[np.ndarray] = None,
+ query_embedding_t1: Optional[np.ndarray] = None,
+) -> Tuple[float, float]:
+ """
+ LLM-as-judge evaluation (async).
+ Returns (reward, gating) where gating=0.0 if update should be skipped.
+
+ The gating signal is derived from the judge's confidence and label:
+ - If confidence < tau_c or label == topic_shift: gating = 0.0
+ - Otherwise: gating = confidence (continuous, in [tau_c, 1.0])
+
+ This replaces the old heuristic gating with the judge's own confidence.
+ """
+ sample = TurnSample(
+ user_id="",
+ session_id="",
+ turn_id=0,
+ query_t=q_t,
+ answer_t=answer_t,
+ query_t1=q_t1,
+ memories=memories_t,
+ query_embedding_t=query_embedding_t,
+ query_embedding_t1=query_embedding_t1,
+ )
+
+ result: RewardResult = await client.judge(sample)
+
+ if result.should_update:
+ return result.reward, result.confidence
+ else:
+ return 0.0, 0.0
diff --git a/src/personalization/feedback/llm_reward.py b/src/personalization/feedback/llm_reward.py
new file mode 100644
index 0000000..6adcf98
--- /dev/null
+++ b/src/personalization/feedback/llm_reward.py
@@ -0,0 +1,253 @@
+"""
+LLM-as-Judge reward model using OpenAI GPT-5-nano (async for parallelism).
+
+Replaces keyword-based heuristic reward with structured LLM judgement.
+Judge receives only (q_t, a_t, q_{t+1}) — no oracle preference cards, no history.
+
+Label taxonomy → scalar reward mapping:
+ neg_constraint_restate → -1.0
+ neg_correction → -0.8
+ neg_confusion → -0.6
+ pos_praise → +0.8
+ pos_progress → +0.1
+ neutral → 0.0
+ topic_shift → 0.0 (update skipped)
+
+Confidence gating: if confidence < tau_c, reward is set to 0 and update is skipped.
+"""
+from __future__ import annotations
+
+import asyncio
+import hashlib
+import json
+import os
+from dataclasses import dataclass, field
+from typing import Dict, List, Optional, Tuple
+
+from openai import AsyncOpenAI, RateLimitError, APITimeoutError, APIConnectionError
+
+from personalization.feedback.schemas import TurnSample
+
+
+# --- Label → Reward Mapping ---
+
+REWARD_MAP: Dict[str, float] = {
+ "neg_constraint_restate": -1.0,
+ "neg_correction": -0.8,
+ "neg_confusion": -0.6,
+ "pos_praise": +0.8,
+ "pos_progress": +0.1,
+ "neutral": 0.0,
+ "topic_shift": 0.0,
+}
+
+VALID_LABELS = set(REWARD_MAP.keys())
+
+
+# --- Configuration ---
+
+@dataclass
+class LLMRewardConfig:
+ model: str = "gpt-5-nano"
+ api_key: Optional[str] = None # Falls back to OPENAI_API_KEY env var
+ base_url: Optional[str] = None # For custom endpoints
+ max_concurrent: int = 32 # Semaphore limit for parallel requests
+ max_retries: int = 3
+ retry_base_delay: float = 1.0 # Exponential backoff base (seconds)
+ timeout: float = 60.0 # Per-request timeout (reasoning models are slower)
+ max_completion_tokens: int = 2048 # Must be high — reasoning models use internal tokens
+ confidence_threshold: float = 0.6 # tau_c: skip update if confidence < this
+ enable_cache: bool = True # Cache by hash of (q_t, a_t, q_{t+1})
+
+
+# --- Prompt ---
+
+JUDGE_SYSTEM_PROMPT = """\
+You are a feedback classifier. Given a user query (q_t), the assistant's response (a_t), \
+and the user's next message (q_{t+1}), classify the user's follow-up into exactly one label.
+
+Labels (mutually exclusive):
+- neg_constraint_restate: User reasserts constraints/preferences as correction (e.g., "as I said…", "remember…", "按我说的…").
+- neg_correction: User indicates the content is wrong or the assistant failed to answer.
+- neg_confusion: User indicates confusion or requests re-explanation.
+- pos_praise: Explicit praise or satisfaction with the response.
+- pos_progress: Constructive continuation (examples, extensions, what-if, next steps) without complaint.
+- neutral: Ambiguous or minimal feedback, not clearly positive or negative.
+- topic_shift: User switches to a new unrelated task/topic.
+
+Output a JSON object with fields: label, confidence (0-1), rationale (one short sentence)."""
+
+JUDGE_USER_TEMPLATE = """\
+q_t: {query_t}
+
+a_t: {answer_t}
+
+q_{{t+1}}: {query_t1}"""
+
+
+# --- Result Dataclass ---
+
+@dataclass
+class RewardResult:
+ label: str
+ confidence: float
+ rationale: str
+ reward: float
+ should_update: bool # False if gated by confidence or topic_shift
+
+
+# --- Async Client ---
+
+class LLMRewardClient:
+ """Async OpenAI client for LLM-as-judge reward estimation."""
+
+ def __init__(self, config: Optional[LLMRewardConfig] = None):
+ self.config = config or LLMRewardConfig()
+ self._client = AsyncOpenAI(
+ api_key=self.config.api_key or os.getenv("OPENAI_API_KEY"),
+ base_url=self.config.base_url,
+ timeout=self.config.timeout,
+ )
+ self._semaphore = asyncio.Semaphore(self.config.max_concurrent)
+ self._cache: Dict[str, RewardResult] = {}
+
+ def _cache_key(self, query_t: str, answer_t: str, query_t1: str) -> str:
+ """Deterministic hash of the judge input triple."""
+ content = f"{query_t}\x00{answer_t}\x00{query_t1}"
+ return hashlib.sha256(content.encode("utf-8")).hexdigest()
+
+ async def _call_with_retry(self, messages: List[dict]) -> str:
+ """Single LLM call with exponential backoff retry."""
+ for attempt in range(self.config.max_retries):
+ try:
+ async with self._semaphore:
+ response = await self._client.chat.completions.create(
+ model=self.config.model,
+ messages=messages,
+ max_completion_tokens=self.config.max_completion_tokens,
+ response_format={"type": "json_object"},
+ )
+ content = response.choices[0].message.content
+ if content:
+ return content.strip()
+ # Reasoning model may exhaust tokens on thinking — retry
+ if response.choices[0].finish_reason == "length":
+ continue
+ return ""
+ except (RateLimitError, APITimeoutError, APIConnectionError) as e:
+ if attempt == self.config.max_retries - 1:
+ raise
+ delay = self.config.retry_base_delay * (2 ** attempt)
+ await asyncio.sleep(delay)
+ return ""
+
+ def _build_messages(self, sample: TurnSample) -> List[dict]:
+ """Construct the judge prompt from (q_t, a_t, q_{t+1}) only."""
+ user_content = JUDGE_USER_TEMPLATE.format(
+ query_t=sample.query_t,
+ answer_t=sample.answer_t,
+ query_t1=sample.query_t1,
+ )
+ return [
+ {"role": "system", "content": JUDGE_SYSTEM_PROMPT},
+ {"role": "user", "content": user_content},
+ ]
+
+ def _parse_result(self, raw: str) -> RewardResult:
+ """Parse structured JSON output into RewardResult."""
+ try:
+ parsed = json.loads(raw)
+ label = parsed["label"]
+ confidence = float(parsed["confidence"])
+ rationale = parsed.get("rationale", "")
+
+ if label not in VALID_LABELS:
+ label = "neutral"
+ confidence = 0.0
+
+ reward = REWARD_MAP[label]
+
+ # Confidence gating and topic_shift skip
+ should_update = (
+ confidence >= self.config.confidence_threshold
+ and label != "topic_shift"
+ )
+ if not should_update:
+ reward = 0.0
+
+ return RewardResult(
+ label=label,
+ confidence=confidence,
+ rationale=rationale,
+ reward=reward,
+ should_update=should_update,
+ )
+ except (json.JSONDecodeError, KeyError, TypeError, ValueError):
+ return RewardResult(
+ label="neutral",
+ confidence=0.0,
+ rationale="parse_failure",
+ reward=0.0,
+ should_update=False,
+ )
+
+ async def judge(self, sample: TurnSample) -> RewardResult:
+ """Judge a single turn (async). Returns RewardResult with gating applied."""
+ # Cache lookup
+ if self.config.enable_cache:
+ key = self._cache_key(sample.query_t, sample.answer_t, sample.query_t1)
+ if key in self._cache:
+ return self._cache[key]
+
+ messages = self._build_messages(sample)
+ raw = await self._call_with_retry(messages)
+ result = self._parse_result(raw)
+
+ # Cache store
+ if self.config.enable_cache:
+ self._cache[key] = result
+
+ return result
+
+ async def judge_batch(self, samples: List[TurnSample]) -> List[RewardResult]:
+ """Judge a batch of turns in parallel. Returns list of RewardResult."""
+ tasks = [self.judge(s) for s in samples]
+ return await asyncio.gather(*tasks)
+
+ async def close(self):
+ """Close the underlying HTTP client."""
+ await self._client.close()
+
+
+# --- Synchronous Wrappers ---
+
+def estimate_reward_llm(
+ sample: TurnSample,
+ config: Optional[LLMRewardConfig] = None,
+) -> Tuple[float, bool]:
+ """
+ Synchronous single-sample reward estimation.
+ Returns (reward, should_update).
+ """
+ client = LLMRewardClient(config)
+ try:
+ result = asyncio.run(client.judge(sample))
+ return result.reward, result.should_update
+ finally:
+ asyncio.run(client.close())
+
+
+def estimate_rewards_batch(
+ samples: List[TurnSample],
+ config: Optional[LLMRewardConfig] = None,
+) -> List[Tuple[float, bool]]:
+ """
+ Synchronous batch reward estimation (runs async internally).
+ Returns list of (reward, should_update) tuples.
+ """
+ client = LLMRewardClient(config)
+ try:
+ results = asyncio.run(client.judge_batch(samples))
+ return [(r.reward, r.should_update) for r in results]
+ finally:
+ asyncio.run(client.close())
diff --git a/src/personalization/feedback/local_llm_reward.py b/src/personalization/feedback/local_llm_reward.py
new file mode 100644
index 0000000..70bbeb8
--- /dev/null
+++ b/src/personalization/feedback/local_llm_reward.py
@@ -0,0 +1,370 @@
+"""
+Local LLM reward model using vLLM server for batch inference.
+
+Drop-in replacement for LLMRewardClient when you want to use a local model
+(e.g., Llama-3.1-8B-Instruct) instead of OpenAI API.
+
+Uses BatchVLLMClient for efficient concurrent requests - vLLM's continuous
+batching will process them together for high throughput.
+"""
+from __future__ import annotations
+
+import asyncio
+import hashlib
+import json
+from dataclasses import dataclass
+from typing import Dict, List, Optional
+
+import aiohttp
+
+from personalization.feedback.schemas import TurnSample
+from personalization.feedback.llm_reward import (
+ REWARD_MAP,
+ VALID_LABELS,
+ JUDGE_SYSTEM_PROMPT,
+ JUDGE_USER_TEMPLATE,
+ RewardResult,
+)
+
+
+@dataclass
+class LocalLLMRewardConfig:
+ """Configuration for local LLM reward model."""
+ vllm_url: str = "http://localhost:8005/v1" # vLLM server URL
+ model_name: Optional[str] = None # Auto-discovered if None
+ max_tokens: int = 256
+ temperature: float = 0.1
+ max_concurrent: int = 100 # High concurrency for vLLM batching
+ timeout: Optional[int] = 60 # Per-request timeout in seconds
+ confidence_threshold: float = 0.6 # tau_c: skip update if confidence < this
+ enable_cache: bool = True # Cache by hash of (q_t, a_t, q_{t+1})
+
+
+class LocalLLMRewardClient:
+ """
+ Local LLM reward client using vLLM server.
+
+ Designed for batch processing - uses async HTTP requests that vLLM
+ batches together via continuous batching for high throughput.
+ """
+
+ def __init__(self, config: Optional[LocalLLMRewardConfig] = None):
+ self.config = config or LocalLLMRewardConfig()
+ self._model_name = self.config.model_name
+ self._cache: Dict[str, RewardResult] = {}
+
+ # Discover model name if not provided
+ if self._model_name is None:
+ self._discover_model_sync()
+
+ def _discover_model_sync(self):
+ """Synchronously discover model name from vLLM server."""
+ import requests
+ try:
+ response = requests.get(
+ f"{self.config.vllm_url}/models",
+ timeout=10
+ )
+ response.raise_for_status()
+ models = response.json()
+ if models.get("data") and len(models["data"]) > 0:
+ self._model_name = models["data"][0]["id"]
+ else:
+ self._model_name = "default"
+ except Exception as e:
+ print(f"[LocalLLMReward] Warning: Could not discover model ({e})")
+ self._model_name = "default"
+
+ def _cache_key(self, query_t: str, answer_t: str, query_t1: str) -> str:
+ """Deterministic hash of the judge input triple."""
+ content = f"{query_t}\x00{answer_t}\x00{query_t1}"
+ return hashlib.sha256(content.encode("utf-8")).hexdigest()
+
+ def _build_messages(self, sample: TurnSample) -> List[dict]:
+ """Construct the judge prompt from (q_t, a_t, q_{t+1})."""
+ user_content = JUDGE_USER_TEMPLATE.format(
+ query_t=sample.query_t,
+ answer_t=sample.answer_t,
+ query_t1=sample.query_t1,
+ )
+ return [
+ {"role": "system", "content": JUDGE_SYSTEM_PROMPT},
+ {"role": "user", "content": user_content},
+ ]
+
+ def _parse_result(self, raw: Optional[str]) -> RewardResult:
+ """Parse structured JSON output into RewardResult."""
+ if raw is None:
+ return RewardResult(
+ label="neutral",
+ confidence=0.0,
+ rationale="request_failed",
+ reward=0.0,
+ should_update=False,
+ )
+
+ try:
+ # Handle markdown code blocks
+ text = raw.strip()
+ if text.startswith("```"):
+ lines = text.split("\n")
+ text = "\n".join(
+ lines[1:-1] if lines[-1].strip() == "```" else lines[1:]
+ )
+
+ parsed = json.loads(text)
+ label = parsed.get("label", "neutral")
+ confidence = float(parsed.get("confidence", 0.0))
+ rationale = parsed.get("rationale", "")
+
+ if label not in VALID_LABELS:
+ label = "neutral"
+ confidence = 0.0
+
+ reward = REWARD_MAP[label]
+
+ # Confidence gating and topic_shift skip
+ should_update = (
+ confidence >= self.config.confidence_threshold
+ and label != "topic_shift"
+ )
+ if not should_update:
+ reward = 0.0
+
+ return RewardResult(
+ label=label,
+ confidence=confidence,
+ rationale=rationale,
+ reward=reward,
+ should_update=should_update,
+ )
+ except (json.JSONDecodeError, KeyError, TypeError, ValueError):
+ # Try to extract JSON from text
+ import re
+ match = re.search(r'\{[^}]+\}', raw, re.DOTALL)
+ if match:
+ try:
+ parsed = json.loads(match.group())
+ label = parsed.get("label", "neutral")
+ confidence = float(parsed.get("confidence", 0.0))
+ rationale = parsed.get("rationale", "")
+
+ if label not in VALID_LABELS:
+ label = "neutral"
+ confidence = 0.0
+
+ reward = REWARD_MAP[label]
+ should_update = (
+ confidence >= self.config.confidence_threshold
+ and label != "topic_shift"
+ )
+ if not should_update:
+ reward = 0.0
+
+ return RewardResult(
+ label=label,
+ confidence=confidence,
+ rationale=rationale,
+ reward=reward,
+ should_update=should_update,
+ )
+ except:
+ pass
+
+ return RewardResult(
+ label="neutral",
+ confidence=0.0,
+ rationale="parse_failure",
+ reward=0.0,
+ should_update=False,
+ )
+
+ async def _single_request(
+ self,
+ session: aiohttp.ClientSession,
+ messages: List[dict],
+ idx: int,
+ ) -> tuple:
+ """Make a single async request to vLLM server."""
+ payload = {
+ "model": self._model_name,
+ "messages": messages,
+ "max_tokens": self.config.max_tokens,
+ "temperature": self.config.temperature,
+ "response_format": {"type": "json_object"},
+ }
+
+ for attempt in range(3):
+ try:
+ timeout_config = (
+ aiohttp.ClientTimeout(total=self.config.timeout)
+ if self.config.timeout else None
+ )
+ async with session.post(
+ f"{self.config.vllm_url}/chat/completions",
+ json=payload,
+ timeout=timeout_config,
+ ) as response:
+ if response.status == 200:
+ result = await response.json()
+ content = result["choices"][0]["message"]["content"]
+ return (idx, content, None)
+ elif response.status == 429:
+ # Rate limit - wait and retry
+ await asyncio.sleep(2 ** attempt)
+ continue
+ else:
+ error_text = await response.text()
+ return (idx, None, f"HTTP {response.status}: {error_text[:200]}")
+ except asyncio.TimeoutError:
+ if attempt < 2:
+ continue
+ return (idx, None, "Timeout")
+ except Exception as e:
+ return (idx, None, str(e))
+
+ return (idx, None, "Max retries exceeded")
+
+ async def judge_batch_async(
+ self,
+ samples: List[TurnSample],
+ show_progress: bool = False,
+ ) -> List[RewardResult]:
+ """
+ Judge a batch of turns using concurrent vLLM requests.
+
+ vLLM's continuous batching will process these together for
+ high throughput.
+ """
+ n_samples = len(samples)
+ results = [None] * n_samples
+
+ # Check cache and build request list
+ to_request = [] # (original_idx, messages)
+ for i, sample in enumerate(samples):
+ if self.config.enable_cache:
+ key = self._cache_key(sample.query_t, sample.answer_t, sample.query_t1)
+ if key in self._cache:
+ results[i] = self._cache[key]
+ continue
+
+ messages = self._build_messages(sample)
+ to_request.append((i, messages))
+
+ if not to_request:
+ return results
+
+ # Make concurrent requests
+ semaphore = asyncio.Semaphore(self.config.max_concurrent)
+
+ async def limited_request(session, messages, idx):
+ async with semaphore:
+ return await self._single_request(session, messages, idx)
+
+ connector = aiohttp.TCPConnector(limit=self.config.max_concurrent)
+ headers = {"Content-Type": "application/json"}
+
+ async with aiohttp.ClientSession(
+ connector=connector, headers=headers
+ ) as session:
+ tasks = [
+ limited_request(session, messages, orig_idx)
+ for orig_idx, messages in to_request
+ ]
+
+ completed = 0
+ for coro in asyncio.as_completed(tasks):
+ orig_idx, content, error = await coro
+ completed += 1
+
+ if error:
+ print(f"[LocalLLMReward] Request {orig_idx} failed: {error}")
+
+ result = self._parse_result(content)
+ results[orig_idx] = result
+
+ # Cache the result
+ if self.config.enable_cache:
+ sample = samples[orig_idx]
+ key = self._cache_key(
+ sample.query_t, sample.answer_t, sample.query_t1
+ )
+ self._cache[key] = result
+
+ if show_progress and completed % 10 == 0:
+ print(f" [LocalLLMReward {completed}/{len(to_request)}] completed")
+
+ return results
+
+ async def judge_async(self, sample: TurnSample) -> RewardResult:
+ """Judge a single turn (async)."""
+ results = await self.judge_batch_async([sample])
+ return results[0]
+
+ def judge_batch(self, samples: List[TurnSample]) -> List[RewardResult]:
+ """
+ Judge a batch of turns (sync wrapper).
+
+ This is the main entry point for batch reward estimation.
+ """
+ try:
+ loop = asyncio.get_running_loop()
+ except RuntimeError:
+ loop = None
+
+ if loop is not None:
+ # Already in an event loop - create a new thread to run the coroutine
+ import concurrent.futures
+ with concurrent.futures.ThreadPoolExecutor() as executor:
+ future = executor.submit(asyncio.run, self.judge_batch_async(samples))
+ return future.result()
+ else:
+ return asyncio.run(self.judge_batch_async(samples))
+
+ async def judge(self, sample: TurnSample) -> RewardResult:
+ """Judge a single turn (async interface for compatibility with LLMRewardClient)."""
+ return await self.judge_async(sample)
+
+ def judge_sync(self, sample: TurnSample) -> RewardResult:
+ """Judge a single turn (sync wrapper)."""
+ try:
+ loop = asyncio.get_running_loop()
+ except RuntimeError:
+ loop = None
+
+ if loop is not None:
+ # Already in an event loop - create a new thread to run the coroutine
+ import concurrent.futures
+ with concurrent.futures.ThreadPoolExecutor() as executor:
+ future = executor.submit(asyncio.run, self.judge_async(sample))
+ return future.result()
+ else:
+ return asyncio.run(self.judge_async(sample))
+
+
+# --- Convenience Functions ---
+
+def estimate_reward_local(
+ sample: TurnSample,
+ config: Optional[LocalLLMRewardConfig] = None,
+) -> tuple:
+ """
+ Synchronous single-sample reward estimation using local LLM.
+ Returns (reward, should_update).
+ """
+ client = LocalLLMRewardClient(config)
+ result = client.judge(sample)
+ return result.reward, result.should_update
+
+
+def estimate_rewards_batch_local(
+ samples: List[TurnSample],
+ config: Optional[LocalLLMRewardConfig] = None,
+) -> List[tuple]:
+ """
+ Synchronous batch reward estimation using local LLM.
+ Returns list of (reward, should_update) tuples.
+ """
+ client = LocalLLMRewardClient(config)
+ results = client.judge_batch(samples)
+ return [(r.reward, r.should_update) for r in results]
diff --git a/src/personalization/feedback/online_update.py b/src/personalization/feedback/online_update.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/feedback/online_update.py
diff --git a/src/personalization/feedback/reward_model.py b/src/personalization/feedback/reward_model.py
new file mode 100644
index 0000000..3584b43
--- /dev/null
+++ b/src/personalization/feedback/reward_model.py
@@ -0,0 +1,64 @@
+import numpy as np
+from personalization.feedback.schemas import TurnSample
+
+def cosine_sim(a: np.ndarray, b: np.ndarray) -> float:
+ norm_a = np.linalg.norm(a)
+ norm_b = np.linalg.norm(b)
+ if norm_a == 0 or norm_b == 0:
+ return 0.0
+ return float(np.dot(a, b) / (norm_a * norm_b))
+
+def estimate_reward(sample: TurnSample) -> float:
+ """
+ Return a scalar reward_hat, indicating if the previous answer was helpful.
+ Range: [-1.0, 1.0] (approx)
+ """
+
+ # 1. Language/Topic Coherence
+ if sample.query_embedding_t is None or sample.query_embedding_t1 is None:
+ topic_sim = 0.5
+ else:
+ topic_sim = cosine_sim(sample.query_embedding_t, sample.query_embedding_t1)
+
+ # 2. Negative Keywords (Complaint/Correction)
+ negative_keywords = [
+ "you didn't", "that's not", "incorrect", "redo", "again", "explain more",
+ "doesn't help", "wrong", "no", "not what i asked",
+ "你没", "不是", "这不是", "重来", "重新", "不对", "错了", "没说清楚"
+ ]
+
+ # 3. Positive Keywords (Follow-up/Elaboration)
+ positive_keywords = [
+ "can you elaborate", "give an example", "continue", "what if", "based on that",
+ "thanks", "good", "great", "cool",
+ "能不能详细一点", "举个例子", "再继续", "那如果", "接下来", "在这个基础上", "谢谢", "不错"
+ ]
+
+ q1_lower = sample.query_t1.lower()
+
+ has_negative = any(kw in q1_lower for kw in negative_keywords)
+ has_positive = any(kw in q1_lower for kw in positive_keywords)
+
+ reward = 0.0
+
+ if has_negative:
+ reward -= 1.0
+
+ if has_positive:
+ # Only reward if topic similarity is decent, otherwise might be "thanks, bye" (end of session)
+ # But "thanks" is good.
+ reward += 0.5
+ if topic_sim > 0.3:
+ reward += 0.5
+
+ if topic_sim < 0.2:
+ # Topic shift -> previous interaction likely finished or failed.
+ # If no explicit positive/negative, assume neutral/slightly decayed.
+ # If user changes topic, it often means the previous task is done (neutral/positive)
+ # OR they gave up (negative). Hard to tell.
+ # Let's dampen the reward towards 0.
+ reward *= 0.5
+
+ # Clip
+ return max(-1.0, min(1.0, reward))
+
diff --git a/src/personalization/feedback/sampler.py b/src/personalization/feedback/sampler.py
new file mode 100644
index 0000000..9e26912
--- /dev/null
+++ b/src/personalization/feedback/sampler.py
@@ -0,0 +1,109 @@
+from typing import Iterable, List, Optional
+import numpy as np
+from tqdm import tqdm
+
+from personalization.retrieval.preference_store.schemas import ChatTurn, MemoryCard
+from personalization.feedback.schemas import TurnSample
+from personalization.retrieval.pipeline import retrieve_with_rerank
+from personalization.models.llm.qwen_instruct import QwenInstruct
+from personalization.models.embedding.base import EmbeddingModel
+from personalization.models.reranker.base import Reranker
+from personalization.user_model.tensor_store import UserTensorStore
+
+def build_turn_samples_from_sessions(
+ sessions: Iterable[List[ChatTurn]],
+ embed_model: EmbeddingModel,
+ llm: QwenInstruct,
+ reranker: Reranker,
+ memory_cards: List[MemoryCard],
+ memory_embeddings: np.ndarray,
+ user_store: UserTensorStore,
+ item_vectors: np.ndarray,
+ max_samples: Optional[int] = None,
+ topk_dense: int = 64,
+ topk_rerank: int = 3,
+) -> List[TurnSample]:
+ samples = []
+
+ for turns in tqdm(sessions, desc="Building TurnSamples"):
+ if max_samples and len(samples) >= max_samples:
+ break
+
+ # Ensure sorted by turn_id
+ sorted_turns = sorted(turns, key=lambda x: x.turn_id)
+
+ # Iterate to find (q_t, a_t, q_{t+1})
+ for i in range(len(sorted_turns)):
+ if max_samples and len(samples) >= max_samples:
+ break
+
+ q_t = sorted_turns[i]
+ if q_t.role != "user":
+ continue
+
+ # Find next user turn
+ # Also try to find assistant response in between
+ a_t_text = ""
+ q_t1 = None
+
+ # Look ahead
+ for j in range(i + 1, len(sorted_turns)):
+ next_turn = sorted_turns[j]
+ if next_turn.role == "assistant" and not a_t_text:
+ a_t_text = next_turn.text
+ elif next_turn.role == "user":
+ q_t1 = next_turn
+ break
+
+ if not q_t1:
+ # End of session or no subsequent user query
+ continue
+
+ # We have q_t, a_t (optional but preferred), q_t1
+ # If a_t is missing, we might skip or use empty string.
+ # For RL, we usually need the answer to evaluate quality.
+ # If dataset doesn't have assistant turns, we might need to generate one?
+ # For now, let's proceed even if a_t is empty, or maybe require it.
+ if not a_t_text:
+ # Try to use LLM to generate if needed, but for offline sampling
+ # from existing chats, we prefer existing answers.
+ # If using OASST1, it should have assistant turns.
+ pass
+
+ # 3. Retrieve memories for q_t
+ memories_t = retrieve_with_rerank(
+ user_id=q_t.user_id,
+ query=q_t.text,
+ embed_model=embed_model,
+ reranker=reranker,
+ memory_cards=memory_cards,
+ memory_embeddings=memory_embeddings,
+ user_store=user_store,
+ item_vectors=item_vectors,
+ topk_dense=topk_dense,
+ topk_rerank=topk_rerank,
+ beta_long=0.0,
+ beta_short=0.0,
+ only_own_memories=True # Assume we want user specific memories
+ )
+
+ # 4. Precompute embeddings
+ # We can do this efficiently later or batch, but here per sample
+ e_q_t = embed_model.encode([q_t.text], return_tensor=False)[0]
+ e_q_t1 = embed_model.encode([q_t1.text], return_tensor=False)[0]
+
+ sample = TurnSample(
+ user_id=q_t.user_id,
+ session_id=q_t.session_id,
+ turn_id=q_t.turn_id,
+ query_t=q_t.text,
+ answer_t=a_t_text,
+ query_t1=q_t1.text,
+ memories=memories_t,
+ query_embedding_t=np.array(e_q_t),
+ query_embedding_t1=np.array(e_q_t1)
+ )
+ samples.append(sample)
+
+ return samples
+
diff --git a/src/personalization/feedback/schemas.py b/src/personalization/feedback/schemas.py
new file mode 100644
index 0000000..b15db80
--- /dev/null
+++ b/src/personalization/feedback/schemas.py
@@ -0,0 +1,23 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import List, Optional, Any
+import numpy as np
+
+from personalization.retrieval.preference_store.schemas import MemoryCard
+
+@dataclass
+class TurnSample:
+ user_id: str
+ session_id: str
+ turn_id: int # index of q_t within the session
+ query_t: str # q_t
+ answer_t: str # a_t
+ query_t1: str # q_{t+1}
+ memories: List[MemoryCard] # A_t
+
+ # Optional pre-computed vectors and features
+ query_embedding_t: Optional[np.ndarray] = None
+ query_embedding_t1: Optional[np.ndarray] = None
+ memory_embeddings: Optional[np.ndarray] = None # corresponding e_m or v_m for memories
+