summaryrefslogtreecommitdiff
path: root/collaborativeagents/agents
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-27 09:57:37 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-27 09:57:37 -0600
commitdc801c07cf38b0c495686463e6ca6f871a64440e (patch)
tree599f03114775921dbc472403c701f4a3a8ea188a /collaborativeagents/agents
parente43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (diff)
Add collaborativeagents module and update gitignore
- Add collaborativeagents subproject with adapters, agents, and evaluation modules - Update .gitignore to exclude large binary files (.whl, .tar), wandb logs, and results Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'collaborativeagents/agents')
-rw-r--r--collaborativeagents/agents/__init__.py7
-rw-r--r--collaborativeagents/agents/batch_vllm_agent.py643
-rw-r--r--collaborativeagents/agents/local_user_agent.py334
-rw-r--r--collaborativeagents/agents/openai_user_agent.py229
-rw-r--r--collaborativeagents/agents/vllm_user_agent.py381
5 files changed, 1594 insertions, 0 deletions
diff --git a/collaborativeagents/agents/__init__.py b/collaborativeagents/agents/__init__.py
new file mode 100644
index 0000000..8105ab0
--- /dev/null
+++ b/collaborativeagents/agents/__init__.py
@@ -0,0 +1,7 @@
+"""
+Local agents for personalization experiments.
+"""
+
+from .local_user_agent import LocalUserAgent, SharedLocalUserAgent, TERMINATION_SIGNAL
+
+__all__ = ["LocalUserAgent", "SharedLocalUserAgent", "TERMINATION_SIGNAL"]
diff --git a/collaborativeagents/agents/batch_vllm_agent.py b/collaborativeagents/agents/batch_vllm_agent.py
new file mode 100644
index 0000000..f57bd78
--- /dev/null
+++ b/collaborativeagents/agents/batch_vllm_agent.py
@@ -0,0 +1,643 @@
+"""
+Batch processing for high-throughput conversation generation.
+
+This implements turn-synchronous batch processing:
+- vLLM servers for local models (agent)
+- OpenAI async SDK for API-based models (user simulator)
+
+Key insight: Process ALL conversations at the same turn level together,
+maximizing throughput via concurrent async requests.
+"""
+
+import asyncio
+import aiohttp
+import os
+from typing import List, Dict, Any, Optional
+from copy import deepcopy
+from json_repair import repair_json
+import time
+
+TERMINATION_SIGNAL = "TERMINATE"
+
+
+class BatchVLLMClient:
+ """
+ Async batch client for vLLM server.
+
+ Sends multiple requests concurrently and gathers results,
+ allowing vLLM's continuous batching to process them together.
+ """
+
+ def __init__(
+ self,
+ vllm_url: str,
+ model_name: str = None,
+ max_tokens: int = 512,
+ temperature: float = 0.7,
+ timeout: int = None, # None = infinite timeout
+ max_concurrent: int = 100,
+ api_key: str = None,
+ is_reasoning_model: bool = False,
+ json_mode: bool = False, # Enable JSON output mode for vLLM
+ ):
+ self.vllm_url = vllm_url.rstrip('/')
+ self.model_name = model_name
+ self.max_tokens = max_tokens
+ self.temperature = temperature
+ self.timeout = timeout # None for infinite
+ self.max_concurrent = max_concurrent
+ self.api_key = api_key
+ self.is_reasoning_model = is_reasoning_model
+ self.json_mode = json_mode
+
+ if self.model_name is None and not self.api_key:
+ self._discover_model_sync()
+
+ def _discover_model_sync(self):
+ """Synchronously discover model name."""
+ import requests
+ try:
+ response = requests.get(f"{self.vllm_url}/models", timeout=10)
+ response.raise_for_status()
+ models = response.json()
+ if models.get("data") and len(models["data"]) > 0:
+ self.model_name = models["data"][0]["id"]
+ else:
+ self.model_name = "default"
+ except Exception as e:
+ print(f"[BatchVLLMClient] Warning: Could not discover model ({e})")
+ self.model_name = "default"
+
+ async def _single_completion(
+ self,
+ session: aiohttp.ClientSession,
+ messages: List[Dict[str, str]],
+ idx: int,
+ retry_with_lower_tokens: bool = True
+ ) -> tuple:
+ """Make a single async completion request with retry logic."""
+ max_tokens = self.max_tokens
+
+ for attempt in range(3): # Up to 3 attempts
+ if self.is_reasoning_model:
+ payload = {
+ "model": self.model_name,
+ "messages": messages,
+ "max_completion_tokens": max_tokens,
+ "response_format": {"type": "json_object"},
+ }
+ elif self.api_key:
+ payload = {
+ "model": self.model_name,
+ "messages": messages,
+ "max_tokens": max_tokens,
+ "temperature": self.temperature,
+ "response_format": {"type": "json_object"},
+ }
+ else:
+ payload = {
+ "model": self.model_name,
+ "messages": messages,
+ "max_tokens": max_tokens,
+ "temperature": self.temperature,
+ "top_p": 0.9,
+ }
+ # Add JSON mode for local vLLM if requested
+ if self.json_mode:
+ payload["response_format"] = {"type": "json_object"}
+
+ try:
+ # Use None for infinite timeout
+ timeout_config = aiohttp.ClientTimeout(total=self.timeout) if self.timeout else None
+ async with session.post(
+ f"{self.vllm_url}/chat/completions",
+ json=payload,
+ timeout=timeout_config
+ ) as response:
+ if response.status == 200:
+ result = await response.json()
+ content = result["choices"][0]["message"]["content"]
+ # Reasoning models may exhaust tokens on internal reasoning
+ if not content and self.is_reasoning_model:
+ finish = result["choices"][0].get("finish_reason", "")
+ if finish == "length":
+ max_tokens = min(max_tokens * 2, 8192)
+ continue
+ return (idx, content, None)
+ elif response.status == 400:
+ error_text = await response.text()
+ if "max_tokens" in error_text and retry_with_lower_tokens:
+ max_tokens = max(64, max_tokens // 2)
+ continue
+ return (idx, None, f"HTTP 400: {error_text[:200]}")
+ elif response.status == 429:
+ # Rate limit — wait and retry
+ await asyncio.sleep(2 ** attempt)
+ continue
+ else:
+ error_text = await response.text()
+ return (idx, None, f"HTTP {response.status}: {error_text[:200]}")
+ except asyncio.TimeoutError:
+ if attempt < 2:
+ continue
+ return (idx, None, "Timeout")
+ except Exception as e:
+ return (idx, None, str(e))
+
+ return (idx, None, "Max retries exceeded")
+
+ async def batch_completion_async(
+ self,
+ messages_list: List[List[Dict[str, str]]],
+ show_progress: bool = False
+ ) -> List[Optional[str]]:
+ """
+ Send multiple completion requests concurrently.
+
+ vLLM's continuous batching will automatically batch these together.
+ Uses semaphore to limit concurrent requests.
+ """
+ results = [None] * len(messages_list)
+ errors = [None] * len(messages_list)
+ completed = 0
+
+ # Use semaphore to limit concurrent requests
+ semaphore = asyncio.Semaphore(self.max_concurrent)
+
+ async def limited_completion(session, messages, idx):
+ async with semaphore:
+ return await self._single_completion(session, messages, idx)
+
+ connector = aiohttp.TCPConnector(limit=self.max_concurrent)
+ headers = {"Content-Type": "application/json"}
+ if self.api_key:
+ headers["Authorization"] = f"Bearer {self.api_key}"
+ async with aiohttp.ClientSession(connector=connector, headers=headers) as session:
+ tasks = [
+ limited_completion(session, messages, idx)
+ for idx, messages in enumerate(messages_list)
+ ]
+
+ for coro in asyncio.as_completed(tasks):
+ idx, content, error = await coro
+ completed += 1
+ if error:
+ errors[idx] = error
+ results[idx] = content
+ if show_progress and completed % 10 == 0:
+ print(f" [{completed}/{len(messages_list)}] completed")
+
+ # Summary of errors
+ error_count = sum(1 for e in errors if e is not None)
+ if error_count > 0:
+ print(f"[BatchVLLMClient] {error_count}/{len(messages_list)} requests failed")
+
+ return results
+
+ def batch_completion(
+ self,
+ messages_list: List[List[Dict[str, str]]]
+ ) -> List[Optional[str]]:
+ """Synchronous wrapper for batch completion."""
+ return asyncio.run(self.batch_completion_async(messages_list))
+
+
+class BatchOpenAIClient:
+ """
+ Async batch client using the OpenAI Python SDK.
+
+ Drop-in replacement for BatchVLLMClient when targeting OpenAI API.
+ Uses AsyncOpenAI for proper SSL, auth, and reasoning model support.
+ """
+
+ def __init__(
+ self,
+ model: str = "gpt-5",
+ max_tokens: int = 4096,
+ max_concurrent: int = 32,
+ timeout: float = 120.0,
+ api_key: str = None,
+ ):
+ from openai import AsyncOpenAI
+ self.model = model
+ self.max_tokens = max_tokens
+ self.max_concurrent = max_concurrent
+ self.timeout = timeout
+ self._is_reasoning = any(model.startswith(p) for p in ("o1", "o3", "gpt-5"))
+ self._client = AsyncOpenAI(
+ api_key=api_key or os.environ.get("OPENAI_API_KEY"),
+ timeout=timeout,
+ max_retries=2,
+ )
+
+ async def _single_completion(
+ self,
+ messages: List[Dict[str, str]],
+ idx: int,
+ ) -> tuple:
+ """Single async completion via OpenAI SDK."""
+ max_tokens = self.max_tokens
+
+ for attempt in range(3):
+ try:
+ kwargs = {
+ "model": self.model,
+ "messages": messages,
+ "response_format": {"type": "json_object"},
+ }
+ if self._is_reasoning:
+ kwargs["max_completion_tokens"] = max_tokens
+ else:
+ kwargs["max_tokens"] = max_tokens
+ kwargs["temperature"] = 0.7
+
+ response = await self._client.chat.completions.create(**kwargs)
+ content = response.choices[0].message.content
+
+ # Reasoning models may exhaust tokens on internal reasoning
+ if not content and self._is_reasoning:
+ if response.choices[0].finish_reason == "length":
+ max_tokens = min(max_tokens * 2, 16384)
+ continue
+
+ return (idx, content, None)
+
+ except Exception as e:
+ if attempt < 2:
+ await asyncio.sleep(1 * (attempt + 1))
+ continue
+ return (idx, None, f"{type(e).__name__}: {e}")
+
+ return (idx, None, "Max retries exceeded")
+
+ async def batch_completion_async(
+ self,
+ messages_list: List[List[Dict[str, str]]],
+ show_progress: bool = False
+ ) -> List[Optional[str]]:
+ """Send multiple completion requests concurrently via AsyncOpenAI."""
+ results = [None] * len(messages_list)
+ errors = [None] * len(messages_list)
+ completed = 0
+ semaphore = asyncio.Semaphore(self.max_concurrent)
+
+ async def limited_completion(messages, idx):
+ async with semaphore:
+ return await self._single_completion(messages, idx)
+
+ tasks = [
+ limited_completion(messages, idx)
+ for idx, messages in enumerate(messages_list)
+ ]
+
+ for coro in asyncio.as_completed(tasks):
+ idx, content, error = await coro
+ completed += 1
+ if error:
+ errors[idx] = error
+ results[idx] = content
+ if show_progress and completed % 10 == 0:
+ print(f" [BatchOpenAI {completed}/{len(messages_list)}] completed")
+
+ error_count = sum(1 for e in errors if e is not None)
+ if error_count > 0:
+ error_samples = [e for e in errors if e is not None][:3]
+ print(f"[BatchOpenAIClient] {error_count}/{len(messages_list)} requests failed. Examples: {error_samples}")
+
+ return results
+
+ def batch_completion(
+ self,
+ messages_list: List[List[Dict[str, str]]]
+ ) -> List[Optional[str]]:
+ """Synchronous wrapper for batch completion."""
+ return asyncio.run(self.batch_completion_async(messages_list))
+
+
+class BatchConversationGenerator:
+ """
+ Generate conversations using turn-synchronous batch processing.
+
+ This processes ALL samples at the same turn together, maximizing
+ vLLM's continuous batching efficiency.
+ """
+
+ USER_SYSTEM_PROMPT = """You are a user simulator collaborating with an agent to solve a problem. You will be provided with a problem description, and you must get the agent to help you solve it. You will also be provided with user preferences, which you must follow and actively enforce throughout the conversation.
+
+# Problem Description
+{problem}
+Note: the agent cannot see this problem description.
+
+# User Persona
+{user_persona}
+
+# User Preferences
+{user_preferences}
+These preferences are NON-NEGOTIABLE that define how you prefer the agent to behave. They must be strictly enforced:
+ - **Answer clarifying questions**: The agent may ask clarifying questions before attempting an answer. Answer such questions, and do not enforce preferences about answer format or content while the agent is clarifying.
+ - **Enforce immediately**: Every agent response must satisfy your preferences before you can proceed. Explicitly ask the agent to adjust their response until it complies.
+ - **Never proceed without compliance**: Do NOT update your draft answer, do NOT consider terminating, and do NOT move forward until the agent follows your preferences.
+
+# Draft Answer Management
+- **Maintain a working draft**: Start with "I don't know". Update your draft answer based on what you learn from agent responses.
+- **Don't update when enforcing preferences**: If the agent response does not follow your preferences, do NOT update your draft answer, regardless of whether the agent provides helpful information.
+
+# Conversation Termination
+Before generating your response, determine if you should terminate:
+ - Do you feel like your draft answer is a good answer to the problem?
+ - Do you feel like the agent cannot help further?
+If the agent response does not follow your preferences, you must NOT terminate - instead, enforce the preferences.
+When ready to terminate, respond with "TERMINATE".
+
+# Output Format:
+{{
+ "preferences_check": "For EACH relevant preference, evaluate: is it satisfied?",
+ "enforce_preferences": true/false,
+ "reasoning": "Brief reasoning (2-3 sentences). Does agent follow preferences? If no, enforce. If yes, update draft.",
+ "draft_answer": "Your current working draft answer",
+ "should_terminate": true/false,
+ "response": "Your response to the agent"
+}}
+"""
+
+ USER_SYSTEM_PROMPT_NO_PREF = """You are a user simulator collaborating with an agent to solve a problem.
+
+# Problem Description
+{problem}
+
+# User Persona
+{user_persona}
+
+# Output Format:
+{{
+ "reasoning": "Brief reasoning",
+ "draft_answer": "Your current draft answer",
+ "should_terminate": true/false,
+ "response": "Your response"
+}}
+"""
+
+ def __init__(
+ self,
+ user_vllm_url: str,
+ agent_vllm_url: str,
+ max_turns: int = 10,
+ user_max_tokens: int = 512,
+ agent_max_tokens: int = 1024,
+ temperature: float = 0.7,
+ user_api_key: str = None,
+ user_model_name: str = None,
+ user_is_reasoning: bool = False,
+ ):
+ if user_api_key:
+ # Use OpenAI SDK client for API-based user models
+ self.user_client = BatchOpenAIClient(
+ model=user_model_name or "gpt-5",
+ max_tokens=user_max_tokens,
+ max_concurrent=32,
+ timeout=120.0,
+ api_key=user_api_key,
+ )
+ else:
+ self.user_client = BatchVLLMClient(
+ vllm_url=user_vllm_url,
+ model_name=user_model_name,
+ max_tokens=user_max_tokens,
+ temperature=temperature,
+ )
+ self.agent_client = BatchVLLMClient(
+ vllm_url=agent_vllm_url,
+ max_tokens=agent_max_tokens,
+ temperature=temperature,
+ )
+ self.max_turns = max_turns
+
+ def _build_user_system_prompt(
+ self,
+ problem: str,
+ user_persona: str,
+ user_preferences: str = None
+ ) -> str:
+ if user_preferences:
+ return self.USER_SYSTEM_PROMPT.format(
+ problem=problem,
+ user_persona=user_persona,
+ user_preferences=user_preferences
+ )
+ else:
+ return self.USER_SYSTEM_PROMPT_NO_PREF.format(
+ problem=problem,
+ user_persona=user_persona
+ )
+
+ def _reverse_roles(self, conversation: List[Dict]) -> List[Dict]:
+ """Reverse roles for user perspective."""
+ return [
+ {"role": "user" if msg["role"] == "assistant" else "assistant",
+ "content": msg["content"]}
+ for msg in conversation
+ ]
+
+ def _parse_user_response(self, content: str) -> Optional[Dict]:
+ """Parse user response JSON."""
+ if content is None:
+ return None
+ try:
+ parsed = repair_json(content, return_objects=True)
+ required_keys = ["reasoning", "draft_answer", "should_terminate", "response"]
+ if all(k in parsed for k in required_keys):
+ return parsed
+ # Fallback: treat as raw response
+ if TERMINATION_SIGNAL in content:
+ return {
+ "reasoning": "", "draft_answer": "",
+ "should_terminate": True, "response": TERMINATION_SIGNAL
+ }
+ return {
+ "reasoning": "", "draft_answer": "",
+ "should_terminate": False, "response": content
+ }
+ except:
+ if TERMINATION_SIGNAL in content:
+ return {
+ "reasoning": "", "draft_answer": "",
+ "should_terminate": True, "response": TERMINATION_SIGNAL
+ }
+ return {
+ "reasoning": "", "draft_answer": "",
+ "should_terminate": False, "response": content
+ }
+
+ def generate_batch(
+ self,
+ samples: List[Dict],
+ user_persona: str,
+ user_preferences: str = None,
+ agent_system_prompt: str = "You are a helpful AI assistant.",
+ ) -> List[Dict]:
+ """
+ Generate conversations for a batch of samples using turn-synchronous processing.
+
+ Args:
+ samples: List of dicts with 'problem' key
+ user_persona: User persona description
+ user_preferences: User preferences string (optional)
+ agent_system_prompt: System prompt for the agent
+
+ Returns:
+ List of conversation results
+ """
+ n_samples = len(samples)
+
+ # Initialize state for all conversations
+ conversations = [[{"role": "assistant", "content": "How can I help you?"}]
+ for _ in range(n_samples)]
+ full_logs = [[] for _ in range(n_samples)]
+
+ # Build user system prompts for each sample
+ user_system_prompts = [
+ self._build_user_system_prompt(
+ problem=sample['problem'],
+ user_persona=user_persona,
+ user_preferences=user_preferences
+ )
+ for sample in samples
+ ]
+
+ # Track active conversations (not terminated or failed)
+ active_indices = set(range(n_samples))
+ failed_indices = set()
+
+ for turn in range(self.max_turns):
+ if not active_indices:
+ break
+
+ # ========== USER TURN (BATCHED) ==========
+ user_indices = sorted(active_indices)
+
+ # Build batch of user messages
+ user_messages_batch = []
+ for idx in user_indices:
+ messages = [{"role": "system", "content": user_system_prompts[idx]}]
+ messages.extend(self._reverse_roles(conversations[idx]))
+ user_messages_batch.append(messages)
+
+ # Batch call to user model
+ user_responses_raw = self.user_client.batch_completion(user_messages_batch)
+
+ # Process user responses
+ for i, idx in enumerate(user_indices):
+ raw = user_responses_raw[i]
+ parsed = self._parse_user_response(raw)
+
+ if parsed is None:
+ active_indices.discard(idx)
+ failed_indices.add(idx)
+ continue
+
+ conversations[idx].append({
+ "role": "user",
+ "content": str(parsed["response"])
+ })
+ full_logs[idx].append(parsed)
+
+ # Check for termination
+ if parsed.get("should_terminate") or TERMINATION_SIGNAL in parsed["response"]:
+ active_indices.discard(idx)
+
+ if not active_indices:
+ break
+
+ # ========== AGENT TURN (BATCHED) ==========
+ agent_indices = sorted(active_indices)
+
+ # Build batch of agent messages
+ agent_messages_batch = []
+ for idx in agent_indices:
+ messages = [{"role": "system", "content": agent_system_prompt}]
+ messages.extend(conversations[idx])
+ agent_messages_batch.append(messages)
+
+ # Batch call to agent model
+ agent_responses_raw = self.agent_client.batch_completion(agent_messages_batch)
+
+ # Process agent responses
+ for i, idx in enumerate(agent_indices):
+ raw = agent_responses_raw[i]
+
+ if raw is None:
+ active_indices.discard(idx)
+ failed_indices.add(idx)
+ continue
+
+ conversations[idx].append({
+ "role": "assistant",
+ "content": raw
+ })
+ full_logs[idx].append({"response": raw, "reasoning": ""})
+
+ # Build results
+ results = []
+ for i, sample in enumerate(samples):
+ if i in failed_indices:
+ results.append(None)
+ else:
+ results.append({
+ "sample": sample,
+ "conversation": conversations[i],
+ "full_conversation_log": full_logs[i]
+ })
+
+ return results
+
+
+def benchmark_batch_generation(
+ user_url: str,
+ agent_url: str,
+ n_samples: int = 20,
+ max_turns: int = 5,
+):
+ """Quick benchmark of batch generation."""
+
+ # Create dummy samples
+ samples = [
+ {"problem": f"What is {i+1} + {i+2}? Show your work.", "solution": str(2*i+3)}
+ for i in range(n_samples)
+ ]
+
+ generator = BatchConversationGenerator(
+ user_vllm_url=user_url,
+ agent_vllm_url=agent_url,
+ max_turns=max_turns,
+ )
+
+ start = time.time()
+ results = generator.generate_batch(
+ samples=samples,
+ user_persona="A curious student learning math.",
+ user_preferences="1. Show step-by-step working\n2. Explain clearly",
+ )
+ elapsed = time.time() - start
+
+ successes = sum(1 for r in results if r is not None)
+ print(f"\n=== Batch Generation Benchmark ===")
+ print(f"Samples: {n_samples}, Max turns: {max_turns}")
+ print(f"Successes: {successes}/{n_samples}")
+ print(f"Time: {elapsed:.1f}s")
+ print(f"Throughput: {successes * 3600 / elapsed:.0f} conversations/hr")
+
+ return results
+
+
+if __name__ == "__main__":
+ import sys
+ if len(sys.argv) >= 3:
+ user_url = sys.argv[1]
+ agent_url = sys.argv[2]
+ n_samples = int(sys.argv[3]) if len(sys.argv) > 3 else 20
+ else:
+ user_url = "http://localhost:8004/v1"
+ agent_url = "http://localhost:8003/v1"
+ n_samples = 20
+
+ benchmark_batch_generation(user_url, agent_url, n_samples=n_samples)
diff --git a/collaborativeagents/agents/local_user_agent.py b/collaborativeagents/agents/local_user_agent.py
new file mode 100644
index 0000000..eae311e
--- /dev/null
+++ b/collaborativeagents/agents/local_user_agent.py
@@ -0,0 +1,334 @@
+"""
+Local User Agent - Uses local transformers model for user simulation.
+
+This replaces the litellm-based UserAgent with a local transformers implementation
+for running experiments without requiring an API server.
+"""
+
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from typing import List, Dict, Any, Optional
+from copy import deepcopy
+from json_repair import repair_json
+
+# Default model paths
+DEFAULT_MODEL_PATH_8B = "/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/models/llama-3.1-8b-instruct"
+DEFAULT_MODEL_PATH_70B = "hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4"
+
+# Use 70B by default for better user simulation
+DEFAULT_MODEL_PATH = DEFAULT_MODEL_PATH_70B
+
+# Termination signal from CollaborativeAgents
+TERMINATION_SIGNAL = "TERMINATE"
+
+# User system prompt with preferences (simplified version)
+USER_SYSTEM_PROMPT_WITH_PREFERENCES = """You are a user simulator collaborating with an agent to solve a problem. You will be provided with a problem description, and you must get the agent to help you solve it. You will also be provided with user preferences, which you must actively enforce throughout the conversation.
+
+# Problem Description
+{problem}
+Note: the agent cannot see this problem description.
+
+# User Persona
+{user_persona}
+
+# User Preferences
+{user_preferences}
+These preferences are NON-NEGOTIABLE that define how you prefer the agent to behave. They must be strictly enforced:
+- **Enforce immediately**: Every agent response must satisfy your preferences before you can proceed.
+- **Never proceed without compliance**: Do NOT move forward until the agent follows your preferences.
+
+# Draft Answer Management
+- Maintain a draft answer throughout the conversation. Start with "I don't know".
+- Update your draft based on agent responses that follow your preferences.
+
+# Conversation Termination
+When ready to terminate (draft answer is good OR agent cannot help), respond with "{termination_signal}".
+
+# Output Format:
+Respond with a JSON object:
+{{
+ "reasoning": "Brief reasoning about the agent's response and your preferences",
+ "draft_answer": "Your current working draft answer",
+ "should_terminate": true/false,
+ "response": "Your response to the agent"
+}}
+"""
+
+USER_SYSTEM_PROMPT_WITHOUT_PREFERENCES = """You are a user simulator collaborating with an agent to solve a problem.
+
+# Problem Description
+{problem}
+
+# User Persona
+{user_persona}
+
+# Conversation Termination
+When ready to terminate, respond with "{termination_signal}".
+
+# Output Format:
+{{
+ "reasoning": "Brief reasoning",
+ "draft_answer": "Your current working draft answer",
+ "should_terminate": true/false,
+ "response": "Your response to the agent"
+}}
+"""
+
+
+class LocalUserAgent:
+ """
+ Local User Agent using transformers for user simulation.
+
+ Simulates a user who:
+ - Presents problems to the agent
+ - Enforces preferences throughout the conversation
+ - Decides when to terminate the conversation
+ """
+
+ def __init__(
+ self,
+ user_task_description: str,
+ problem: str,
+ user_persona: str = None,
+ user_preferences: str = None,
+ model_path: str = DEFAULT_MODEL_PATH,
+ num_retries: int = 3,
+ # For compatibility with original UserAgent interface
+ model_name: str = None,
+ api_base: str = None,
+ api_key: str = None,
+ ):
+ self.problem = problem
+ self.user_persona = user_persona or "A helpful user seeking assistance."
+ self.user_preferences = user_preferences
+ self.num_retries = num_retries
+ self.model_path = model_path
+
+ # Build system prompt
+ if user_preferences:
+ self.system_prompt = USER_SYSTEM_PROMPT_WITH_PREFERENCES.format(
+ problem=problem,
+ user_persona=self.user_persona,
+ user_preferences=user_preferences,
+ termination_signal=TERMINATION_SIGNAL
+ )
+ else:
+ self.system_prompt = USER_SYSTEM_PROMPT_WITHOUT_PREFERENCES.format(
+ problem=problem,
+ user_persona=self.user_persona,
+ termination_signal=TERMINATION_SIGNAL
+ )
+
+ # Model components (loaded lazily)
+ self._model = None
+ self._tokenizer = None
+ self._initialized = False
+
+ def _ensure_initialized(self):
+ """Lazy initialization of model."""
+ if self._initialized:
+ return
+
+ import os
+ # Use project HF cache
+ cache_dir = os.environ.get("HF_HOME", "/projects/bfqt/users/yurenh2/hf_cache/huggingface")
+
+ print(f"[LocalUserAgent] Loading model from {self.model_path}...")
+ self._tokenizer = AutoTokenizer.from_pretrained(
+ self.model_path,
+ cache_dir=cache_dir,
+ trust_remote_code=True
+ )
+
+ # Check if this is an AWQ model
+ is_awq = "awq" in self.model_path.lower()
+
+ if is_awq:
+ # AWQ models use float16 and auto device map
+ self._model = AutoModelForCausalLM.from_pretrained(
+ self.model_path,
+ torch_dtype=torch.float16,
+ device_map="auto",
+ cache_dir=cache_dir,
+ trust_remote_code=True,
+ )
+ else:
+ # Standard model loading
+ self._model = AutoModelForCausalLM.from_pretrained(
+ self.model_path,
+ torch_dtype=torch.bfloat16,
+ device_map="auto",
+ cache_dir=cache_dir,
+ )
+
+ if self._tokenizer.pad_token_id is None:
+ self._tokenizer.pad_token = self._tokenizer.eos_token
+
+ self._initialized = True
+ print(f"[LocalUserAgent] Initialized (AWQ={is_awq})")
+
+ def _generate(self, messages: List[Dict[str, str]], max_new_tokens: int = 512) -> str:
+ """Generate response using local model."""
+ self._ensure_initialized()
+
+ # Apply chat template
+ prompt = self._tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True
+ )
+
+ inputs = self._tokenizer(
+ prompt,
+ return_tensors="pt",
+ truncation=True,
+ max_length=8192
+ ).to(self._model.device)
+
+ with torch.no_grad():
+ outputs = self._model.generate(
+ **inputs,
+ max_new_tokens=max_new_tokens,
+ do_sample=True,
+ temperature=0.7,
+ top_p=0.9,
+ eos_token_id=self._tokenizer.eos_token_id,
+ pad_token_id=self._tokenizer.pad_token_id,
+ )
+
+ # Extract only the generated part
+ input_len = inputs["input_ids"].shape[1]
+ gen_ids = outputs[0][input_len:]
+ response = self._tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
+
+ return response
+
+ def get_system_prompt(self) -> str:
+ """Get the system prompt."""
+ return self.system_prompt
+
+ def reverse_roles(self, conversation: List[Dict[str, str]]) -> List[Dict[str, str]]:
+ """Reverse roles for user perspective (agent becomes user, user becomes assistant)."""
+ conversation = deepcopy(conversation)
+ return [
+ {"role": "user" if msg["role"] == "assistant" else "assistant", "content": msg["content"]}
+ for msg in conversation
+ ]
+
+ def generate_user_response(self, conversation: List[Dict[str, str]]) -> Optional[Dict[str, Any]]:
+ """
+ Generate user response given the conversation history.
+
+ Args:
+ conversation: List of {"role": "user"|"assistant", "content": str}
+ From user perspective: agent messages are "assistant"
+
+ Returns:
+ Dict with keys: reasoning, draft_answer, should_terminate, response
+ Or None if failed
+ """
+ for attempt in range(self.num_retries):
+ try:
+ # Build messages: system prompt + reversed conversation
+ messages = [{"role": "system", "content": self.system_prompt}]
+ messages.extend(self.reverse_roles(conversation))
+
+ # Generate response
+ response_text = self._generate(messages)
+
+ # Try to parse as JSON
+ try:
+ parsed = repair_json(response_text, return_objects=True)
+
+ # Check for required keys
+ required_keys = ["reasoning", "draft_answer", "should_terminate", "response"]
+ missing = [k for k in required_keys if k not in parsed]
+
+ if missing:
+ print(f"[LocalUserAgent] Missing keys: {missing}, attempt {attempt+1}")
+ continue
+
+ return parsed
+
+ except Exception as e:
+ # If JSON parsing fails, try to extract response directly
+ print(f"[LocalUserAgent] JSON parse failed: {e}, attempt {attempt+1}")
+
+ # Fallback: return the raw text as the response
+ if TERMINATION_SIGNAL in response_text:
+ return {
+ "reasoning": "Ending conversation",
+ "draft_answer": "",
+ "should_terminate": True,
+ "response": TERMINATION_SIGNAL
+ }
+ else:
+ return {
+ "reasoning": "",
+ "draft_answer": "",
+ "should_terminate": False,
+ "response": response_text
+ }
+
+ except Exception as e:
+ print(f"[LocalUserAgent] Error: {e}, attempt {attempt+1}")
+ continue
+
+ print(f"[LocalUserAgent] Failed after {self.num_retries} attempts")
+ return None
+
+
+# Singleton model for efficiency (shared across multiple LocalUserAgent instances)
+_shared_model = None
+_shared_tokenizer = None
+
+
+class SharedLocalUserAgent(LocalUserAgent):
+ """
+ LocalUserAgent that shares model across instances to save memory.
+ """
+
+ def _ensure_initialized(self):
+ """Use shared model instead of loading a new one."""
+ global _shared_model, _shared_tokenizer
+
+ if self._initialized:
+ return
+
+ if _shared_model is None:
+ import os
+ cache_dir = os.environ.get("HF_HOME", "/projects/bfqt/users/yurenh2/hf_cache/huggingface")
+
+ print(f"[SharedLocalUserAgent] Loading shared model from {self.model_path}...")
+ _shared_tokenizer = AutoTokenizer.from_pretrained(
+ self.model_path,
+ cache_dir=cache_dir,
+ trust_remote_code=True
+ )
+
+ # Check if this is an AWQ model
+ is_awq = "awq" in self.model_path.lower()
+
+ if is_awq:
+ _shared_model = AutoModelForCausalLM.from_pretrained(
+ self.model_path,
+ torch_dtype=torch.float16,
+ device_map="auto",
+ cache_dir=cache_dir,
+ trust_remote_code=True,
+ )
+ else:
+ _shared_model = AutoModelForCausalLM.from_pretrained(
+ self.model_path,
+ torch_dtype=torch.bfloat16,
+ device_map="auto",
+ cache_dir=cache_dir,
+ )
+
+ if _shared_tokenizer.pad_token_id is None:
+ _shared_tokenizer.pad_token = _shared_tokenizer.eos_token
+ print(f"[SharedLocalUserAgent] Shared model loaded (AWQ={is_awq})")
+
+ self._model = _shared_model
+ self._tokenizer = _shared_tokenizer
+ self._initialized = True
diff --git a/collaborativeagents/agents/openai_user_agent.py b/collaborativeagents/agents/openai_user_agent.py
new file mode 100644
index 0000000..b49cd57
--- /dev/null
+++ b/collaborativeagents/agents/openai_user_agent.py
@@ -0,0 +1,229 @@
+"""
+OpenAI-based User Agent for user simulation with GPT-5.
+
+Drop-in replacement for VLLMUserAgent — same interface, uses OpenAI API.
+Supports reasoning models (GPT-5, o-series) that require max_completion_tokens.
+"""
+
+import os
+from typing import List, Dict, Any, Optional
+from copy import deepcopy
+from json_repair import repair_json
+from openai import OpenAI, RateLimitError, APITimeoutError, APIConnectionError
+
+from agents.vllm_user_agent import (
+ USER_SYSTEM_PROMPT_WITH_PREFERENCES,
+ USER_SYSTEM_PROMPT_WITHOUT_PREFERENCES,
+ TERMINATION_SIGNAL,
+)
+
+
+class OpenAIUserAgent:
+ """
+ User Agent that uses the OpenAI API (GPT-5) for user simulation.
+
+ Key differences from VLLMUserAgent:
+ - Uses openai.OpenAI client
+ - GPT-5 is a reasoning model: no temperature, uses max_completion_tokens
+ - Higher quality simulation at the cost of API calls
+ """
+
+ def __init__(
+ self,
+ user_task_description: str,
+ problem: str,
+ user_persona: str = None,
+ user_preferences: str = None,
+ model: str = "gpt-5",
+ api_key: Optional[str] = None,
+ base_url: Optional[str] = None,
+ num_retries: int = 3,
+ max_completion_tokens: int = 4096, # High for reasoning models
+ max_context_length: int = 128000, # GPT-5 context window
+ retry_base_delay: float = 1.0,
+ ):
+ self.user_task_description = user_task_description
+ self.problem = problem
+ self.user_persona = user_persona or "A helpful user seeking assistance."
+ self.user_preferences = user_preferences
+ self.model = model
+ self.num_retries = num_retries
+ self.max_completion_tokens = max_completion_tokens
+ self.max_context_length = max_context_length
+ self.retry_base_delay = retry_base_delay
+
+ # Initialize OpenAI client
+ self._client = OpenAI(
+ api_key=api_key or os.getenv("OPENAI_API_KEY"),
+ base_url=base_url,
+ timeout=120.0,
+ )
+
+ # Build system prompt (same format as VLLMUserAgent)
+ if user_preferences:
+ self.system_prompt = USER_SYSTEM_PROMPT_WITH_PREFERENCES.format(
+ user_task_description=user_task_description,
+ problem=problem,
+ user_persona=self.user_persona,
+ user_preferences=user_preferences,
+ termination_signal=TERMINATION_SIGNAL,
+ )
+ else:
+ self.system_prompt = USER_SYSTEM_PROMPT_WITHOUT_PREFERENCES.format(
+ user_task_description=user_task_description,
+ problem=problem,
+ user_persona=self.user_persona,
+ termination_signal=TERMINATION_SIGNAL,
+ )
+
+ def _estimate_tokens(self, text: str) -> int:
+ """Estimate token count (~3.5 chars/token)."""
+ return int(len(text) / 3.5)
+
+ def _truncate_messages(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
+ """Truncate messages to fit within context, keeping recent messages."""
+ if not messages:
+ return messages
+
+ system_msg = messages[0] if messages[0]["role"] == "system" else None
+ conversation = messages[1:] if system_msg else messages
+
+ system_tokens = self._estimate_tokens(system_msg["content"]) if system_msg else 0
+ available_tokens = self.max_context_length - system_tokens - self.max_completion_tokens - 200
+
+ total_conv_tokens = sum(self._estimate_tokens(m["content"]) for m in conversation)
+ if total_conv_tokens <= available_tokens:
+ return messages
+
+ # Keep recent messages
+ truncated = []
+ current_tokens = 0
+ for msg in reversed(conversation):
+ msg_tokens = self._estimate_tokens(msg["content"])
+ if current_tokens + msg_tokens <= available_tokens:
+ truncated.insert(0, msg)
+ current_tokens += msg_tokens
+ else:
+ break
+
+ if len(truncated) < len(conversation):
+ print(f"[OpenAIUserAgent] Truncated: kept {len(truncated)}/{len(conversation)} turns")
+
+ return [system_msg] + truncated if system_msg else truncated
+
+ def _generate(self, messages: List[Dict[str, str]]) -> str:
+ """Generate response using OpenAI API with retry."""
+ messages = self._truncate_messages(messages)
+
+ import time
+ for attempt in range(self.num_retries):
+ try:
+ # Build API call params
+ params = {
+ "model": self.model,
+ "messages": messages,
+ "max_completion_tokens": self.max_completion_tokens,
+ }
+
+ # Non-reasoning models support temperature and response_format
+ if not self._is_reasoning_model():
+ params["temperature"] = 0.7
+ params["response_format"] = {"type": "json_object"}
+
+ response = self._client.chat.completions.create(**params)
+
+ content = response.choices[0].message.content
+ if content:
+ return content.strip()
+
+ # Reasoning model may exhaust tokens
+ if response.choices[0].finish_reason == "length":
+ print(f"[OpenAIUserAgent] Response truncated (length), attempt {attempt+1}")
+ continue
+
+ return ""
+
+ except (RateLimitError, APITimeoutError, APIConnectionError) as e:
+ if attempt == self.num_retries - 1:
+ raise
+ delay = self.retry_base_delay * (2 ** attempt)
+ print(f"[OpenAIUserAgent] API error ({type(e).__name__}), retrying in {delay:.1f}s...")
+ time.sleep(delay)
+
+ return ""
+
+ def _is_reasoning_model(self) -> bool:
+ """Check if the model is a reasoning model (no temperature/response_format support)."""
+ reasoning_prefixes = ("o1", "o3", "gpt-5")
+ return any(self.model.startswith(p) for p in reasoning_prefixes)
+
+ def get_system_prompt(self) -> str:
+ """Get the system prompt."""
+ return self.system_prompt
+
+ def reverse_roles(self, conversation: List[Dict[str, str]]) -> List[Dict[str, str]]:
+ """Reverse roles for user perspective."""
+ conversation = deepcopy(conversation)
+ return [
+ {"role": "user" if msg["role"] == "assistant" else "assistant", "content": msg["content"]}
+ for msg in conversation
+ ]
+
+ def generate_user_response(self, conversation: List[Dict[str, str]]) -> Optional[Dict[str, Any]]:
+ """
+ Generate user response given the conversation history.
+
+ Args:
+ conversation: List of {"role": "user"|"assistant", "content": str}
+
+ Returns:
+ Dict with keys: reasoning, draft_answer, should_terminate, response
+ Or None if all retries failed.
+ """
+ for attempt in range(self.num_retries):
+ try:
+ messages = [{"role": "system", "content": self.system_prompt}]
+ messages.extend(self.reverse_roles(conversation))
+
+ response_text = self._generate(messages)
+
+ if not response_text:
+ print(f"[OpenAIUserAgent] Empty response, attempt {attempt+1}")
+ continue
+
+ # Parse JSON response
+ try:
+ parsed = repair_json(response_text, return_objects=True)
+
+ required_keys = ["reasoning", "draft_answer", "should_terminate", "response"]
+ missing = [k for k in required_keys if k not in parsed]
+
+ if missing:
+ print(f"[OpenAIUserAgent] Missing keys: {missing}, attempt {attempt+1}")
+ continue
+
+ return parsed
+
+ except Exception:
+ # Fallback: raw text as response
+ if TERMINATION_SIGNAL in response_text:
+ return {
+ "reasoning": "Ending conversation",
+ "draft_answer": "",
+ "should_terminate": True,
+ "response": TERMINATION_SIGNAL,
+ }
+ else:
+ return {
+ "reasoning": "",
+ "draft_answer": "",
+ "should_terminate": False,
+ "response": response_text,
+ }
+
+ except Exception as e:
+ print(f"[OpenAIUserAgent] Error: {e}, attempt {attempt+1}")
+ continue
+
+ print(f"[OpenAIUserAgent] Failed after {self.num_retries} attempts")
+ return None
diff --git a/collaborativeagents/agents/vllm_user_agent.py b/collaborativeagents/agents/vllm_user_agent.py
new file mode 100644
index 0000000..3d73637
--- /dev/null
+++ b/collaborativeagents/agents/vllm_user_agent.py
@@ -0,0 +1,381 @@
+"""
+vLLM-based User Agent for high-performance user simulation.
+
+This replaces the local transformers-based user agent with a vLLM client
+for much faster inference when running parallel experiments.
+"""
+
+import requests
+from typing import List, Dict, Any, Optional
+from copy import deepcopy
+from json_repair import repair_json
+
+# Termination signal from CollaborativeAgents
+TERMINATION_SIGNAL = "TERMINATE"
+
+# User system prompt with preferences (CollaborativeAgents style)
+USER_SYSTEM_PROMPT_WITH_PREFERENCES = """You are a user simulator collaborating with an agent to solve a problem. You will be provided with a problem description, and you must get the agent to help you solve it. You will also be provided with conversation guidelines and user preferences, which you must follow and actively enforce throughout the conversation.
+
+# Problem Description
+{user_task_description}
+{problem}
+Note: the agent cannot see this problem description.
+
+# User Persona
+{user_persona}
+
+# User Preferences
+{user_preferences}
+These preferences are NON-NEGOTIABLE that define how you prefer the agent to behave. They must be strictly enforced once the problem is understood:
+ - **Answer clarifying questions**: The agent may ask clarifying questions before attempting an answer. Answer such questions, and do not enforce preferences about answer format or content while the agent is clarifying.
+ - **Enforce immediately**: Every agent response must satisfy your preferences before you can proceed. Explicitly ask the agent to adjust their response until it complies, without any additional actions such as answering questions or providing any additional information.
+ - **Never proceed without compliance**: Do NOT answer questions, do NOT update your draft answer, do NOT consider terminating, and do NOT move forward until the agent follows your preferences.
+Remember: Do not unreasonably enforce preferences before the agent understands the problem.
+
+# Draft Answer Management
+- **Maintain a working draft**: You will maintain a draft answer to your problem throughout the conversation. Start with an empty draft (e.g., "I don't know"). Update your draft answer based on what you learn from agent responses.
+- **Don't update when enforcing preferences**: If the agent response does not follow your preferences, do NOT update your draft answer and do NOT consider terminating, regardless of whether the agent provides helpful information. Wait until they adjust their approach and satisfy your preferences.
+
+# Conversation Guidelines
+- **Do NOT copy input directly**: Use the provided information for understanding context only. Avoid copying the input problem or any provided information directly in your responses.
+- **Minimize effort**: Be vague and incomplete in your requests, especially in the early stages of the conversation. Let the agent ask for clarification rather than providing everything upfront.
+- **Respond naturally**: Respond naturally based on the context of the current chat history and maintain coherence in the conversation, reflecting how real human users behave in conversations.
+
+# Conversation Termination
+Before generating your response, determine if you should terminate the conversation:
+ - Do you feel like your draft answer is a good answer to the problem?
+ - Do you feel like the agent cannot help further?
+If the agent response does not follow your preferences, you must NOT terminate - instead, enforce the preferences.
+When ready to terminate, respond with "{termination_signal}".
+
+# Output Format:
+{{
+ "preferences_check": str, # For EACH of your preferences that is relevant to this response, evaluate: is it satisfied? List each relevant preference and whether it was followed.
+ "enforce_preferences": bool, # Whether you have to enforce any of your preferences?
+ "reasoning": str, # Brief reasoning (2-3 sentences max). Does the agent response follow all of your preferences? If no, you must enforce them and not proceed. If yes, how should you update your draft answer? Are you satisfied with your current answer and ready to terminate the conversation?
+ "draft_answer": str, # Your current working draft answer to the problem. Start with "I don't know". Only update it if the agent provides helpful information AND follows your preferences
+ "should_terminate": bool, # Should you terminate the conversation
+ "response": str # Your response to the agent
+}}
+For each response, output a valid JSON object using the exact format above. Use double quotes, escape any double quotes within strings using backslashes, escape newlines as \\n, and do not include any text before or after the JSON object.
+"""
+
+USER_SYSTEM_PROMPT_WITHOUT_PREFERENCES = """You are a user simulator collaborating with an agent to solve a problem. You will be provided with a problem description, and you must get the agent to help you solve it. You will also be provided with conversation guidelines, which you must follow throughout the conversation.
+
+# Problem Description
+{user_task_description}
+{problem}
+Note: the agent cannot see this problem description.
+
+# User Persona
+{user_persona}
+
+# Draft Answer Management
+- **Maintain a working draft**: You will maintain a draft answer to your problem throughout the conversation. Start with an empty draft (e.g., "I don't know"). Update your draft answer based on what you learn from agent responses.
+
+# Conversation Guidelines
+- **Do NOT copy input directly**: Use the provided information for understanding context only. Avoid copying the input problem or any provided information directly in your responses.
+- **Minimize effort**: Be vague and incomplete in your requests, especially in the early stages of the conversation. Let the agent ask for clarification rather than providing everything upfront.
+- **Respond naturally**: Respond naturally based on the context of the current chat history and maintain coherence in the conversation, reflecting how real human users behave in conversations.
+
+# Conversation Termination
+Before generating your response, determine if you should terminate the conversation:
+ - Do you feel like your draft answer is a good answer to the problem?
+ - Do you feel like the agent cannot help further?
+When ready to terminate, respond with "{termination_signal}".
+
+# Output Format:
+{{
+ "reasoning": str, # Brief reasoning (2-3 sentences max). How should you update your draft answer? Are you satisfied with your current answer and ready to terminate the conversation?
+ "draft_answer": str, # Your current working draft answer to the problem. Start with "I don't know". Update it if the agent provides helpful information
+ "should_terminate": bool, # Should you terminate the conversation
+ "response": str # Your response to the agent
+}}
+For each response, output a valid JSON object using the exact format above. Use double quotes, escape any double quotes within strings using backslashes, escape newlines as \\n, and do not include any text before or after the JSON object.
+"""
+
+
+class VLLMUserAgent:
+ """
+ User Agent that uses a vLLM server for fast inference.
+
+ Benefits:
+ - Much faster than local transformers (continuous batching)
+ - Can handle concurrent requests from multiple profiles
+ - Supports AWQ/quantized models efficiently
+ """
+
+ def __init__(
+ self,
+ user_task_description: str,
+ problem: str,
+ user_persona: str = None,
+ user_preferences: str = None,
+ vllm_url: str = "http://localhost:8004/v1",
+ model_name: str = None, # Auto-discovered from server
+ num_retries: int = 3,
+ max_tokens: int = 512,
+ temperature: float = 0.7,
+ max_context_length: int = 16384, # Context limit for truncation
+ # For compatibility with LocalUserAgent interface
+ model_path: str = None,
+ api_base: str = None,
+ api_key: str = None,
+ ):
+ self.user_task_description = user_task_description
+ self.problem = problem
+ self.user_persona = user_persona or "A helpful user seeking assistance."
+ self.user_preferences = user_preferences
+ self.num_retries = num_retries
+ self.max_tokens = max_tokens
+ self.temperature = temperature
+ self.max_context_length = max_context_length
+
+ # vLLM configuration
+ self.vllm_url = vllm_url.rstrip('/')
+ self.model_name = model_name
+
+ # Build system prompt
+ if user_preferences:
+ self.system_prompt = USER_SYSTEM_PROMPT_WITH_PREFERENCES.format(
+ user_task_description=user_task_description,
+ problem=problem,
+ user_persona=self.user_persona,
+ user_preferences=user_preferences,
+ termination_signal=TERMINATION_SIGNAL
+ )
+ else:
+ self.system_prompt = USER_SYSTEM_PROMPT_WITHOUT_PREFERENCES.format(
+ user_task_description=user_task_description,
+ problem=problem,
+ user_persona=self.user_persona,
+ termination_signal=TERMINATION_SIGNAL
+ )
+
+ # Auto-discover model name if not provided
+ if self.model_name is None:
+ self._discover_model()
+
+ def _discover_model(self):
+ """Auto-discover the model name from the vLLM server."""
+ try:
+ response = requests.get(f"{self.vllm_url}/models", timeout=10)
+ response.raise_for_status()
+ models = response.json()
+ if models.get("data") and len(models["data"]) > 0:
+ self.model_name = models["data"][0]["id"]
+ else:
+ self.model_name = "default"
+ except Exception as e:
+ print(f"[VLLMUserAgent] Warning: Could not discover model ({e}), using 'default'")
+ self.model_name = "default"
+
+ def _estimate_tokens(self, text: str) -> int:
+ """Estimate token count using character-based heuristic (~3.5 chars/token)."""
+ return int(len(text) / 3.5)
+
+ def _truncate_messages(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
+ """Truncate messages to fit within max_context_length, keeping recent messages."""
+ if not messages:
+ return messages
+
+ # System message is always first and always kept
+ system_msg = messages[0] if messages[0]["role"] == "system" else None
+ conversation = messages[1:] if system_msg else messages
+
+ system_tokens = self._estimate_tokens(system_msg["content"]) if system_msg else 0
+ available_tokens = self.max_context_length - system_tokens - self.max_tokens - 100
+
+ # Check if truncation is needed
+ total_conv_tokens = sum(self._estimate_tokens(m["content"]) for m in conversation)
+
+ if total_conv_tokens <= available_tokens:
+ return messages
+
+ # Truncate from the beginning (keep recent messages)
+ truncated = []
+ current_tokens = 0
+ for msg in reversed(conversation):
+ msg_tokens = self._estimate_tokens(msg["content"])
+ if current_tokens + msg_tokens <= available_tokens:
+ truncated.insert(0, msg)
+ current_tokens += msg_tokens
+ else:
+ break
+
+ if len(truncated) < len(conversation):
+ print(f"[VLLMUserAgent] Truncated: kept {len(truncated)}/{len(conversation)} turns")
+
+ return [system_msg] + truncated if system_msg else truncated
+
+ def _generate(self, messages: List[Dict[str, str]]) -> str:
+ """Generate response using vLLM server with auto-truncation."""
+ # Truncate messages if context is too long
+ messages = self._truncate_messages(messages)
+
+ payload = {
+ "model": self.model_name,
+ "messages": messages,
+ "max_tokens": self.max_tokens,
+ "temperature": self.temperature,
+ "top_p": 0.9,
+ }
+
+ try:
+ response = requests.post(
+ f"{self.vllm_url}/chat/completions",
+ json=payload,
+ timeout=120
+ )
+ response.raise_for_status()
+ result = response.json()
+ return result["choices"][0]["message"]["content"]
+ except Exception as e:
+ raise RuntimeError(f"vLLM request failed: {e}")
+
+ def get_system_prompt(self) -> str:
+ """Get the system prompt."""
+ return self.system_prompt
+
+ def reverse_roles(self, conversation: List[Dict[str, str]]) -> List[Dict[str, str]]:
+ """Reverse roles for user perspective (agent becomes user, user becomes assistant)."""
+ conversation = deepcopy(conversation)
+ return [
+ {"role": "user" if msg["role"] == "assistant" else "assistant", "content": msg["content"]}
+ for msg in conversation
+ ]
+
+ def generate_user_response(self, conversation: List[Dict[str, str]]) -> Optional[Dict[str, Any]]:
+ """
+ Generate user response given the conversation history.
+
+ Args:
+ conversation: List of {"role": "user"|"assistant", "content": str}
+ From user perspective: agent messages are "assistant"
+
+ Returns:
+ Dict with keys: reasoning, draft_answer, should_terminate, response
+ Or None if failed
+ """
+ for attempt in range(self.num_retries):
+ try:
+ # Build messages: system prompt + reversed conversation
+ messages = [{"role": "system", "content": self.system_prompt}]
+ messages.extend(self.reverse_roles(conversation))
+
+ # Generate response
+ response_text = self._generate(messages)
+
+ # Try to parse as JSON
+ try:
+ parsed = repair_json(response_text, return_objects=True)
+
+ # Check for required keys
+ required_keys = ["reasoning", "draft_answer", "should_terminate", "response"]
+ missing = [k for k in required_keys if k not in parsed]
+
+ if missing:
+ print(f"[VLLMUserAgent] Missing keys: {missing}, attempt {attempt+1}")
+ continue
+
+ return parsed
+
+ except Exception as e:
+ # Fallback: return the raw text as the response
+ if TERMINATION_SIGNAL in response_text:
+ return {
+ "reasoning": "Ending conversation",
+ "draft_answer": "",
+ "should_terminate": True,
+ "response": TERMINATION_SIGNAL
+ }
+ else:
+ return {
+ "reasoning": "",
+ "draft_answer": "",
+ "should_terminate": False,
+ "response": response_text
+ }
+
+ except Exception as e:
+ print(f"[VLLMUserAgent] Error: {e}, attempt {attempt+1}")
+ continue
+
+ print(f"[VLLMUserAgent] Failed after {self.num_retries} attempts")
+ return None
+
+
+class VLLMAgentClient:
+ """
+ Simple vLLM client for agent responses.
+
+ Used by baseline methods that don't have their own vLLM integration.
+ """
+
+ def __init__(
+ self,
+ vllm_url: str = "http://localhost:8003/v1",
+ model_name: str = None,
+ system_prompt: str = None,
+ max_tokens: int = 1024,
+ temperature: float = 0.7,
+ ):
+ self.vllm_url = vllm_url.rstrip('/')
+ self.model_name = model_name
+ self.system_prompt = system_prompt or "You are a helpful AI assistant."
+ self.max_tokens = max_tokens
+ self.temperature = temperature
+
+ if self.model_name is None:
+ self._discover_model()
+
+ def _discover_model(self):
+ """Auto-discover the model name from the vLLM server."""
+ try:
+ response = requests.get(f"{self.vllm_url}/models", timeout=10)
+ response.raise_for_status()
+ models = response.json()
+ if models.get("data") and len(models["data"]) > 0:
+ self.model_name = models["data"][0]["id"]
+ else:
+ self.model_name = "default"
+ except Exception as e:
+ self.model_name = "default"
+
+ def generate_response(self, query: str, conversation_history: List[Dict[str, str]] = None) -> Dict[str, Any]:
+ """Generate agent response."""
+ messages = [{"role": "system", "content": self.system_prompt}]
+
+ if conversation_history:
+ messages.extend(conversation_history)
+
+ messages.append({"role": "user", "content": query})
+
+ payload = {
+ "model": self.model_name,
+ "messages": messages,
+ "max_tokens": self.max_tokens,
+ "temperature": self.temperature,
+ }
+
+ try:
+ response = requests.post(
+ f"{self.vllm_url}/chat/completions",
+ json=payload,
+ timeout=120
+ )
+ response.raise_for_status()
+ result = response.json()
+ content = result["choices"][0]["message"]["content"]
+ return {"response": content, "reasoning": ""}
+ except Exception as e:
+ return {"response": f"[Error: {e}]", "reasoning": ""}
+
+ def __call__(self, conversation: List[Dict[str, str]]) -> str:
+ """Callable interface for compatibility."""
+ # Get last user message
+ for msg in reversed(conversation):
+ if msg["role"] == "user":
+ result = self.generate_response(msg["content"], conversation[:-1])
+ return result.get("response", "")
+ return ""