summaryrefslogtreecommitdiff
path: root/src/personalization/retrieval/pipeline.py
blob: 3d3eeb73ba38fdf520c80c032319d747df9e7edf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
from typing import List, Tuple
import numpy as np

from personalization.models.embedding.base import EmbeddingModel
from personalization.models.reranker.base import Reranker
from personalization.retrieval.preference_store.schemas import MemoryCard
from personalization.user_model.tensor_store import UserTensorStore, UserState
from personalization.user_model.scoring import score_with_user
from personalization.user_model.policy.reinforce import compute_policy_scores

def cosine_similarity_matrix(E: np.ndarray, e_q: np.ndarray) -> np.ndarray:
    # E: [M, d], e_q: [d]
    return np.dot(E, e_q)

def dense_topk_indices(
    query: str,
    embed_model: EmbeddingModel,
    memory_embeddings: np.ndarray,
    valid_indices: List[int] = None,
    topk: int = 64
) -> List[int]:
    """
    Return indices of topk memories based on dense embedding similarity.
    If valid_indices is provided, only search within that subset.
    """
    if valid_indices is not None and len(valid_indices) == 0:
        return []

    e_q_list = embed_model.encode([query], normalize=True, return_tensor=False)
    e_q = np.array(e_q_list[0], dtype=np.float32)
    
    # Select subset of embeddings if restricted
    if valid_indices is not None:
        # subset_embeddings = memory_embeddings[valid_indices]
        # But valid_indices might be arbitrary.
        # Efficient way: only dot product with subset
        # E_sub: [M_sub, d]
        E_sub = memory_embeddings[valid_indices]
        sims_sub = np.dot(E_sub, e_q)
        
        # Topk within subset
        k = min(topk, len(sims_sub))
        if k == 0:
            return []
            
        # argsort gives indices relative to E_sub (0..M_sub-1)
        # We need to map back to original indices
        idx_sub = np.argsort(sims_sub)[-k:][::-1]
        
        return [valid_indices[i] for i in idx_sub]
    
    # Global search
    sims = np.dot(memory_embeddings, e_q)
    k = min(topk, len(memory_embeddings))
    if k == 0:
        return []
        
    idx = np.argsort(sims)[-k:][::-1]
    return idx.tolist()

def retrieve_with_policy(
    user_id: str,
    query: str,
    embed_model: EmbeddingModel,
    reranker: Reranker,
    memory_cards: List[MemoryCard],
    memory_embeddings: np.ndarray,  # shape: [M, d]
    user_store: UserTensorStore,
    item_vectors: np.ndarray,       # shape: [M, k], v_m
    topk_dense: int = 64,
    topk_rerank: int = 8,
    beta_long: float = 0.0,
    beta_short: float = 0.0,
    tau: float = 1.0,
    only_own_memories: bool = False,
    sample: bool = False,
) -> Tuple[List[MemoryCard], np.ndarray, np.ndarray, List[int], np.ndarray]:
    """
    Returns extended info for policy update:
    (candidates, candidate_item_vectors, base_scores, chosen_indices, policy_probs)
    
    Args:
        sample: If True, use stochastic sampling from policy distribution (for training/exploration).
                If False, use deterministic top-k by policy scores (for evaluation).
    """
    # 0. Filter indices if needed
    valid_indices = None
    if only_own_memories:
        valid_indices = [i for i, card in enumerate(memory_cards) if card.user_id == user_id]
        if not valid_indices:
            return [], np.array([]), np.array([]), [], np.array([])

    # 1. Dense retrieval
    dense_idx = dense_topk_indices(
        query, 
        embed_model, 
        memory_embeddings, 
        valid_indices=valid_indices,
        topk=topk_dense
    )
    # DEBUG: Check for duplicates or out of bounds
    if len(dense_idx) > 0:
        import os
        if os.getenv("RETRIEVAL_DEBUG") == "1":
            print(f"  [Pipeline] Dense Indices (Top {len(dense_idx)}): {dense_idx[:10]}...")
            print(f"  [Pipeline] Max Index: {max(dense_idx)} | Memory Size: {len(memory_cards)}")
            
    if not dense_idx:
        return [], np.array([]), np.array([]), [], np.array([])

    candidates = [memory_cards[i] for i in dense_idx]
    candidate_docs = [c.note_text for c in candidates]
    
    # 2. Rerank base score (P(yes|q,m))
    base_scores = np.array(reranker.score(query, candidate_docs))
    
    # 3. Policy Scoring (Softmax)
    user_state: UserState = user_store.get_state(user_id)
    candidate_vectors = item_vectors[dense_idx] # [K, k]
    
    policy_out = compute_policy_scores(
        base_scores=base_scores,
        user_state=user_state,
        item_vectors=candidate_vectors,
        beta_long=beta_long,
        beta_short=beta_short,
        tau=tau
    )
    
    # 4. Selection: Greedy (eval) or Stochastic (training)
    k = min(topk_rerank, len(policy_out.scores))
    
    if sample:
        # Stochastic sampling from policy distribution (for training/exploration)
        # Sample k indices without replacement, weighted by policy probs
        probs = policy_out.probs
        # Normalize to ensure sum to 1 (handle numerical issues)
        probs = probs / (probs.sum() + 1e-10)
        # Sample without replacement
        chosen_indices = np.random.choice(
            len(probs), size=k, replace=False, p=probs
        ).tolist()
    else:
        # Deterministic top-k by policy scores (for evaluation)
        top_indices_local = policy_out.scores.argsort()[-k:][::-1]
        chosen_indices = top_indices_local.tolist()
    
    import os
    if os.getenv("RETRIEVAL_DEBUG") == "1":
        print(f"  [Pipeline] Candidates: {len(candidates)} | Chosen Indices: {chosen_indices} | Sample: {sample}")
        
    return candidates, candidate_vectors, base_scores, chosen_indices, policy_out.probs

def retrieve_no_policy(
    user_id: str,
    query: str,
    embed_model: EmbeddingModel,
    reranker: Reranker,
    memory_cards: List[MemoryCard],
    memory_embeddings: np.ndarray,  # shape: [M, d]
    topk_dense: int = 64,
    topk_rerank: int = 8,
    only_own_memories: bool = False,
) -> Tuple[List[MemoryCard], np.ndarray, np.ndarray, List[int], np.ndarray]:
    """
    Deterministic retrieval baseline (NoPersonal mode):
    - Dense retrieval -> Rerank -> Top-K (no policy sampling, no user vector influence)
    
    Returns same structure as retrieve_with_policy for compatibility:
    (candidates, candidate_item_vectors, base_scores, chosen_indices, rerank_scores_for_chosen)
    
    Note: candidate_item_vectors is empty array (not used in NoPersonal mode)
          The last return value is rerank scores instead of policy probs
    """
    # 0. Filter indices if needed
    valid_indices = None
    if only_own_memories:
        valid_indices = [i for i, card in enumerate(memory_cards) if card.user_id == user_id]
        if not valid_indices:
            return [], np.array([]), np.array([]), [], np.array([])

    # 1. Dense retrieval
    dense_idx = dense_topk_indices(
        query, 
        embed_model, 
        memory_embeddings, 
        valid_indices=valid_indices,
        topk=topk_dense
    )
    
    if not dense_idx:
        return [], np.array([]), np.array([]), [], np.array([])

    candidates = [memory_cards[i] for i in dense_idx]
    candidate_docs = [c.note_text for c in candidates]
    
    # 2. Rerank base score (P(yes|q,m))
    base_scores = np.array(reranker.score(query, candidate_docs))
    
    # 3. Deterministic Top-K selection based on rerank scores ONLY (no policy)
    k = min(topk_rerank, len(base_scores))
    top_indices_local = base_scores.argsort()[-k:][::-1]
    chosen_indices = top_indices_local.tolist()
    
    # Get scores for chosen items (for logging compatibility)
    chosen_scores = base_scores[top_indices_local]
    
    # Return empty item vectors (not used in NoPersonal mode)
    # Return rerank scores as the "probs" field for logging compatibility
    return candidates, np.array([]), base_scores, chosen_indices, chosen_scores


def retrieve_with_rerank(
    user_id: str,
    query: str,
    embed_model: EmbeddingModel,
    reranker: Reranker,
    memory_cards: List[MemoryCard],
    memory_embeddings: np.ndarray,  # shape: [M, d]
    user_store: UserTensorStore,
    item_vectors: np.ndarray,       # shape: [M, k], v_m
    topk_dense: int = 64,
    topk_rerank: int = 8,
    beta_long: float = 0.0,
    beta_short: float = 0.0,
    only_own_memories: bool = False,
) -> List[MemoryCard]:
    """
    Wrapper around retrieve_with_policy for standard inference.
    """
    candidates, _, _, chosen_indices, _ = retrieve_with_policy(
        user_id=user_id,
        query=query,
        embed_model=embed_model,
        reranker=reranker,
        memory_cards=memory_cards,
        memory_embeddings=memory_embeddings,
        user_store=user_store,
        item_vectors=item_vectors,
        topk_dense=topk_dense,
        topk_rerank=topk_rerank,
        beta_long=beta_long,
        beta_short=beta_short,
        tau=1.0, # Default tau
        only_own_memories=only_own_memories
    )
    
    return [candidates[i] for i in chosen_indices]