From e43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Wed, 17 Dec 2025 04:29:37 -0600 Subject: Initial commit (clean history) --- src/personalization/user_model/tensor_store.py | 80 ++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 src/personalization/user_model/tensor_store.py (limited to 'src/personalization/user_model/tensor_store.py') diff --git a/src/personalization/user_model/tensor_store.py b/src/personalization/user_model/tensor_store.py new file mode 100644 index 0000000..42dbf4e --- /dev/null +++ b/src/personalization/user_model/tensor_store.py @@ -0,0 +1,80 @@ +import numpy as np +from dataclasses import dataclass +from typing import Dict, Optional +import os + +@dataclass +class UserState: + user_id: str + z_long: np.ndarray # [k] + z_short: np.ndarray # [k] + reward_ma: float # baseline for reward, init 0.0 + +class UserTensorStore: + def __init__(self, k: int, path: str): + self.k = k + self.path = path + self._states: Dict[str, UserState] = {} + self._load() + + # Calculate global mean for initialization + if self._states: + z_all = np.stack([st.z_long for st in self._states.values()]) + self.global_init_z = np.mean(z_all, axis=0) + else: + self.global_init_z = np.zeros(self.k, dtype=np.float32) + + def _load(self): + if os.path.exists(self.path): + try: + data = np.load(self.path, allow_pickle=True) + # Assume saved as dict of user_id -> dict/object + # For simplicity, let's say we save a single dict in a .npy or .npz + # But np.save/load with pickle is tricky for complex objects. + # Let's save as .npz where each key is user_id and value is a structured array or just use z_long for now? + # A robust way for prototype: + # save multiple arrays: "u1_long", "u1_short", "u1_meta" + pass + # For Day 2 prototype, we might just re-init from init script or rely on memory if not persisting strictly. + # But let's try to load if we can. + + # Let's implement a simple npz schema: + # keys: "{uid}_long", "{uid}_short", "{uid}_meta" (meta=[reward_ma]) + for key in data.files: + if key.endswith("_long"): + uid = key[:-5] + z_long = data[key] + z_short = data.get(f"{uid}_short", np.zeros(self.k)) + meta = data.get(f"{uid}_meta", np.array([0.0])) + self._states[uid] = UserState(uid, z_long, z_short, float(meta[0])) + except Exception as e: + print(f"Warning: Failed to load UserStore from {self.path}: {e}") + + def _save(self): + # Save to npz + save_dict = {} + for uid, state in self._states.items(): + save_dict[f"{uid}_long"] = state.z_long + save_dict[f"{uid}_short"] = state.z_short + save_dict[f"{uid}_meta"] = np.array([state.reward_ma]) + np.savez(self.path, **save_dict) + + def get_state(self, user_id: str) -> UserState: + if user_id not in self._states: + # Lazy init with global mean for new users + state = UserState( + user_id=user_id, + z_long=self.global_init_z.copy(), + z_short=np.zeros(self.k, dtype=np.float32), + reward_ma=0.0, + ) + self._states[user_id] = state + return self._states[user_id] + + def save_state(self, state: UserState) -> None: + self._states[state.user_id] = state + + def persist(self): + """Public method to force save to disk.""" + self._save() + -- cgit v1.2.3