summaryrefslogtreecommitdiff
path: root/collaborativeagents/agents/openai_user_agent.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-27 09:57:37 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-27 09:57:37 -0600
commitdc801c07cf38b0c495686463e6ca6f871a64440e (patch)
tree599f03114775921dbc472403c701f4a3a8ea188a /collaborativeagents/agents/openai_user_agent.py
parente43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (diff)
Add collaborativeagents module and update gitignore
- Add collaborativeagents subproject with adapters, agents, and evaluation modules - Update .gitignore to exclude large binary files (.whl, .tar), wandb logs, and results Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'collaborativeagents/agents/openai_user_agent.py')
-rw-r--r--collaborativeagents/agents/openai_user_agent.py229
1 files changed, 229 insertions, 0 deletions
diff --git a/collaborativeagents/agents/openai_user_agent.py b/collaborativeagents/agents/openai_user_agent.py
new file mode 100644
index 0000000..b49cd57
--- /dev/null
+++ b/collaborativeagents/agents/openai_user_agent.py
@@ -0,0 +1,229 @@
+"""
+OpenAI-based User Agent for user simulation with GPT-5.
+
+Drop-in replacement for VLLMUserAgent — same interface, uses OpenAI API.
+Supports reasoning models (GPT-5, o-series) that require max_completion_tokens.
+"""
+
+import os
+from typing import List, Dict, Any, Optional
+from copy import deepcopy
+from json_repair import repair_json
+from openai import OpenAI, RateLimitError, APITimeoutError, APIConnectionError
+
+from agents.vllm_user_agent import (
+ USER_SYSTEM_PROMPT_WITH_PREFERENCES,
+ USER_SYSTEM_PROMPT_WITHOUT_PREFERENCES,
+ TERMINATION_SIGNAL,
+)
+
+
+class OpenAIUserAgent:
+ """
+ User Agent that uses the OpenAI API (GPT-5) for user simulation.
+
+ Key differences from VLLMUserAgent:
+ - Uses openai.OpenAI client
+ - GPT-5 is a reasoning model: no temperature, uses max_completion_tokens
+ - Higher quality simulation at the cost of API calls
+ """
+
+ def __init__(
+ self,
+ user_task_description: str,
+ problem: str,
+ user_persona: str = None,
+ user_preferences: str = None,
+ model: str = "gpt-5",
+ api_key: Optional[str] = None,
+ base_url: Optional[str] = None,
+ num_retries: int = 3,
+ max_completion_tokens: int = 4096, # High for reasoning models
+ max_context_length: int = 128000, # GPT-5 context window
+ retry_base_delay: float = 1.0,
+ ):
+ self.user_task_description = user_task_description
+ self.problem = problem
+ self.user_persona = user_persona or "A helpful user seeking assistance."
+ self.user_preferences = user_preferences
+ self.model = model
+ self.num_retries = num_retries
+ self.max_completion_tokens = max_completion_tokens
+ self.max_context_length = max_context_length
+ self.retry_base_delay = retry_base_delay
+
+ # Initialize OpenAI client
+ self._client = OpenAI(
+ api_key=api_key or os.getenv("OPENAI_API_KEY"),
+ base_url=base_url,
+ timeout=120.0,
+ )
+
+ # Build system prompt (same format as VLLMUserAgent)
+ if user_preferences:
+ self.system_prompt = USER_SYSTEM_PROMPT_WITH_PREFERENCES.format(
+ user_task_description=user_task_description,
+ problem=problem,
+ user_persona=self.user_persona,
+ user_preferences=user_preferences,
+ termination_signal=TERMINATION_SIGNAL,
+ )
+ else:
+ self.system_prompt = USER_SYSTEM_PROMPT_WITHOUT_PREFERENCES.format(
+ user_task_description=user_task_description,
+ problem=problem,
+ user_persona=self.user_persona,
+ termination_signal=TERMINATION_SIGNAL,
+ )
+
+ def _estimate_tokens(self, text: str) -> int:
+ """Estimate token count (~3.5 chars/token)."""
+ return int(len(text) / 3.5)
+
+ def _truncate_messages(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
+ """Truncate messages to fit within context, keeping recent messages."""
+ if not messages:
+ return messages
+
+ system_msg = messages[0] if messages[0]["role"] == "system" else None
+ conversation = messages[1:] if system_msg else messages
+
+ system_tokens = self._estimate_tokens(system_msg["content"]) if system_msg else 0
+ available_tokens = self.max_context_length - system_tokens - self.max_completion_tokens - 200
+
+ total_conv_tokens = sum(self._estimate_tokens(m["content"]) for m in conversation)
+ if total_conv_tokens <= available_tokens:
+ return messages
+
+ # Keep recent messages
+ truncated = []
+ current_tokens = 0
+ for msg in reversed(conversation):
+ msg_tokens = self._estimate_tokens(msg["content"])
+ if current_tokens + msg_tokens <= available_tokens:
+ truncated.insert(0, msg)
+ current_tokens += msg_tokens
+ else:
+ break
+
+ if len(truncated) < len(conversation):
+ print(f"[OpenAIUserAgent] Truncated: kept {len(truncated)}/{len(conversation)} turns")
+
+ return [system_msg] + truncated if system_msg else truncated
+
+ def _generate(self, messages: List[Dict[str, str]]) -> str:
+ """Generate response using OpenAI API with retry."""
+ messages = self._truncate_messages(messages)
+
+ import time
+ for attempt in range(self.num_retries):
+ try:
+ # Build API call params
+ params = {
+ "model": self.model,
+ "messages": messages,
+ "max_completion_tokens": self.max_completion_tokens,
+ }
+
+ # Non-reasoning models support temperature and response_format
+ if not self._is_reasoning_model():
+ params["temperature"] = 0.7
+ params["response_format"] = {"type": "json_object"}
+
+ response = self._client.chat.completions.create(**params)
+
+ content = response.choices[0].message.content
+ if content:
+ return content.strip()
+
+ # Reasoning model may exhaust tokens
+ if response.choices[0].finish_reason == "length":
+ print(f"[OpenAIUserAgent] Response truncated (length), attempt {attempt+1}")
+ continue
+
+ return ""
+
+ except (RateLimitError, APITimeoutError, APIConnectionError) as e:
+ if attempt == self.num_retries - 1:
+ raise
+ delay = self.retry_base_delay * (2 ** attempt)
+ print(f"[OpenAIUserAgent] API error ({type(e).__name__}), retrying in {delay:.1f}s...")
+ time.sleep(delay)
+
+ return ""
+
+ def _is_reasoning_model(self) -> bool:
+ """Check if the model is a reasoning model (no temperature/response_format support)."""
+ reasoning_prefixes = ("o1", "o3", "gpt-5")
+ return any(self.model.startswith(p) for p in reasoning_prefixes)
+
+ def get_system_prompt(self) -> str:
+ """Get the system prompt."""
+ return self.system_prompt
+
+ def reverse_roles(self, conversation: List[Dict[str, str]]) -> List[Dict[str, str]]:
+ """Reverse roles for user perspective."""
+ conversation = deepcopy(conversation)
+ return [
+ {"role": "user" if msg["role"] == "assistant" else "assistant", "content": msg["content"]}
+ for msg in conversation
+ ]
+
+ def generate_user_response(self, conversation: List[Dict[str, str]]) -> Optional[Dict[str, Any]]:
+ """
+ Generate user response given the conversation history.
+
+ Args:
+ conversation: List of {"role": "user"|"assistant", "content": str}
+
+ Returns:
+ Dict with keys: reasoning, draft_answer, should_terminate, response
+ Or None if all retries failed.
+ """
+ for attempt in range(self.num_retries):
+ try:
+ messages = [{"role": "system", "content": self.system_prompt}]
+ messages.extend(self.reverse_roles(conversation))
+
+ response_text = self._generate(messages)
+
+ if not response_text:
+ print(f"[OpenAIUserAgent] Empty response, attempt {attempt+1}")
+ continue
+
+ # Parse JSON response
+ try:
+ parsed = repair_json(response_text, return_objects=True)
+
+ required_keys = ["reasoning", "draft_answer", "should_terminate", "response"]
+ missing = [k for k in required_keys if k not in parsed]
+
+ if missing:
+ print(f"[OpenAIUserAgent] Missing keys: {missing}, attempt {attempt+1}")
+ continue
+
+ return parsed
+
+ except Exception:
+ # Fallback: raw text as response
+ if TERMINATION_SIGNAL in response_text:
+ return {
+ "reasoning": "Ending conversation",
+ "draft_answer": "",
+ "should_terminate": True,
+ "response": TERMINATION_SIGNAL,
+ }
+ else:
+ return {
+ "reasoning": "",
+ "draft_answer": "",
+ "should_terminate": False,
+ "response": response_text,
+ }
+
+ except Exception as e:
+ print(f"[OpenAIUserAgent] Error: {e}, attempt {attempt+1}")
+ continue
+
+ print(f"[OpenAIUserAgent] Failed after {self.num_retries} attempts")
+ return None