diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-27 09:57:37 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-27 09:57:37 -0600 |
| commit | dc801c07cf38b0c495686463e6ca6f871a64440e (patch) | |
| tree | 599f03114775921dbc472403c701f4a3a8ea188a /collaborativeagents/agents/openai_user_agent.py | |
| parent | e43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (diff) | |
Add collaborativeagents module and update gitignore
- Add collaborativeagents subproject with adapters, agents, and evaluation modules
- Update .gitignore to exclude large binary files (.whl, .tar), wandb logs, and results
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'collaborativeagents/agents/openai_user_agent.py')
| -rw-r--r-- | collaborativeagents/agents/openai_user_agent.py | 229 |
1 files changed, 229 insertions, 0 deletions
diff --git a/collaborativeagents/agents/openai_user_agent.py b/collaborativeagents/agents/openai_user_agent.py new file mode 100644 index 0000000..b49cd57 --- /dev/null +++ b/collaborativeagents/agents/openai_user_agent.py @@ -0,0 +1,229 @@ +""" +OpenAI-based User Agent for user simulation with GPT-5. + +Drop-in replacement for VLLMUserAgent — same interface, uses OpenAI API. +Supports reasoning models (GPT-5, o-series) that require max_completion_tokens. +""" + +import os +from typing import List, Dict, Any, Optional +from copy import deepcopy +from json_repair import repair_json +from openai import OpenAI, RateLimitError, APITimeoutError, APIConnectionError + +from agents.vllm_user_agent import ( + USER_SYSTEM_PROMPT_WITH_PREFERENCES, + USER_SYSTEM_PROMPT_WITHOUT_PREFERENCES, + TERMINATION_SIGNAL, +) + + +class OpenAIUserAgent: + """ + User Agent that uses the OpenAI API (GPT-5) for user simulation. + + Key differences from VLLMUserAgent: + - Uses openai.OpenAI client + - GPT-5 is a reasoning model: no temperature, uses max_completion_tokens + - Higher quality simulation at the cost of API calls + """ + + def __init__( + self, + user_task_description: str, + problem: str, + user_persona: str = None, + user_preferences: str = None, + model: str = "gpt-5", + api_key: Optional[str] = None, + base_url: Optional[str] = None, + num_retries: int = 3, + max_completion_tokens: int = 4096, # High for reasoning models + max_context_length: int = 128000, # GPT-5 context window + retry_base_delay: float = 1.0, + ): + self.user_task_description = user_task_description + self.problem = problem + self.user_persona = user_persona or "A helpful user seeking assistance." + self.user_preferences = user_preferences + self.model = model + self.num_retries = num_retries + self.max_completion_tokens = max_completion_tokens + self.max_context_length = max_context_length + self.retry_base_delay = retry_base_delay + + # Initialize OpenAI client + self._client = OpenAI( + api_key=api_key or os.getenv("OPENAI_API_KEY"), + base_url=base_url, + timeout=120.0, + ) + + # Build system prompt (same format as VLLMUserAgent) + if user_preferences: + self.system_prompt = USER_SYSTEM_PROMPT_WITH_PREFERENCES.format( + user_task_description=user_task_description, + problem=problem, + user_persona=self.user_persona, + user_preferences=user_preferences, + termination_signal=TERMINATION_SIGNAL, + ) + else: + self.system_prompt = USER_SYSTEM_PROMPT_WITHOUT_PREFERENCES.format( + user_task_description=user_task_description, + problem=problem, + user_persona=self.user_persona, + termination_signal=TERMINATION_SIGNAL, + ) + + def _estimate_tokens(self, text: str) -> int: + """Estimate token count (~3.5 chars/token).""" + return int(len(text) / 3.5) + + def _truncate_messages(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]: + """Truncate messages to fit within context, keeping recent messages.""" + if not messages: + return messages + + system_msg = messages[0] if messages[0]["role"] == "system" else None + conversation = messages[1:] if system_msg else messages + + system_tokens = self._estimate_tokens(system_msg["content"]) if system_msg else 0 + available_tokens = self.max_context_length - system_tokens - self.max_completion_tokens - 200 + + total_conv_tokens = sum(self._estimate_tokens(m["content"]) for m in conversation) + if total_conv_tokens <= available_tokens: + return messages + + # Keep recent messages + truncated = [] + current_tokens = 0 + for msg in reversed(conversation): + msg_tokens = self._estimate_tokens(msg["content"]) + if current_tokens + msg_tokens <= available_tokens: + truncated.insert(0, msg) + current_tokens += msg_tokens + else: + break + + if len(truncated) < len(conversation): + print(f"[OpenAIUserAgent] Truncated: kept {len(truncated)}/{len(conversation)} turns") + + return [system_msg] + truncated if system_msg else truncated + + def _generate(self, messages: List[Dict[str, str]]) -> str: + """Generate response using OpenAI API with retry.""" + messages = self._truncate_messages(messages) + + import time + for attempt in range(self.num_retries): + try: + # Build API call params + params = { + "model": self.model, + "messages": messages, + "max_completion_tokens": self.max_completion_tokens, + } + + # Non-reasoning models support temperature and response_format + if not self._is_reasoning_model(): + params["temperature"] = 0.7 + params["response_format"] = {"type": "json_object"} + + response = self._client.chat.completions.create(**params) + + content = response.choices[0].message.content + if content: + return content.strip() + + # Reasoning model may exhaust tokens + if response.choices[0].finish_reason == "length": + print(f"[OpenAIUserAgent] Response truncated (length), attempt {attempt+1}") + continue + + return "" + + except (RateLimitError, APITimeoutError, APIConnectionError) as e: + if attempt == self.num_retries - 1: + raise + delay = self.retry_base_delay * (2 ** attempt) + print(f"[OpenAIUserAgent] API error ({type(e).__name__}), retrying in {delay:.1f}s...") + time.sleep(delay) + + return "" + + def _is_reasoning_model(self) -> bool: + """Check if the model is a reasoning model (no temperature/response_format support).""" + reasoning_prefixes = ("o1", "o3", "gpt-5") + return any(self.model.startswith(p) for p in reasoning_prefixes) + + def get_system_prompt(self) -> str: + """Get the system prompt.""" + return self.system_prompt + + def reverse_roles(self, conversation: List[Dict[str, str]]) -> List[Dict[str, str]]: + """Reverse roles for user perspective.""" + conversation = deepcopy(conversation) + return [ + {"role": "user" if msg["role"] == "assistant" else "assistant", "content": msg["content"]} + for msg in conversation + ] + + def generate_user_response(self, conversation: List[Dict[str, str]]) -> Optional[Dict[str, Any]]: + """ + Generate user response given the conversation history. + + Args: + conversation: List of {"role": "user"|"assistant", "content": str} + + Returns: + Dict with keys: reasoning, draft_answer, should_terminate, response + Or None if all retries failed. + """ + for attempt in range(self.num_retries): + try: + messages = [{"role": "system", "content": self.system_prompt}] + messages.extend(self.reverse_roles(conversation)) + + response_text = self._generate(messages) + + if not response_text: + print(f"[OpenAIUserAgent] Empty response, attempt {attempt+1}") + continue + + # Parse JSON response + try: + parsed = repair_json(response_text, return_objects=True) + + required_keys = ["reasoning", "draft_answer", "should_terminate", "response"] + missing = [k for k in required_keys if k not in parsed] + + if missing: + print(f"[OpenAIUserAgent] Missing keys: {missing}, attempt {attempt+1}") + continue + + return parsed + + except Exception: + # Fallback: raw text as response + if TERMINATION_SIGNAL in response_text: + return { + "reasoning": "Ending conversation", + "draft_answer": "", + "should_terminate": True, + "response": TERMINATION_SIGNAL, + } + else: + return { + "reasoning": "", + "draft_answer": "", + "should_terminate": False, + "response": response_text, + } + + except Exception as e: + print(f"[OpenAIUserAgent] Error: {e}, attempt {attempt+1}") + continue + + print(f"[OpenAIUserAgent] Failed after {self.num_retries} attempts") + return None |
