#!/usr/bin/env python3 """ Script to initialize User States (z_long) from Memory Embeddings. """ import sys import os import numpy as np import json from collections import defaultdict # Add src to sys.path sys.path.append(os.path.join(os.path.dirname(__file__), "../src")) from personalization.user_model.tensor_store import UserTensorStore, UserState from personalization.retrieval.preference_store.schemas import MemoryCard def main(): cards_path = "data/corpora/memory_cards.jsonl" item_proj_path = "data/corpora/item_projection.npz" user_store_path = "data/users/user_store.npz" # Ensure user dir os.makedirs(os.path.dirname(user_store_path), exist_ok=True) # 1. Load data print("Loading memory cards...") cards = [] if os.path.exists(cards_path): with open(cards_path, "r") as f: for line in f: cards.append(MemoryCard.model_validate_json(line)) else: print("No memory cards found. Exiting.") return print("Loading item projection V...") if not os.path.exists(item_proj_path): print("Item projection not found. Run build_item_space.py first.") return proj_data = np.load(item_proj_path) V = proj_data["V"] # [M, k] if len(cards) != V.shape[0]: print(f"Warning: Number of cards ({len(cards)}) != V rows ({V.shape[0]}). Mismatch?") # If mismatch, we might need to be careful. For now assume aligned. k = V.shape[1] # 2. Group by user user_indices = defaultdict(list) for idx, card in enumerate(cards): user_indices[card.user_id].append(idx) # 3. Initialize Store print(f"Initializing UserStore at {user_store_path}...") store = UserTensorStore(k=k, path=user_store_path) # 4. Compute z_long and save print(f"Processing {len(user_indices)} users...") for uid, indices in user_indices.items(): if not indices: continue # Get item vectors for this user # indices is list of int, V is numpy array user_items = V[indices] # Mean pooling z_long = np.mean(user_items, axis=0) # Get/Create state state = store.get_state(uid) state.z_long = z_long state.z_short = np.zeros(k, dtype=np.float32) state.reward_ma = 0.0 store.save_state(state) store.persist() print("Done. User states initialized.") if __name__ == "__main__": main()