summaryrefslogtreecommitdiff
path: root/scripts/personamem_build_user_vectors.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/personamem_build_user_vectors.py')
-rw-r--r--scripts/personamem_build_user_vectors.py193
1 files changed, 193 insertions, 0 deletions
diff --git a/scripts/personamem_build_user_vectors.py b/scripts/personamem_build_user_vectors.py
new file mode 100644
index 0000000..4719f9c
--- /dev/null
+++ b/scripts/personamem_build_user_vectors.py
@@ -0,0 +1,193 @@
+#!/usr/bin/env python3
+"""
+Script to build user vectors for PersonaMem 32k dataset.
+1. Load shared contexts.
+2. Convert to ChatTurns.
+3. Extract preferences -> MemoryCards.
+4. Compute embeddings & item vectors.
+5. Aggregate to user vectors.
+"""
+
+import sys
+import os
+import json
+import uuid
+import numpy as np
+from tqdm import tqdm
+from typing import List, Dict
+
+# Add src to sys.path
+sys.path.append(os.path.join(os.path.dirname(__file__), "../src"))
+
+from personalization.config.settings import load_local_models_config
+from personalization.config.registry import get_preference_extractor
+from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B
+from personalization.retrieval.preference_store.schemas import MemoryCard, ChatTurn, PreferenceList
+from personalization.data.personamem_loader import load_personamem_contexts_32k, PersonaMemContext
+from personalization.user_model.features import ItemProjection
+
+def context_to_chatturns(context: PersonaMemContext) -> List[ChatTurn]:
+ turns = []
+ for i, msg in enumerate(context.messages):
+ if isinstance(msg, dict):
+ role = msg.get("role", "user")
+ text = msg.get("content", "")
+ elif isinstance(msg, str):
+ # Fallback for string messages, assume user? or skip?
+ # print(f"Warning: msg is string: {msg[:50]}...")
+ role = "user"
+ text = msg
+ else:
+ continue
+
+ turns.append(ChatTurn(
+ user_id=context.shared_context_id, # Use context id as user id equivalent for building vector
+ session_id=context.shared_context_id,
+ turn_id=i,
+ role="user" if role == "user" else "assistant",
+ text=text,
+ timestamp=None,
+ meta={}
+ ))
+ return turns
+
+def sliding_windows(seq: List, window_size: int, step: int):
+ for i in range(0, len(seq), step):
+ yield seq[i : i + window_size]
+
+def find_last_user_text(window: List[ChatTurn]) -> str:
+ for t in reversed(window):
+ if t.role == "user":
+ return t.text
+ return ""
+
+def main():
+ # Paths (adjust as needed, assume downloaded to data/raw_datasets/personamem)
+ ctx_path = "data/raw_datasets/personamem/shared_contexts_32k.jsonl"
+ item_proj_path = "data/corpora/item_projection.npz"
+ output_vec_path = "data/personamem/user_vectors.npz"
+ output_cards_path = "data/personamem/memory_cards.jsonl"
+
+ # Ensure dirs
+ os.makedirs(os.path.dirname(output_vec_path), exist_ok=True)
+
+ if not os.path.exists(ctx_path):
+ print(f"Error: Context file not found at {ctx_path}")
+ return
+
+ if not os.path.exists(item_proj_path):
+ print(f"Error: Item projection not found at {item_proj_path}. Run build_item_space.py first.")
+ return
+
+ # Load Models
+ print("Loading models...")
+ cfg = load_local_models_config()
+ # Explicitly use Qwen3Embedding8B
+ embed_model = Qwen3Embedding8B.from_config(cfg)
+
+ # Use registry for extractor (SFT model)
+ extractor_name = "qwen3_0_6b_sft"
+ try:
+ extractor = get_preference_extractor(extractor_name)
+ except:
+ print(f"Fallback to rule extractor for {extractor_name} not found.")
+ extractor = get_preference_extractor("rule")
+
+ # Load Projection
+ proj_data = np.load(item_proj_path)
+ projection = ItemProjection(P=proj_data["P"], mean=proj_data["mean"])
+
+ # Load Contexts
+ print("Loading contexts...")
+ contexts = load_personamem_contexts_32k(ctx_path)
+ print(f"Loaded {len(contexts)} contexts.")
+
+ all_cards = []
+
+ # Process each context
+ print("Extracting preferences...")
+ # For demo speed, maybe limit? Or full run. Full run might take time.
+ # Assuming batch processing or just loop for now.
+
+ for ctx_id, ctx in tqdm(contexts.items()):
+ turns = context_to_chatturns(ctx)
+ # print(f"Context {ctx_id}: {len(turns)} turns")
+
+ # Sliding window extraction
+ for window in sliding_windows(turns, window_size=6, step=3):
+ # Only extract if window has user turns
+ if not any(t.role == "user" for t in window):
+ continue
+
+ try:
+ prefs = extractor.extract_turn(window)
+ # if prefs.preferences:
+ # print(f" Found {len(prefs.preferences)} preferences")
+ except Exception as e:
+ print(f"Extraction failed: {e}")
+ continue
+
+ if not prefs.preferences:
+ continue
+
+ source_query = find_last_user_text(window)
+ if not source_query:
+ continue
+
+ # Embed
+ e_m = embed_model.encode([source_query], return_tensor=False)[0]
+ e_m_np = np.array(e_m)
+ v_m = projection.transform_vector(e_m_np)
+
+ # Serialize note
+ notes = [f"When {p.condition}, {p.action}." for p in prefs.preferences]
+ note_text = " ".join(notes)
+
+ card = MemoryCard(
+ card_id=str(uuid.uuid4()),
+ user_id=ctx_id, # persona_id/context_id
+ source_session_id=ctx_id,
+ source_turn_ids=[t.turn_id for t in window if t.role == "user"],
+ raw_queries=[source_query],
+ preference_list=prefs,
+ note_text=note_text,
+ embedding_e=e_m,
+ kind="pref"
+ )
+ all_cards.append(card)
+
+ print(f"Extracted {len(all_cards)} memory cards.")
+
+ # Build User Vectors
+ print("Building user vectors...")
+ z_by_user = {}
+
+ # Group cards by user
+ cards_by_user = {}
+ for c in all_cards:
+ if c.user_id not in cards_by_user:
+ cards_by_user[c.user_id] = []
+ cards_by_user[c.user_id].append(c)
+
+ for uid, u_cards in cards_by_user.items():
+ # Stack v_m: [M_u, k]
+ V = np.stack([projection.transform_vector(np.array(c.embedding_e, dtype=np.float32)) for c in u_cards], axis=0)
+ z = np.mean(V, axis=0)
+ z_by_user[uid] = z
+
+ # Save
+ print(f"Saving {len(all_cards)} cards to {output_cards_path}...")
+ with open(output_cards_path, "w", encoding="utf-8") as f:
+ for c in all_cards:
+ f.write(c.model_dump_json() + "\n")
+
+ print(f"Saving user vectors to {output_vec_path}...")
+ user_ids = list(z_by_user.keys())
+ Z = np.array([z_by_user[uid] for uid in user_ids], dtype=np.float32)
+ np.savez(output_vec_path, user_ids=user_ids, Z=Z)
+
+ print("Done.")
+
+if __name__ == "__main__":
+ main()
+