""" OpenAI-based User Agent for user simulation with GPT-5. Drop-in replacement for VLLMUserAgent — same interface, uses OpenAI API. Supports reasoning models (GPT-5, o-series) that require max_completion_tokens. """ import os from typing import List, Dict, Any, Optional from copy import deepcopy from json_repair import repair_json from openai import OpenAI, RateLimitError, APITimeoutError, APIConnectionError from agents.vllm_user_agent import ( USER_SYSTEM_PROMPT_WITH_PREFERENCES, USER_SYSTEM_PROMPT_WITHOUT_PREFERENCES, TERMINATION_SIGNAL, ) class OpenAIUserAgent: """ User Agent that uses the OpenAI API (GPT-5) for user simulation. Key differences from VLLMUserAgent: - Uses openai.OpenAI client - GPT-5 is a reasoning model: no temperature, uses max_completion_tokens - Higher quality simulation at the cost of API calls """ def __init__( self, user_task_description: str, problem: str, user_persona: str = None, user_preferences: str = None, model: str = "gpt-5", api_key: Optional[str] = None, base_url: Optional[str] = None, num_retries: int = 3, max_completion_tokens: int = 4096, # High for reasoning models max_context_length: int = 128000, # GPT-5 context window retry_base_delay: float = 1.0, ): self.user_task_description = user_task_description self.problem = problem self.user_persona = user_persona or "A helpful user seeking assistance." self.user_preferences = user_preferences self.model = model self.num_retries = num_retries self.max_completion_tokens = max_completion_tokens self.max_context_length = max_context_length self.retry_base_delay = retry_base_delay # Initialize OpenAI client self._client = OpenAI( api_key=api_key or os.getenv("OPENAI_API_KEY"), base_url=base_url, timeout=120.0, ) # Build system prompt (same format as VLLMUserAgent) if user_preferences: self.system_prompt = USER_SYSTEM_PROMPT_WITH_PREFERENCES.format( user_task_description=user_task_description, problem=problem, user_persona=self.user_persona, user_preferences=user_preferences, termination_signal=TERMINATION_SIGNAL, ) else: self.system_prompt = USER_SYSTEM_PROMPT_WITHOUT_PREFERENCES.format( user_task_description=user_task_description, problem=problem, user_persona=self.user_persona, termination_signal=TERMINATION_SIGNAL, ) def _estimate_tokens(self, text: str) -> int: """Estimate token count (~3.5 chars/token).""" return int(len(text) / 3.5) def _truncate_messages(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]: """Truncate messages to fit within context, keeping recent messages.""" if not messages: return messages system_msg = messages[0] if messages[0]["role"] == "system" else None conversation = messages[1:] if system_msg else messages system_tokens = self._estimate_tokens(system_msg["content"]) if system_msg else 0 available_tokens = self.max_context_length - system_tokens - self.max_completion_tokens - 200 total_conv_tokens = sum(self._estimate_tokens(m["content"]) for m in conversation) if total_conv_tokens <= available_tokens: return messages # Keep recent messages truncated = [] current_tokens = 0 for msg in reversed(conversation): msg_tokens = self._estimate_tokens(msg["content"]) if current_tokens + msg_tokens <= available_tokens: truncated.insert(0, msg) current_tokens += msg_tokens else: break if len(truncated) < len(conversation): print(f"[OpenAIUserAgent] Truncated: kept {len(truncated)}/{len(conversation)} turns") return [system_msg] + truncated if system_msg else truncated def _generate(self, messages: List[Dict[str, str]]) -> str: """Generate response using OpenAI API with retry.""" messages = self._truncate_messages(messages) import time for attempt in range(self.num_retries): try: # Build API call params params = { "model": self.model, "messages": messages, "max_completion_tokens": self.max_completion_tokens, } # Non-reasoning models support temperature and response_format if not self._is_reasoning_model(): params["temperature"] = 0.7 params["response_format"] = {"type": "json_object"} response = self._client.chat.completions.create(**params) content = response.choices[0].message.content if content: return content.strip() # Reasoning model may exhaust tokens if response.choices[0].finish_reason == "length": print(f"[OpenAIUserAgent] Response truncated (length), attempt {attempt+1}") continue return "" except (RateLimitError, APITimeoutError, APIConnectionError) as e: if attempt == self.num_retries - 1: raise delay = self.retry_base_delay * (2 ** attempt) print(f"[OpenAIUserAgent] API error ({type(e).__name__}), retrying in {delay:.1f}s...") time.sleep(delay) return "" def _is_reasoning_model(self) -> bool: """Check if the model is a reasoning model (no temperature/response_format support).""" reasoning_prefixes = ("o1", "o3", "gpt-5") return any(self.model.startswith(p) for p in reasoning_prefixes) def get_system_prompt(self) -> str: """Get the system prompt.""" return self.system_prompt def reverse_roles(self, conversation: List[Dict[str, str]]) -> List[Dict[str, str]]: """Reverse roles for user perspective.""" conversation = deepcopy(conversation) return [ {"role": "user" if msg["role"] == "assistant" else "assistant", "content": msg["content"]} for msg in conversation ] def generate_user_response(self, conversation: List[Dict[str, str]]) -> Optional[Dict[str, Any]]: """ Generate user response given the conversation history. Args: conversation: List of {"role": "user"|"assistant", "content": str} Returns: Dict with keys: reasoning, draft_answer, should_terminate, response Or None if all retries failed. """ for attempt in range(self.num_retries): try: messages = [{"role": "system", "content": self.system_prompt}] messages.extend(self.reverse_roles(conversation)) response_text = self._generate(messages) if not response_text: print(f"[OpenAIUserAgent] Empty response, attempt {attempt+1}") continue # Parse JSON response try: parsed = repair_json(response_text, return_objects=True) required_keys = ["reasoning", "draft_answer", "should_terminate", "response"] missing = [k for k in required_keys if k not in parsed] if missing: print(f"[OpenAIUserAgent] Missing keys: {missing}, attempt {attempt+1}") continue return parsed except Exception: # Fallback: raw text as response if TERMINATION_SIGNAL in response_text: return { "reasoning": "Ending conversation", "draft_answer": "", "should_terminate": True, "response": TERMINATION_SIGNAL, } else: return { "reasoning": "", "draft_answer": "", "should_terminate": False, "response": response_text, } except Exception as e: print(f"[OpenAIUserAgent] Error: {e}, attempt {attempt+1}") continue print(f"[OpenAIUserAgent] Failed after {self.num_retries} attempts") return None