summaryrefslogtreecommitdiff
path: root/src/personalization/feedback/handlers.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/personalization/feedback/handlers.py')
-rw-r--r--src/personalization/feedback/handlers.py59
1 files changed, 48 insertions, 11 deletions
diff --git a/src/personalization/feedback/handlers.py b/src/personalization/feedback/handlers.py
index 60a8d17..f0468b6 100644
--- a/src/personalization/feedback/handlers.py
+++ b/src/personalization/feedback/handlers.py
@@ -5,6 +5,10 @@ from personalization.retrieval.preference_store.schemas import MemoryCard
from personalization.feedback.schemas import TurnSample
from personalization.feedback.reward_model import estimate_reward
from personalization.feedback.gating import estimate_retrieval_gating
+from personalization.feedback.llm_reward import (
+ LLMRewardClient, LLMRewardConfig, RewardResult
+)
+
def eval_step(
q_t: str,
@@ -15,23 +19,18 @@ def eval_step(
query_embedding_t1: Optional[np.ndarray] = None,
) -> Tuple[float, float]:
"""
- Unified evaluation interface.
+ Keyword-based evaluation (legacy).
Given (q_t, a_t, q_{t+1}, memories), returns (reward_hat, gating_hat).
"""
-
- # Construct a lightweight TurnSample
- # We might need embeddings for gating. If not provided, gating might return default.
-
- # Ensure memories have embeddings for gating
mem_embs = None
if memories_t and memories_t[0].embedding_e:
try:
mem_embs = np.array([m.embedding_e for m in memories_t])
except:
pass
-
+
sample = TurnSample(
- user_id="", # Not needed for simple eval
+ user_id="",
session_id="",
turn_id=0,
query_t=q_t,
@@ -40,11 +39,49 @@ def eval_step(
memories=memories_t,
query_embedding_t=query_embedding_t,
query_embedding_t1=query_embedding_t1,
- memory_embeddings=mem_embs
+ memory_embeddings=mem_embs,
)
-
+
r_hat = estimate_reward(sample)
g_hat = estimate_retrieval_gating(sample, r_hat)
-
+
return r_hat, g_hat
+
+async def eval_step_llm(
+ q_t: str,
+ answer_t: str,
+ q_t1: str,
+ memories_t: List[MemoryCard],
+ client: LLMRewardClient,
+ query_embedding_t: Optional[np.ndarray] = None,
+ query_embedding_t1: Optional[np.ndarray] = None,
+) -> Tuple[float, float]:
+ """
+ LLM-as-judge evaluation (async).
+ Returns (reward, gating) where gating=0.0 if update should be skipped.
+
+ The gating signal is derived from the judge's confidence and label:
+ - If confidence < tau_c or label == topic_shift: gating = 0.0
+ - Otherwise: gating = confidence (continuous, in [tau_c, 1.0])
+
+ This replaces the old heuristic gating with the judge's own confidence.
+ """
+ sample = TurnSample(
+ user_id="",
+ session_id="",
+ turn_id=0,
+ query_t=q_t,
+ answer_t=answer_t,
+ query_t1=q_t1,
+ memories=memories_t,
+ query_embedding_t=query_embedding_t,
+ query_embedding_t1=query_embedding_t1,
+ )
+
+ result: RewardResult = await client.judge(sample)
+
+ if result.should_update:
+ return result.reward, result.confidence
+ else:
+ return 0.0, 0.0