diff options
Diffstat (limited to 'src/personalization/evaluation/baselines/rag_memory.py')
| -rw-r--r-- | src/personalization/evaluation/baselines/rag_memory.py | 204 |
1 files changed, 204 insertions, 0 deletions
diff --git a/src/personalization/evaluation/baselines/rag_memory.py b/src/personalization/evaluation/baselines/rag_memory.py new file mode 100644 index 0000000..2b391c3 --- /dev/null +++ b/src/personalization/evaluation/baselines/rag_memory.py @@ -0,0 +1,204 @@ +""" +RAG Memory Baseline (Y3/Y4) + +Wraps the PersonalizedLLM for use in the evaluation framework. +Y3: Extractor + RAG (mode="nopersonal") +Y4: Extractor + RAG + User Vector (mode="full") +""" + +from typing import List, Dict, Any, Optional +import os +import sys + +from .base import BaselineAgent, AgentResponse + +# Add src to path for imports +_src_path = os.path.join(os.path.dirname(__file__), "../../../..") +if _src_path not in sys.path: + sys.path.insert(0, _src_path) + + +class RAGMemoryAgent(BaselineAgent): + """ + Y3/Y4: RAG-based memory with optional user vector. + + This agent: + - Extracts preferences from conversations using the extractor + - Stores preferences as memory cards + - Retrieves relevant memories using RAG for each query + - (Y4 only) Uses user vector to personalize retrieval + """ + + def __init__( + self, + model_name: str = "llama-8b", + mode: str = "nopersonal", # "nopersonal" for Y3, "full" for Y4 + memory_cards_path: str = None, + memory_embeddings_path: str = None, + enable_preference_extraction: bool = True, + enable_rl_updates: bool = False, + only_own_memories: bool = True, + **kwargs + ): + """ + Args: + model_name: LLM model to use + mode: "nopersonal" (Y3) or "full" (Y4) + memory_cards_path: Path to memory cards file + memory_embeddings_path: Path to embeddings file + enable_preference_extraction: Whether to extract preferences + enable_rl_updates: Whether to update user vectors (Y4 only) + only_own_memories: Only retrieve user's own memories + """ + super().__init__(model_name, **kwargs) + + self.mode = mode + self.enable_rl_updates = enable_rl_updates and (mode == "full") + + # Default paths + base_dir = os.path.join(os.path.dirname(__file__), "../../../../..") + self.memory_cards_path = memory_cards_path or os.path.join( + base_dir, "data/eval/memory_cards.jsonl" + ) + self.memory_embeddings_path = memory_embeddings_path or os.path.join( + base_dir, "data/eval/memory_embeddings.npy" + ) + + self.enable_preference_extraction = enable_preference_extraction + self.only_own_memories = only_own_memories + + # Lazy initialization + self._llm = None + self._initialized = False + + def _ensure_initialized(self): + """Lazy initialization of PersonalizedLLM.""" + if self._initialized: + return + + try: + from personalization.serving.personalized_llm import PersonalizedLLM + + self._llm = PersonalizedLLM( + mode=self.mode, + enable_preference_extraction=self.enable_preference_extraction, + enable_rl_updates=self.enable_rl_updates, + only_own_memories=self.only_own_memories, + memory_cards_path=self.memory_cards_path, + memory_embeddings_path=self.memory_embeddings_path, + eval_mode=True, # Deterministic selection + ) + self._initialized = True + + except Exception as e: + print(f"Warning: Could not initialize PersonalizedLLM: {e}") + print("Falling back to simple response mode.") + self._llm = None + self._initialized = True + + def respond( + self, + user_id: str, + query: str, + conversation_history: List[Dict[str, str]], + **kwargs + ) -> AgentResponse: + """Generate response using RAG memory.""" + + self._ensure_initialized() + + if self._llm is None: + # Fallback mode + return AgentResponse( + answer=f"[RAGMemoryAgent-{self.mode}] Response to: {query[:50]}...", + debug_info={"mode": "fallback"}, + ) + + try: + # Use PersonalizedLLM's chat interface + response = self._llm.chat(user_id, query) + + debug_info = { + "mode": self.mode, + "num_memories_retrieved": len(response.debug.selected_memory_ids) if response.debug else 0, + "selected_memories": response.debug.selected_memory_notes if response.debug else [], + "extracted_preferences": response.debug.extracted_preferences if response.debug else [], + } + + if response.debug and response.debug.extra: + debug_info.update(response.debug.extra) + + return AgentResponse( + answer=response.answer, + debug_info=debug_info, + ) + + except Exception as e: + print(f"Error in RAGMemoryAgent.respond: {e}") + return AgentResponse( + answer=f"I apologize for the error. Regarding: {query[:100]}", + debug_info={"error": str(e)}, + ) + + def end_session(self, user_id: str, conversation: List[Dict[str, str]]): + """ + Called at end of session. + PersonalizedLLM already extracts preferences during chat(), + so we just reset the session state. + """ + self._ensure_initialized() + + if self._llm is not None: + self._llm.reset_session(user_id) + + def reset_user(self, user_id: str): + """Reset all state for a user.""" + self._ensure_initialized() + + if self._llm is not None: + self._llm.reset_user(user_id) + + def apply_feedback(self, user_id: str, reward: float, gating: float = 1.0): + """ + Apply feedback for user vector updates (Y4 only). + + Args: + user_id: User identifier + reward: Reward signal (e.g., from preference satisfaction) + gating: Gating signal (1.0 = use this feedback, 0.0 = skip) + """ + if not self.enable_rl_updates or self._llm is None: + return + + try: + from personalization.serving.personalized_llm import Feedback + + feedback = Feedback( + user_id=user_id, + turn_id=0, # Not used in current implementation + reward=reward, + gating=gating, + ) + self._llm.apply_feedback(feedback) + + except Exception as e: + print(f"Error applying feedback: {e}") + + def get_user_state(self, user_id: str) -> Dict[str, Any]: + """Get user state summary (for Y4 analysis).""" + self._ensure_initialized() + + if self._llm is not None: + return self._llm.get_user_state_summary(user_id) + return {} + + def persist(self): + """Save all state to disk.""" + if self._llm is not None: + self._llm.persist() + + def get_name(self) -> str: + mode_name = "RAG" if self.mode == "nopersonal" else "RAG+UV" + return f"{mode_name}({self.model_name})" + + |
