summaryrefslogtreecommitdiff
path: root/src/personalization/evaluation/baselines/rag_memory.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/personalization/evaluation/baselines/rag_memory.py')
-rw-r--r--src/personalization/evaluation/baselines/rag_memory.py204
1 files changed, 204 insertions, 0 deletions
diff --git a/src/personalization/evaluation/baselines/rag_memory.py b/src/personalization/evaluation/baselines/rag_memory.py
new file mode 100644
index 0000000..2b391c3
--- /dev/null
+++ b/src/personalization/evaluation/baselines/rag_memory.py
@@ -0,0 +1,204 @@
+"""
+RAG Memory Baseline (Y3/Y4)
+
+Wraps the PersonalizedLLM for use in the evaluation framework.
+Y3: Extractor + RAG (mode="nopersonal")
+Y4: Extractor + RAG + User Vector (mode="full")
+"""
+
+from typing import List, Dict, Any, Optional
+import os
+import sys
+
+from .base import BaselineAgent, AgentResponse
+
+# Add src to path for imports
+_src_path = os.path.join(os.path.dirname(__file__), "../../../..")
+if _src_path not in sys.path:
+ sys.path.insert(0, _src_path)
+
+
+class RAGMemoryAgent(BaselineAgent):
+ """
+ Y3/Y4: RAG-based memory with optional user vector.
+
+ This agent:
+ - Extracts preferences from conversations using the extractor
+ - Stores preferences as memory cards
+ - Retrieves relevant memories using RAG for each query
+ - (Y4 only) Uses user vector to personalize retrieval
+ """
+
+ def __init__(
+ self,
+ model_name: str = "llama-8b",
+ mode: str = "nopersonal", # "nopersonal" for Y3, "full" for Y4
+ memory_cards_path: str = None,
+ memory_embeddings_path: str = None,
+ enable_preference_extraction: bool = True,
+ enable_rl_updates: bool = False,
+ only_own_memories: bool = True,
+ **kwargs
+ ):
+ """
+ Args:
+ model_name: LLM model to use
+ mode: "nopersonal" (Y3) or "full" (Y4)
+ memory_cards_path: Path to memory cards file
+ memory_embeddings_path: Path to embeddings file
+ enable_preference_extraction: Whether to extract preferences
+ enable_rl_updates: Whether to update user vectors (Y4 only)
+ only_own_memories: Only retrieve user's own memories
+ """
+ super().__init__(model_name, **kwargs)
+
+ self.mode = mode
+ self.enable_rl_updates = enable_rl_updates and (mode == "full")
+
+ # Default paths
+ base_dir = os.path.join(os.path.dirname(__file__), "../../../../..")
+ self.memory_cards_path = memory_cards_path or os.path.join(
+ base_dir, "data/eval/memory_cards.jsonl"
+ )
+ self.memory_embeddings_path = memory_embeddings_path or os.path.join(
+ base_dir, "data/eval/memory_embeddings.npy"
+ )
+
+ self.enable_preference_extraction = enable_preference_extraction
+ self.only_own_memories = only_own_memories
+
+ # Lazy initialization
+ self._llm = None
+ self._initialized = False
+
+ def _ensure_initialized(self):
+ """Lazy initialization of PersonalizedLLM."""
+ if self._initialized:
+ return
+
+ try:
+ from personalization.serving.personalized_llm import PersonalizedLLM
+
+ self._llm = PersonalizedLLM(
+ mode=self.mode,
+ enable_preference_extraction=self.enable_preference_extraction,
+ enable_rl_updates=self.enable_rl_updates,
+ only_own_memories=self.only_own_memories,
+ memory_cards_path=self.memory_cards_path,
+ memory_embeddings_path=self.memory_embeddings_path,
+ eval_mode=True, # Deterministic selection
+ )
+ self._initialized = True
+
+ except Exception as e:
+ print(f"Warning: Could not initialize PersonalizedLLM: {e}")
+ print("Falling back to simple response mode.")
+ self._llm = None
+ self._initialized = True
+
+ def respond(
+ self,
+ user_id: str,
+ query: str,
+ conversation_history: List[Dict[str, str]],
+ **kwargs
+ ) -> AgentResponse:
+ """Generate response using RAG memory."""
+
+ self._ensure_initialized()
+
+ if self._llm is None:
+ # Fallback mode
+ return AgentResponse(
+ answer=f"[RAGMemoryAgent-{self.mode}] Response to: {query[:50]}...",
+ debug_info={"mode": "fallback"},
+ )
+
+ try:
+ # Use PersonalizedLLM's chat interface
+ response = self._llm.chat(user_id, query)
+
+ debug_info = {
+ "mode": self.mode,
+ "num_memories_retrieved": len(response.debug.selected_memory_ids) if response.debug else 0,
+ "selected_memories": response.debug.selected_memory_notes if response.debug else [],
+ "extracted_preferences": response.debug.extracted_preferences if response.debug else [],
+ }
+
+ if response.debug and response.debug.extra:
+ debug_info.update(response.debug.extra)
+
+ return AgentResponse(
+ answer=response.answer,
+ debug_info=debug_info,
+ )
+
+ except Exception as e:
+ print(f"Error in RAGMemoryAgent.respond: {e}")
+ return AgentResponse(
+ answer=f"I apologize for the error. Regarding: {query[:100]}",
+ debug_info={"error": str(e)},
+ )
+
+ def end_session(self, user_id: str, conversation: List[Dict[str, str]]):
+ """
+ Called at end of session.
+ PersonalizedLLM already extracts preferences during chat(),
+ so we just reset the session state.
+ """
+ self._ensure_initialized()
+
+ if self._llm is not None:
+ self._llm.reset_session(user_id)
+
+ def reset_user(self, user_id: str):
+ """Reset all state for a user."""
+ self._ensure_initialized()
+
+ if self._llm is not None:
+ self._llm.reset_user(user_id)
+
+ def apply_feedback(self, user_id: str, reward: float, gating: float = 1.0):
+ """
+ Apply feedback for user vector updates (Y4 only).
+
+ Args:
+ user_id: User identifier
+ reward: Reward signal (e.g., from preference satisfaction)
+ gating: Gating signal (1.0 = use this feedback, 0.0 = skip)
+ """
+ if not self.enable_rl_updates or self._llm is None:
+ return
+
+ try:
+ from personalization.serving.personalized_llm import Feedback
+
+ feedback = Feedback(
+ user_id=user_id,
+ turn_id=0, # Not used in current implementation
+ reward=reward,
+ gating=gating,
+ )
+ self._llm.apply_feedback(feedback)
+
+ except Exception as e:
+ print(f"Error applying feedback: {e}")
+
+ def get_user_state(self, user_id: str) -> Dict[str, Any]:
+ """Get user state summary (for Y4 analysis)."""
+ self._ensure_initialized()
+
+ if self._llm is not None:
+ return self._llm.get_user_state_summary(user_id)
+ return {}
+
+ def persist(self):
+ """Save all state to disk."""
+ if self._llm is not None:
+ self._llm.persist()
+
+ def get_name(self) -> str:
+ mode_name = "RAG" if self.mode == "nopersonal" else "RAG+UV"
+ return f"{mode_name}({self.model_name})"
+
+