diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-27 09:57:37 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-27 09:57:37 -0600 |
| commit | dc801c07cf38b0c495686463e6ca6f871a64440e (patch) | |
| tree | 599f03114775921dbc472403c701f4a3a8ea188a /src/personalization/evaluation/user_simulator | |
| parent | e43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (diff) | |
Add collaborativeagents module and update gitignore
- Add collaborativeagents subproject with adapters, agents, and evaluation modules
- Update .gitignore to exclude large binary files (.whl, .tar), wandb logs, and results
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'src/personalization/evaluation/user_simulator')
| -rw-r--r-- | src/personalization/evaluation/user_simulator/__init__.py | 5 | ||||
| -rw-r--r-- | src/personalization/evaluation/user_simulator/simulator.py | 310 |
2 files changed, 315 insertions, 0 deletions
diff --git a/src/personalization/evaluation/user_simulator/__init__.py b/src/personalization/evaluation/user_simulator/__init__.py new file mode 100644 index 0000000..f7799d0 --- /dev/null +++ b/src/personalization/evaluation/user_simulator/__init__.py @@ -0,0 +1,5 @@ +from .simulator import UserSimulator, UserSimulatorResponse + +__all__ = ["UserSimulator", "UserSimulatorResponse"] + + diff --git a/src/personalization/evaluation/user_simulator/simulator.py b/src/personalization/evaluation/user_simulator/simulator.py new file mode 100644 index 0000000..5f5f701 --- /dev/null +++ b/src/personalization/evaluation/user_simulator/simulator.py @@ -0,0 +1,310 @@ +""" +User Simulator + +Simulates a user with specific preferences who: +1. Presents problems to the agent +2. Checks if agent responses satisfy their preferences +3. Enforces preferences when violated +4. Tracks draft answer and decides when to terminate +""" + +import json +import os +from dataclasses import dataclass, field +from typing import List, Dict, Any, Optional + +from ..profiles.generator import UserProfile +from ..preference_bank.schemas import PreferenceItem + + +# User simulator system prompt template +USER_SYSTEM_PROMPT = """You are simulating a user who is collaborating with an AI assistant to solve a problem. You have specific preferences about how the assistant should respond. + +# Problem to Solve +{task_description} +{problem} +Note: The assistant cannot see this problem description directly. You need to communicate with them. + +# Your Persona +{persona} + +# Your Preferences (Grouped by Topic) +{preferences_grouped} + +# Preference Enforcement Rules +- For each assistant response, check which of YOUR preferences are RELEVANT to the current context +- A preference is relevant if the assistant's response touches on that topic/condition +- If a relevant preference is VIOLATED, you MUST enforce it before proceeding +- Do NOT update your draft answer or proceed until violated preferences are fixed +- Only check preferences that apply to the current response (e.g., coding preferences for code responses) + +# Draft Answer Management +- Maintain a working draft answer to the problem +- Start with "I don't know" +- Update it based on helpful information from the assistant +- Do NOT update if you're enforcing preferences + +# Conversation Guidelines +- Be somewhat vague initially, let the assistant ask clarifying questions +- Respond naturally like a real user +- Do not copy the problem description directly + +# Termination +Terminate when: +- Your draft answer seems correct and complete +- The assistant cannot help further + +When ready to terminate, include "TERMINATE" in your response. + +# Output Format (JSON) +{{ + "preference_checks": [ + {{ + "preference_id": str, + "topic": str, + "relevant": bool, + "satisfied": bool or null, + "violation_detail": str + }} + ], + "any_violation": bool, + "enforcement_needed": bool, + "reasoning": str, + "draft_answer": str, + "should_terminate": bool, + "response": str +}} + +IMPORTANT: Only include preferences that are RELEVANT to the current assistant response in preference_checks. +Output valid JSON only, no other text.""" + + +@dataclass +class PreferenceCheck: + """Result of checking one preference.""" + preference_id: str + topic: str + relevant: bool + satisfied: Optional[bool] # None if not relevant + violation_detail: str = "" + + +@dataclass +class UserSimulatorResponse: + """Response from the user simulator.""" + response: str # Text response to agent + preference_checks: List[PreferenceCheck] # Checked preferences + any_violation: bool # Any preference violated? + enforcement_needed: bool # Need to enforce? + draft_answer: str # Current draft answer + should_terminate: bool # Should end conversation? + reasoning: str # Internal reasoning + raw_output: Dict[str, Any] = field(default_factory=dict) + + +class UserSimulator: + """ + Simulates a user with preferences interacting with an agent. + """ + + def __init__( + self, + model_name: str = "Llama-3.3-70B-Instruct", + api_base: Optional[str] = None, + api_key: Optional[str] = None, + temperature: float = 0.8, + max_tokens: int = 2048, + ): + self.model_name = model_name + self.api_base = api_base or os.getenv("USER_SIM_API_BASE", "http://localhost:8004/v1") + self.api_key = api_key or os.getenv("USER_SIM_API_KEY", "EMPTY") + self.temperature = temperature + self.max_tokens = max_tokens + + # Current session state + self._profile: Optional[UserProfile] = None + self._task_description: str = "" + self._problem: str = "" + self._solution: str = "" + + self._init_client() + + def _init_client(self): + """Initialize OpenAI client.""" + try: + import openai + self.client = openai.OpenAI( + base_url=self.api_base, + api_key=self.api_key, + ) + except Exception as e: + print(f"Warning: Could not initialize OpenAI client for user simulator: {e}") + self.client = None + + def setup( + self, + profile: UserProfile, + task_description: str, + problem: str, + solution: str = "", + ): + """ + Set up the simulator for a new task. + + Args: + profile: User profile with preferences + task_description: Description of the task type + problem: The specific problem to solve + solution: Ground truth solution (for evaluation) + """ + self._profile = profile + self._task_description = task_description + self._problem = problem + self._solution = solution + + def _build_system_prompt(self) -> str: + """Build the system prompt with user profile and task.""" + if self._profile is None: + raise ValueError("User profile not set. Call setup() first.") + + return USER_SYSTEM_PROMPT.format( + task_description=self._task_description, + problem=self._problem, + persona=self._profile.persona, + preferences_grouped=self._profile.format_preferences_grouped(), + ) + + def _parse_response(self, raw_text: str) -> UserSimulatorResponse: + """Parse LLM output into structured response.""" + try: + # Try to extract JSON from response + text = raw_text.strip() + + # Handle markdown code blocks + if "```json" in text: + text = text.split("```json")[1].split("```")[0] + elif "```" in text: + text = text.split("```")[1].split("```")[0] + + data = json.loads(text) + + # Parse preference checks + pref_checks = [] + for check in data.get("preference_checks", []): + pref_checks.append(PreferenceCheck( + preference_id=check.get("preference_id", ""), + topic=check.get("topic", ""), + relevant=check.get("relevant", False), + satisfied=check.get("satisfied"), + violation_detail=check.get("violation_detail", ""), + )) + + return UserSimulatorResponse( + response=data.get("response", ""), + preference_checks=pref_checks, + any_violation=data.get("any_violation", False), + enforcement_needed=data.get("enforcement_needed", False), + draft_answer=data.get("draft_answer", "I don't know"), + should_terminate=data.get("should_terminate", False), + reasoning=data.get("reasoning", ""), + raw_output=data, + ) + + except Exception as e: + print(f"Error parsing user simulator response: {e}") + print(f"Raw text: {raw_text[:500]}...") + + # Return a basic response + return UserSimulatorResponse( + response=raw_text if len(raw_text) < 500 else "Could you please continue?", + preference_checks=[], + any_violation=False, + enforcement_needed=False, + draft_answer="I don't know", + should_terminate=False, + reasoning="Parse error", + raw_output={"error": str(e), "raw": raw_text}, + ) + + def respond( + self, + conversation_history: List[Dict[str, str]], + ) -> UserSimulatorResponse: + """ + Generate user response based on conversation. + + Args: + conversation_history: List of {"role": "user/assistant", "content": "..."} + + Returns: + UserSimulatorResponse with user's reply and preference status + """ + if self._profile is None: + raise ValueError("User profile not set. Call setup() first.") + + system_prompt = self._build_system_prompt() + + # Build messages - reverse roles (user simulator sees itself as user) + messages = [{"role": "system", "content": system_prompt}] + + for msg in conversation_history: + # Flip roles: agent's messages become user input to simulator + if msg["role"] == "assistant": + messages.append({"role": "user", "content": msg["content"]}) + else: + messages.append({"role": "assistant", "content": msg["content"]}) + + if self.client is None: + # Fallback for testing + return self._fallback_response(conversation_history) + + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=messages, + temperature=self.temperature, + max_tokens=self.max_tokens, + ) + + raw_text = response.choices[0].message.content + return self._parse_response(raw_text) + + except Exception as e: + print(f"Error calling user simulator LLM: {e}") + return self._fallback_response(conversation_history) + + def _fallback_response( + self, + conversation_history: List[Dict[str, str]], + ) -> UserSimulatorResponse: + """Generate a simple fallback response for testing.""" + num_turns = len([m for m in conversation_history if m["role"] == "assistant"]) + + if num_turns == 0: + # First turn - present the problem + response = f"Hi, I need help with this: {self._problem[:200]}..." + elif num_turns < 3: + response = "Thanks, that helps. Can you explain more?" + else: + response = "Got it, I think I understand now. TERMINATE" + + return UserSimulatorResponse( + response=response, + preference_checks=[], + any_violation=False, + enforcement_needed=False, + draft_answer="Draft answer from fallback", + should_terminate="TERMINATE" in response, + reasoning="Fallback mode", + raw_output={}, + ) + + def get_solution(self) -> str: + """Get the ground truth solution.""" + return self._solution + + def get_profile(self) -> Optional[UserProfile]: + """Get the current user profile.""" + return self._profile + + |
