summaryrefslogtreecommitdiff
path: root/scripts/day2_demo.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
commite43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (patch)
tree6ce8a00d2f8b9ebd83c894a27ea01ac50cfb2ff5 /scripts/day2_demo.py
Initial commit (clean history)HEADmain
Diffstat (limited to 'scripts/day2_demo.py')
-rw-r--r--scripts/day2_demo.py162
1 files changed, 162 insertions, 0 deletions
diff --git a/scripts/day2_demo.py b/scripts/day2_demo.py
new file mode 100644
index 0000000..ca81d99
--- /dev/null
+++ b/scripts/day2_demo.py
@@ -0,0 +1,162 @@
+#!/usr/bin/env python3
+"""
+Day 2 Demo: End-to-end Memory RAG with Reranker and (Shell) Personalization.
+"""
+
+import sys
+import os
+import numpy as np
+import torch
+from typing import List
+
+# Add src to sys.path
+sys.path.append(os.path.join(os.path.dirname(__file__), "../src"))
+
+from personalization.config.settings import load_local_models_config
+from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B
+from personalization.models.llm.qwen_instruct import QwenInstruct
+from personalization.models.reranker.qwen3_reranker import Qwen3Reranker
+from personalization.retrieval.preference_store.schemas import MemoryCard
+from personalization.user_model.tensor_store import UserTensorStore
+from personalization.retrieval.pipeline import retrieve_with_rerank
+
+def main():
+ # Paths
+ cards_path = "data/corpora/memory_cards.jsonl"
+ embs_path = "data/corpora/memory_embeddings.npy"
+ item_proj_path = "data/corpora/item_projection.npz"
+ user_store_path = "data/users/user_store.npz"
+
+ # 1. Load Data
+ print("Loading data stores...")
+ if not os.path.exists(cards_path) or not os.path.exists(embs_path):
+ print("Memory data missing. Run migrate_preferences.py")
+ sys.exit(1)
+
+ cards = []
+ with open(cards_path, "r") as f:
+ for line in f:
+ cards.append(MemoryCard.model_validate_json(line))
+
+ memory_embeddings = np.load(embs_path)
+
+ if not os.path.exists(item_proj_path):
+ print("Item projection missing. Run build_item_space.py")
+ sys.exit(1)
+
+ proj_data = np.load(item_proj_path)
+ item_vectors = proj_data["V"]
+
+ # 2. Load Models
+ print("Loading models...")
+ cfg = load_local_models_config()
+
+ embedder = Qwen3Embedding8B.from_config(cfg)
+ reranker = Qwen3Reranker.from_config(cfg)
+ llm = QwenInstruct.from_config(cfg)
+
+ # 3. Load User Store
+ # k = item_vectors.shape[1]
+ k = 256 # Hardcoded per config
+ user_store = UserTensorStore(k=k, path=user_store_path)
+
+ # --- CHECK 1: User Vector Similarity ---
+ print("\n--- CHECK 1: User Vector Similarity ---")
+ # Get 3 users with memories
+ valid_users = [uid for uid, state in user_store._states.items()
+ if np.linalg.norm(state.z_long) > 1e-6] # Only non-zero users
+
+ if len(valid_users) < 3:
+ print(f"Not enough users with memories found (found {len(valid_users)}). Skipping similarity check.")
+ else:
+ # Pick 3 random users
+ import random
+ selected_users = random.sample(valid_users, 3)
+ vectors = [user_store.get_state(uid).z_long for uid in selected_users]
+
+ # Calculate pairwise cosine similarity
+ def cos_sim(a, b):
+ norm_a = np.linalg.norm(a)
+ norm_b = np.linalg.norm(b)
+ if norm_a == 0 or norm_b == 0: return 0.0
+ return np.dot(a, b) / (norm_a * norm_b)
+
+ print(f"Selected Users: {selected_users}")
+ print(f"Sim(0, 1): {cos_sim(vectors[0], vectors[1]):.4f}")
+ print(f"Sim(0, 2): {cos_sim(vectors[0], vectors[2]):.4f}")
+ print(f"Sim(1, 2): {cos_sim(vectors[1], vectors[2]):.4f}")
+
+ # --- CHECK 2: Real User Retrieval ---
+ print("\n--- CHECK 2: Real User Retrieval ---")
+
+ if len(valid_users) > 0:
+ # Pick one user
+ target_user = valid_users[0]
+ # Find a query from this user?
+ # For now, let's use a generic query that might hit some tech preferences,
+ # or ideally find a query from the dataset if we had it loaded.
+ # Let's try a generic coding query since OASST1 has many.
+ query = "How do I write a Python function for fibonacci?"
+
+ print(f"User: {target_user}")
+ print(f"Query: {query}")
+
+ # 5. Retrieve Pipeline
+ print("\nRunning Retrieval Pipeline (GLOBAL search)...")
+ hits_global = retrieve_with_rerank(
+ user_id=target_user,
+ query=query,
+ embed_model=embedder,
+ reranker=reranker,
+ memory_cards=cards,
+ memory_embeddings=memory_embeddings,
+ user_store=user_store,
+ item_vectors=item_vectors,
+ topk_dense=64,
+ topk_rerank=3,
+ beta_long=0.0,
+ beta_short=0.0,
+ only_own_memories=False # Global search
+ )
+
+ print(f"\nTop {len(hits_global)} Memories (Global):")
+ for h in hits_global:
+ print(f" - [{h.kind}] {h.note_text} (User: {h.user_id})")
+
+ print("\nRunning Retrieval Pipeline (OWN memories only)...")
+ hits_own = retrieve_with_rerank(
+ user_id=target_user,
+ query=query,
+ embed_model=embedder,
+ reranker=reranker,
+ memory_cards=cards,
+ memory_embeddings=memory_embeddings,
+ user_store=user_store,
+ item_vectors=item_vectors,
+ topk_dense=64,
+ topk_rerank=3,
+ beta_long=0.0,
+ beta_short=0.0,
+ only_own_memories=True # Own search
+ )
+
+ print(f"\nTop {len(hits_own)} Memories (Own):")
+ notes = []
+ for h in hits_own:
+ print(f" - [{h.kind}] {h.note_text} (User: {h.user_id})")
+ notes.append(h.note_text)
+
+ # 6. Generate Answer (using OWN memories by default for demo)
+ print("\nGenerating Answer (using Own Memories)...")
+ history = [{"role": "user", "content": query}]
+ answer = llm.answer(history, notes)
+
+ print("-" * 40)
+ print(answer)
+ print("-" * 40)
+ else:
+ print("No valid users found for demo.")
+
+if __name__ == "__main__":
+ main()
+