summaryrefslogtreecommitdiff
path: root/src/personalization/evaluation/user_simulator/simulator.py
blob: 5f5f701a8ca65a34442da48093f1c07c985a2ff7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
"""
User Simulator

Simulates a user with specific preferences who:
1. Presents problems to the agent
2. Checks if agent responses satisfy their preferences
3. Enforces preferences when violated
4. Tracks draft answer and decides when to terminate
"""

import json
import os
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional

from ..profiles.generator import UserProfile
from ..preference_bank.schemas import PreferenceItem


# User simulator system prompt template
USER_SYSTEM_PROMPT = """You are simulating a user who is collaborating with an AI assistant to solve a problem. You have specific preferences about how the assistant should respond.

# Problem to Solve
{task_description}
{problem}
Note: The assistant cannot see this problem description directly. You need to communicate with them.

# Your Persona
{persona}

# Your Preferences (Grouped by Topic)
{preferences_grouped}

# Preference Enforcement Rules
- For each assistant response, check which of YOUR preferences are RELEVANT to the current context
- A preference is relevant if the assistant's response touches on that topic/condition
- If a relevant preference is VIOLATED, you MUST enforce it before proceeding
- Do NOT update your draft answer or proceed until violated preferences are fixed
- Only check preferences that apply to the current response (e.g., coding preferences for code responses)

# Draft Answer Management
- Maintain a working draft answer to the problem
- Start with "I don't know"
- Update it based on helpful information from the assistant
- Do NOT update if you're enforcing preferences

# Conversation Guidelines
- Be somewhat vague initially, let the assistant ask clarifying questions
- Respond naturally like a real user
- Do not copy the problem description directly

# Termination
Terminate when:
- Your draft answer seems correct and complete
- The assistant cannot help further

When ready to terminate, include "TERMINATE" in your response.

# Output Format (JSON)
{{
    "preference_checks": [
        {{
            "preference_id": str,
            "topic": str,
            "relevant": bool,
            "satisfied": bool or null,
            "violation_detail": str
        }}
    ],
    "any_violation": bool,
    "enforcement_needed": bool,
    "reasoning": str,
    "draft_answer": str,
    "should_terminate": bool,
    "response": str
}}

IMPORTANT: Only include preferences that are RELEVANT to the current assistant response in preference_checks.
Output valid JSON only, no other text."""


@dataclass
class PreferenceCheck:
    """Result of checking one preference."""
    preference_id: str
    topic: str
    relevant: bool
    satisfied: Optional[bool]  # None if not relevant
    violation_detail: str = ""


@dataclass
class UserSimulatorResponse:
    """Response from the user simulator."""
    response: str                           # Text response to agent
    preference_checks: List[PreferenceCheck]  # Checked preferences
    any_violation: bool                     # Any preference violated?
    enforcement_needed: bool                # Need to enforce?
    draft_answer: str                       # Current draft answer
    should_terminate: bool                  # Should end conversation?
    reasoning: str                          # Internal reasoning
    raw_output: Dict[str, Any] = field(default_factory=dict)


class UserSimulator:
    """
    Simulates a user with preferences interacting with an agent.
    """
    
    def __init__(
        self,
        model_name: str = "Llama-3.3-70B-Instruct",
        api_base: Optional[str] = None,
        api_key: Optional[str] = None,
        temperature: float = 0.8,
        max_tokens: int = 2048,
    ):
        self.model_name = model_name
        self.api_base = api_base or os.getenv("USER_SIM_API_BASE", "http://localhost:8004/v1")
        self.api_key = api_key or os.getenv("USER_SIM_API_KEY", "EMPTY")
        self.temperature = temperature
        self.max_tokens = max_tokens
        
        # Current session state
        self._profile: Optional[UserProfile] = None
        self._task_description: str = ""
        self._problem: str = ""
        self._solution: str = ""
        
        self._init_client()
    
    def _init_client(self):
        """Initialize OpenAI client."""
        try:
            import openai
            self.client = openai.OpenAI(
                base_url=self.api_base,
                api_key=self.api_key,
            )
        except Exception as e:
            print(f"Warning: Could not initialize OpenAI client for user simulator: {e}")
            self.client = None
    
    def setup(
        self,
        profile: UserProfile,
        task_description: str,
        problem: str,
        solution: str = "",
    ):
        """
        Set up the simulator for a new task.
        
        Args:
            profile: User profile with preferences
            task_description: Description of the task type
            problem: The specific problem to solve
            solution: Ground truth solution (for evaluation)
        """
        self._profile = profile
        self._task_description = task_description
        self._problem = problem
        self._solution = solution
    
    def _build_system_prompt(self) -> str:
        """Build the system prompt with user profile and task."""
        if self._profile is None:
            raise ValueError("User profile not set. Call setup() first.")
        
        return USER_SYSTEM_PROMPT.format(
            task_description=self._task_description,
            problem=self._problem,
            persona=self._profile.persona,
            preferences_grouped=self._profile.format_preferences_grouped(),
        )
    
    def _parse_response(self, raw_text: str) -> UserSimulatorResponse:
        """Parse LLM output into structured response."""
        try:
            # Try to extract JSON from response
            text = raw_text.strip()
            
            # Handle markdown code blocks
            if "```json" in text:
                text = text.split("```json")[1].split("```")[0]
            elif "```" in text:
                text = text.split("```")[1].split("```")[0]
            
            data = json.loads(text)
            
            # Parse preference checks
            pref_checks = []
            for check in data.get("preference_checks", []):
                pref_checks.append(PreferenceCheck(
                    preference_id=check.get("preference_id", ""),
                    topic=check.get("topic", ""),
                    relevant=check.get("relevant", False),
                    satisfied=check.get("satisfied"),
                    violation_detail=check.get("violation_detail", ""),
                ))
            
            return UserSimulatorResponse(
                response=data.get("response", ""),
                preference_checks=pref_checks,
                any_violation=data.get("any_violation", False),
                enforcement_needed=data.get("enforcement_needed", False),
                draft_answer=data.get("draft_answer", "I don't know"),
                should_terminate=data.get("should_terminate", False),
                reasoning=data.get("reasoning", ""),
                raw_output=data,
            )
            
        except Exception as e:
            print(f"Error parsing user simulator response: {e}")
            print(f"Raw text: {raw_text[:500]}...")
            
            # Return a basic response
            return UserSimulatorResponse(
                response=raw_text if len(raw_text) < 500 else "Could you please continue?",
                preference_checks=[],
                any_violation=False,
                enforcement_needed=False,
                draft_answer="I don't know",
                should_terminate=False,
                reasoning="Parse error",
                raw_output={"error": str(e), "raw": raw_text},
            )
    
    def respond(
        self,
        conversation_history: List[Dict[str, str]],
    ) -> UserSimulatorResponse:
        """
        Generate user response based on conversation.
        
        Args:
            conversation_history: List of {"role": "user/assistant", "content": "..."}
        
        Returns:
            UserSimulatorResponse with user's reply and preference status
        """
        if self._profile is None:
            raise ValueError("User profile not set. Call setup() first.")
        
        system_prompt = self._build_system_prompt()
        
        # Build messages - reverse roles (user simulator sees itself as user)
        messages = [{"role": "system", "content": system_prompt}]
        
        for msg in conversation_history:
            # Flip roles: agent's messages become user input to simulator
            if msg["role"] == "assistant":
                messages.append({"role": "user", "content": msg["content"]})
            else:
                messages.append({"role": "assistant", "content": msg["content"]})
        
        if self.client is None:
            # Fallback for testing
            return self._fallback_response(conversation_history)
        
        try:
            response = self.client.chat.completions.create(
                model=self.model_name,
                messages=messages,
                temperature=self.temperature,
                max_tokens=self.max_tokens,
            )
            
            raw_text = response.choices[0].message.content
            return self._parse_response(raw_text)
            
        except Exception as e:
            print(f"Error calling user simulator LLM: {e}")
            return self._fallback_response(conversation_history)
    
    def _fallback_response(
        self,
        conversation_history: List[Dict[str, str]],
    ) -> UserSimulatorResponse:
        """Generate a simple fallback response for testing."""
        num_turns = len([m for m in conversation_history if m["role"] == "assistant"])
        
        if num_turns == 0:
            # First turn - present the problem
            response = f"Hi, I need help with this: {self._problem[:200]}..."
        elif num_turns < 3:
            response = "Thanks, that helps. Can you explain more?"
        else:
            response = "Got it, I think I understand now. TERMINATE"
        
        return UserSimulatorResponse(
            response=response,
            preference_checks=[],
            any_violation=False,
            enforcement_needed=False,
            draft_answer="Draft answer from fallback",
            should_terminate="TERMINATE" in response,
            reasoning="Fallback mode",
            raw_output={},
        )
    
    def get_solution(self) -> str:
        """Get the ground truth solution."""
        return self._solution
    
    def get_profile(self) -> Optional[UserProfile]:
        """Get the current user profile."""
        return self._profile