summaryrefslogtreecommitdiff
path: root/scripts/day4_offline_rl_replay.py
blob: 1c5cdd921617794bb64def6130081261b65d7db7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
#!/usr/bin/env python3
"""
Day 4: Offline RL Replay Script.
Simulates REINFORCE update on user state using OASST1 chat logs.
"""

import sys
import os
import json
import numpy as np
import random
from tqdm import tqdm
from collections import defaultdict

# Add src to sys.path
sys.path.append(os.path.join(os.path.dirname(__file__), "../src"))

from personalization.config.settings import load_local_models_config
from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B
from personalization.models.reranker.qwen3_reranker import Qwen3Reranker
from personalization.retrieval.preference_store.schemas import MemoryCard
from personalization.user_model.tensor_store import UserTensorStore
from personalization.retrieval.pipeline import retrieve_with_policy
from personalization.feedback.handlers import eval_step
from personalization.user_model.policy.reinforce import reinforce_update_user_state

def main():
    # Configuration
    cfg = load_local_models_config()
    rl_cfg = {
        "item_dim": 256,
        "beta_long": 0.1,
        "beta_short": 0.3,
        "tau": 1.0,
        "eta_long": 1e-3,
        "eta_short": 5e-3,
        "ema_alpha": 0.05,
        "short_decay": 0.1
    }
    # Load from yaml if available, here hardcoded for safety/clarity in demo
    
    # Paths
    cards_path = "data/corpora/memory_cards.jsonl"
    embs_path = "data/corpora/memory_embeddings.npy"
    item_proj_path = "data/corpora/item_projection.npz"
    user_store_path = "data/users/user_store.npz"
    chat_path = "data/raw_datasets/oasst1_queries.jsonl"
    
    # 1. Load Data
    print("Loading data stores...")
    if not os.path.exists(cards_path):
        print("Data missing.")
        sys.exit(1)
        
    cards = []
    with open(cards_path, "r") as f:
        for line in f:
            cards.append(MemoryCard.model_validate_json(line))
            
    memory_embeddings = np.load(embs_path)
    item_vectors = np.load(item_proj_path)["V"]
    
    # 2. Load Models
    print("Loading models...")
    embedder = Qwen3Embedding8B.from_config(cfg)
    reranker = Qwen3Reranker.from_config(cfg)
    
    user_store = UserTensorStore(k=rl_cfg["item_dim"], path=user_store_path)
    
    # 3. Load Chat Logs and Group by Session
    print("Loading chat logs...")
    sessions = defaultdict(list)
    with open(chat_path, "r") as f:
        for line in f:
            row = json.loads(line)
            sessions[row["session_id"]].append(row)
            
    # Filter sessions with at least 2 user turns to have q_t and q_t+1
    valid_sessions = [s for s in sessions.values() if len(s) >= 2]
    print(f"Found {len(valid_sessions)} valid sessions for replay.")
    
    # Sample a few users/sessions to replay
    # For speed, pick 50 sessions
    sampled_sessions = random.sample(valid_sessions, min(50, len(valid_sessions)))
    
    total_reward = 0.0
    total_gating = 0.0
    steps = 0
    
    # 4. Replay Loop
    print("\nStarting RL Replay...")
    
    for session in tqdm(sampled_sessions, desc="Replay Sessions"):
        # Sort by turn
        session.sort(key=lambda x: x.get("turn_id", 0))
        
        # We need (q_t, a_t, q_t+1)
        # OASST1 flat queries usually don't contain assistant answers in the same file if we only extracted user queries.
        # But for this demo, let's assume we can proceed without a_t (or mock it) for retrieval check, 
        # OR we rely on the fact that if q_t+1 exists, there was an answer.
        # Limitation: We might not have a_t text. 
        # Day 3 eval_step takes a_t. We can pass empty string if unavailable, 
        # as our current reward model mainly looks at q_t+1 keywords.
        
        user_id = session[0]["user_id"]
        
        # Pre-check user state
        state_before = user_store.get_state(user_id)
        norm_long_before = np.linalg.norm(state_before.z_long)
        norm_short_before = np.linalg.norm(state_before.z_short)
        
        num_updates = 0
        for i in range(len(session) - 1):
            q_t_row = session[i]
            q_t1_row = session[i+1]
            
            q_t = q_t_row["original_query"]
            q_t1 = q_t1_row["original_query"]
            a_t = "" # Missing in this file
            
            # 1. Retrieve with Policy
            # Note: We use only_own_memories=True to simulate personalized memory bank
            # But if the user is new/no memory in our store, this returns empty.
            # Let's try both or fallback? For RL on user vector, we need candidates.
            # If user has no memory, policy cannot select anything.
            # Let's use global search for replay to ensure we have candidates to rerank.
            # In real system, we'd search both.
            
            candidates, cand_vectors, base_scores, chosen_idx, probs = retrieve_with_policy(
                user_id=user_id,
                query=q_t,
                embed_model=embedder,
                reranker=reranker,
                memory_cards=cards,
                memory_embeddings=memory_embeddings,
                user_store=user_store,
                item_vectors=item_vectors,
                topk_dense=32,
                topk_rerank=5,
                beta_long=rl_cfg["beta_long"],
                beta_short=rl_cfg["beta_short"],
                tau=rl_cfg["tau"],
                only_own_memories=False 
            )
            
            if not candidates:
                continue
                
            chosen_memories = [candidates[i] for i in chosen_idx]
            
            # 2. Eval (Reward & Gating)
            # Need embeddings
            e_q = embedder.encode([q_t], return_tensor=False)[0]
            e_q1 = embedder.encode([q_t1], return_tensor=False)[0]
            
            r_hat, g_hat = eval_step(
                q_t, a_t, q_t1, 
                chosen_memories, 
                query_embedding_t=np.array(e_q),
                query_embedding_t1=np.array(e_q1)
            )
            
            total_reward += r_hat
            total_gating += g_hat
            steps += 1
            
            # 3. Update (REINFORCE)
            state = user_store.get_state(user_id)
            updated = reinforce_update_user_state(
                user_state=state,
                item_vectors=cand_vectors,
                chosen_indices=chosen_idx,
                policy_probs=probs,
                reward_hat=r_hat,
                gating=g_hat,
                tau=rl_cfg["tau"],
                eta_long=rl_cfg["eta_long"],
                eta_short=rl_cfg["eta_short"],
                ema_alpha=rl_cfg["ema_alpha"],
                short_decay=rl_cfg["short_decay"]
            )
            if updated:
                num_updates += 1
                
            user_store.save_state(state)
            
        # Post-check
        state_after = user_store.get_state(user_id)
        norm_long_after = np.linalg.norm(state_after.z_long)
        norm_short_after = np.linalg.norm(state_after.z_short)
        
        # Print change only if updated
        if num_updates > 0:
            print(f"User {user_id}: {num_updates} updates")
            print(f"  ||z_long|| {norm_long_before:.8f} -> {norm_long_after:.8f}")
            print(f"  ||z_short|| {norm_short_before:.8f} -> {norm_short_after:.8f}")

    print("\n--- Replay Finished ---")
    print(f"Total Steps: {steps}")
    print(f"Avg Reward: {total_reward / max(1, steps):.8f}")
    print(f"Avg Gating: {total_gating / max(1, steps):.8f}")
    
    user_store.persist()
    print(f"User store saved to {user_store_path}")

if __name__ == "__main__":
    main()