#!/usr/bin/env python3 """ Script to build user vectors for PersonaMem 32k dataset. 1. Load shared contexts. 2. Convert to ChatTurns. 3. Extract preferences -> MemoryCards. 4. Compute embeddings & item vectors. 5. Aggregate to user vectors. """ import sys import os import json import uuid import numpy as np from tqdm import tqdm from typing import List, Dict # Add src to sys.path sys.path.append(os.path.join(os.path.dirname(__file__), "../src")) from personalization.config.settings import load_local_models_config from personalization.config.registry import get_preference_extractor from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B from personalization.retrieval.preference_store.schemas import MemoryCard, ChatTurn, PreferenceList from personalization.data.personamem_loader import load_personamem_contexts_32k, PersonaMemContext from personalization.user_model.features import ItemProjection def context_to_chatturns(context: PersonaMemContext) -> List[ChatTurn]: turns = [] for i, msg in enumerate(context.messages): if isinstance(msg, dict): role = msg.get("role", "user") text = msg.get("content", "") elif isinstance(msg, str): # Fallback for string messages, assume user? or skip? # print(f"Warning: msg is string: {msg[:50]}...") role = "user" text = msg else: continue turns.append(ChatTurn( user_id=context.shared_context_id, # Use context id as user id equivalent for building vector session_id=context.shared_context_id, turn_id=i, role="user" if role == "user" else "assistant", text=text, timestamp=None, meta={} )) return turns def sliding_windows(seq: List, window_size: int, step: int): for i in range(0, len(seq), step): yield seq[i : i + window_size] def find_last_user_text(window: List[ChatTurn]) -> str: for t in reversed(window): if t.role == "user": return t.text return "" def main(): # Paths (adjust as needed, assume downloaded to data/raw_datasets/personamem) ctx_path = "data/raw_datasets/personamem/shared_contexts_32k.jsonl" item_proj_path = "data/corpora/item_projection.npz" output_vec_path = "data/personamem/user_vectors.npz" output_cards_path = "data/personamem/memory_cards.jsonl" # Ensure dirs os.makedirs(os.path.dirname(output_vec_path), exist_ok=True) if not os.path.exists(ctx_path): print(f"Error: Context file not found at {ctx_path}") return if not os.path.exists(item_proj_path): print(f"Error: Item projection not found at {item_proj_path}. Run build_item_space.py first.") return # Load Models print("Loading models...") cfg = load_local_models_config() # Explicitly use Qwen3Embedding8B embed_model = Qwen3Embedding8B.from_config(cfg) # Use registry for extractor (SFT model) extractor_name = "qwen3_0_6b_sft" try: extractor = get_preference_extractor(extractor_name) except: print(f"Fallback to rule extractor for {extractor_name} not found.") extractor = get_preference_extractor("rule") # Load Projection proj_data = np.load(item_proj_path) projection = ItemProjection(P=proj_data["P"], mean=proj_data["mean"]) # Load Contexts print("Loading contexts...") contexts = load_personamem_contexts_32k(ctx_path) print(f"Loaded {len(contexts)} contexts.") all_cards = [] # Process each context print("Extracting preferences...") # For demo speed, maybe limit? Or full run. Full run might take time. # Assuming batch processing or just loop for now. for ctx_id, ctx in tqdm(contexts.items()): turns = context_to_chatturns(ctx) # print(f"Context {ctx_id}: {len(turns)} turns") # Sliding window extraction for window in sliding_windows(turns, window_size=6, step=3): # Only extract if window has user turns if not any(t.role == "user" for t in window): continue try: prefs = extractor.extract_turn(window) # if prefs.preferences: # print(f" Found {len(prefs.preferences)} preferences") except Exception as e: print(f"Extraction failed: {e}") continue if not prefs.preferences: continue source_query = find_last_user_text(window) if not source_query: continue # Embed e_m = embed_model.encode([source_query], return_tensor=False)[0] e_m_np = np.array(e_m) v_m = projection.transform_vector(e_m_np) # Serialize note notes = [f"When {p.condition}, {p.action}." for p in prefs.preferences] note_text = " ".join(notes) card = MemoryCard( card_id=str(uuid.uuid4()), user_id=ctx_id, # persona_id/context_id source_session_id=ctx_id, source_turn_ids=[t.turn_id for t in window if t.role == "user"], raw_queries=[source_query], preference_list=prefs, note_text=note_text, embedding_e=e_m, kind="pref" ) all_cards.append(card) print(f"Extracted {len(all_cards)} memory cards.") # Build User Vectors print("Building user vectors...") z_by_user = {} # Group cards by user cards_by_user = {} for c in all_cards: if c.user_id not in cards_by_user: cards_by_user[c.user_id] = [] cards_by_user[c.user_id].append(c) for uid, u_cards in cards_by_user.items(): # Stack v_m: [M_u, k] V = np.stack([projection.transform_vector(np.array(c.embedding_e, dtype=np.float32)) for c in u_cards], axis=0) z = np.mean(V, axis=0) z_by_user[uid] = z # Save print(f"Saving {len(all_cards)} cards to {output_cards_path}...") with open(output_cards_path, "w", encoding="utf-8") as f: for c in all_cards: f.write(c.model_dump_json() + "\n") print(f"Saving user vectors to {output_vec_path}...") user_ids = list(z_by_user.keys()) Z = np.array([z_by_user[uid] for uid in user_ids], dtype=np.float32) np.savez(output_vec_path, user_ids=user_ids, Z=Z) print("Done.") if __name__ == "__main__": main()