summaryrefslogtreecommitdiff
path: root/scripts/online_personalization_demo.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/online_personalization_demo.py')
-rw-r--r--scripts/online_personalization_demo.py399
1 files changed, 399 insertions, 0 deletions
diff --git a/scripts/online_personalization_demo.py b/scripts/online_personalization_demo.py
new file mode 100644
index 0000000..f5b6d68
--- /dev/null
+++ b/scripts/online_personalization_demo.py
@@ -0,0 +1,399 @@
+#!/usr/bin/env python3
+"""
+Online Personalization REPL Demo.
+Interactive CLI for chatting with the Personalized Memory RAG system.
+Includes:
+- Extractor-0.6B for online preference extraction
+- Reranker + Policy Retrieval
+- Online RL updates (REINFORCE)
+"""
+
+import sys
+import os
+import uuid
+import numpy as np
+import torch
+import readline # For better input handling
+import yaml
+
+# Add src to sys.path
+sys.path.append(os.path.join(os.path.dirname(__file__), "../src"))
+
+from personalization.config.settings import load_local_models_config
+from personalization.config.registry import (
+ get_preference_extractor,
+ get_chat_model,
+)
+from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B
+from personalization.models.reranker.qwen3_reranker import Qwen3Reranker
+# from personalization.models.llm.qwen_instruct import QwenInstruct # Deprecated direct import
+from personalization.user_model.tensor_store import UserTensorStore
+from personalization.user_model.session_state import OnlineSessionState
+from personalization.retrieval.preference_store.schemas import MemoryCard, ChatTurn, PreferenceList
+from personalization.retrieval.pipeline import retrieve_with_policy
+from personalization.feedback.handlers import eval_step
+from personalization.user_model.policy.reinforce import reinforce_update_user_state
+from personalization.user_model.features import ItemProjection
+
+def load_memory_store():
+ cards_path = "data/corpora/memory_cards.jsonl"
+ embs_path = "data/corpora/memory_embeddings.npy"
+ item_proj_path = "data/corpora/item_projection.npz"
+
+ if not os.path.exists(cards_path) or not os.path.exists(embs_path):
+ print("Memory data missing. Starting with empty memory store is possible but item space requires base data.")
+ # For this demo, we assume base data exists to define PCA space.
+ sys.exit(1)
+
+ print(f"Loading memory cards from {cards_path}...")
+ cards = []
+ with open(cards_path, "r") as f:
+ for line in f:
+ cards.append(MemoryCard.model_validate_json(line))
+
+ memory_embeddings = np.load(embs_path)
+
+ # Load PCA projection
+ proj_data = np.load(item_proj_path)
+ # We need to reconstruct ItemProjection object to transform new memories
+ projection = ItemProjection(P=proj_data["P"], mean=proj_data["mean"])
+ item_vectors = proj_data["V"]
+
+ return cards, memory_embeddings, item_vectors, projection
+
+def build_user_turn(user_id: str, text: str, turn_id: int) -> ChatTurn:
+ return ChatTurn(
+ user_id=user_id,
+ session_id="online_debug_session",
+ turn_id=turn_id,
+ role="user",
+ text=text,
+ meta={"source": "repl"}
+ )
+
+def build_assistant_turn(user_id: str, text: str, turn_id: int) -> ChatTurn:
+ return ChatTurn(
+ user_id=user_id,
+ session_id="online_debug_session",
+ turn_id=turn_id,
+ role="assistant",
+ text=text,
+ meta={"source": "repl"}
+ )
+
+def add_preferences_as_memory_cards(
+ prefs: PreferenceList,
+ query: str,
+ user_id: str,
+ turn_id: int,
+ embed_model: Qwen3Embedding8B,
+ projection: ItemProjection,
+ memory_cards: list,
+ memory_embeddings: np.ndarray,
+ item_vectors: np.ndarray,
+) -> tuple[np.ndarray, np.ndarray]:
+ """
+ Adds extracted preferences as new memory cards.
+ Returns updated memory_embeddings and item_vectors.
+ """
+ if not prefs.preferences:
+ print(" [Extractor] No preferences found in this turn.")
+ return memory_embeddings, item_vectors
+
+ e_m_list = []
+ v_m_list = []
+
+ # Only compute embedding once if we use the query as source for all prefs from this turn
+ # Alternatively, embed the note text.
+ # The current design uses the original query embedding e_m.
+ e_q = embed_model.encode([query], return_tensor=False)[0]
+ v_q = projection.transform_vector(np.array(e_q))
+
+ print(f" [Extractor] Extracted {len(prefs.preferences)} preferences:")
+ for pref in prefs.preferences:
+ note_text = f"When {pref.condition}, {pref.action}."
+ print(f" - {note_text}")
+
+ # Simple deduplication check based on note_text for this user
+ # In a real system, use vector similarity or hash
+ is_duplicate = False
+ for card in memory_cards:
+ if card.user_id == user_id and card.note_text == note_text:
+ is_duplicate = True
+ break
+
+ if is_duplicate:
+ print(" (Duplicate, skipping add)")
+ continue
+
+ card = MemoryCard(
+ card_id=str(uuid.uuid4()),
+ user_id=user_id,
+ source_session_id="online_debug_session",
+ source_turn_ids=[turn_id],
+ raw_queries=[query],
+ preference_list=PreferenceList(preferences=[pref]),
+ note_text=note_text,
+ embedding_e=e_q, # Store list[float]
+ kind="pref",
+ )
+
+ memory_cards.append(card)
+ e_m_list.append(e_q)
+ v_m_list.append(v_q)
+
+ # Update numpy arrays
+ if e_m_list:
+ new_embs = np.array(e_m_list)
+ new_vecs = np.array(v_m_list)
+ memory_embeddings = np.vstack([memory_embeddings, new_embs])
+ item_vectors = np.vstack([item_vectors, new_vecs])
+ print(f" [Debug] Added {len(e_m_list)} new cards. Total cards: {len(memory_cards)}")
+
+ return memory_embeddings, item_vectors
+
+def main():
+ # 1. Load Config & Models
+ print("Loading configuration...")
+ cfg = load_local_models_config()
+
+ # RL Config (Should load from user_model.yaml, hardcoded for safety/demo)
+ rl_cfg = {
+ "item_dim": 256,
+ "beta_long": 0.1,
+ "beta_short": 0.3,
+ "tau": 1.0,
+ "eta_long": 1e-3,
+ "eta_short": 5e-3,
+ "ema_alpha": 0.05,
+ "short_decay": 0.1,
+ "dense_topk": 64,
+ "rerank_topk": 3,
+ "max_new_tokens": 512
+ }
+
+ print("Loading models and stores...")
+ # Using explicit classes for clarity, but registry can also be used
+ embed_model = Qwen3Embedding8B.from_config(cfg)
+ reranker = Qwen3Reranker.from_config(cfg)
+
+ # Use registry for ChatModel (supports switching backends)
+ # Default to "qwen_1_5b" if not specified in user_model.yaml
+ llm_name = "qwen_1_5b"
+
+ # Try loading from config safely
+ try:
+ config_path = os.path.join(os.path.dirname(__file__), "../configs/user_model.yaml")
+ if os.path.exists(config_path):
+ with open(config_path, "r") as f:
+ user_cfg = yaml.safe_load(f)
+ if user_cfg and "llm_name" in user_cfg:
+ llm_name = user_cfg["llm_name"]
+ print(f"Loaded llm_name from config: {llm_name}")
+ else:
+ print(f"Warning: Config file not found at {config_path}")
+ except Exception as e:
+ print(f"Failed to load user_model.yaml: {e}")
+ pass
+
+ print(f"Loading ChatModel: {llm_name}...")
+ chat_model = get_chat_model(llm_name)
+
+ # Use registry for extractor to support switching
+ extractor_name = "qwen3_0_6b_sft" # Default per design doc
+ print(f"Loading extractor: {extractor_name}...")
+ try:
+ extractor = get_preference_extractor(extractor_name)
+ except Exception as e:
+ print(f"Failed to load {extractor_name}: {e}. Fallback to rule.")
+ extractor = get_preference_extractor("rule")
+
+ user_store = UserTensorStore(
+ k=rl_cfg["item_dim"],
+ path="data/users/user_store_online.npz",
+ )
+
+ # Load Memory
+ memory_cards, memory_embeddings, item_vectors, projection = load_memory_store()
+
+ # 2. Init Session
+ user_id = "debug_user"
+ user_state = user_store.get_state(user_id)
+ session_state = OnlineSessionState(user_id=user_id)
+
+ print(f"\n--- Online Personalization REPL (User: {user_id}) ---")
+ print(f"Initial State: ||z_long||={np.linalg.norm(user_state.z_long):.16f}, ||z_short||={np.linalg.norm(user_state.z_short):.16f}")
+ print("Type 'exit' or 'quit' to stop.\n")
+
+ while True:
+ try:
+ q_t = input("User: ").strip()
+ except (EOFError, KeyboardInterrupt):
+ print("\nExiting...")
+ break
+
+ if q_t.lower() in ("exit", "quit"):
+ break
+ if not q_t:
+ continue
+
+ # 3. RL Update (from previous turn)
+ e_q_t = embed_model.encode([q_t], return_tensor=False)[0]
+ e_q_t = np.array(e_q_t)
+
+ if session_state.last_query is not None:
+ r_hat, g_hat = eval_step(
+ q_t=session_state.last_query,
+ answer_t=session_state.last_answer,
+ q_t1=q_t,
+ memories_t=session_state.last_memories,
+ query_embedding_t=session_state.last_query_embedding,
+ query_embedding_t1=e_q_t,
+ )
+
+ print(f" [Feedback] Reward: {r_hat:.2f}, Gating: {g_hat:.2f}")
+
+ if (session_state.last_candidate_item_vectors is not None and
+ session_state.last_policy_probs is not None and
+ len(session_state.last_chosen_indices) > 0):
+
+ # IMPORTANT: Extract the vectors of the chosen items to align with probs
+ # last_candidate_item_vectors: [64, dim]
+ # last_chosen_indices: [3] indices into the 64 candidates
+ # last_policy_probs: [3] probabilities for the chosen items
+
+ # We need the vectors corresponding to the chosen indices
+ # chosen_indices contains indices into candidates list
+ chosen_vectors = session_state.last_candidate_item_vectors[session_state.last_chosen_indices]
+
+ updated = reinforce_update_user_state(
+ user_state=user_state,
+ item_vectors=chosen_vectors, # Corrected: Pass only chosen vectors [3, dim]
+ chosen_indices=np.arange(len(session_state.last_chosen_indices)), # Indices are now 0,1,2 relative to chosen_vectors
+ policy_probs=session_state.last_policy_probs,
+ reward_hat=r_hat,
+ gating=g_hat,
+ tau=rl_cfg["tau"],
+ eta_long=rl_cfg["eta_long"],
+ eta_short=rl_cfg["eta_short"],
+ ema_alpha=rl_cfg["ema_alpha"],
+ short_decay=rl_cfg["short_decay"],
+ )
+ if updated:
+ print(" [RL] User state updated.")
+ user_store.save_state(user_state) # Save immediately for safety
+
+ # 4. Update History
+ user_turn = build_user_turn(user_id, q_t, len(session_state.history))
+ session_state.history.append(user_turn)
+
+ # 5. Extract Preferences -> New Memory
+ # Extract from recent history
+ prefs = extractor.extract_turn(session_state.history)
+ memory_embeddings, item_vectors = add_preferences_as_memory_cards(
+ prefs, q_t, user_id, user_turn.turn_id,
+ embed_model, projection, memory_cards, memory_embeddings, item_vectors
+ )
+
+ # 6. Retrieve + Policy
+ # Use only_own_memories=True to allow strict privacy
+ # Fix unpacking order: pipeline returns (candidates, vecs, scores, indices, probs)
+ candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_with_policy(
+ user_id=user_id,
+ query=q_t,
+ embed_model=embed_model,
+ reranker=reranker,
+ memory_cards=memory_cards,
+ memory_embeddings=memory_embeddings,
+ user_store=user_store,
+ item_vectors=item_vectors,
+ topk_dense=rl_cfg["dense_topk"],
+ topk_rerank=rl_cfg["rerank_topk"],
+ beta_long=rl_cfg["beta_long"],
+ beta_short=rl_cfg["beta_short"],
+ tau=rl_cfg["tau"],
+ only_own_memories=True # User requested strict privacy for demo
+ )
+
+ # Map back to indices in candidates list (0..K-1)
+ print(f"DEBUG: candidates len={len(candidates)}, type={type(candidates)}")
+ print(f"DEBUG: chosen_indices={chosen_indices}, type={type(chosen_indices)}")
+ if len(chosen_indices) > 0:
+ print(f"DEBUG: first idx type={type(chosen_indices[0])}, val={chosen_indices[0]}")
+
+ memories_t = [candidates[int(i)] for i in chosen_indices]
+ if memories_t:
+ print(f" [Retrieval] Found {len(memories_t)} memories:")
+
+ # Display Deduplication: Group by note_text
+ from collections import Counter
+ content_counts = Counter([m.note_text for m in memories_t])
+
+ # Print unique contents with counts
+ for text, count in content_counts.most_common():
+ user_info = f" ({count} users)" if count > 1 else ""
+ print(f" - {text}{user_info}")
+
+ # 7. LLM Answer
+ memory_notes = [m.note_text for m in memories_t]
+ # history should be a list of ChatTurn objects, not dicts
+ # session_state.history is already a list of ChatTurn
+ answer_t = chat_model.answer(
+ history=session_state.history,
+ memory_notes=memory_notes,
+ max_new_tokens=rl_cfg["max_new_tokens"]
+ )
+
+ print(f"Assistant: {answer_t}")
+
+ # 8. Update State for Next Turn
+ assist_turn = build_assistant_turn(user_id, answer_t, len(session_state.history))
+ session_state.history.append(assist_turn)
+
+ session_state.last_query = q_t
+ session_state.last_answer = answer_t
+ session_state.last_memories = memories_t
+ session_state.last_query_embedding = e_q_t
+ session_state.last_candidate_item_vectors = cand_item_vecs
+ session_state.last_policy_probs = probs
+ session_state.last_chosen_indices = chosen_indices
+
+ print(f" [State] ||z_long||={np.linalg.norm(user_state.z_long):.16f}, ||z_short||={np.linalg.norm(user_state.z_short):.16f}")
+ print("-" * 40)
+
+ print("Saving final user state...")
+ user_store.save_state(user_state)
+ user_store.persist()
+
+ # Save updated memories
+ print(f"Saving {len(memory_cards)} memory cards to disk...")
+ # Ideally should be atomic or append-only, but for demo we rewrite
+ # Backup original first? For demo, direct overwrite is fine or save to new file
+ cards_path = "data/corpora/memory_cards.jsonl"
+ embs_path = "data/corpora/memory_embeddings.npy"
+ item_proj_path = "data/corpora/item_projection.npz"
+
+ with open(cards_path, "w", encoding="utf-8") as f:
+ for card in memory_cards:
+ f.write(card.model_dump_json() + "\n")
+
+ np.save(embs_path, memory_embeddings)
+
+ # Update item projection file with new item vectors?
+ # item_projection.npz usually stores the Projection Matrix P and Mean.
+ # The 'V' (item vectors) in it is just a cache.
+ # We should update V in the npz so next load has them.
+ # Load original to keep P and mean
+ proj_data = np.load(item_proj_path)
+ np.savez(
+ item_proj_path,
+ P=proj_data["P"],
+ mean=proj_data["mean"],
+ V=item_vectors
+ )
+ print("Memory store updated.")
+
+if __name__ == "__main__":
+ main()
+
+