summaryrefslogtreecommitdiff
path: root/src/personalization/feedback/handlers.py
blob: 60a8d171cb405391fa3d596c27e2446ab687c52e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from typing import Tuple, List, Optional
import numpy as np

from personalization.retrieval.preference_store.schemas import MemoryCard
from personalization.feedback.schemas import TurnSample
from personalization.feedback.reward_model import estimate_reward
from personalization.feedback.gating import estimate_retrieval_gating

def eval_step(
    q_t: str,
    answer_t: str,
    q_t1: str,
    memories_t: List[MemoryCard],
    query_embedding_t: Optional[np.ndarray] = None,
    query_embedding_t1: Optional[np.ndarray] = None,
) -> Tuple[float, float]:
    """
    Unified evaluation interface.
    Given (q_t, a_t, q_{t+1}, memories), returns (reward_hat, gating_hat).
    """
    
    # Construct a lightweight TurnSample
    # We might need embeddings for gating. If not provided, gating might return default.
    
    # Ensure memories have embeddings for gating
    mem_embs = None
    if memories_t and memories_t[0].embedding_e:
        try:
            mem_embs = np.array([m.embedding_e for m in memories_t])
        except:
            pass
            
    sample = TurnSample(
        user_id="", # Not needed for simple eval
        session_id="",
        turn_id=0,
        query_t=q_t,
        answer_t=answer_t,
        query_t1=q_t1,
        memories=memories_t,
        query_embedding_t=query_embedding_t,
        query_embedding_t1=query_embedding_t1,
        memory_embeddings=mem_embs
    )
    
    r_hat = estimate_reward(sample)
    g_hat = estimate_retrieval_gating(sample, r_hat)
    
    return r_hat, g_hat