diff options
Diffstat (limited to 'scripts/personamem_eval_base_vs_ours.py')
| -rw-r--r-- | scripts/personamem_eval_base_vs_ours.py | 299 |
1 files changed, 299 insertions, 0 deletions
diff --git a/scripts/personamem_eval_base_vs_ours.py b/scripts/personamem_eval_base_vs_ours.py new file mode 100644 index 0000000..e319765 --- /dev/null +++ b/scripts/personamem_eval_base_vs_ours.py @@ -0,0 +1,299 @@ +#!/usr/bin/env python3 +""" +Evaluation script for PersonaMem task: Base vs Ours (User Vector). +Metric: Accuracy (Top-1 correct option) +""" + +import sys +import os +import numpy as np +from tqdm import tqdm +from typing import List + +# Add src to sys.path +sys.path.append(os.path.join(os.path.dirname(__file__), "../src")) + +from personalization.config.settings import load_local_models_config +from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B +from personalization.data.personamem_loader import load_personamem_questions_32k +from personalization.user_model.features import ItemProjection +from personalization.retrieval.preference_store.schemas import MemoryCard +from personalization.user_model.tensor_store import UserState + +def cosine_sim_matrix(a: np.ndarray, b: np.ndarray) -> np.ndarray: + # a: [d] or [N, d], b: [d] or [M, d] + norm_a = np.linalg.norm(a, axis=-1, keepdims=True) + norm_b = np.linalg.norm(b, axis=-1, keepdims=True) + + # Ensure correct shapes for broadcasting + # If a is 1D [d], dot gives 1D. If a is 2D [N, d], dot gives 2D. + # b is typically [M, d] (memories or options) + + dot = np.dot(b, a.T) + denom = np.dot(norm_b, norm_a.T) + 1e-8 + + # Flatten if needed + sim = dot / denom + if sim.ndim == 2 and sim.shape[1] == 1: + return sim.flatten() + return sim + +def dense_retrieval( + e_q: np.ndarray, + memory_embeddings: np.ndarray, + topk: int = 5 +) -> np.ndarray: + """Returns topk indices of memories most similar to query.""" + if memory_embeddings.shape[0] == 0: + return np.array([], dtype=int) + + sims = cosine_sim_matrix(e_q, memory_embeddings) + # Get topk + k = min(topk, len(sims)) + idx = np.argsort(sims)[-k:][::-1] + return idx + +def policy_retrieval( + e_q: np.ndarray, + memory_embeddings: np.ndarray, + item_vectors: np.ndarray, + z_user: np.ndarray, + item_proj: ItemProjection, + topk_dense: int = 20, + topk_final: int = 5, + beta: float = 0.2 +) -> np.ndarray: + """ + Simulates retrieve_with_policy: + 1. Dense retrieval (topk_dense) + 2. Policy Scoring: s = s_base + beta * (z . v_m) + 3. Select topk_final + """ + if memory_embeddings.shape[0] == 0: + return np.array([], dtype=int) + + # 1. Dense Candidate Generation + dense_idx = dense_retrieval(e_q, memory_embeddings, topk=topk_dense) + if len(dense_idx) == 0: + return np.array([], dtype=int) + + candidates_e = memory_embeddings[dense_idx] + candidates_v = item_vectors[dense_idx] + + # 2. Base Scores (Sim(q, m)) + base_scores = cosine_sim_matrix(e_q, candidates_e) + + # 3. Policy Bonus + # z: [k], v: [K, k] + bonus = np.dot(candidates_v, z_user) + + total_scores = base_scores + beta * bonus + + # 4. Final Selection + k = min(topk_final, len(total_scores)) + local_top_idx = np.argsort(total_scores)[-k:][::-1] + + # Map back to global indices + return dense_idx[local_top_idx] + +def score_rag( + e_q: np.ndarray, + e_opts: np.ndarray, + retrieved_embeddings: np.ndarray +) -> np.ndarray: + """ + Computes score for each option based on query match AND memory match. + Score = Sim(Q, Opt) + Mean(Sim(Memories, Opt)) + """ + # Base: Query-Option similarity + # e_q: [d], e_opts: [4, d] -> [4] + s_q_opt = cosine_sim_matrix(e_q, e_opts) + + if len(retrieved_embeddings) == 0: + return s_q_opt + + # Memory-Option similarity + # retrieved: [K, d], e_opts: [4, d] + # We want for each option, the average similarity to retrieved memories. + # sim_matrix: [4, K] (options x memories) - check implementation of cosine_sim_matrix + # cosine_sim_matrix(a, b) does dot(b, a.T). + # Let a=retrieved [K, d], b=e_opts [4, d]. Result: [4, K] + + s_mem_opt = cosine_sim_matrix(retrieved_embeddings, e_opts) + + # Max pooling or Mean pooling over memories + # Usually max is better for "if any memory supports this option" + s_rag = np.max(s_mem_opt, axis=1) # [4] + + # Combine + return s_q_opt + s_rag + +def main(): + # Paths + q_path = "data/raw_datasets/personamem/questions_32k.csv" + vec_path = "data/personamem/user_vectors.npz" + cards_path = "data/personamem/memory_cards.jsonl" # Generated by builder + item_proj_path = "data/corpora/item_projection.npz" + + if not os.path.exists(q_path) or not os.path.exists(vec_path) or not os.path.exists(cards_path): + print("Data missing. Run personamem_build_user_vectors.py first.") + sys.exit(1) + + print("Loading resources...") + cfg = load_local_models_config() + embed_model = Qwen3Embedding8B.from_config(cfg) + + proj_data = np.load(item_proj_path) + item_proj = ItemProjection(P=proj_data["P"], mean=proj_data["mean"]) + + # Load User Vectors + uv_data = np.load(vec_path, allow_pickle=True) + user_ids = uv_data["user_ids"] + Z = uv_data["Z"] + user_vector_map = {uid: Z[i] for i, uid in enumerate(user_ids)} + print(f"Loaded {len(user_vector_map)} user vectors.") + + # Load Memory Cards & Embeddings + print("Loading memory store...") + cards_by_user = {} + embs_by_user = {} + vecs_by_user = {} # v vectors + + # We need to load all cards to build per-user indices + # This might be slow if file is huge, but 32k dataset usually produces ~100k cards? + # Builder output: "Extracted 321 memory cards" (from small sample log). + # Let's assume it fits in memory. + + with open(cards_path, "r") as f: + for line in f: + card = MemoryCard.model_validate_json(line) + uid = card.user_id + if uid not in cards_by_user: + cards_by_user[uid] = [] + embs_by_user[uid] = [] + + cards_by_user[uid].append(card) + embs_by_user[uid].append(card.embedding_e) + + # Convert lists to numpy arrays + for uid in embs_by_user: + E = np.array(embs_by_user[uid], dtype=np.float32) + embs_by_user[uid] = E + # Compute V on the fly or load if saved (builder didn't save V in separate file, but we can project) + vecs_by_user[uid] = item_proj.transform_embeddings(E) + + print(f"Loaded memories for {len(cards_by_user)} users.") + + # Load Questions + questions = load_personamem_questions_32k(q_path) + print(f"Loaded {len(questions)} questions.") + + correct_base = [] + correct_ours = [] + + # Hyperparams + betas = [0.0, 1.0] # Sanity check: 0.0 should match Base RAG (if Ours RAG logic aligns when beta=0, wait, Ours RAG uses Policy Retrieval. Beta=0 in Policy Retrieval means Dense Retrieval order. So Ours RAG (beta=0) == Base RAG? Ideally yes, if topk_dense is large enough to contain base_topk) + + # Actually, Ours RAG pipeline: + # 1. Dense (top20) -> 2. Re-rank (Base + beta*Bonus) -> 3. Top5 + # Base RAG pipeline: + # 1. Dense (top5) + + # If beta=0, Ours RAG re-ranking is based on Base Score (Sim(q,m)). + # Since Dense Retrieval already sorts by Sim(q,m), re-ranking by Sim(q,m) keeps order. + # So if topk_dense >= topk_final, Ours (beta=0) should pick same top-5 as Base RAG. + + for beta in betas: + print(f"\nEvaluating with RAG (beta={beta})...") + + correct_base = [] + correct_ours = [] + + case_count = 0 + + for q in tqdm(questions): + target_id = q.shared_context_id + + # Skip if no vector or no memories + if target_id not in user_vector_map or target_id not in embs_by_user: + continue + + z_user = user_vector_map[target_id] + mem_E = embs_by_user[target_id] + mem_V = vecs_by_user[target_id] + + # Embed Query + e_q = embed_model.encode([q.user_question_or_message], return_tensor=False)[0] + e_q = np.array(e_q, dtype=np.float32) + + # Embed Options + if not q.all_options: + continue + e_opts = embed_model.encode(q.all_options, return_tensor=False) + e_opts = np.array(e_opts, dtype=np.float32) + + # --- BASE (No RAG) --- + s_base = score_rag(e_q, e_opts, np.array([])) + + # --- OURS (Personalized RAG) --- + ours_idx = policy_retrieval(e_q, mem_E, mem_V, z_user, item_proj, topk_dense=20, topk_final=5, beta=beta) + ours_mem_E = mem_E[ours_idx] + s_ours = score_rag(e_q, e_opts, ours_mem_E) + + pred_base = int(np.argmax(s_base)) + pred_ours = int(np.argmax(s_ours)) + + is_correct = (pred_ours == q.correct_index) + base_correct = (pred_base == q.correct_index) + + # Detailed Case Print (Task A) + # Print only when Beta > 0 (to avoid duplicate logs) and when they disagree + if beta > 0.0 and pred_base != pred_ours and case_count < 5: + case_count += 1 + + # Reconstruct memory text (need to find card in list) + # Optimization: Create a map or just linear search in cards_by_user[target_id] + user_cards = cards_by_user[target_id] + # ours_idx are indices into mem_E which corresponds to cards_by_user list order + retrieved_notes = [user_cards[i].note_text for i in ours_idx] + + print(f"\n" + "="*60) + print(f"[CASE ANALYSIS] QID: {q.question_id}") + print(f"User Question: {q.user_question_or_message}") + print(f"Correct Option ({q.correct_index}): {q.all_options[q.correct_index]}") + print("-" * 30) + print(f"Base Pred ({pred_base}): {q.all_options[pred_base]} [{'CORRECT' if base_correct else 'WRONG'}]") + print(f"Ours Pred ({pred_ours}): {q.all_options[pred_ours]} [{'CORRECT' if is_correct else 'WRONG'}]") + print("-" * 30) + print(f"Retrieved Memories (Top 3 of {len(retrieved_notes)}):") + for note in retrieved_notes[:3]: + print(f" - {note}") + print("-" * 30) + print(f"Scores Base: {s_base}") + print(f"Scores Ours: {s_ours}") + print("="*60 + "\n") + + if q.correct_index != -1: + correct_base.append(1 if pred_base == q.correct_index else 0) + correct_ours.append(1 if pred_ours == q.correct_index else 0) + + if not correct_base: + print("No valid evaluation samples processed.") + continue + + acc_base = np.mean(correct_base) + acc_ours = np.mean(correct_ours) + + # Win rate + wins = [1 if (c_o == 1 and c_b == 0) else 0 for c_o, c_b in zip(correct_ours, correct_base)] + win_rate = np.mean(wins) + + print(f"\n--- Results (Beta={beta}) ---") + print(f"Total Samples: {len(correct_base)}") + print(f"Accuracy (Base No-RAG): {acc_base:.4f}") + print(f"Accuracy (Ours RAG): {acc_ours:.4f}") + print(f"Win Rate: {win_rate:.4f}") + +if __name__ == "__main__": + main() + |
