summaryrefslogtreecommitdiff
path: root/src/personalization/evaluation/baselines/rag_memory.py
blob: 2b391c3c4354bd779bfa5a06bdf5add10f3af81d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
"""
RAG Memory Baseline (Y3/Y4)

Wraps the PersonalizedLLM for use in the evaluation framework.
Y3: Extractor + RAG (mode="nopersonal")
Y4: Extractor + RAG + User Vector (mode="full")
"""

from typing import List, Dict, Any, Optional
import os
import sys

from .base import BaselineAgent, AgentResponse

# Add src to path for imports
_src_path = os.path.join(os.path.dirname(__file__), "../../../..")
if _src_path not in sys.path:
    sys.path.insert(0, _src_path)


class RAGMemoryAgent(BaselineAgent):
    """
    Y3/Y4: RAG-based memory with optional user vector.
    
    This agent:
    - Extracts preferences from conversations using the extractor
    - Stores preferences as memory cards
    - Retrieves relevant memories using RAG for each query
    - (Y4 only) Uses user vector to personalize retrieval
    """
    
    def __init__(
        self,
        model_name: str = "llama-8b",
        mode: str = "nopersonal",  # "nopersonal" for Y3, "full" for Y4
        memory_cards_path: str = None,
        memory_embeddings_path: str = None,
        enable_preference_extraction: bool = True,
        enable_rl_updates: bool = False,
        only_own_memories: bool = True,
        **kwargs
    ):
        """
        Args:
            model_name: LLM model to use
            mode: "nopersonal" (Y3) or "full" (Y4)
            memory_cards_path: Path to memory cards file
            memory_embeddings_path: Path to embeddings file
            enable_preference_extraction: Whether to extract preferences
            enable_rl_updates: Whether to update user vectors (Y4 only)
            only_own_memories: Only retrieve user's own memories
        """
        super().__init__(model_name, **kwargs)
        
        self.mode = mode
        self.enable_rl_updates = enable_rl_updates and (mode == "full")
        
        # Default paths
        base_dir = os.path.join(os.path.dirname(__file__), "../../../../..")
        self.memory_cards_path = memory_cards_path or os.path.join(
            base_dir, "data/eval/memory_cards.jsonl"
        )
        self.memory_embeddings_path = memory_embeddings_path or os.path.join(
            base_dir, "data/eval/memory_embeddings.npy"
        )
        
        self.enable_preference_extraction = enable_preference_extraction
        self.only_own_memories = only_own_memories
        
        # Lazy initialization
        self._llm = None
        self._initialized = False
    
    def _ensure_initialized(self):
        """Lazy initialization of PersonalizedLLM."""
        if self._initialized:
            return
        
        try:
            from personalization.serving.personalized_llm import PersonalizedLLM
            
            self._llm = PersonalizedLLM(
                mode=self.mode,
                enable_preference_extraction=self.enable_preference_extraction,
                enable_rl_updates=self.enable_rl_updates,
                only_own_memories=self.only_own_memories,
                memory_cards_path=self.memory_cards_path,
                memory_embeddings_path=self.memory_embeddings_path,
                eval_mode=True,  # Deterministic selection
            )
            self._initialized = True
            
        except Exception as e:
            print(f"Warning: Could not initialize PersonalizedLLM: {e}")
            print("Falling back to simple response mode.")
            self._llm = None
            self._initialized = True
    
    def respond(
        self,
        user_id: str,
        query: str,
        conversation_history: List[Dict[str, str]],
        **kwargs
    ) -> AgentResponse:
        """Generate response using RAG memory."""
        
        self._ensure_initialized()
        
        if self._llm is None:
            # Fallback mode
            return AgentResponse(
                answer=f"[RAGMemoryAgent-{self.mode}] Response to: {query[:50]}...",
                debug_info={"mode": "fallback"},
            )
        
        try:
            # Use PersonalizedLLM's chat interface
            response = self._llm.chat(user_id, query)
            
            debug_info = {
                "mode": self.mode,
                "num_memories_retrieved": len(response.debug.selected_memory_ids) if response.debug else 0,
                "selected_memories": response.debug.selected_memory_notes if response.debug else [],
                "extracted_preferences": response.debug.extracted_preferences if response.debug else [],
            }
            
            if response.debug and response.debug.extra:
                debug_info.update(response.debug.extra)
            
            return AgentResponse(
                answer=response.answer,
                debug_info=debug_info,
            )
            
        except Exception as e:
            print(f"Error in RAGMemoryAgent.respond: {e}")
            return AgentResponse(
                answer=f"I apologize for the error. Regarding: {query[:100]}",
                debug_info={"error": str(e)},
            )
    
    def end_session(self, user_id: str, conversation: List[Dict[str, str]]):
        """
        Called at end of session.
        PersonalizedLLM already extracts preferences during chat(),
        so we just reset the session state.
        """
        self._ensure_initialized()
        
        if self._llm is not None:
            self._llm.reset_session(user_id)
    
    def reset_user(self, user_id: str):
        """Reset all state for a user."""
        self._ensure_initialized()
        
        if self._llm is not None:
            self._llm.reset_user(user_id)
    
    def apply_feedback(self, user_id: str, reward: float, gating: float = 1.0):
        """
        Apply feedback for user vector updates (Y4 only).
        
        Args:
            user_id: User identifier
            reward: Reward signal (e.g., from preference satisfaction)
            gating: Gating signal (1.0 = use this feedback, 0.0 = skip)
        """
        if not self.enable_rl_updates or self._llm is None:
            return
        
        try:
            from personalization.serving.personalized_llm import Feedback
            
            feedback = Feedback(
                user_id=user_id,
                turn_id=0,  # Not used in current implementation
                reward=reward,
                gating=gating,
            )
            self._llm.apply_feedback(feedback)
            
        except Exception as e:
            print(f"Error applying feedback: {e}")
    
    def get_user_state(self, user_id: str) -> Dict[str, Any]:
        """Get user state summary (for Y4 analysis)."""
        self._ensure_initialized()
        
        if self._llm is not None:
            return self._llm.get_user_state_summary(user_id)
        return {}
    
    def persist(self):
        """Save all state to disk."""
        if self._llm is not None:
            self._llm.persist()
    
    def get_name(self) -> str:
        mode_name = "RAG" if self.mode == "nopersonal" else "RAG+UV"
        return f"{mode_name}({self.model_name})"