summaryrefslogtreecommitdiff
path: root/src/personalization/feedback
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-27 09:57:37 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-27 09:57:37 -0600
commitdc801c07cf38b0c495686463e6ca6f871a64440e (patch)
tree599f03114775921dbc472403c701f4a3a8ea188a /src/personalization/feedback
parente43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (diff)
Add collaborativeagents module and update gitignore
- Add collaborativeagents subproject with adapters, agents, and evaluation modules - Update .gitignore to exclude large binary files (.whl, .tar), wandb logs, and results Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'src/personalization/feedback')
-rw-r--r--src/personalization/feedback/handlers.py59
-rw-r--r--src/personalization/feedback/llm_reward.py253
2 files changed, 301 insertions, 11 deletions
diff --git a/src/personalization/feedback/handlers.py b/src/personalization/feedback/handlers.py
index 60a8d17..f0468b6 100644
--- a/src/personalization/feedback/handlers.py
+++ b/src/personalization/feedback/handlers.py
@@ -5,6 +5,10 @@ from personalization.retrieval.preference_store.schemas import MemoryCard
from personalization.feedback.schemas import TurnSample
from personalization.feedback.reward_model import estimate_reward
from personalization.feedback.gating import estimate_retrieval_gating
+from personalization.feedback.llm_reward import (
+ LLMRewardClient, LLMRewardConfig, RewardResult
+)
+
def eval_step(
q_t: str,
@@ -15,23 +19,18 @@ def eval_step(
query_embedding_t1: Optional[np.ndarray] = None,
) -> Tuple[float, float]:
"""
- Unified evaluation interface.
+ Keyword-based evaluation (legacy).
Given (q_t, a_t, q_{t+1}, memories), returns (reward_hat, gating_hat).
"""
-
- # Construct a lightweight TurnSample
- # We might need embeddings for gating. If not provided, gating might return default.
-
- # Ensure memories have embeddings for gating
mem_embs = None
if memories_t and memories_t[0].embedding_e:
try:
mem_embs = np.array([m.embedding_e for m in memories_t])
except:
pass
-
+
sample = TurnSample(
- user_id="", # Not needed for simple eval
+ user_id="",
session_id="",
turn_id=0,
query_t=q_t,
@@ -40,11 +39,49 @@ def eval_step(
memories=memories_t,
query_embedding_t=query_embedding_t,
query_embedding_t1=query_embedding_t1,
- memory_embeddings=mem_embs
+ memory_embeddings=mem_embs,
)
-
+
r_hat = estimate_reward(sample)
g_hat = estimate_retrieval_gating(sample, r_hat)
-
+
return r_hat, g_hat
+
+async def eval_step_llm(
+ q_t: str,
+ answer_t: str,
+ q_t1: str,
+ memories_t: List[MemoryCard],
+ client: LLMRewardClient,
+ query_embedding_t: Optional[np.ndarray] = None,
+ query_embedding_t1: Optional[np.ndarray] = None,
+) -> Tuple[float, float]:
+ """
+ LLM-as-judge evaluation (async).
+ Returns (reward, gating) where gating=0.0 if update should be skipped.
+
+ The gating signal is derived from the judge's confidence and label:
+ - If confidence < tau_c or label == topic_shift: gating = 0.0
+ - Otherwise: gating = confidence (continuous, in [tau_c, 1.0])
+
+ This replaces the old heuristic gating with the judge's own confidence.
+ """
+ sample = TurnSample(
+ user_id="",
+ session_id="",
+ turn_id=0,
+ query_t=q_t,
+ answer_t=answer_t,
+ query_t1=q_t1,
+ memories=memories_t,
+ query_embedding_t=query_embedding_t,
+ query_embedding_t1=query_embedding_t1,
+ )
+
+ result: RewardResult = await client.judge(sample)
+
+ if result.should_update:
+ return result.reward, result.confidence
+ else:
+ return 0.0, 0.0
diff --git a/src/personalization/feedback/llm_reward.py b/src/personalization/feedback/llm_reward.py
new file mode 100644
index 0000000..6adcf98
--- /dev/null
+++ b/src/personalization/feedback/llm_reward.py
@@ -0,0 +1,253 @@
+"""
+LLM-as-Judge reward model using OpenAI GPT-5-nano (async for parallelism).
+
+Replaces keyword-based heuristic reward with structured LLM judgement.
+Judge receives only (q_t, a_t, q_{t+1}) — no oracle preference cards, no history.
+
+Label taxonomy → scalar reward mapping:
+ neg_constraint_restate → -1.0
+ neg_correction → -0.8
+ neg_confusion → -0.6
+ pos_praise → +0.8
+ pos_progress → +0.1
+ neutral → 0.0
+ topic_shift → 0.0 (update skipped)
+
+Confidence gating: if confidence < tau_c, reward is set to 0 and update is skipped.
+"""
+from __future__ import annotations
+
+import asyncio
+import hashlib
+import json
+import os
+from dataclasses import dataclass, field
+from typing import Dict, List, Optional, Tuple
+
+from openai import AsyncOpenAI, RateLimitError, APITimeoutError, APIConnectionError
+
+from personalization.feedback.schemas import TurnSample
+
+
+# --- Label → Reward Mapping ---
+
+REWARD_MAP: Dict[str, float] = {
+ "neg_constraint_restate": -1.0,
+ "neg_correction": -0.8,
+ "neg_confusion": -0.6,
+ "pos_praise": +0.8,
+ "pos_progress": +0.1,
+ "neutral": 0.0,
+ "topic_shift": 0.0,
+}
+
+VALID_LABELS = set(REWARD_MAP.keys())
+
+
+# --- Configuration ---
+
+@dataclass
+class LLMRewardConfig:
+ model: str = "gpt-5-nano"
+ api_key: Optional[str] = None # Falls back to OPENAI_API_KEY env var
+ base_url: Optional[str] = None # For custom endpoints
+ max_concurrent: int = 32 # Semaphore limit for parallel requests
+ max_retries: int = 3
+ retry_base_delay: float = 1.0 # Exponential backoff base (seconds)
+ timeout: float = 60.0 # Per-request timeout (reasoning models are slower)
+ max_completion_tokens: int = 2048 # Must be high — reasoning models use internal tokens
+ confidence_threshold: float = 0.6 # tau_c: skip update if confidence < this
+ enable_cache: bool = True # Cache by hash of (q_t, a_t, q_{t+1})
+
+
+# --- Prompt ---
+
+JUDGE_SYSTEM_PROMPT = """\
+You are a feedback classifier. Given a user query (q_t), the assistant's response (a_t), \
+and the user's next message (q_{t+1}), classify the user's follow-up into exactly one label.
+
+Labels (mutually exclusive):
+- neg_constraint_restate: User reasserts constraints/preferences as correction (e.g., "as I said…", "remember…", "按我说的…").
+- neg_correction: User indicates the content is wrong or the assistant failed to answer.
+- neg_confusion: User indicates confusion or requests re-explanation.
+- pos_praise: Explicit praise or satisfaction with the response.
+- pos_progress: Constructive continuation (examples, extensions, what-if, next steps) without complaint.
+- neutral: Ambiguous or minimal feedback, not clearly positive or negative.
+- topic_shift: User switches to a new unrelated task/topic.
+
+Output a JSON object with fields: label, confidence (0-1), rationale (one short sentence)."""
+
+JUDGE_USER_TEMPLATE = """\
+q_t: {query_t}
+
+a_t: {answer_t}
+
+q_{{t+1}}: {query_t1}"""
+
+
+# --- Result Dataclass ---
+
+@dataclass
+class RewardResult:
+ label: str
+ confidence: float
+ rationale: str
+ reward: float
+ should_update: bool # False if gated by confidence or topic_shift
+
+
+# --- Async Client ---
+
+class LLMRewardClient:
+ """Async OpenAI client for LLM-as-judge reward estimation."""
+
+ def __init__(self, config: Optional[LLMRewardConfig] = None):
+ self.config = config or LLMRewardConfig()
+ self._client = AsyncOpenAI(
+ api_key=self.config.api_key or os.getenv("OPENAI_API_KEY"),
+ base_url=self.config.base_url,
+ timeout=self.config.timeout,
+ )
+ self._semaphore = asyncio.Semaphore(self.config.max_concurrent)
+ self._cache: Dict[str, RewardResult] = {}
+
+ def _cache_key(self, query_t: str, answer_t: str, query_t1: str) -> str:
+ """Deterministic hash of the judge input triple."""
+ content = f"{query_t}\x00{answer_t}\x00{query_t1}"
+ return hashlib.sha256(content.encode("utf-8")).hexdigest()
+
+ async def _call_with_retry(self, messages: List[dict]) -> str:
+ """Single LLM call with exponential backoff retry."""
+ for attempt in range(self.config.max_retries):
+ try:
+ async with self._semaphore:
+ response = await self._client.chat.completions.create(
+ model=self.config.model,
+ messages=messages,
+ max_completion_tokens=self.config.max_completion_tokens,
+ response_format={"type": "json_object"},
+ )
+ content = response.choices[0].message.content
+ if content:
+ return content.strip()
+ # Reasoning model may exhaust tokens on thinking — retry
+ if response.choices[0].finish_reason == "length":
+ continue
+ return ""
+ except (RateLimitError, APITimeoutError, APIConnectionError) as e:
+ if attempt == self.config.max_retries - 1:
+ raise
+ delay = self.config.retry_base_delay * (2 ** attempt)
+ await asyncio.sleep(delay)
+ return ""
+
+ def _build_messages(self, sample: TurnSample) -> List[dict]:
+ """Construct the judge prompt from (q_t, a_t, q_{t+1}) only."""
+ user_content = JUDGE_USER_TEMPLATE.format(
+ query_t=sample.query_t,
+ answer_t=sample.answer_t,
+ query_t1=sample.query_t1,
+ )
+ return [
+ {"role": "system", "content": JUDGE_SYSTEM_PROMPT},
+ {"role": "user", "content": user_content},
+ ]
+
+ def _parse_result(self, raw: str) -> RewardResult:
+ """Parse structured JSON output into RewardResult."""
+ try:
+ parsed = json.loads(raw)
+ label = parsed["label"]
+ confidence = float(parsed["confidence"])
+ rationale = parsed.get("rationale", "")
+
+ if label not in VALID_LABELS:
+ label = "neutral"
+ confidence = 0.0
+
+ reward = REWARD_MAP[label]
+
+ # Confidence gating and topic_shift skip
+ should_update = (
+ confidence >= self.config.confidence_threshold
+ and label != "topic_shift"
+ )
+ if not should_update:
+ reward = 0.0
+
+ return RewardResult(
+ label=label,
+ confidence=confidence,
+ rationale=rationale,
+ reward=reward,
+ should_update=should_update,
+ )
+ except (json.JSONDecodeError, KeyError, TypeError, ValueError):
+ return RewardResult(
+ label="neutral",
+ confidence=0.0,
+ rationale="parse_failure",
+ reward=0.0,
+ should_update=False,
+ )
+
+ async def judge(self, sample: TurnSample) -> RewardResult:
+ """Judge a single turn (async). Returns RewardResult with gating applied."""
+ # Cache lookup
+ if self.config.enable_cache:
+ key = self._cache_key(sample.query_t, sample.answer_t, sample.query_t1)
+ if key in self._cache:
+ return self._cache[key]
+
+ messages = self._build_messages(sample)
+ raw = await self._call_with_retry(messages)
+ result = self._parse_result(raw)
+
+ # Cache store
+ if self.config.enable_cache:
+ self._cache[key] = result
+
+ return result
+
+ async def judge_batch(self, samples: List[TurnSample]) -> List[RewardResult]:
+ """Judge a batch of turns in parallel. Returns list of RewardResult."""
+ tasks = [self.judge(s) for s in samples]
+ return await asyncio.gather(*tasks)
+
+ async def close(self):
+ """Close the underlying HTTP client."""
+ await self._client.close()
+
+
+# --- Synchronous Wrappers ---
+
+def estimate_reward_llm(
+ sample: TurnSample,
+ config: Optional[LLMRewardConfig] = None,
+) -> Tuple[float, bool]:
+ """
+ Synchronous single-sample reward estimation.
+ Returns (reward, should_update).
+ """
+ client = LLMRewardClient(config)
+ try:
+ result = asyncio.run(client.judge(sample))
+ return result.reward, result.should_update
+ finally:
+ asyncio.run(client.close())
+
+
+def estimate_rewards_batch(
+ samples: List[TurnSample],
+ config: Optional[LLMRewardConfig] = None,
+) -> List[Tuple[float, bool]]:
+ """
+ Synchronous batch reward estimation (runs async internally).
+ Returns list of (reward, should_update) tuples.
+ """
+ client = LLMRewardClient(config)
+ try:
+ results = asyncio.run(client.judge_batch(samples))
+ return [(r.reward, r.should_update) for r in results]
+ finally:
+ asyncio.run(client.close())