summaryrefslogtreecommitdiff
path: root/src/personalization/feedback/schemas.py
blob: b15db8023582190241f45bdfc80db384fc3e1886 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from __future__ import annotations

from dataclasses import dataclass
from typing import List, Optional, Any
import numpy as np

from personalization.retrieval.preference_store.schemas import MemoryCard

@dataclass
class TurnSample:
    user_id: str
    session_id: str
    turn_id: int        # index of q_t within the session
    query_t: str        # q_t
    answer_t: str       # a_t
    query_t1: str       # q_{t+1}
    memories: List[MemoryCard]   # A_t

    # Optional pre-computed vectors and features
    query_embedding_t: Optional[np.ndarray] = None
    query_embedding_t1: Optional[np.ndarray] = None
    memory_embeddings: Optional[np.ndarray] = None   # corresponding e_m or v_m for memories