summaryrefslogtreecommitdiff
path: root/collaborativeagents/agents/vllm_user_agent.py
diff options
context:
space:
mode:
Diffstat (limited to 'collaborativeagents/agents/vllm_user_agent.py')
-rw-r--r--collaborativeagents/agents/vllm_user_agent.py381
1 files changed, 381 insertions, 0 deletions
diff --git a/collaborativeagents/agents/vllm_user_agent.py b/collaborativeagents/agents/vllm_user_agent.py
new file mode 100644
index 0000000..3d73637
--- /dev/null
+++ b/collaborativeagents/agents/vllm_user_agent.py
@@ -0,0 +1,381 @@
+"""
+vLLM-based User Agent for high-performance user simulation.
+
+This replaces the local transformers-based user agent with a vLLM client
+for much faster inference when running parallel experiments.
+"""
+
+import requests
+from typing import List, Dict, Any, Optional
+from copy import deepcopy
+from json_repair import repair_json
+
+# Termination signal from CollaborativeAgents
+TERMINATION_SIGNAL = "TERMINATE"
+
+# User system prompt with preferences (CollaborativeAgents style)
+USER_SYSTEM_PROMPT_WITH_PREFERENCES = """You are a user simulator collaborating with an agent to solve a problem. You will be provided with a problem description, and you must get the agent to help you solve it. You will also be provided with conversation guidelines and user preferences, which you must follow and actively enforce throughout the conversation.
+
+# Problem Description
+{user_task_description}
+{problem}
+Note: the agent cannot see this problem description.
+
+# User Persona
+{user_persona}
+
+# User Preferences
+{user_preferences}
+These preferences are NON-NEGOTIABLE that define how you prefer the agent to behave. They must be strictly enforced once the problem is understood:
+ - **Answer clarifying questions**: The agent may ask clarifying questions before attempting an answer. Answer such questions, and do not enforce preferences about answer format or content while the agent is clarifying.
+ - **Enforce immediately**: Every agent response must satisfy your preferences before you can proceed. Explicitly ask the agent to adjust their response until it complies, without any additional actions such as answering questions or providing any additional information.
+ - **Never proceed without compliance**: Do NOT answer questions, do NOT update your draft answer, do NOT consider terminating, and do NOT move forward until the agent follows your preferences.
+Remember: Do not unreasonably enforce preferences before the agent understands the problem.
+
+# Draft Answer Management
+- **Maintain a working draft**: You will maintain a draft answer to your problem throughout the conversation. Start with an empty draft (e.g., "I don't know"). Update your draft answer based on what you learn from agent responses.
+- **Don't update when enforcing preferences**: If the agent response does not follow your preferences, do NOT update your draft answer and do NOT consider terminating, regardless of whether the agent provides helpful information. Wait until they adjust their approach and satisfy your preferences.
+
+# Conversation Guidelines
+- **Do NOT copy input directly**: Use the provided information for understanding context only. Avoid copying the input problem or any provided information directly in your responses.
+- **Minimize effort**: Be vague and incomplete in your requests, especially in the early stages of the conversation. Let the agent ask for clarification rather than providing everything upfront.
+- **Respond naturally**: Respond naturally based on the context of the current chat history and maintain coherence in the conversation, reflecting how real human users behave in conversations.
+
+# Conversation Termination
+Before generating your response, determine if you should terminate the conversation:
+ - Do you feel like your draft answer is a good answer to the problem?
+ - Do you feel like the agent cannot help further?
+If the agent response does not follow your preferences, you must NOT terminate - instead, enforce the preferences.
+When ready to terminate, respond with "{termination_signal}".
+
+# Output Format:
+{{
+ "preferences_check": str, # For EACH of your preferences that is relevant to this response, evaluate: is it satisfied? List each relevant preference and whether it was followed.
+ "enforce_preferences": bool, # Whether you have to enforce any of your preferences?
+ "reasoning": str, # Brief reasoning (2-3 sentences max). Does the agent response follow all of your preferences? If no, you must enforce them and not proceed. If yes, how should you update your draft answer? Are you satisfied with your current answer and ready to terminate the conversation?
+ "draft_answer": str, # Your current working draft answer to the problem. Start with "I don't know". Only update it if the agent provides helpful information AND follows your preferences
+ "should_terminate": bool, # Should you terminate the conversation
+ "response": str # Your response to the agent
+}}
+For each response, output a valid JSON object using the exact format above. Use double quotes, escape any double quotes within strings using backslashes, escape newlines as \\n, and do not include any text before or after the JSON object.
+"""
+
+USER_SYSTEM_PROMPT_WITHOUT_PREFERENCES = """You are a user simulator collaborating with an agent to solve a problem. You will be provided with a problem description, and you must get the agent to help you solve it. You will also be provided with conversation guidelines, which you must follow throughout the conversation.
+
+# Problem Description
+{user_task_description}
+{problem}
+Note: the agent cannot see this problem description.
+
+# User Persona
+{user_persona}
+
+# Draft Answer Management
+- **Maintain a working draft**: You will maintain a draft answer to your problem throughout the conversation. Start with an empty draft (e.g., "I don't know"). Update your draft answer based on what you learn from agent responses.
+
+# Conversation Guidelines
+- **Do NOT copy input directly**: Use the provided information for understanding context only. Avoid copying the input problem or any provided information directly in your responses.
+- **Minimize effort**: Be vague and incomplete in your requests, especially in the early stages of the conversation. Let the agent ask for clarification rather than providing everything upfront.
+- **Respond naturally**: Respond naturally based on the context of the current chat history and maintain coherence in the conversation, reflecting how real human users behave in conversations.
+
+# Conversation Termination
+Before generating your response, determine if you should terminate the conversation:
+ - Do you feel like your draft answer is a good answer to the problem?
+ - Do you feel like the agent cannot help further?
+When ready to terminate, respond with "{termination_signal}".
+
+# Output Format:
+{{
+ "reasoning": str, # Brief reasoning (2-3 sentences max). How should you update your draft answer? Are you satisfied with your current answer and ready to terminate the conversation?
+ "draft_answer": str, # Your current working draft answer to the problem. Start with "I don't know". Update it if the agent provides helpful information
+ "should_terminate": bool, # Should you terminate the conversation
+ "response": str # Your response to the agent
+}}
+For each response, output a valid JSON object using the exact format above. Use double quotes, escape any double quotes within strings using backslashes, escape newlines as \\n, and do not include any text before or after the JSON object.
+"""
+
+
+class VLLMUserAgent:
+ """
+ User Agent that uses a vLLM server for fast inference.
+
+ Benefits:
+ - Much faster than local transformers (continuous batching)
+ - Can handle concurrent requests from multiple profiles
+ - Supports AWQ/quantized models efficiently
+ """
+
+ def __init__(
+ self,
+ user_task_description: str,
+ problem: str,
+ user_persona: str = None,
+ user_preferences: str = None,
+ vllm_url: str = "http://localhost:8004/v1",
+ model_name: str = None, # Auto-discovered from server
+ num_retries: int = 3,
+ max_tokens: int = 512,
+ temperature: float = 0.7,
+ max_context_length: int = 16384, # Context limit for truncation
+ # For compatibility with LocalUserAgent interface
+ model_path: str = None,
+ api_base: str = None,
+ api_key: str = None,
+ ):
+ self.user_task_description = user_task_description
+ self.problem = problem
+ self.user_persona = user_persona or "A helpful user seeking assistance."
+ self.user_preferences = user_preferences
+ self.num_retries = num_retries
+ self.max_tokens = max_tokens
+ self.temperature = temperature
+ self.max_context_length = max_context_length
+
+ # vLLM configuration
+ self.vllm_url = vllm_url.rstrip('/')
+ self.model_name = model_name
+
+ # Build system prompt
+ if user_preferences:
+ self.system_prompt = USER_SYSTEM_PROMPT_WITH_PREFERENCES.format(
+ user_task_description=user_task_description,
+ problem=problem,
+ user_persona=self.user_persona,
+ user_preferences=user_preferences,
+ termination_signal=TERMINATION_SIGNAL
+ )
+ else:
+ self.system_prompt = USER_SYSTEM_PROMPT_WITHOUT_PREFERENCES.format(
+ user_task_description=user_task_description,
+ problem=problem,
+ user_persona=self.user_persona,
+ termination_signal=TERMINATION_SIGNAL
+ )
+
+ # Auto-discover model name if not provided
+ if self.model_name is None:
+ self._discover_model()
+
+ def _discover_model(self):
+ """Auto-discover the model name from the vLLM server."""
+ try:
+ response = requests.get(f"{self.vllm_url}/models", timeout=10)
+ response.raise_for_status()
+ models = response.json()
+ if models.get("data") and len(models["data"]) > 0:
+ self.model_name = models["data"][0]["id"]
+ else:
+ self.model_name = "default"
+ except Exception as e:
+ print(f"[VLLMUserAgent] Warning: Could not discover model ({e}), using 'default'")
+ self.model_name = "default"
+
+ def _estimate_tokens(self, text: str) -> int:
+ """Estimate token count using character-based heuristic (~3.5 chars/token)."""
+ return int(len(text) / 3.5)
+
+ def _truncate_messages(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
+ """Truncate messages to fit within max_context_length, keeping recent messages."""
+ if not messages:
+ return messages
+
+ # System message is always first and always kept
+ system_msg = messages[0] if messages[0]["role"] == "system" else None
+ conversation = messages[1:] if system_msg else messages
+
+ system_tokens = self._estimate_tokens(system_msg["content"]) if system_msg else 0
+ available_tokens = self.max_context_length - system_tokens - self.max_tokens - 100
+
+ # Check if truncation is needed
+ total_conv_tokens = sum(self._estimate_tokens(m["content"]) for m in conversation)
+
+ if total_conv_tokens <= available_tokens:
+ return messages
+
+ # Truncate from the beginning (keep recent messages)
+ truncated = []
+ current_tokens = 0
+ for msg in reversed(conversation):
+ msg_tokens = self._estimate_tokens(msg["content"])
+ if current_tokens + msg_tokens <= available_tokens:
+ truncated.insert(0, msg)
+ current_tokens += msg_tokens
+ else:
+ break
+
+ if len(truncated) < len(conversation):
+ print(f"[VLLMUserAgent] Truncated: kept {len(truncated)}/{len(conversation)} turns")
+
+ return [system_msg] + truncated if system_msg else truncated
+
+ def _generate(self, messages: List[Dict[str, str]]) -> str:
+ """Generate response using vLLM server with auto-truncation."""
+ # Truncate messages if context is too long
+ messages = self._truncate_messages(messages)
+
+ payload = {
+ "model": self.model_name,
+ "messages": messages,
+ "max_tokens": self.max_tokens,
+ "temperature": self.temperature,
+ "top_p": 0.9,
+ }
+
+ try:
+ response = requests.post(
+ f"{self.vllm_url}/chat/completions",
+ json=payload,
+ timeout=120
+ )
+ response.raise_for_status()
+ result = response.json()
+ return result["choices"][0]["message"]["content"]
+ except Exception as e:
+ raise RuntimeError(f"vLLM request failed: {e}")
+
+ def get_system_prompt(self) -> str:
+ """Get the system prompt."""
+ return self.system_prompt
+
+ def reverse_roles(self, conversation: List[Dict[str, str]]) -> List[Dict[str, str]]:
+ """Reverse roles for user perspective (agent becomes user, user becomes assistant)."""
+ conversation = deepcopy(conversation)
+ return [
+ {"role": "user" if msg["role"] == "assistant" else "assistant", "content": msg["content"]}
+ for msg in conversation
+ ]
+
+ def generate_user_response(self, conversation: List[Dict[str, str]]) -> Optional[Dict[str, Any]]:
+ """
+ Generate user response given the conversation history.
+
+ Args:
+ conversation: List of {"role": "user"|"assistant", "content": str}
+ From user perspective: agent messages are "assistant"
+
+ Returns:
+ Dict with keys: reasoning, draft_answer, should_terminate, response
+ Or None if failed
+ """
+ for attempt in range(self.num_retries):
+ try:
+ # Build messages: system prompt + reversed conversation
+ messages = [{"role": "system", "content": self.system_prompt}]
+ messages.extend(self.reverse_roles(conversation))
+
+ # Generate response
+ response_text = self._generate(messages)
+
+ # Try to parse as JSON
+ try:
+ parsed = repair_json(response_text, return_objects=True)
+
+ # Check for required keys
+ required_keys = ["reasoning", "draft_answer", "should_terminate", "response"]
+ missing = [k for k in required_keys if k not in parsed]
+
+ if missing:
+ print(f"[VLLMUserAgent] Missing keys: {missing}, attempt {attempt+1}")
+ continue
+
+ return parsed
+
+ except Exception as e:
+ # Fallback: return the raw text as the response
+ if TERMINATION_SIGNAL in response_text:
+ return {
+ "reasoning": "Ending conversation",
+ "draft_answer": "",
+ "should_terminate": True,
+ "response": TERMINATION_SIGNAL
+ }
+ else:
+ return {
+ "reasoning": "",
+ "draft_answer": "",
+ "should_terminate": False,
+ "response": response_text
+ }
+
+ except Exception as e:
+ print(f"[VLLMUserAgent] Error: {e}, attempt {attempt+1}")
+ continue
+
+ print(f"[VLLMUserAgent] Failed after {self.num_retries} attempts")
+ return None
+
+
+class VLLMAgentClient:
+ """
+ Simple vLLM client for agent responses.
+
+ Used by baseline methods that don't have their own vLLM integration.
+ """
+
+ def __init__(
+ self,
+ vllm_url: str = "http://localhost:8003/v1",
+ model_name: str = None,
+ system_prompt: str = None,
+ max_tokens: int = 1024,
+ temperature: float = 0.7,
+ ):
+ self.vllm_url = vllm_url.rstrip('/')
+ self.model_name = model_name
+ self.system_prompt = system_prompt or "You are a helpful AI assistant."
+ self.max_tokens = max_tokens
+ self.temperature = temperature
+
+ if self.model_name is None:
+ self._discover_model()
+
+ def _discover_model(self):
+ """Auto-discover the model name from the vLLM server."""
+ try:
+ response = requests.get(f"{self.vllm_url}/models", timeout=10)
+ response.raise_for_status()
+ models = response.json()
+ if models.get("data") and len(models["data"]) > 0:
+ self.model_name = models["data"][0]["id"]
+ else:
+ self.model_name = "default"
+ except Exception as e:
+ self.model_name = "default"
+
+ def generate_response(self, query: str, conversation_history: List[Dict[str, str]] = None) -> Dict[str, Any]:
+ """Generate agent response."""
+ messages = [{"role": "system", "content": self.system_prompt}]
+
+ if conversation_history:
+ messages.extend(conversation_history)
+
+ messages.append({"role": "user", "content": query})
+
+ payload = {
+ "model": self.model_name,
+ "messages": messages,
+ "max_tokens": self.max_tokens,
+ "temperature": self.temperature,
+ }
+
+ try:
+ response = requests.post(
+ f"{self.vllm_url}/chat/completions",
+ json=payload,
+ timeout=120
+ )
+ response.raise_for_status()
+ result = response.json()
+ content = result["choices"][0]["message"]["content"]
+ return {"response": content, "reasoning": ""}
+ except Exception as e:
+ return {"response": f"[Error: {e}]", "reasoning": ""}
+
+ def __call__(self, conversation: List[Dict[str, str]]) -> str:
+ """Callable interface for compatibility."""
+ # Get last user message
+ for msg in reversed(conversation):
+ if msg["role"] == "user":
+ result = self.generate_response(msg["content"], conversation[:-1])
+ return result.get("response", "")
+ return ""