summaryrefslogtreecommitdiff
path: root/scripts/day4_offline_rl_replay.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/day4_offline_rl_replay.py')
-rw-r--r--scripts/day4_offline_rl_replay.py208
1 files changed, 208 insertions, 0 deletions
diff --git a/scripts/day4_offline_rl_replay.py b/scripts/day4_offline_rl_replay.py
new file mode 100644
index 0000000..1c5cdd9
--- /dev/null
+++ b/scripts/day4_offline_rl_replay.py
@@ -0,0 +1,208 @@
+#!/usr/bin/env python3
+"""
+Day 4: Offline RL Replay Script.
+Simulates REINFORCE update on user state using OASST1 chat logs.
+"""
+
+import sys
+import os
+import json
+import numpy as np
+import random
+from tqdm import tqdm
+from collections import defaultdict
+
+# Add src to sys.path
+sys.path.append(os.path.join(os.path.dirname(__file__), "../src"))
+
+from personalization.config.settings import load_local_models_config
+from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B
+from personalization.models.reranker.qwen3_reranker import Qwen3Reranker
+from personalization.retrieval.preference_store.schemas import MemoryCard
+from personalization.user_model.tensor_store import UserTensorStore
+from personalization.retrieval.pipeline import retrieve_with_policy
+from personalization.feedback.handlers import eval_step
+from personalization.user_model.policy.reinforce import reinforce_update_user_state
+
+def main():
+ # Configuration
+ cfg = load_local_models_config()
+ rl_cfg = {
+ "item_dim": 256,
+ "beta_long": 0.1,
+ "beta_short": 0.3,
+ "tau": 1.0,
+ "eta_long": 1e-3,
+ "eta_short": 5e-3,
+ "ema_alpha": 0.05,
+ "short_decay": 0.1
+ }
+ # Load from yaml if available, here hardcoded for safety/clarity in demo
+
+ # Paths
+ cards_path = "data/corpora/memory_cards.jsonl"
+ embs_path = "data/corpora/memory_embeddings.npy"
+ item_proj_path = "data/corpora/item_projection.npz"
+ user_store_path = "data/users/user_store.npz"
+ chat_path = "data/raw_datasets/oasst1_queries.jsonl"
+
+ # 1. Load Data
+ print("Loading data stores...")
+ if not os.path.exists(cards_path):
+ print("Data missing.")
+ sys.exit(1)
+
+ cards = []
+ with open(cards_path, "r") as f:
+ for line in f:
+ cards.append(MemoryCard.model_validate_json(line))
+
+ memory_embeddings = np.load(embs_path)
+ item_vectors = np.load(item_proj_path)["V"]
+
+ # 2. Load Models
+ print("Loading models...")
+ embedder = Qwen3Embedding8B.from_config(cfg)
+ reranker = Qwen3Reranker.from_config(cfg)
+
+ user_store = UserTensorStore(k=rl_cfg["item_dim"], path=user_store_path)
+
+ # 3. Load Chat Logs and Group by Session
+ print("Loading chat logs...")
+ sessions = defaultdict(list)
+ with open(chat_path, "r") as f:
+ for line in f:
+ row = json.loads(line)
+ sessions[row["session_id"]].append(row)
+
+ # Filter sessions with at least 2 user turns to have q_t and q_t+1
+ valid_sessions = [s for s in sessions.values() if len(s) >= 2]
+ print(f"Found {len(valid_sessions)} valid sessions for replay.")
+
+ # Sample a few users/sessions to replay
+ # For speed, pick 50 sessions
+ sampled_sessions = random.sample(valid_sessions, min(50, len(valid_sessions)))
+
+ total_reward = 0.0
+ total_gating = 0.0
+ steps = 0
+
+ # 4. Replay Loop
+ print("\nStarting RL Replay...")
+
+ for session in tqdm(sampled_sessions, desc="Replay Sessions"):
+ # Sort by turn
+ session.sort(key=lambda x: x.get("turn_id", 0))
+
+ # We need (q_t, a_t, q_t+1)
+ # OASST1 flat queries usually don't contain assistant answers in the same file if we only extracted user queries.
+ # But for this demo, let's assume we can proceed without a_t (or mock it) for retrieval check,
+ # OR we rely on the fact that if q_t+1 exists, there was an answer.
+ # Limitation: We might not have a_t text.
+ # Day 3 eval_step takes a_t. We can pass empty string if unavailable,
+ # as our current reward model mainly looks at q_t+1 keywords.
+
+ user_id = session[0]["user_id"]
+
+ # Pre-check user state
+ state_before = user_store.get_state(user_id)
+ norm_long_before = np.linalg.norm(state_before.z_long)
+ norm_short_before = np.linalg.norm(state_before.z_short)
+
+ num_updates = 0
+ for i in range(len(session) - 1):
+ q_t_row = session[i]
+ q_t1_row = session[i+1]
+
+ q_t = q_t_row["original_query"]
+ q_t1 = q_t1_row["original_query"]
+ a_t = "" # Missing in this file
+
+ # 1. Retrieve with Policy
+ # Note: We use only_own_memories=True to simulate personalized memory bank
+ # But if the user is new/no memory in our store, this returns empty.
+ # Let's try both or fallback? For RL on user vector, we need candidates.
+ # If user has no memory, policy cannot select anything.
+ # Let's use global search for replay to ensure we have candidates to rerank.
+ # In real system, we'd search both.
+
+ candidates, cand_vectors, base_scores, chosen_idx, probs = retrieve_with_policy(
+ user_id=user_id,
+ query=q_t,
+ embed_model=embedder,
+ reranker=reranker,
+ memory_cards=cards,
+ memory_embeddings=memory_embeddings,
+ user_store=user_store,
+ item_vectors=item_vectors,
+ topk_dense=32,
+ topk_rerank=5,
+ beta_long=rl_cfg["beta_long"],
+ beta_short=rl_cfg["beta_short"],
+ tau=rl_cfg["tau"],
+ only_own_memories=False
+ )
+
+ if not candidates:
+ continue
+
+ chosen_memories = [candidates[i] for i in chosen_idx]
+
+ # 2. Eval (Reward & Gating)
+ # Need embeddings
+ e_q = embedder.encode([q_t], return_tensor=False)[0]
+ e_q1 = embedder.encode([q_t1], return_tensor=False)[0]
+
+ r_hat, g_hat = eval_step(
+ q_t, a_t, q_t1,
+ chosen_memories,
+ query_embedding_t=np.array(e_q),
+ query_embedding_t1=np.array(e_q1)
+ )
+
+ total_reward += r_hat
+ total_gating += g_hat
+ steps += 1
+
+ # 3. Update (REINFORCE)
+ state = user_store.get_state(user_id)
+ updated = reinforce_update_user_state(
+ user_state=state,
+ item_vectors=cand_vectors,
+ chosen_indices=chosen_idx,
+ policy_probs=probs,
+ reward_hat=r_hat,
+ gating=g_hat,
+ tau=rl_cfg["tau"],
+ eta_long=rl_cfg["eta_long"],
+ eta_short=rl_cfg["eta_short"],
+ ema_alpha=rl_cfg["ema_alpha"],
+ short_decay=rl_cfg["short_decay"]
+ )
+ if updated:
+ num_updates += 1
+
+ user_store.save_state(state)
+
+ # Post-check
+ state_after = user_store.get_state(user_id)
+ norm_long_after = np.linalg.norm(state_after.z_long)
+ norm_short_after = np.linalg.norm(state_after.z_short)
+
+ # Print change only if updated
+ if num_updates > 0:
+ print(f"User {user_id}: {num_updates} updates")
+ print(f" ||z_long|| {norm_long_before:.8f} -> {norm_long_after:.8f}")
+ print(f" ||z_short|| {norm_short_before:.8f} -> {norm_short_after:.8f}")
+
+ print("\n--- Replay Finished ---")
+ print(f"Total Steps: {steps}")
+ print(f"Avg Reward: {total_reward / max(1, steps):.8f}")
+ print(f"Avg Gating: {total_gating / max(1, steps):.8f}")
+
+ user_store.persist()
+ print(f"User store saved to {user_store_path}")
+
+if __name__ == "__main__":
+ main()
+