summaryrefslogtreecommitdiff
path: root/src/personalization/feedback/gating.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
commite43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (patch)
tree6ce8a00d2f8b9ebd83c894a27ea01ac50cfb2ff5 /src/personalization/feedback/gating.py
Initial commit (clean history)HEADmain
Diffstat (limited to 'src/personalization/feedback/gating.py')
-rw-r--r--src/personalization/feedback/gating.py72
1 files changed, 72 insertions, 0 deletions
diff --git a/src/personalization/feedback/gating.py b/src/personalization/feedback/gating.py
new file mode 100644
index 0000000..d741874
--- /dev/null
+++ b/src/personalization/feedback/gating.py
@@ -0,0 +1,72 @@
+import numpy as np
+from personalization.feedback.schemas import TurnSample
+
+def cosine_sim_batch(matrix: np.ndarray, vector: np.ndarray) -> np.ndarray:
+ # matrix: [N, d], vector: [d]
+ # return: [N]
+ norm_m = np.linalg.norm(matrix, axis=1)
+ norm_v = np.linalg.norm(vector)
+
+ # Avoid div by zero
+ den = norm_m * norm_v
+ den[den == 0] = 1e-9
+
+ return np.dot(matrix, vector) / den
+
+def estimate_retrieval_gating(sample: TurnSample, reward_hat: float) -> float:
+ """
+ Return g_t in [0,1], representing how much the reward is due to retrieval.
+ """
+ e_q = sample.query_embedding_t
+ e_q1 = sample.query_embedding_t1
+
+ if e_q is None or e_q1 is None or not sample.memories:
+ return 0.5 # Neutral
+
+ # We need embeddings of the memories.
+ # In a real pipeline, we might pass them in sample.memory_embeddings.
+ # If missing, we can't compute sim.
+ if sample.memory_embeddings is None:
+ # Try to use embedding_e from memory cards if available
+ # But MemoryCard.embedding_e is List[float]
+ try:
+ mem_embs = np.array([m.embedding_e for m in sample.memories])
+ if mem_embs.shape[1] == 0: # Empty embeddings
+ return 0.5
+ except:
+ return 0.5
+ else:
+ mem_embs = sample.memory_embeddings
+
+ # Compute similarities
+ # shape: [K]
+ sims_q = cosine_sim_batch(mem_embs, e_q)
+ sims_q1 = cosine_sim_batch(mem_embs, e_q1)
+
+ s_q_max = sims_q.max() if len(sims_q) > 0 else 0
+ s_q1_max = sims_q1.max() if len(sims_q1) > 0 else 0
+
+ g = 0.5
+
+ # Heuristics
+
+ # Case A: Retrieval clearly irrelevant + bad reward
+ # q_t / q_{t+1} have low similarity to memories -> likely retrieval failure (or no relevant memories)
+ if reward_hat < -0.5 and s_q_max < 0.2 and s_q1_max < 0.2:
+ g = 0.9 # Blame retrieval (for failing to find anything, or nothing exists)
+
+ # Case B: Retrieval looks good but reward is bad
+ # Memories are relevant to query, but user still unhappy -> LLM didn't use them well?
+ elif reward_hat < -0.5 and s_q_max > 0.5:
+ g = 0.2 # Likely LLM fault
+
+ # Case C: Good reward
+ # If reward is high, we assume both did okay.
+ elif reward_hat > 0.5:
+ if s_q_max > 0.4:
+ g = 0.6 # Retrieval helped
+ else:
+ g = 0.3 # LLM handled it without strong retrieval help
+
+ return float(g)
+