diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2025-12-17 04:29:37 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2025-12-17 04:29:37 -0600 |
| commit | e43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (patch) | |
| tree | 6ce8a00d2f8b9ebd83c894a27ea01ac50cfb2ff5 /src/personalization/retrieval/pipeline.py | |
Diffstat (limited to 'src/personalization/retrieval/pipeline.py')
| -rw-r--r-- | src/personalization/retrieval/pipeline.py | 250 |
1 files changed, 250 insertions, 0 deletions
diff --git a/src/personalization/retrieval/pipeline.py b/src/personalization/retrieval/pipeline.py new file mode 100644 index 0000000..3d3eeb7 --- /dev/null +++ b/src/personalization/retrieval/pipeline.py @@ -0,0 +1,250 @@ +from typing import List, Tuple +import numpy as np + +from personalization.models.embedding.base import EmbeddingModel +from personalization.models.reranker.base import Reranker +from personalization.retrieval.preference_store.schemas import MemoryCard +from personalization.user_model.tensor_store import UserTensorStore, UserState +from personalization.user_model.scoring import score_with_user +from personalization.user_model.policy.reinforce import compute_policy_scores + +def cosine_similarity_matrix(E: np.ndarray, e_q: np.ndarray) -> np.ndarray: + # E: [M, d], e_q: [d] + return np.dot(E, e_q) + +def dense_topk_indices( + query: str, + embed_model: EmbeddingModel, + memory_embeddings: np.ndarray, + valid_indices: List[int] = None, + topk: int = 64 +) -> List[int]: + """ + Return indices of topk memories based on dense embedding similarity. + If valid_indices is provided, only search within that subset. + """ + if valid_indices is not None and len(valid_indices) == 0: + return [] + + e_q_list = embed_model.encode([query], normalize=True, return_tensor=False) + e_q = np.array(e_q_list[0], dtype=np.float32) + + # Select subset of embeddings if restricted + if valid_indices is not None: + # subset_embeddings = memory_embeddings[valid_indices] + # But valid_indices might be arbitrary. + # Efficient way: only dot product with subset + # E_sub: [M_sub, d] + E_sub = memory_embeddings[valid_indices] + sims_sub = np.dot(E_sub, e_q) + + # Topk within subset + k = min(topk, len(sims_sub)) + if k == 0: + return [] + + # argsort gives indices relative to E_sub (0..M_sub-1) + # We need to map back to original indices + idx_sub = np.argsort(sims_sub)[-k:][::-1] + + return [valid_indices[i] for i in idx_sub] + + # Global search + sims = np.dot(memory_embeddings, e_q) + k = min(topk, len(memory_embeddings)) + if k == 0: + return [] + + idx = np.argsort(sims)[-k:][::-1] + return idx.tolist() + +def retrieve_with_policy( + user_id: str, + query: str, + embed_model: EmbeddingModel, + reranker: Reranker, + memory_cards: List[MemoryCard], + memory_embeddings: np.ndarray, # shape: [M, d] + user_store: UserTensorStore, + item_vectors: np.ndarray, # shape: [M, k], v_m + topk_dense: int = 64, + topk_rerank: int = 8, + beta_long: float = 0.0, + beta_short: float = 0.0, + tau: float = 1.0, + only_own_memories: bool = False, + sample: bool = False, +) -> Tuple[List[MemoryCard], np.ndarray, np.ndarray, List[int], np.ndarray]: + """ + Returns extended info for policy update: + (candidates, candidate_item_vectors, base_scores, chosen_indices, policy_probs) + + Args: + sample: If True, use stochastic sampling from policy distribution (for training/exploration). + If False, use deterministic top-k by policy scores (for evaluation). + """ + # 0. Filter indices if needed + valid_indices = None + if only_own_memories: + valid_indices = [i for i, card in enumerate(memory_cards) if card.user_id == user_id] + if not valid_indices: + return [], np.array([]), np.array([]), [], np.array([]) + + # 1. Dense retrieval + dense_idx = dense_topk_indices( + query, + embed_model, + memory_embeddings, + valid_indices=valid_indices, + topk=topk_dense + ) + # DEBUG: Check for duplicates or out of bounds + if len(dense_idx) > 0: + import os + if os.getenv("RETRIEVAL_DEBUG") == "1": + print(f" [Pipeline] Dense Indices (Top {len(dense_idx)}): {dense_idx[:10]}...") + print(f" [Pipeline] Max Index: {max(dense_idx)} | Memory Size: {len(memory_cards)}") + + if not dense_idx: + return [], np.array([]), np.array([]), [], np.array([]) + + candidates = [memory_cards[i] for i in dense_idx] + candidate_docs = [c.note_text for c in candidates] + + # 2. Rerank base score (P(yes|q,m)) + base_scores = np.array(reranker.score(query, candidate_docs)) + + # 3. Policy Scoring (Softmax) + user_state: UserState = user_store.get_state(user_id) + candidate_vectors = item_vectors[dense_idx] # [K, k] + + policy_out = compute_policy_scores( + base_scores=base_scores, + user_state=user_state, + item_vectors=candidate_vectors, + beta_long=beta_long, + beta_short=beta_short, + tau=tau + ) + + # 4. Selection: Greedy (eval) or Stochastic (training) + k = min(topk_rerank, len(policy_out.scores)) + + if sample: + # Stochastic sampling from policy distribution (for training/exploration) + # Sample k indices without replacement, weighted by policy probs + probs = policy_out.probs + # Normalize to ensure sum to 1 (handle numerical issues) + probs = probs / (probs.sum() + 1e-10) + # Sample without replacement + chosen_indices = np.random.choice( + len(probs), size=k, replace=False, p=probs + ).tolist() + else: + # Deterministic top-k by policy scores (for evaluation) + top_indices_local = policy_out.scores.argsort()[-k:][::-1] + chosen_indices = top_indices_local.tolist() + + import os + if os.getenv("RETRIEVAL_DEBUG") == "1": + print(f" [Pipeline] Candidates: {len(candidates)} | Chosen Indices: {chosen_indices} | Sample: {sample}") + + return candidates, candidate_vectors, base_scores, chosen_indices, policy_out.probs + +def retrieve_no_policy( + user_id: str, + query: str, + embed_model: EmbeddingModel, + reranker: Reranker, + memory_cards: List[MemoryCard], + memory_embeddings: np.ndarray, # shape: [M, d] + topk_dense: int = 64, + topk_rerank: int = 8, + only_own_memories: bool = False, +) -> Tuple[List[MemoryCard], np.ndarray, np.ndarray, List[int], np.ndarray]: + """ + Deterministic retrieval baseline (NoPersonal mode): + - Dense retrieval -> Rerank -> Top-K (no policy sampling, no user vector influence) + + Returns same structure as retrieve_with_policy for compatibility: + (candidates, candidate_item_vectors, base_scores, chosen_indices, rerank_scores_for_chosen) + + Note: candidate_item_vectors is empty array (not used in NoPersonal mode) + The last return value is rerank scores instead of policy probs + """ + # 0. Filter indices if needed + valid_indices = None + if only_own_memories: + valid_indices = [i for i, card in enumerate(memory_cards) if card.user_id == user_id] + if not valid_indices: + return [], np.array([]), np.array([]), [], np.array([]) + + # 1. Dense retrieval + dense_idx = dense_topk_indices( + query, + embed_model, + memory_embeddings, + valid_indices=valid_indices, + topk=topk_dense + ) + + if not dense_idx: + return [], np.array([]), np.array([]), [], np.array([]) + + candidates = [memory_cards[i] for i in dense_idx] + candidate_docs = [c.note_text for c in candidates] + + # 2. Rerank base score (P(yes|q,m)) + base_scores = np.array(reranker.score(query, candidate_docs)) + + # 3. Deterministic Top-K selection based on rerank scores ONLY (no policy) + k = min(topk_rerank, len(base_scores)) + top_indices_local = base_scores.argsort()[-k:][::-1] + chosen_indices = top_indices_local.tolist() + + # Get scores for chosen items (for logging compatibility) + chosen_scores = base_scores[top_indices_local] + + # Return empty item vectors (not used in NoPersonal mode) + # Return rerank scores as the "probs" field for logging compatibility + return candidates, np.array([]), base_scores, chosen_indices, chosen_scores + + +def retrieve_with_rerank( + user_id: str, + query: str, + embed_model: EmbeddingModel, + reranker: Reranker, + memory_cards: List[MemoryCard], + memory_embeddings: np.ndarray, # shape: [M, d] + user_store: UserTensorStore, + item_vectors: np.ndarray, # shape: [M, k], v_m + topk_dense: int = 64, + topk_rerank: int = 8, + beta_long: float = 0.0, + beta_short: float = 0.0, + only_own_memories: bool = False, +) -> List[MemoryCard]: + """ + Wrapper around retrieve_with_policy for standard inference. + """ + candidates, _, _, chosen_indices, _ = retrieve_with_policy( + user_id=user_id, + query=query, + embed_model=embed_model, + reranker=reranker, + memory_cards=memory_cards, + memory_embeddings=memory_embeddings, + user_store=user_store, + item_vectors=item_vectors, + topk_dense=topk_dense, + topk_rerank=topk_rerank, + beta_long=beta_long, + beta_short=beta_short, + tau=1.0, # Default tau + only_own_memories=only_own_memories + ) + + return [candidates[i] for i in chosen_indices] + + |
