summaryrefslogtreecommitdiff
path: root/src/personalization/feedback/handlers.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
commite43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (patch)
tree6ce8a00d2f8b9ebd83c894a27ea01ac50cfb2ff5 /src/personalization/feedback/handlers.py
Initial commit (clean history)HEADmain
Diffstat (limited to 'src/personalization/feedback/handlers.py')
-rw-r--r--src/personalization/feedback/handlers.py50
1 files changed, 50 insertions, 0 deletions
diff --git a/src/personalization/feedback/handlers.py b/src/personalization/feedback/handlers.py
new file mode 100644
index 0000000..60a8d17
--- /dev/null
+++ b/src/personalization/feedback/handlers.py
@@ -0,0 +1,50 @@
+from typing import Tuple, List, Optional
+import numpy as np
+
+from personalization.retrieval.preference_store.schemas import MemoryCard
+from personalization.feedback.schemas import TurnSample
+from personalization.feedback.reward_model import estimate_reward
+from personalization.feedback.gating import estimate_retrieval_gating
+
+def eval_step(
+ q_t: str,
+ answer_t: str,
+ q_t1: str,
+ memories_t: List[MemoryCard],
+ query_embedding_t: Optional[np.ndarray] = None,
+ query_embedding_t1: Optional[np.ndarray] = None,
+) -> Tuple[float, float]:
+ """
+ Unified evaluation interface.
+ Given (q_t, a_t, q_{t+1}, memories), returns (reward_hat, gating_hat).
+ """
+
+ # Construct a lightweight TurnSample
+ # We might need embeddings for gating. If not provided, gating might return default.
+
+ # Ensure memories have embeddings for gating
+ mem_embs = None
+ if memories_t and memories_t[0].embedding_e:
+ try:
+ mem_embs = np.array([m.embedding_e for m in memories_t])
+ except:
+ pass
+
+ sample = TurnSample(
+ user_id="", # Not needed for simple eval
+ session_id="",
+ turn_id=0,
+ query_t=q_t,
+ answer_t=answer_t,
+ query_t1=q_t1,
+ memories=memories_t,
+ query_embedding_t=query_embedding_t,
+ query_embedding_t1=query_embedding_t1,
+ memory_embeddings=mem_embs
+ )
+
+ r_hat = estimate_reward(sample)
+ g_hat = estimate_retrieval_gating(sample, r_hat)
+
+ return r_hat, g_hat
+