#!/usr/bin/env python3 """ Day 4: Offline RL Replay Script. Simulates REINFORCE update on user state using OASST1 chat logs. """ import sys import os import json import numpy as np import random from tqdm import tqdm from collections import defaultdict # Add src to sys.path sys.path.append(os.path.join(os.path.dirname(__file__), "../src")) from personalization.config.settings import load_local_models_config from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B from personalization.models.reranker.qwen3_reranker import Qwen3Reranker from personalization.retrieval.preference_store.schemas import MemoryCard from personalization.user_model.tensor_store import UserTensorStore from personalization.retrieval.pipeline import retrieve_with_policy from personalization.feedback.handlers import eval_step from personalization.user_model.policy.reinforce import reinforce_update_user_state def main(): # Configuration cfg = load_local_models_config() rl_cfg = { "item_dim": 256, "beta_long": 0.1, "beta_short": 0.3, "tau": 1.0, "eta_long": 1e-3, "eta_short": 5e-3, "ema_alpha": 0.05, "short_decay": 0.1 } # Load from yaml if available, here hardcoded for safety/clarity in demo # Paths cards_path = "data/corpora/memory_cards.jsonl" embs_path = "data/corpora/memory_embeddings.npy" item_proj_path = "data/corpora/item_projection.npz" user_store_path = "data/users/user_store.npz" chat_path = "data/raw_datasets/oasst1_queries.jsonl" # 1. Load Data print("Loading data stores...") if not os.path.exists(cards_path): print("Data missing.") sys.exit(1) cards = [] with open(cards_path, "r") as f: for line in f: cards.append(MemoryCard.model_validate_json(line)) memory_embeddings = np.load(embs_path) item_vectors = np.load(item_proj_path)["V"] # 2. Load Models print("Loading models...") embedder = Qwen3Embedding8B.from_config(cfg) reranker = Qwen3Reranker.from_config(cfg) user_store = UserTensorStore(k=rl_cfg["item_dim"], path=user_store_path) # 3. Load Chat Logs and Group by Session print("Loading chat logs...") sessions = defaultdict(list) with open(chat_path, "r") as f: for line in f: row = json.loads(line) sessions[row["session_id"]].append(row) # Filter sessions with at least 2 user turns to have q_t and q_t+1 valid_sessions = [s for s in sessions.values() if len(s) >= 2] print(f"Found {len(valid_sessions)} valid sessions for replay.") # Sample a few users/sessions to replay # For speed, pick 50 sessions sampled_sessions = random.sample(valid_sessions, min(50, len(valid_sessions))) total_reward = 0.0 total_gating = 0.0 steps = 0 # 4. Replay Loop print("\nStarting RL Replay...") for session in tqdm(sampled_sessions, desc="Replay Sessions"): # Sort by turn session.sort(key=lambda x: x.get("turn_id", 0)) # We need (q_t, a_t, q_t+1) # OASST1 flat queries usually don't contain assistant answers in the same file if we only extracted user queries. # But for this demo, let's assume we can proceed without a_t (or mock it) for retrieval check, # OR we rely on the fact that if q_t+1 exists, there was an answer. # Limitation: We might not have a_t text. # Day 3 eval_step takes a_t. We can pass empty string if unavailable, # as our current reward model mainly looks at q_t+1 keywords. user_id = session[0]["user_id"] # Pre-check user state state_before = user_store.get_state(user_id) norm_long_before = np.linalg.norm(state_before.z_long) norm_short_before = np.linalg.norm(state_before.z_short) num_updates = 0 for i in range(len(session) - 1): q_t_row = session[i] q_t1_row = session[i+1] q_t = q_t_row["original_query"] q_t1 = q_t1_row["original_query"] a_t = "" # Missing in this file # 1. Retrieve with Policy # Note: We use only_own_memories=True to simulate personalized memory bank # But if the user is new/no memory in our store, this returns empty. # Let's try both or fallback? For RL on user vector, we need candidates. # If user has no memory, policy cannot select anything. # Let's use global search for replay to ensure we have candidates to rerank. # In real system, we'd search both. candidates, cand_vectors, base_scores, chosen_idx, probs = retrieve_with_policy( user_id=user_id, query=q_t, embed_model=embedder, reranker=reranker, memory_cards=cards, memory_embeddings=memory_embeddings, user_store=user_store, item_vectors=item_vectors, topk_dense=32, topk_rerank=5, beta_long=rl_cfg["beta_long"], beta_short=rl_cfg["beta_short"], tau=rl_cfg["tau"], only_own_memories=False ) if not candidates: continue chosen_memories = [candidates[i] for i in chosen_idx] # 2. Eval (Reward & Gating) # Need embeddings e_q = embedder.encode([q_t], return_tensor=False)[0] e_q1 = embedder.encode([q_t1], return_tensor=False)[0] r_hat, g_hat = eval_step( q_t, a_t, q_t1, chosen_memories, query_embedding_t=np.array(e_q), query_embedding_t1=np.array(e_q1) ) total_reward += r_hat total_gating += g_hat steps += 1 # 3. Update (REINFORCE) state = user_store.get_state(user_id) updated = reinforce_update_user_state( user_state=state, item_vectors=cand_vectors, chosen_indices=chosen_idx, policy_probs=probs, reward_hat=r_hat, gating=g_hat, tau=rl_cfg["tau"], eta_long=rl_cfg["eta_long"], eta_short=rl_cfg["eta_short"], ema_alpha=rl_cfg["ema_alpha"], short_decay=rl_cfg["short_decay"] ) if updated: num_updates += 1 user_store.save_state(state) # Post-check state_after = user_store.get_state(user_id) norm_long_after = np.linalg.norm(state_after.z_long) norm_short_after = np.linalg.norm(state_after.z_short) # Print change only if updated if num_updates > 0: print(f"User {user_id}: {num_updates} updates") print(f" ||z_long|| {norm_long_before:.8f} -> {norm_long_after:.8f}") print(f" ||z_short|| {norm_short_before:.8f} -> {norm_short_after:.8f}") print("\n--- Replay Finished ---") print(f"Total Steps: {steps}") print(f"Avg Reward: {total_reward / max(1, steps):.8f}") print(f"Avg Gating: {total_gating / max(1, steps):.8f}") user_store.persist() print(f"User store saved to {user_store_path}") if __name__ == "__main__": main()