1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
|
import numpy as np
from personalization.feedback.schemas import TurnSample
def cosine_sim_batch(matrix: np.ndarray, vector: np.ndarray) -> np.ndarray:
# matrix: [N, d], vector: [d]
# return: [N]
norm_m = np.linalg.norm(matrix, axis=1)
norm_v = np.linalg.norm(vector)
# Avoid div by zero
den = norm_m * norm_v
den[den == 0] = 1e-9
return np.dot(matrix, vector) / den
def estimate_retrieval_gating(sample: TurnSample, reward_hat: float) -> float:
"""
Return g_t in [0,1], representing how much the reward is due to retrieval.
"""
e_q = sample.query_embedding_t
e_q1 = sample.query_embedding_t1
if e_q is None or e_q1 is None or not sample.memories:
return 0.5 # Neutral
# We need embeddings of the memories.
# In a real pipeline, we might pass them in sample.memory_embeddings.
# If missing, we can't compute sim.
if sample.memory_embeddings is None:
# Try to use embedding_e from memory cards if available
# But MemoryCard.embedding_e is List[float]
try:
mem_embs = np.array([m.embedding_e for m in sample.memories])
if mem_embs.shape[1] == 0: # Empty embeddings
return 0.5
except:
return 0.5
else:
mem_embs = sample.memory_embeddings
# Compute similarities
# shape: [K]
sims_q = cosine_sim_batch(mem_embs, e_q)
sims_q1 = cosine_sim_batch(mem_embs, e_q1)
s_q_max = sims_q.max() if len(sims_q) > 0 else 0
s_q1_max = sims_q1.max() if len(sims_q1) > 0 else 0
g = 0.5
# Heuristics
# Case A: Retrieval clearly irrelevant + bad reward
# q_t / q_{t+1} have low similarity to memories -> likely retrieval failure (or no relevant memories)
if reward_hat < -0.5 and s_q_max < 0.2 and s_q1_max < 0.2:
g = 0.9 # Blame retrieval (for failing to find anything, or nothing exists)
# Case B: Retrieval looks good but reward is bad
# Memories are relevant to query, but user still unhappy -> LLM didn't use them well?
elif reward_hat < -0.5 and s_q_max > 0.5:
g = 0.2 # Likely LLM fault
# Case C: Good reward
# If reward is high, we assume both did okay.
elif reward_hat > 0.5:
if s_q_max > 0.4:
g = 0.6 # Retrieval helped
else:
g = 0.3 # LLM handled it without strong retrieval help
return float(g)
|