diff options
Diffstat (limited to 'src/personalization/feedback/sampler.py')
| -rw-r--r-- | src/personalization/feedback/sampler.py | 109 |
1 files changed, 109 insertions, 0 deletions
diff --git a/src/personalization/feedback/sampler.py b/src/personalization/feedback/sampler.py new file mode 100644 index 0000000..9e26912 --- /dev/null +++ b/src/personalization/feedback/sampler.py @@ -0,0 +1,109 @@ +from typing import Iterable, List, Optional +import numpy as np +from tqdm import tqdm + +from personalization.retrieval.preference_store.schemas import ChatTurn, MemoryCard +from personalization.feedback.schemas import TurnSample +from personalization.retrieval.pipeline import retrieve_with_rerank +from personalization.models.llm.qwen_instruct import QwenInstruct +from personalization.models.embedding.base import EmbeddingModel +from personalization.models.reranker.base import Reranker +from personalization.user_model.tensor_store import UserTensorStore + +def build_turn_samples_from_sessions( + sessions: Iterable[List[ChatTurn]], + embed_model: EmbeddingModel, + llm: QwenInstruct, + reranker: Reranker, + memory_cards: List[MemoryCard], + memory_embeddings: np.ndarray, + user_store: UserTensorStore, + item_vectors: np.ndarray, + max_samples: Optional[int] = None, + topk_dense: int = 64, + topk_rerank: int = 3, +) -> List[TurnSample]: + samples = [] + + for turns in tqdm(sessions, desc="Building TurnSamples"): + if max_samples and len(samples) >= max_samples: + break + + # Ensure sorted by turn_id + sorted_turns = sorted(turns, key=lambda x: x.turn_id) + + # Iterate to find (q_t, a_t, q_{t+1}) + for i in range(len(sorted_turns)): + if max_samples and len(samples) >= max_samples: + break + + q_t = sorted_turns[i] + if q_t.role != "user": + continue + + # Find next user turn + # Also try to find assistant response in between + a_t_text = "" + q_t1 = None + + # Look ahead + for j in range(i + 1, len(sorted_turns)): + next_turn = sorted_turns[j] + if next_turn.role == "assistant" and not a_t_text: + a_t_text = next_turn.text + elif next_turn.role == "user": + q_t1 = next_turn + break + + if not q_t1: + # End of session or no subsequent user query + continue + + # We have q_t, a_t (optional but preferred), q_t1 + # If a_t is missing, we might skip or use empty string. + # For RL, we usually need the answer to evaluate quality. + # If dataset doesn't have assistant turns, we might need to generate one? + # For now, let's proceed even if a_t is empty, or maybe require it. + if not a_t_text: + # Try to use LLM to generate if needed, but for offline sampling + # from existing chats, we prefer existing answers. + # If using OASST1, it should have assistant turns. + pass + + # 3. Retrieve memories for q_t + memories_t = retrieve_with_rerank( + user_id=q_t.user_id, + query=q_t.text, + embed_model=embed_model, + reranker=reranker, + memory_cards=memory_cards, + memory_embeddings=memory_embeddings, + user_store=user_store, + item_vectors=item_vectors, + topk_dense=topk_dense, + topk_rerank=topk_rerank, + beta_long=0.0, + beta_short=0.0, + only_own_memories=True # Assume we want user specific memories + ) + + # 4. Precompute embeddings + # We can do this efficiently later or batch, but here per sample + e_q_t = embed_model.encode([q_t.text], return_tensor=False)[0] + e_q_t1 = embed_model.encode([q_t1.text], return_tensor=False)[0] + + sample = TurnSample( + user_id=q_t.user_id, + session_id=q_t.session_id, + turn_id=q_t.turn_id, + query_t=q_t.text, + answer_t=a_t_text, + query_t1=q_t1.text, + memories=memories_t, + query_embedding_t=np.array(e_q_t), + query_embedding_t1=np.array(e_q_t1) + ) + samples.append(sample) + + return samples + |
