summaryrefslogtreecommitdiff
path: root/src/personalization/user_model/tensor_store.py
blob: 42dbf4e000b721e46485ef7fc76d3a52a63dc614 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import numpy as np
from dataclasses import dataclass
from typing import Dict, Optional
import os

@dataclass
class UserState:
    user_id: str
    z_long: np.ndarray   # [k]
    z_short: np.ndarray  # [k]
    reward_ma: float     # baseline for reward, init 0.0

class UserTensorStore:
    def __init__(self, k: int, path: str):
        self.k = k
        self.path = path
        self._states: Dict[str, UserState] = {}
        self._load()
        
        # Calculate global mean for initialization
        if self._states:
            z_all = np.stack([st.z_long for st in self._states.values()])
            self.global_init_z = np.mean(z_all, axis=0)
        else:
            self.global_init_z = np.zeros(self.k, dtype=np.float32)

    def _load(self):
        if os.path.exists(self.path):
            try:
                data = np.load(self.path, allow_pickle=True)
                # Assume saved as dict of user_id -> dict/object
                # For simplicity, let's say we save a single dict in a .npy or .npz
                # But np.save/load with pickle is tricky for complex objects.
                # Let's save as .npz where each key is user_id and value is a structured array or just use z_long for now?
                # A robust way for prototype:
                # save multiple arrays: "u1_long", "u1_short", "u1_meta"
                pass 
                # For Day 2 prototype, we might just re-init from init script or rely on memory if not persisting strictly.
                # But let's try to load if we can.
                
                # Let's implement a simple npz schema:
                # keys: "{uid}_long", "{uid}_short", "{uid}_meta" (meta=[reward_ma])
                for key in data.files:
                    if key.endswith("_long"):
                        uid = key[:-5]
                        z_long = data[key]
                        z_short = data.get(f"{uid}_short", np.zeros(self.k))
                        meta = data.get(f"{uid}_meta", np.array([0.0]))
                        self._states[uid] = UserState(uid, z_long, z_short, float(meta[0]))
            except Exception as e:
                print(f"Warning: Failed to load UserStore from {self.path}: {e}")

    def _save(self):
        # Save to npz
        save_dict = {}
        for uid, state in self._states.items():
            save_dict[f"{uid}_long"] = state.z_long
            save_dict[f"{uid}_short"] = state.z_short
            save_dict[f"{uid}_meta"] = np.array([state.reward_ma])
        np.savez(self.path, **save_dict)

    def get_state(self, user_id: str) -> UserState:
        if user_id not in self._states:
            # Lazy init with global mean for new users
            state = UserState(
                user_id=user_id,
                z_long=self.global_init_z.copy(),
                z_short=np.zeros(self.k, dtype=np.float32),
                reward_ma=0.0,
            )
            self._states[user_id] = state
        return self._states[user_id]

    def save_state(self, state: UserState) -> None:
        self._states[state.user_id] = state
    
    def persist(self):
        """Public method to force save to disk."""
        self._save()