summaryrefslogtreecommitdiff
path: root/scripts/personamem_eval_base_vs_ours.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
commite43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (patch)
tree6ce8a00d2f8b9ebd83c894a27ea01ac50cfb2ff5 /scripts/personamem_eval_base_vs_ours.py
Initial commit (clean history)HEADmain
Diffstat (limited to 'scripts/personamem_eval_base_vs_ours.py')
-rw-r--r--scripts/personamem_eval_base_vs_ours.py299
1 files changed, 299 insertions, 0 deletions
diff --git a/scripts/personamem_eval_base_vs_ours.py b/scripts/personamem_eval_base_vs_ours.py
new file mode 100644
index 0000000..e319765
--- /dev/null
+++ b/scripts/personamem_eval_base_vs_ours.py
@@ -0,0 +1,299 @@
+#!/usr/bin/env python3
+"""
+Evaluation script for PersonaMem task: Base vs Ours (User Vector).
+Metric: Accuracy (Top-1 correct option)
+"""
+
+import sys
+import os
+import numpy as np
+from tqdm import tqdm
+from typing import List
+
+# Add src to sys.path
+sys.path.append(os.path.join(os.path.dirname(__file__), "../src"))
+
+from personalization.config.settings import load_local_models_config
+from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B
+from personalization.data.personamem_loader import load_personamem_questions_32k
+from personalization.user_model.features import ItemProjection
+from personalization.retrieval.preference_store.schemas import MemoryCard
+from personalization.user_model.tensor_store import UserState
+
+def cosine_sim_matrix(a: np.ndarray, b: np.ndarray) -> np.ndarray:
+ # a: [d] or [N, d], b: [d] or [M, d]
+ norm_a = np.linalg.norm(a, axis=-1, keepdims=True)
+ norm_b = np.linalg.norm(b, axis=-1, keepdims=True)
+
+ # Ensure correct shapes for broadcasting
+ # If a is 1D [d], dot gives 1D. If a is 2D [N, d], dot gives 2D.
+ # b is typically [M, d] (memories or options)
+
+ dot = np.dot(b, a.T)
+ denom = np.dot(norm_b, norm_a.T) + 1e-8
+
+ # Flatten if needed
+ sim = dot / denom
+ if sim.ndim == 2 and sim.shape[1] == 1:
+ return sim.flatten()
+ return sim
+
+def dense_retrieval(
+ e_q: np.ndarray,
+ memory_embeddings: np.ndarray,
+ topk: int = 5
+) -> np.ndarray:
+ """Returns topk indices of memories most similar to query."""
+ if memory_embeddings.shape[0] == 0:
+ return np.array([], dtype=int)
+
+ sims = cosine_sim_matrix(e_q, memory_embeddings)
+ # Get topk
+ k = min(topk, len(sims))
+ idx = np.argsort(sims)[-k:][::-1]
+ return idx
+
+def policy_retrieval(
+ e_q: np.ndarray,
+ memory_embeddings: np.ndarray,
+ item_vectors: np.ndarray,
+ z_user: np.ndarray,
+ item_proj: ItemProjection,
+ topk_dense: int = 20,
+ topk_final: int = 5,
+ beta: float = 0.2
+) -> np.ndarray:
+ """
+ Simulates retrieve_with_policy:
+ 1. Dense retrieval (topk_dense)
+ 2. Policy Scoring: s = s_base + beta * (z . v_m)
+ 3. Select topk_final
+ """
+ if memory_embeddings.shape[0] == 0:
+ return np.array([], dtype=int)
+
+ # 1. Dense Candidate Generation
+ dense_idx = dense_retrieval(e_q, memory_embeddings, topk=topk_dense)
+ if len(dense_idx) == 0:
+ return np.array([], dtype=int)
+
+ candidates_e = memory_embeddings[dense_idx]
+ candidates_v = item_vectors[dense_idx]
+
+ # 2. Base Scores (Sim(q, m))
+ base_scores = cosine_sim_matrix(e_q, candidates_e)
+
+ # 3. Policy Bonus
+ # z: [k], v: [K, k]
+ bonus = np.dot(candidates_v, z_user)
+
+ total_scores = base_scores + beta * bonus
+
+ # 4. Final Selection
+ k = min(topk_final, len(total_scores))
+ local_top_idx = np.argsort(total_scores)[-k:][::-1]
+
+ # Map back to global indices
+ return dense_idx[local_top_idx]
+
+def score_rag(
+ e_q: np.ndarray,
+ e_opts: np.ndarray,
+ retrieved_embeddings: np.ndarray
+) -> np.ndarray:
+ """
+ Computes score for each option based on query match AND memory match.
+ Score = Sim(Q, Opt) + Mean(Sim(Memories, Opt))
+ """
+ # Base: Query-Option similarity
+ # e_q: [d], e_opts: [4, d] -> [4]
+ s_q_opt = cosine_sim_matrix(e_q, e_opts)
+
+ if len(retrieved_embeddings) == 0:
+ return s_q_opt
+
+ # Memory-Option similarity
+ # retrieved: [K, d], e_opts: [4, d]
+ # We want for each option, the average similarity to retrieved memories.
+ # sim_matrix: [4, K] (options x memories) - check implementation of cosine_sim_matrix
+ # cosine_sim_matrix(a, b) does dot(b, a.T).
+ # Let a=retrieved [K, d], b=e_opts [4, d]. Result: [4, K]
+
+ s_mem_opt = cosine_sim_matrix(retrieved_embeddings, e_opts)
+
+ # Max pooling or Mean pooling over memories
+ # Usually max is better for "if any memory supports this option"
+ s_rag = np.max(s_mem_opt, axis=1) # [4]
+
+ # Combine
+ return s_q_opt + s_rag
+
+def main():
+ # Paths
+ q_path = "data/raw_datasets/personamem/questions_32k.csv"
+ vec_path = "data/personamem/user_vectors.npz"
+ cards_path = "data/personamem/memory_cards.jsonl" # Generated by builder
+ item_proj_path = "data/corpora/item_projection.npz"
+
+ if not os.path.exists(q_path) or not os.path.exists(vec_path) or not os.path.exists(cards_path):
+ print("Data missing. Run personamem_build_user_vectors.py first.")
+ sys.exit(1)
+
+ print("Loading resources...")
+ cfg = load_local_models_config()
+ embed_model = Qwen3Embedding8B.from_config(cfg)
+
+ proj_data = np.load(item_proj_path)
+ item_proj = ItemProjection(P=proj_data["P"], mean=proj_data["mean"])
+
+ # Load User Vectors
+ uv_data = np.load(vec_path, allow_pickle=True)
+ user_ids = uv_data["user_ids"]
+ Z = uv_data["Z"]
+ user_vector_map = {uid: Z[i] for i, uid in enumerate(user_ids)}
+ print(f"Loaded {len(user_vector_map)} user vectors.")
+
+ # Load Memory Cards & Embeddings
+ print("Loading memory store...")
+ cards_by_user = {}
+ embs_by_user = {}
+ vecs_by_user = {} # v vectors
+
+ # We need to load all cards to build per-user indices
+ # This might be slow if file is huge, but 32k dataset usually produces ~100k cards?
+ # Builder output: "Extracted 321 memory cards" (from small sample log).
+ # Let's assume it fits in memory.
+
+ with open(cards_path, "r") as f:
+ for line in f:
+ card = MemoryCard.model_validate_json(line)
+ uid = card.user_id
+ if uid not in cards_by_user:
+ cards_by_user[uid] = []
+ embs_by_user[uid] = []
+
+ cards_by_user[uid].append(card)
+ embs_by_user[uid].append(card.embedding_e)
+
+ # Convert lists to numpy arrays
+ for uid in embs_by_user:
+ E = np.array(embs_by_user[uid], dtype=np.float32)
+ embs_by_user[uid] = E
+ # Compute V on the fly or load if saved (builder didn't save V in separate file, but we can project)
+ vecs_by_user[uid] = item_proj.transform_embeddings(E)
+
+ print(f"Loaded memories for {len(cards_by_user)} users.")
+
+ # Load Questions
+ questions = load_personamem_questions_32k(q_path)
+ print(f"Loaded {len(questions)} questions.")
+
+ correct_base = []
+ correct_ours = []
+
+ # Hyperparams
+ betas = [0.0, 1.0] # Sanity check: 0.0 should match Base RAG (if Ours RAG logic aligns when beta=0, wait, Ours RAG uses Policy Retrieval. Beta=0 in Policy Retrieval means Dense Retrieval order. So Ours RAG (beta=0) == Base RAG? Ideally yes, if topk_dense is large enough to contain base_topk)
+
+ # Actually, Ours RAG pipeline:
+ # 1. Dense (top20) -> 2. Re-rank (Base + beta*Bonus) -> 3. Top5
+ # Base RAG pipeline:
+ # 1. Dense (top5)
+
+ # If beta=0, Ours RAG re-ranking is based on Base Score (Sim(q,m)).
+ # Since Dense Retrieval already sorts by Sim(q,m), re-ranking by Sim(q,m) keeps order.
+ # So if topk_dense >= topk_final, Ours (beta=0) should pick same top-5 as Base RAG.
+
+ for beta in betas:
+ print(f"\nEvaluating with RAG (beta={beta})...")
+
+ correct_base = []
+ correct_ours = []
+
+ case_count = 0
+
+ for q in tqdm(questions):
+ target_id = q.shared_context_id
+
+ # Skip if no vector or no memories
+ if target_id not in user_vector_map or target_id not in embs_by_user:
+ continue
+
+ z_user = user_vector_map[target_id]
+ mem_E = embs_by_user[target_id]
+ mem_V = vecs_by_user[target_id]
+
+ # Embed Query
+ e_q = embed_model.encode([q.user_question_or_message], return_tensor=False)[0]
+ e_q = np.array(e_q, dtype=np.float32)
+
+ # Embed Options
+ if not q.all_options:
+ continue
+ e_opts = embed_model.encode(q.all_options, return_tensor=False)
+ e_opts = np.array(e_opts, dtype=np.float32)
+
+ # --- BASE (No RAG) ---
+ s_base = score_rag(e_q, e_opts, np.array([]))
+
+ # --- OURS (Personalized RAG) ---
+ ours_idx = policy_retrieval(e_q, mem_E, mem_V, z_user, item_proj, topk_dense=20, topk_final=5, beta=beta)
+ ours_mem_E = mem_E[ours_idx]
+ s_ours = score_rag(e_q, e_opts, ours_mem_E)
+
+ pred_base = int(np.argmax(s_base))
+ pred_ours = int(np.argmax(s_ours))
+
+ is_correct = (pred_ours == q.correct_index)
+ base_correct = (pred_base == q.correct_index)
+
+ # Detailed Case Print (Task A)
+ # Print only when Beta > 0 (to avoid duplicate logs) and when they disagree
+ if beta > 0.0 and pred_base != pred_ours and case_count < 5:
+ case_count += 1
+
+ # Reconstruct memory text (need to find card in list)
+ # Optimization: Create a map or just linear search in cards_by_user[target_id]
+ user_cards = cards_by_user[target_id]
+ # ours_idx are indices into mem_E which corresponds to cards_by_user list order
+ retrieved_notes = [user_cards[i].note_text for i in ours_idx]
+
+ print(f"\n" + "="*60)
+ print(f"[CASE ANALYSIS] QID: {q.question_id}")
+ print(f"User Question: {q.user_question_or_message}")
+ print(f"Correct Option ({q.correct_index}): {q.all_options[q.correct_index]}")
+ print("-" * 30)
+ print(f"Base Pred ({pred_base}): {q.all_options[pred_base]} [{'CORRECT' if base_correct else 'WRONG'}]")
+ print(f"Ours Pred ({pred_ours}): {q.all_options[pred_ours]} [{'CORRECT' if is_correct else 'WRONG'}]")
+ print("-" * 30)
+ print(f"Retrieved Memories (Top 3 of {len(retrieved_notes)}):")
+ for note in retrieved_notes[:3]:
+ print(f" - {note}")
+ print("-" * 30)
+ print(f"Scores Base: {s_base}")
+ print(f"Scores Ours: {s_ours}")
+ print("="*60 + "\n")
+
+ if q.correct_index != -1:
+ correct_base.append(1 if pred_base == q.correct_index else 0)
+ correct_ours.append(1 if pred_ours == q.correct_index else 0)
+
+ if not correct_base:
+ print("No valid evaluation samples processed.")
+ continue
+
+ acc_base = np.mean(correct_base)
+ acc_ours = np.mean(correct_ours)
+
+ # Win rate
+ wins = [1 if (c_o == 1 and c_b == 0) else 0 for c_o, c_b in zip(correct_ours, correct_base)]
+ win_rate = np.mean(wins)
+
+ print(f"\n--- Results (Beta={beta}) ---")
+ print(f"Total Samples: {len(correct_base)}")
+ print(f"Accuracy (Base No-RAG): {acc_base:.4f}")
+ print(f"Accuracy (Ours RAG): {acc_ours:.4f}")
+ print(f"Win Rate: {win_rate:.4f}")
+
+if __name__ == "__main__":
+ main()
+