diff options
Diffstat (limited to 'scripts/day4_offline_rl_replay.py')
| -rw-r--r-- | scripts/day4_offline_rl_replay.py | 208 |
1 files changed, 208 insertions, 0 deletions
diff --git a/scripts/day4_offline_rl_replay.py b/scripts/day4_offline_rl_replay.py new file mode 100644 index 0000000..1c5cdd9 --- /dev/null +++ b/scripts/day4_offline_rl_replay.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python3 +""" +Day 4: Offline RL Replay Script. +Simulates REINFORCE update on user state using OASST1 chat logs. +""" + +import sys +import os +import json +import numpy as np +import random +from tqdm import tqdm +from collections import defaultdict + +# Add src to sys.path +sys.path.append(os.path.join(os.path.dirname(__file__), "../src")) + +from personalization.config.settings import load_local_models_config +from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B +from personalization.models.reranker.qwen3_reranker import Qwen3Reranker +from personalization.retrieval.preference_store.schemas import MemoryCard +from personalization.user_model.tensor_store import UserTensorStore +from personalization.retrieval.pipeline import retrieve_with_policy +from personalization.feedback.handlers import eval_step +from personalization.user_model.policy.reinforce import reinforce_update_user_state + +def main(): + # Configuration + cfg = load_local_models_config() + rl_cfg = { + "item_dim": 256, + "beta_long": 0.1, + "beta_short": 0.3, + "tau": 1.0, + "eta_long": 1e-3, + "eta_short": 5e-3, + "ema_alpha": 0.05, + "short_decay": 0.1 + } + # Load from yaml if available, here hardcoded for safety/clarity in demo + + # Paths + cards_path = "data/corpora/memory_cards.jsonl" + embs_path = "data/corpora/memory_embeddings.npy" + item_proj_path = "data/corpora/item_projection.npz" + user_store_path = "data/users/user_store.npz" + chat_path = "data/raw_datasets/oasst1_queries.jsonl" + + # 1. Load Data + print("Loading data stores...") + if not os.path.exists(cards_path): + print("Data missing.") + sys.exit(1) + + cards = [] + with open(cards_path, "r") as f: + for line in f: + cards.append(MemoryCard.model_validate_json(line)) + + memory_embeddings = np.load(embs_path) + item_vectors = np.load(item_proj_path)["V"] + + # 2. Load Models + print("Loading models...") + embedder = Qwen3Embedding8B.from_config(cfg) + reranker = Qwen3Reranker.from_config(cfg) + + user_store = UserTensorStore(k=rl_cfg["item_dim"], path=user_store_path) + + # 3. Load Chat Logs and Group by Session + print("Loading chat logs...") + sessions = defaultdict(list) + with open(chat_path, "r") as f: + for line in f: + row = json.loads(line) + sessions[row["session_id"]].append(row) + + # Filter sessions with at least 2 user turns to have q_t and q_t+1 + valid_sessions = [s for s in sessions.values() if len(s) >= 2] + print(f"Found {len(valid_sessions)} valid sessions for replay.") + + # Sample a few users/sessions to replay + # For speed, pick 50 sessions + sampled_sessions = random.sample(valid_sessions, min(50, len(valid_sessions))) + + total_reward = 0.0 + total_gating = 0.0 + steps = 0 + + # 4. Replay Loop + print("\nStarting RL Replay...") + + for session in tqdm(sampled_sessions, desc="Replay Sessions"): + # Sort by turn + session.sort(key=lambda x: x.get("turn_id", 0)) + + # We need (q_t, a_t, q_t+1) + # OASST1 flat queries usually don't contain assistant answers in the same file if we only extracted user queries. + # But for this demo, let's assume we can proceed without a_t (or mock it) for retrieval check, + # OR we rely on the fact that if q_t+1 exists, there was an answer. + # Limitation: We might not have a_t text. + # Day 3 eval_step takes a_t. We can pass empty string if unavailable, + # as our current reward model mainly looks at q_t+1 keywords. + + user_id = session[0]["user_id"] + + # Pre-check user state + state_before = user_store.get_state(user_id) + norm_long_before = np.linalg.norm(state_before.z_long) + norm_short_before = np.linalg.norm(state_before.z_short) + + num_updates = 0 + for i in range(len(session) - 1): + q_t_row = session[i] + q_t1_row = session[i+1] + + q_t = q_t_row["original_query"] + q_t1 = q_t1_row["original_query"] + a_t = "" # Missing in this file + + # 1. Retrieve with Policy + # Note: We use only_own_memories=True to simulate personalized memory bank + # But if the user is new/no memory in our store, this returns empty. + # Let's try both or fallback? For RL on user vector, we need candidates. + # If user has no memory, policy cannot select anything. + # Let's use global search for replay to ensure we have candidates to rerank. + # In real system, we'd search both. + + candidates, cand_vectors, base_scores, chosen_idx, probs = retrieve_with_policy( + user_id=user_id, + query=q_t, + embed_model=embedder, + reranker=reranker, + memory_cards=cards, + memory_embeddings=memory_embeddings, + user_store=user_store, + item_vectors=item_vectors, + topk_dense=32, + topk_rerank=5, + beta_long=rl_cfg["beta_long"], + beta_short=rl_cfg["beta_short"], + tau=rl_cfg["tau"], + only_own_memories=False + ) + + if not candidates: + continue + + chosen_memories = [candidates[i] for i in chosen_idx] + + # 2. Eval (Reward & Gating) + # Need embeddings + e_q = embedder.encode([q_t], return_tensor=False)[0] + e_q1 = embedder.encode([q_t1], return_tensor=False)[0] + + r_hat, g_hat = eval_step( + q_t, a_t, q_t1, + chosen_memories, + query_embedding_t=np.array(e_q), + query_embedding_t1=np.array(e_q1) + ) + + total_reward += r_hat + total_gating += g_hat + steps += 1 + + # 3. Update (REINFORCE) + state = user_store.get_state(user_id) + updated = reinforce_update_user_state( + user_state=state, + item_vectors=cand_vectors, + chosen_indices=chosen_idx, + policy_probs=probs, + reward_hat=r_hat, + gating=g_hat, + tau=rl_cfg["tau"], + eta_long=rl_cfg["eta_long"], + eta_short=rl_cfg["eta_short"], + ema_alpha=rl_cfg["ema_alpha"], + short_decay=rl_cfg["short_decay"] + ) + if updated: + num_updates += 1 + + user_store.save_state(state) + + # Post-check + state_after = user_store.get_state(user_id) + norm_long_after = np.linalg.norm(state_after.z_long) + norm_short_after = np.linalg.norm(state_after.z_short) + + # Print change only if updated + if num_updates > 0: + print(f"User {user_id}: {num_updates} updates") + print(f" ||z_long|| {norm_long_before:.8f} -> {norm_long_after:.8f}") + print(f" ||z_short|| {norm_short_before:.8f} -> {norm_short_after:.8f}") + + print("\n--- Replay Finished ---") + print(f"Total Steps: {steps}") + print(f"Avg Reward: {total_reward / max(1, steps):.8f}") + print(f"Avg Gating: {total_gating / max(1, steps):.8f}") + + user_store.persist() + print(f"User store saved to {user_store_path}") + +if __name__ == "__main__": + main() + |
