diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2025-12-17 04:29:37 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2025-12-17 04:29:37 -0600 |
| commit | e43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (patch) | |
| tree | 6ce8a00d2f8b9ebd83c894a27ea01ac50cfb2ff5 /scripts/personamem_build_user_vectors.py | |
Diffstat (limited to 'scripts/personamem_build_user_vectors.py')
| -rw-r--r-- | scripts/personamem_build_user_vectors.py | 193 |
1 files changed, 193 insertions, 0 deletions
diff --git a/scripts/personamem_build_user_vectors.py b/scripts/personamem_build_user_vectors.py new file mode 100644 index 0000000..4719f9c --- /dev/null +++ b/scripts/personamem_build_user_vectors.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 +""" +Script to build user vectors for PersonaMem 32k dataset. +1. Load shared contexts. +2. Convert to ChatTurns. +3. Extract preferences -> MemoryCards. +4. Compute embeddings & item vectors. +5. Aggregate to user vectors. +""" + +import sys +import os +import json +import uuid +import numpy as np +from tqdm import tqdm +from typing import List, Dict + +# Add src to sys.path +sys.path.append(os.path.join(os.path.dirname(__file__), "../src")) + +from personalization.config.settings import load_local_models_config +from personalization.config.registry import get_preference_extractor +from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B +from personalization.retrieval.preference_store.schemas import MemoryCard, ChatTurn, PreferenceList +from personalization.data.personamem_loader import load_personamem_contexts_32k, PersonaMemContext +from personalization.user_model.features import ItemProjection + +def context_to_chatturns(context: PersonaMemContext) -> List[ChatTurn]: + turns = [] + for i, msg in enumerate(context.messages): + if isinstance(msg, dict): + role = msg.get("role", "user") + text = msg.get("content", "") + elif isinstance(msg, str): + # Fallback for string messages, assume user? or skip? + # print(f"Warning: msg is string: {msg[:50]}...") + role = "user" + text = msg + else: + continue + + turns.append(ChatTurn( + user_id=context.shared_context_id, # Use context id as user id equivalent for building vector + session_id=context.shared_context_id, + turn_id=i, + role="user" if role == "user" else "assistant", + text=text, + timestamp=None, + meta={} + )) + return turns + +def sliding_windows(seq: List, window_size: int, step: int): + for i in range(0, len(seq), step): + yield seq[i : i + window_size] + +def find_last_user_text(window: List[ChatTurn]) -> str: + for t in reversed(window): + if t.role == "user": + return t.text + return "" + +def main(): + # Paths (adjust as needed, assume downloaded to data/raw_datasets/personamem) + ctx_path = "data/raw_datasets/personamem/shared_contexts_32k.jsonl" + item_proj_path = "data/corpora/item_projection.npz" + output_vec_path = "data/personamem/user_vectors.npz" + output_cards_path = "data/personamem/memory_cards.jsonl" + + # Ensure dirs + os.makedirs(os.path.dirname(output_vec_path), exist_ok=True) + + if not os.path.exists(ctx_path): + print(f"Error: Context file not found at {ctx_path}") + return + + if not os.path.exists(item_proj_path): + print(f"Error: Item projection not found at {item_proj_path}. Run build_item_space.py first.") + return + + # Load Models + print("Loading models...") + cfg = load_local_models_config() + # Explicitly use Qwen3Embedding8B + embed_model = Qwen3Embedding8B.from_config(cfg) + + # Use registry for extractor (SFT model) + extractor_name = "qwen3_0_6b_sft" + try: + extractor = get_preference_extractor(extractor_name) + except: + print(f"Fallback to rule extractor for {extractor_name} not found.") + extractor = get_preference_extractor("rule") + + # Load Projection + proj_data = np.load(item_proj_path) + projection = ItemProjection(P=proj_data["P"], mean=proj_data["mean"]) + + # Load Contexts + print("Loading contexts...") + contexts = load_personamem_contexts_32k(ctx_path) + print(f"Loaded {len(contexts)} contexts.") + + all_cards = [] + + # Process each context + print("Extracting preferences...") + # For demo speed, maybe limit? Or full run. Full run might take time. + # Assuming batch processing or just loop for now. + + for ctx_id, ctx in tqdm(contexts.items()): + turns = context_to_chatturns(ctx) + # print(f"Context {ctx_id}: {len(turns)} turns") + + # Sliding window extraction + for window in sliding_windows(turns, window_size=6, step=3): + # Only extract if window has user turns + if not any(t.role == "user" for t in window): + continue + + try: + prefs = extractor.extract_turn(window) + # if prefs.preferences: + # print(f" Found {len(prefs.preferences)} preferences") + except Exception as e: + print(f"Extraction failed: {e}") + continue + + if not prefs.preferences: + continue + + source_query = find_last_user_text(window) + if not source_query: + continue + + # Embed + e_m = embed_model.encode([source_query], return_tensor=False)[0] + e_m_np = np.array(e_m) + v_m = projection.transform_vector(e_m_np) + + # Serialize note + notes = [f"When {p.condition}, {p.action}." for p in prefs.preferences] + note_text = " ".join(notes) + + card = MemoryCard( + card_id=str(uuid.uuid4()), + user_id=ctx_id, # persona_id/context_id + source_session_id=ctx_id, + source_turn_ids=[t.turn_id for t in window if t.role == "user"], + raw_queries=[source_query], + preference_list=prefs, + note_text=note_text, + embedding_e=e_m, + kind="pref" + ) + all_cards.append(card) + + print(f"Extracted {len(all_cards)} memory cards.") + + # Build User Vectors + print("Building user vectors...") + z_by_user = {} + + # Group cards by user + cards_by_user = {} + for c in all_cards: + if c.user_id not in cards_by_user: + cards_by_user[c.user_id] = [] + cards_by_user[c.user_id].append(c) + + for uid, u_cards in cards_by_user.items(): + # Stack v_m: [M_u, k] + V = np.stack([projection.transform_vector(np.array(c.embedding_e, dtype=np.float32)) for c in u_cards], axis=0) + z = np.mean(V, axis=0) + z_by_user[uid] = z + + # Save + print(f"Saving {len(all_cards)} cards to {output_cards_path}...") + with open(output_cards_path, "w", encoding="utf-8") as f: + for c in all_cards: + f.write(c.model_dump_json() + "\n") + + print(f"Saving user vectors to {output_vec_path}...") + user_ids = list(z_by_user.keys()) + Z = np.array([z_by_user[uid] for uid in user_ids], dtype=np.float32) + np.savez(output_vec_path, user_ids=user_ids, Z=Z) + + print("Done.") + +if __name__ == "__main__": + main() + |
