summaryrefslogtreecommitdiff
path: root/src/personalization/feedback/handlers.py
blob: f0468b614f8e219f397a137afb792862ae5ea594 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from typing import Tuple, List, Optional
import numpy as np

from personalization.retrieval.preference_store.schemas import MemoryCard
from personalization.feedback.schemas import TurnSample
from personalization.feedback.reward_model import estimate_reward
from personalization.feedback.gating import estimate_retrieval_gating
from personalization.feedback.llm_reward import (
    LLMRewardClient, LLMRewardConfig, RewardResult
)


def eval_step(
    q_t: str,
    answer_t: str,
    q_t1: str,
    memories_t: List[MemoryCard],
    query_embedding_t: Optional[np.ndarray] = None,
    query_embedding_t1: Optional[np.ndarray] = None,
) -> Tuple[float, float]:
    """
    Keyword-based evaluation (legacy).
    Given (q_t, a_t, q_{t+1}, memories), returns (reward_hat, gating_hat).
    """
    mem_embs = None
    if memories_t and memories_t[0].embedding_e:
        try:
            mem_embs = np.array([m.embedding_e for m in memories_t])
        except:
            pass

    sample = TurnSample(
        user_id="",
        session_id="",
        turn_id=0,
        query_t=q_t,
        answer_t=answer_t,
        query_t1=q_t1,
        memories=memories_t,
        query_embedding_t=query_embedding_t,
        query_embedding_t1=query_embedding_t1,
        memory_embeddings=mem_embs,
    )

    r_hat = estimate_reward(sample)
    g_hat = estimate_retrieval_gating(sample, r_hat)

    return r_hat, g_hat


async def eval_step_llm(
    q_t: str,
    answer_t: str,
    q_t1: str,
    memories_t: List[MemoryCard],
    client: LLMRewardClient,
    query_embedding_t: Optional[np.ndarray] = None,
    query_embedding_t1: Optional[np.ndarray] = None,
) -> Tuple[float, float]:
    """
    LLM-as-judge evaluation (async).
    Returns (reward, gating) where gating=0.0 if update should be skipped.

    The gating signal is derived from the judge's confidence and label:
      - If confidence < tau_c or label == topic_shift: gating = 0.0
      - Otherwise: gating = confidence (continuous, in [tau_c, 1.0])

    This replaces the old heuristic gating with the judge's own confidence.
    """
    sample = TurnSample(
        user_id="",
        session_id="",
        turn_id=0,
        query_t=q_t,
        answer_t=answer_t,
        query_t1=q_t1,
        memories=memories_t,
        query_embedding_t=query_embedding_t,
        query_embedding_t1=query_embedding_t1,
    )

    result: RewardResult = await client.judge(sample)

    if result.should_update:
        return result.reward, result.confidence
    else:
        return 0.0, 0.0