#!/usr/bin/env python3 """ Online Personalization REPL Demo. Interactive CLI for chatting with the Personalized Memory RAG system. Includes: - Extractor-0.6B for online preference extraction - Reranker + Policy Retrieval - Online RL updates (REINFORCE) """ import sys import os import uuid import numpy as np import torch import readline # For better input handling import yaml # Add src to sys.path sys.path.append(os.path.join(os.path.dirname(__file__), "../src")) from personalization.config.settings import load_local_models_config from personalization.config.registry import ( get_preference_extractor, get_chat_model, ) from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B from personalization.models.reranker.qwen3_reranker import Qwen3Reranker # from personalization.models.llm.qwen_instruct import QwenInstruct # Deprecated direct import from personalization.user_model.tensor_store import UserTensorStore from personalization.user_model.session_state import OnlineSessionState from personalization.retrieval.preference_store.schemas import MemoryCard, ChatTurn, PreferenceList from personalization.retrieval.pipeline import retrieve_with_policy from personalization.feedback.handlers import eval_step from personalization.user_model.policy.reinforce import reinforce_update_user_state from personalization.user_model.features import ItemProjection def load_memory_store(): cards_path = "data/corpora/memory_cards.jsonl" embs_path = "data/corpora/memory_embeddings.npy" item_proj_path = "data/corpora/item_projection.npz" if not os.path.exists(cards_path) or not os.path.exists(embs_path): print("Memory data missing. Starting with empty memory store is possible but item space requires base data.") # For this demo, we assume base data exists to define PCA space. sys.exit(1) print(f"Loading memory cards from {cards_path}...") cards = [] with open(cards_path, "r") as f: for line in f: cards.append(MemoryCard.model_validate_json(line)) memory_embeddings = np.load(embs_path) # Load PCA projection proj_data = np.load(item_proj_path) # We need to reconstruct ItemProjection object to transform new memories projection = ItemProjection(P=proj_data["P"], mean=proj_data["mean"]) item_vectors = proj_data["V"] return cards, memory_embeddings, item_vectors, projection def build_user_turn(user_id: str, text: str, turn_id: int) -> ChatTurn: return ChatTurn( user_id=user_id, session_id="online_debug_session", turn_id=turn_id, role="user", text=text, meta={"source": "repl"} ) def build_assistant_turn(user_id: str, text: str, turn_id: int) -> ChatTurn: return ChatTurn( user_id=user_id, session_id="online_debug_session", turn_id=turn_id, role="assistant", text=text, meta={"source": "repl"} ) def add_preferences_as_memory_cards( prefs: PreferenceList, query: str, user_id: str, turn_id: int, embed_model: Qwen3Embedding8B, projection: ItemProjection, memory_cards: list, memory_embeddings: np.ndarray, item_vectors: np.ndarray, ) -> tuple[np.ndarray, np.ndarray]: """ Adds extracted preferences as new memory cards. Returns updated memory_embeddings and item_vectors. """ if not prefs.preferences: print(" [Extractor] No preferences found in this turn.") return memory_embeddings, item_vectors e_m_list = [] v_m_list = [] # Only compute embedding once if we use the query as source for all prefs from this turn # Alternatively, embed the note text. # The current design uses the original query embedding e_m. e_q = embed_model.encode([query], return_tensor=False)[0] v_q = projection.transform_vector(np.array(e_q)) print(f" [Extractor] Extracted {len(prefs.preferences)} preferences:") for pref in prefs.preferences: note_text = f"When {pref.condition}, {pref.action}." print(f" - {note_text}") # Simple deduplication check based on note_text for this user # In a real system, use vector similarity or hash is_duplicate = False for card in memory_cards: if card.user_id == user_id and card.note_text == note_text: is_duplicate = True break if is_duplicate: print(" (Duplicate, skipping add)") continue card = MemoryCard( card_id=str(uuid.uuid4()), user_id=user_id, source_session_id="online_debug_session", source_turn_ids=[turn_id], raw_queries=[query], preference_list=PreferenceList(preferences=[pref]), note_text=note_text, embedding_e=e_q, # Store list[float] kind="pref", ) memory_cards.append(card) e_m_list.append(e_q) v_m_list.append(v_q) # Update numpy arrays if e_m_list: new_embs = np.array(e_m_list) new_vecs = np.array(v_m_list) memory_embeddings = np.vstack([memory_embeddings, new_embs]) item_vectors = np.vstack([item_vectors, new_vecs]) print(f" [Debug] Added {len(e_m_list)} new cards. Total cards: {len(memory_cards)}") return memory_embeddings, item_vectors def main(): # 1. Load Config & Models print("Loading configuration...") cfg = load_local_models_config() # RL Config (Should load from user_model.yaml, hardcoded for safety/demo) rl_cfg = { "item_dim": 256, "beta_long": 0.1, "beta_short": 0.3, "tau": 1.0, "eta_long": 1e-3, "eta_short": 5e-3, "ema_alpha": 0.05, "short_decay": 0.1, "dense_topk": 64, "rerank_topk": 3, "max_new_tokens": 512 } print("Loading models and stores...") # Using explicit classes for clarity, but registry can also be used embed_model = Qwen3Embedding8B.from_config(cfg) reranker = Qwen3Reranker.from_config(cfg) # Use registry for ChatModel (supports switching backends) # Default to "qwen_1_5b" if not specified in user_model.yaml llm_name = "qwen_1_5b" # Try loading from config safely try: config_path = os.path.join(os.path.dirname(__file__), "../configs/user_model.yaml") if os.path.exists(config_path): with open(config_path, "r") as f: user_cfg = yaml.safe_load(f) if user_cfg and "llm_name" in user_cfg: llm_name = user_cfg["llm_name"] print(f"Loaded llm_name from config: {llm_name}") else: print(f"Warning: Config file not found at {config_path}") except Exception as e: print(f"Failed to load user_model.yaml: {e}") pass print(f"Loading ChatModel: {llm_name}...") chat_model = get_chat_model(llm_name) # Use registry for extractor to support switching extractor_name = "qwen3_0_6b_sft" # Default per design doc print(f"Loading extractor: {extractor_name}...") try: extractor = get_preference_extractor(extractor_name) except Exception as e: print(f"Failed to load {extractor_name}: {e}. Fallback to rule.") extractor = get_preference_extractor("rule") user_store = UserTensorStore( k=rl_cfg["item_dim"], path="data/users/user_store_online.npz", ) # Load Memory memory_cards, memory_embeddings, item_vectors, projection = load_memory_store() # 2. Init Session user_id = "debug_user" user_state = user_store.get_state(user_id) session_state = OnlineSessionState(user_id=user_id) print(f"\n--- Online Personalization REPL (User: {user_id}) ---") print(f"Initial State: ||z_long||={np.linalg.norm(user_state.z_long):.16f}, ||z_short||={np.linalg.norm(user_state.z_short):.16f}") print("Type 'exit' or 'quit' to stop.\n") while True: try: q_t = input("User: ").strip() except (EOFError, KeyboardInterrupt): print("\nExiting...") break if q_t.lower() in ("exit", "quit"): break if not q_t: continue # 3. RL Update (from previous turn) e_q_t = embed_model.encode([q_t], return_tensor=False)[0] e_q_t = np.array(e_q_t) if session_state.last_query is not None: r_hat, g_hat = eval_step( q_t=session_state.last_query, answer_t=session_state.last_answer, q_t1=q_t, memories_t=session_state.last_memories, query_embedding_t=session_state.last_query_embedding, query_embedding_t1=e_q_t, ) print(f" [Feedback] Reward: {r_hat:.2f}, Gating: {g_hat:.2f}") if (session_state.last_candidate_item_vectors is not None and session_state.last_policy_probs is not None and len(session_state.last_chosen_indices) > 0): # IMPORTANT: Extract the vectors of the chosen items to align with probs # last_candidate_item_vectors: [64, dim] # last_chosen_indices: [3] indices into the 64 candidates # last_policy_probs: [3] probabilities for the chosen items # We need the vectors corresponding to the chosen indices # chosen_indices contains indices into candidates list chosen_vectors = session_state.last_candidate_item_vectors[session_state.last_chosen_indices] updated = reinforce_update_user_state( user_state=user_state, item_vectors=chosen_vectors, # Corrected: Pass only chosen vectors [3, dim] chosen_indices=np.arange(len(session_state.last_chosen_indices)), # Indices are now 0,1,2 relative to chosen_vectors policy_probs=session_state.last_policy_probs, reward_hat=r_hat, gating=g_hat, tau=rl_cfg["tau"], eta_long=rl_cfg["eta_long"], eta_short=rl_cfg["eta_short"], ema_alpha=rl_cfg["ema_alpha"], short_decay=rl_cfg["short_decay"], ) if updated: print(" [RL] User state updated.") user_store.save_state(user_state) # Save immediately for safety # 4. Update History user_turn = build_user_turn(user_id, q_t, len(session_state.history)) session_state.history.append(user_turn) # 5. Extract Preferences -> New Memory # Extract from recent history prefs = extractor.extract_turn(session_state.history) memory_embeddings, item_vectors = add_preferences_as_memory_cards( prefs, q_t, user_id, user_turn.turn_id, embed_model, projection, memory_cards, memory_embeddings, item_vectors ) # 6. Retrieve + Policy # Use only_own_memories=True to allow strict privacy # Fix unpacking order: pipeline returns (candidates, vecs, scores, indices, probs) candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_with_policy( user_id=user_id, query=q_t, embed_model=embed_model, reranker=reranker, memory_cards=memory_cards, memory_embeddings=memory_embeddings, user_store=user_store, item_vectors=item_vectors, topk_dense=rl_cfg["dense_topk"], topk_rerank=rl_cfg["rerank_topk"], beta_long=rl_cfg["beta_long"], beta_short=rl_cfg["beta_short"], tau=rl_cfg["tau"], only_own_memories=True # User requested strict privacy for demo ) # Map back to indices in candidates list (0..K-1) print(f"DEBUG: candidates len={len(candidates)}, type={type(candidates)}") print(f"DEBUG: chosen_indices={chosen_indices}, type={type(chosen_indices)}") if len(chosen_indices) > 0: print(f"DEBUG: first idx type={type(chosen_indices[0])}, val={chosen_indices[0]}") memories_t = [candidates[int(i)] for i in chosen_indices] if memories_t: print(f" [Retrieval] Found {len(memories_t)} memories:") # Display Deduplication: Group by note_text from collections import Counter content_counts = Counter([m.note_text for m in memories_t]) # Print unique contents with counts for text, count in content_counts.most_common(): user_info = f" ({count} users)" if count > 1 else "" print(f" - {text}{user_info}") # 7. LLM Answer memory_notes = [m.note_text for m in memories_t] # history should be a list of ChatTurn objects, not dicts # session_state.history is already a list of ChatTurn answer_t = chat_model.answer( history=session_state.history, memory_notes=memory_notes, max_new_tokens=rl_cfg["max_new_tokens"] ) print(f"Assistant: {answer_t}") # 8. Update State for Next Turn assist_turn = build_assistant_turn(user_id, answer_t, len(session_state.history)) session_state.history.append(assist_turn) session_state.last_query = q_t session_state.last_answer = answer_t session_state.last_memories = memories_t session_state.last_query_embedding = e_q_t session_state.last_candidate_item_vectors = cand_item_vecs session_state.last_policy_probs = probs session_state.last_chosen_indices = chosen_indices print(f" [State] ||z_long||={np.linalg.norm(user_state.z_long):.16f}, ||z_short||={np.linalg.norm(user_state.z_short):.16f}") print("-" * 40) print("Saving final user state...") user_store.save_state(user_state) user_store.persist() # Save updated memories print(f"Saving {len(memory_cards)} memory cards to disk...") # Ideally should be atomic or append-only, but for demo we rewrite # Backup original first? For demo, direct overwrite is fine or save to new file cards_path = "data/corpora/memory_cards.jsonl" embs_path = "data/corpora/memory_embeddings.npy" item_proj_path = "data/corpora/item_projection.npz" with open(cards_path, "w", encoding="utf-8") as f: for card in memory_cards: f.write(card.model_dump_json() + "\n") np.save(embs_path, memory_embeddings) # Update item projection file with new item vectors? # item_projection.npz usually stores the Projection Matrix P and Mean. # The 'V' (item vectors) in it is just a cache. # We should update V in the npz so next load has them. # Load original to keep P and mean proj_data = np.load(item_proj_path) np.savez( item_proj_path, P=proj_data["P"], mean=proj_data["mean"], V=item_vectors ) print("Memory store updated.") if __name__ == "__main__": main()