summaryrefslogtreecommitdiff
path: root/src/personalization/evaluation/user_simulator/simulator.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-27 09:57:37 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-27 09:57:37 -0600
commitdc801c07cf38b0c495686463e6ca6f871a64440e (patch)
tree599f03114775921dbc472403c701f4a3a8ea188a /src/personalization/evaluation/user_simulator/simulator.py
parente43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (diff)
Add collaborativeagents module and update gitignore
- Add collaborativeagents subproject with adapters, agents, and evaluation modules - Update .gitignore to exclude large binary files (.whl, .tar), wandb logs, and results Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'src/personalization/evaluation/user_simulator/simulator.py')
-rw-r--r--src/personalization/evaluation/user_simulator/simulator.py310
1 files changed, 310 insertions, 0 deletions
diff --git a/src/personalization/evaluation/user_simulator/simulator.py b/src/personalization/evaluation/user_simulator/simulator.py
new file mode 100644
index 0000000..5f5f701
--- /dev/null
+++ b/src/personalization/evaluation/user_simulator/simulator.py
@@ -0,0 +1,310 @@
+"""
+User Simulator
+
+Simulates a user with specific preferences who:
+1. Presents problems to the agent
+2. Checks if agent responses satisfy their preferences
+3. Enforces preferences when violated
+4. Tracks draft answer and decides when to terminate
+"""
+
+import json
+import os
+from dataclasses import dataclass, field
+from typing import List, Dict, Any, Optional
+
+from ..profiles.generator import UserProfile
+from ..preference_bank.schemas import PreferenceItem
+
+
+# User simulator system prompt template
+USER_SYSTEM_PROMPT = """You are simulating a user who is collaborating with an AI assistant to solve a problem. You have specific preferences about how the assistant should respond.
+
+# Problem to Solve
+{task_description}
+{problem}
+Note: The assistant cannot see this problem description directly. You need to communicate with them.
+
+# Your Persona
+{persona}
+
+# Your Preferences (Grouped by Topic)
+{preferences_grouped}
+
+# Preference Enforcement Rules
+- For each assistant response, check which of YOUR preferences are RELEVANT to the current context
+- A preference is relevant if the assistant's response touches on that topic/condition
+- If a relevant preference is VIOLATED, you MUST enforce it before proceeding
+- Do NOT update your draft answer or proceed until violated preferences are fixed
+- Only check preferences that apply to the current response (e.g., coding preferences for code responses)
+
+# Draft Answer Management
+- Maintain a working draft answer to the problem
+- Start with "I don't know"
+- Update it based on helpful information from the assistant
+- Do NOT update if you're enforcing preferences
+
+# Conversation Guidelines
+- Be somewhat vague initially, let the assistant ask clarifying questions
+- Respond naturally like a real user
+- Do not copy the problem description directly
+
+# Termination
+Terminate when:
+- Your draft answer seems correct and complete
+- The assistant cannot help further
+
+When ready to terminate, include "TERMINATE" in your response.
+
+# Output Format (JSON)
+{{
+ "preference_checks": [
+ {{
+ "preference_id": str,
+ "topic": str,
+ "relevant": bool,
+ "satisfied": bool or null,
+ "violation_detail": str
+ }}
+ ],
+ "any_violation": bool,
+ "enforcement_needed": bool,
+ "reasoning": str,
+ "draft_answer": str,
+ "should_terminate": bool,
+ "response": str
+}}
+
+IMPORTANT: Only include preferences that are RELEVANT to the current assistant response in preference_checks.
+Output valid JSON only, no other text."""
+
+
+@dataclass
+class PreferenceCheck:
+ """Result of checking one preference."""
+ preference_id: str
+ topic: str
+ relevant: bool
+ satisfied: Optional[bool] # None if not relevant
+ violation_detail: str = ""
+
+
+@dataclass
+class UserSimulatorResponse:
+ """Response from the user simulator."""
+ response: str # Text response to agent
+ preference_checks: List[PreferenceCheck] # Checked preferences
+ any_violation: bool # Any preference violated?
+ enforcement_needed: bool # Need to enforce?
+ draft_answer: str # Current draft answer
+ should_terminate: bool # Should end conversation?
+ reasoning: str # Internal reasoning
+ raw_output: Dict[str, Any] = field(default_factory=dict)
+
+
+class UserSimulator:
+ """
+ Simulates a user with preferences interacting with an agent.
+ """
+
+ def __init__(
+ self,
+ model_name: str = "Llama-3.3-70B-Instruct",
+ api_base: Optional[str] = None,
+ api_key: Optional[str] = None,
+ temperature: float = 0.8,
+ max_tokens: int = 2048,
+ ):
+ self.model_name = model_name
+ self.api_base = api_base or os.getenv("USER_SIM_API_BASE", "http://localhost:8004/v1")
+ self.api_key = api_key or os.getenv("USER_SIM_API_KEY", "EMPTY")
+ self.temperature = temperature
+ self.max_tokens = max_tokens
+
+ # Current session state
+ self._profile: Optional[UserProfile] = None
+ self._task_description: str = ""
+ self._problem: str = ""
+ self._solution: str = ""
+
+ self._init_client()
+
+ def _init_client(self):
+ """Initialize OpenAI client."""
+ try:
+ import openai
+ self.client = openai.OpenAI(
+ base_url=self.api_base,
+ api_key=self.api_key,
+ )
+ except Exception as e:
+ print(f"Warning: Could not initialize OpenAI client for user simulator: {e}")
+ self.client = None
+
+ def setup(
+ self,
+ profile: UserProfile,
+ task_description: str,
+ problem: str,
+ solution: str = "",
+ ):
+ """
+ Set up the simulator for a new task.
+
+ Args:
+ profile: User profile with preferences
+ task_description: Description of the task type
+ problem: The specific problem to solve
+ solution: Ground truth solution (for evaluation)
+ """
+ self._profile = profile
+ self._task_description = task_description
+ self._problem = problem
+ self._solution = solution
+
+ def _build_system_prompt(self) -> str:
+ """Build the system prompt with user profile and task."""
+ if self._profile is None:
+ raise ValueError("User profile not set. Call setup() first.")
+
+ return USER_SYSTEM_PROMPT.format(
+ task_description=self._task_description,
+ problem=self._problem,
+ persona=self._profile.persona,
+ preferences_grouped=self._profile.format_preferences_grouped(),
+ )
+
+ def _parse_response(self, raw_text: str) -> UserSimulatorResponse:
+ """Parse LLM output into structured response."""
+ try:
+ # Try to extract JSON from response
+ text = raw_text.strip()
+
+ # Handle markdown code blocks
+ if "```json" in text:
+ text = text.split("```json")[1].split("```")[0]
+ elif "```" in text:
+ text = text.split("```")[1].split("```")[0]
+
+ data = json.loads(text)
+
+ # Parse preference checks
+ pref_checks = []
+ for check in data.get("preference_checks", []):
+ pref_checks.append(PreferenceCheck(
+ preference_id=check.get("preference_id", ""),
+ topic=check.get("topic", ""),
+ relevant=check.get("relevant", False),
+ satisfied=check.get("satisfied"),
+ violation_detail=check.get("violation_detail", ""),
+ ))
+
+ return UserSimulatorResponse(
+ response=data.get("response", ""),
+ preference_checks=pref_checks,
+ any_violation=data.get("any_violation", False),
+ enforcement_needed=data.get("enforcement_needed", False),
+ draft_answer=data.get("draft_answer", "I don't know"),
+ should_terminate=data.get("should_terminate", False),
+ reasoning=data.get("reasoning", ""),
+ raw_output=data,
+ )
+
+ except Exception as e:
+ print(f"Error parsing user simulator response: {e}")
+ print(f"Raw text: {raw_text[:500]}...")
+
+ # Return a basic response
+ return UserSimulatorResponse(
+ response=raw_text if len(raw_text) < 500 else "Could you please continue?",
+ preference_checks=[],
+ any_violation=False,
+ enforcement_needed=False,
+ draft_answer="I don't know",
+ should_terminate=False,
+ reasoning="Parse error",
+ raw_output={"error": str(e), "raw": raw_text},
+ )
+
+ def respond(
+ self,
+ conversation_history: List[Dict[str, str]],
+ ) -> UserSimulatorResponse:
+ """
+ Generate user response based on conversation.
+
+ Args:
+ conversation_history: List of {"role": "user/assistant", "content": "..."}
+
+ Returns:
+ UserSimulatorResponse with user's reply and preference status
+ """
+ if self._profile is None:
+ raise ValueError("User profile not set. Call setup() first.")
+
+ system_prompt = self._build_system_prompt()
+
+ # Build messages - reverse roles (user simulator sees itself as user)
+ messages = [{"role": "system", "content": system_prompt}]
+
+ for msg in conversation_history:
+ # Flip roles: agent's messages become user input to simulator
+ if msg["role"] == "assistant":
+ messages.append({"role": "user", "content": msg["content"]})
+ else:
+ messages.append({"role": "assistant", "content": msg["content"]})
+
+ if self.client is None:
+ # Fallback for testing
+ return self._fallback_response(conversation_history)
+
+ try:
+ response = self.client.chat.completions.create(
+ model=self.model_name,
+ messages=messages,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ raw_text = response.choices[0].message.content
+ return self._parse_response(raw_text)
+
+ except Exception as e:
+ print(f"Error calling user simulator LLM: {e}")
+ return self._fallback_response(conversation_history)
+
+ def _fallback_response(
+ self,
+ conversation_history: List[Dict[str, str]],
+ ) -> UserSimulatorResponse:
+ """Generate a simple fallback response for testing."""
+ num_turns = len([m for m in conversation_history if m["role"] == "assistant"])
+
+ if num_turns == 0:
+ # First turn - present the problem
+ response = f"Hi, I need help with this: {self._problem[:200]}..."
+ elif num_turns < 3:
+ response = "Thanks, that helps. Can you explain more?"
+ else:
+ response = "Got it, I think I understand now. TERMINATE"
+
+ return UserSimulatorResponse(
+ response=response,
+ preference_checks=[],
+ any_violation=False,
+ enforcement_needed=False,
+ draft_answer="Draft answer from fallback",
+ should_terminate="TERMINATE" in response,
+ reasoning="Fallback mode",
+ raw_output={},
+ )
+
+ def get_solution(self) -> str:
+ """Get the ground truth solution."""
+ return self._solution
+
+ def get_profile(self) -> Optional[UserProfile]:
+ """Get the current user profile."""
+ return self._profile
+
+