summaryrefslogtreecommitdiff
path: root/src/personalization/feedback/gating.py
blob: d741874b76cad1cced3281db702c9a18332deb6b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import numpy as np
from personalization.feedback.schemas import TurnSample

def cosine_sim_batch(matrix: np.ndarray, vector: np.ndarray) -> np.ndarray:
    # matrix: [N, d], vector: [d]
    # return: [N]
    norm_m = np.linalg.norm(matrix, axis=1)
    norm_v = np.linalg.norm(vector)
    
    # Avoid div by zero
    den = norm_m * norm_v
    den[den == 0] = 1e-9
    
    return np.dot(matrix, vector) / den

def estimate_retrieval_gating(sample: TurnSample, reward_hat: float) -> float:
    """
    Return g_t in [0,1], representing how much the reward is due to retrieval.
    """
    e_q = sample.query_embedding_t
    e_q1 = sample.query_embedding_t1
    
    if e_q is None or e_q1 is None or not sample.memories:
        return 0.5 # Neutral
        
    # We need embeddings of the memories. 
    # In a real pipeline, we might pass them in sample.memory_embeddings.
    # If missing, we can't compute sim.
    if sample.memory_embeddings is None:
        # Try to use embedding_e from memory cards if available
        # But MemoryCard.embedding_e is List[float]
        try:
            mem_embs = np.array([m.embedding_e for m in sample.memories])
            if mem_embs.shape[1] == 0: # Empty embeddings
                return 0.5
        except:
            return 0.5
    else:
        mem_embs = sample.memory_embeddings

    # Compute similarities
    # shape: [K]
    sims_q = cosine_sim_batch(mem_embs, e_q)
    sims_q1 = cosine_sim_batch(mem_embs, e_q1)
    
    s_q_max = sims_q.max() if len(sims_q) > 0 else 0
    s_q1_max = sims_q1.max() if len(sims_q1) > 0 else 0
    
    g = 0.5
    
    # Heuristics
    
    # Case A: Retrieval clearly irrelevant + bad reward
    # q_t / q_{t+1} have low similarity to memories -> likely retrieval failure (or no relevant memories)
    if reward_hat < -0.5 and s_q_max < 0.2 and s_q1_max < 0.2:
        g = 0.9 # Blame retrieval (for failing to find anything, or nothing exists)
        
    # Case B: Retrieval looks good but reward is bad
    # Memories are relevant to query, but user still unhappy -> LLM didn't use them well?
    elif reward_hat < -0.5 and s_q_max > 0.5:
        g = 0.2 # Likely LLM fault
        
    # Case C: Good reward
    # If reward is high, we assume both did okay. 
    elif reward_hat > 0.5:
        if s_q_max > 0.4:
            g = 0.6 # Retrieval helped
        else:
            g = 0.3 # LLM handled it without strong retrieval help
            
    return float(g)