diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2025-12-17 04:29:37 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2025-12-17 04:29:37 -0600 |
| commit | e43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (patch) | |
| tree | 6ce8a00d2f8b9ebd83c894a27ea01ac50cfb2ff5 /scripts/day2_demo.py | |
Diffstat (limited to 'scripts/day2_demo.py')
| -rw-r--r-- | scripts/day2_demo.py | 162 |
1 files changed, 162 insertions, 0 deletions
diff --git a/scripts/day2_demo.py b/scripts/day2_demo.py new file mode 100644 index 0000000..ca81d99 --- /dev/null +++ b/scripts/day2_demo.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python3 +""" +Day 2 Demo: End-to-end Memory RAG with Reranker and (Shell) Personalization. +""" + +import sys +import os +import numpy as np +import torch +from typing import List + +# Add src to sys.path +sys.path.append(os.path.join(os.path.dirname(__file__), "../src")) + +from personalization.config.settings import load_local_models_config +from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B +from personalization.models.llm.qwen_instruct import QwenInstruct +from personalization.models.reranker.qwen3_reranker import Qwen3Reranker +from personalization.retrieval.preference_store.schemas import MemoryCard +from personalization.user_model.tensor_store import UserTensorStore +from personalization.retrieval.pipeline import retrieve_with_rerank + +def main(): + # Paths + cards_path = "data/corpora/memory_cards.jsonl" + embs_path = "data/corpora/memory_embeddings.npy" + item_proj_path = "data/corpora/item_projection.npz" + user_store_path = "data/users/user_store.npz" + + # 1. Load Data + print("Loading data stores...") + if not os.path.exists(cards_path) or not os.path.exists(embs_path): + print("Memory data missing. Run migrate_preferences.py") + sys.exit(1) + + cards = [] + with open(cards_path, "r") as f: + for line in f: + cards.append(MemoryCard.model_validate_json(line)) + + memory_embeddings = np.load(embs_path) + + if not os.path.exists(item_proj_path): + print("Item projection missing. Run build_item_space.py") + sys.exit(1) + + proj_data = np.load(item_proj_path) + item_vectors = proj_data["V"] + + # 2. Load Models + print("Loading models...") + cfg = load_local_models_config() + + embedder = Qwen3Embedding8B.from_config(cfg) + reranker = Qwen3Reranker.from_config(cfg) + llm = QwenInstruct.from_config(cfg) + + # 3. Load User Store + # k = item_vectors.shape[1] + k = 256 # Hardcoded per config + user_store = UserTensorStore(k=k, path=user_store_path) + + # --- CHECK 1: User Vector Similarity --- + print("\n--- CHECK 1: User Vector Similarity ---") + # Get 3 users with memories + valid_users = [uid for uid, state in user_store._states.items() + if np.linalg.norm(state.z_long) > 1e-6] # Only non-zero users + + if len(valid_users) < 3: + print(f"Not enough users with memories found (found {len(valid_users)}). Skipping similarity check.") + else: + # Pick 3 random users + import random + selected_users = random.sample(valid_users, 3) + vectors = [user_store.get_state(uid).z_long for uid in selected_users] + + # Calculate pairwise cosine similarity + def cos_sim(a, b): + norm_a = np.linalg.norm(a) + norm_b = np.linalg.norm(b) + if norm_a == 0 or norm_b == 0: return 0.0 + return np.dot(a, b) / (norm_a * norm_b) + + print(f"Selected Users: {selected_users}") + print(f"Sim(0, 1): {cos_sim(vectors[0], vectors[1]):.4f}") + print(f"Sim(0, 2): {cos_sim(vectors[0], vectors[2]):.4f}") + print(f"Sim(1, 2): {cos_sim(vectors[1], vectors[2]):.4f}") + + # --- CHECK 2: Real User Retrieval --- + print("\n--- CHECK 2: Real User Retrieval ---") + + if len(valid_users) > 0: + # Pick one user + target_user = valid_users[0] + # Find a query from this user? + # For now, let's use a generic query that might hit some tech preferences, + # or ideally find a query from the dataset if we had it loaded. + # Let's try a generic coding query since OASST1 has many. + query = "How do I write a Python function for fibonacci?" + + print(f"User: {target_user}") + print(f"Query: {query}") + + # 5. Retrieve Pipeline + print("\nRunning Retrieval Pipeline (GLOBAL search)...") + hits_global = retrieve_with_rerank( + user_id=target_user, + query=query, + embed_model=embedder, + reranker=reranker, + memory_cards=cards, + memory_embeddings=memory_embeddings, + user_store=user_store, + item_vectors=item_vectors, + topk_dense=64, + topk_rerank=3, + beta_long=0.0, + beta_short=0.0, + only_own_memories=False # Global search + ) + + print(f"\nTop {len(hits_global)} Memories (Global):") + for h in hits_global: + print(f" - [{h.kind}] {h.note_text} (User: {h.user_id})") + + print("\nRunning Retrieval Pipeline (OWN memories only)...") + hits_own = retrieve_with_rerank( + user_id=target_user, + query=query, + embed_model=embedder, + reranker=reranker, + memory_cards=cards, + memory_embeddings=memory_embeddings, + user_store=user_store, + item_vectors=item_vectors, + topk_dense=64, + topk_rerank=3, + beta_long=0.0, + beta_short=0.0, + only_own_memories=True # Own search + ) + + print(f"\nTop {len(hits_own)} Memories (Own):") + notes = [] + for h in hits_own: + print(f" - [{h.kind}] {h.note_text} (User: {h.user_id})") + notes.append(h.note_text) + + # 6. Generate Answer (using OWN memories by default for demo) + print("\nGenerating Answer (using Own Memories)...") + history = [{"role": "user", "content": query}] + answer = llm.answer(history, notes) + + print("-" * 40) + print(answer) + print("-" * 40) + else: + print("No valid users found for demo.") + +if __name__ == "__main__": + main() + |
