summaryrefslogtreecommitdiff
path: root/collaborativeagents/utils
diff options
context:
space:
mode:
Diffstat (limited to 'collaborativeagents/utils')
-rw-r--r--collaborativeagents/utils/__init__.py1
-rw-r--r--collaborativeagents/utils/vllm_client.py467
2 files changed, 468 insertions, 0 deletions
diff --git a/collaborativeagents/utils/__init__.py b/collaborativeagents/utils/__init__.py
new file mode 100644
index 0000000..28eb7ae
--- /dev/null
+++ b/collaborativeagents/utils/__init__.py
@@ -0,0 +1 @@
+# Utils module for collaborative agents
diff --git a/collaborativeagents/utils/vllm_client.py b/collaborativeagents/utils/vllm_client.py
new file mode 100644
index 0000000..0403364
--- /dev/null
+++ b/collaborativeagents/utils/vllm_client.py
@@ -0,0 +1,467 @@
+"""
+vLLM Client wrapper for high-performance inference.
+
+This module provides a unified interface to vLLM servers, replacing the slow
+transformers-based inference with vLLM's optimized serving.
+
+Usage:
+ # Start vLLM servers first:
+ # CUDA_VISIBLE_DEVICES=0,1,2,3 vllm serve model --port 8004 --tensor-parallel-size 4
+
+ client = VLLMClient(base_url="http://localhost:8004/v1")
+ response = client.chat(messages=[{"role": "user", "content": "Hello"}])
+"""
+
+import os
+import time
+import json
+from typing import List, Dict, Any, Optional
+from dataclasses import dataclass
+import requests
+from concurrent.futures import ThreadPoolExecutor, as_completed
+
+
+@dataclass
+class VLLMConfig:
+ """Configuration for vLLM client."""
+ base_url: str = "http://localhost:8004/v1"
+ model: str = None # Auto-discover from server if None
+ api_key: str = "EMPTY"
+ timeout: int = 120
+ max_retries: int = 3
+
+
+class VLLMClient:
+ """
+ Client for vLLM OpenAI-compatible API.
+
+ Much faster than raw transformers due to:
+ - Continuous batching
+ - PagedAttention
+ - Optimized CUDA kernels
+ """
+
+ def __init__(self, config: VLLMConfig = None, base_url: str = None):
+ if config:
+ self.config = config
+ else:
+ self.config = VLLMConfig(base_url=base_url or "http://localhost:8004/v1")
+
+ self._session = requests.Session()
+ self._session.headers.update({
+ "Authorization": f"Bearer {self.config.api_key}",
+ "Content-Type": "application/json"
+ })
+
+ # Auto-discover model name if not provided
+ if self.config.model is None:
+ self._discover_model()
+
+ def _discover_model(self):
+ """Auto-discover the model name from the vLLM server."""
+ try:
+ response = self._session.get(
+ f"{self.config.base_url}/models",
+ timeout=10
+ )
+ response.raise_for_status()
+ models = response.json()
+ if models.get("data") and len(models["data"]) > 0:
+ self.config.model = models["data"][0]["id"]
+ print(f"[VLLMClient] Auto-discovered model: {self.config.model}")
+ else:
+ self.config.model = "default"
+ print("[VLLMClient] Warning: No models found, using 'default'")
+ except Exception as e:
+ self.config.model = "default"
+ print(f"[VLLMClient] Warning: Could not discover model ({e}), using 'default'")
+
+ def chat(
+ self,
+ messages: List[Dict[str, str]],
+ max_tokens: int = 512,
+ temperature: float = 0.7,
+ top_p: float = 0.9,
+ stop: Optional[List[str]] = None,
+ ) -> Dict[str, Any]:
+ """
+ Send a chat completion request to vLLM server.
+
+ Args:
+ messages: List of {"role": str, "content": str}
+ max_tokens: Maximum tokens to generate
+ temperature: Sampling temperature
+ top_p: Top-p sampling
+ stop: Stop sequences
+
+ Returns:
+ Dict with 'content', 'usage', 'latency_ms'
+ """
+ url = f"{self.config.base_url}/chat/completions"
+
+ payload = {
+ "model": self.config.model,
+ "messages": messages,
+ "max_tokens": max_tokens,
+ "temperature": temperature,
+ "top_p": top_p,
+ }
+ if stop:
+ payload["stop"] = stop
+
+ start_time = time.time()
+
+ for attempt in range(self.config.max_retries):
+ try:
+ response = self._session.post(
+ url,
+ json=payload,
+ timeout=self.config.timeout
+ )
+ response.raise_for_status()
+
+ result = response.json()
+ latency_ms = (time.time() - start_time) * 1000
+
+ return {
+ "content": result["choices"][0]["message"]["content"],
+ "usage": result.get("usage", {}),
+ "latency_ms": latency_ms,
+ "finish_reason": result["choices"][0].get("finish_reason"),
+ }
+
+ except requests.exceptions.RequestException as e:
+ if attempt < self.config.max_retries - 1:
+ time.sleep(1 * (attempt + 1)) # Exponential backoff
+ continue
+ raise RuntimeError(f"vLLM request failed after {self.config.max_retries} attempts: {e}")
+
+ def generate(
+ self,
+ prompt: str,
+ max_tokens: int = 512,
+ temperature: float = 0.7,
+ top_p: float = 0.9,
+ stop: Optional[List[str]] = None,
+ ) -> Dict[str, Any]:
+ """
+ Send a completion request (non-chat) to vLLM server.
+ """
+ url = f"{self.config.base_url}/completions"
+
+ payload = {
+ "model": self.config.model,
+ "prompt": prompt,
+ "max_tokens": max_tokens,
+ "temperature": temperature,
+ "top_p": top_p,
+ }
+ if stop:
+ payload["stop"] = stop
+
+ start_time = time.time()
+
+ for attempt in range(self.config.max_retries):
+ try:
+ response = self._session.post(
+ url,
+ json=payload,
+ timeout=self.config.timeout
+ )
+ response.raise_for_status()
+
+ result = response.json()
+ latency_ms = (time.time() - start_time) * 1000
+
+ return {
+ "content": result["choices"][0]["text"],
+ "usage": result.get("usage", {}),
+ "latency_ms": latency_ms,
+ }
+
+ except requests.exceptions.RequestException as e:
+ if attempt < self.config.max_retries - 1:
+ time.sleep(1 * (attempt + 1))
+ continue
+ raise RuntimeError(f"vLLM request failed: {e}")
+
+ def health_check(self) -> bool:
+ """Check if vLLM server is healthy."""
+ try:
+ # Try the models endpoint
+ response = self._session.get(
+ f"{self.config.base_url}/models",
+ timeout=5
+ )
+ return response.status_code == 200
+ except:
+ return False
+
+ def get_model_info(self) -> Dict[str, Any]:
+ """Get information about loaded model."""
+ try:
+ response = self._session.get(
+ f"{self.config.base_url}/models",
+ timeout=5
+ )
+ response.raise_for_status()
+ return response.json()
+ except Exception as e:
+ return {"error": str(e)}
+
+
+class VLLMUserSimulator:
+ """
+ User simulator using vLLM for fast inference.
+ Drop-in replacement for LocalUserAgent.
+ """
+
+ TERMINATION_SIGNAL = "TERMINATE"
+
+ SYSTEM_PROMPT = """You are a user simulator collaborating with an agent to solve a problem. You will be provided with a problem description, and you must get the agent to help you solve it. You will also be provided with user preferences, which you must actively enforce throughout the conversation.
+
+# Problem Description
+{problem}
+Note: the agent cannot see this problem description.
+
+# User Persona
+{user_persona}
+
+# User Preferences
+{user_preferences}
+These preferences are NON-NEGOTIABLE that define how you prefer the agent to behave. They must be strictly enforced:
+- **Enforce immediately**: Every agent response must satisfy your preferences before you can proceed.
+- **Never proceed without compliance**: Do NOT move forward until the agent follows your preferences.
+
+# Draft Answer Management
+- Maintain a draft answer throughout the conversation. Start with "I don't know".
+- Update your draft based on agent responses that follow your preferences.
+
+# Conversation Termination
+When ready to terminate (draft answer is good OR agent cannot help), respond with "TERMINATE".
+
+# Output Format:
+Respond with a JSON object:
+{{
+ "reasoning": "Brief reasoning about the agent's response and your preferences",
+ "draft_answer": "Your current working draft answer",
+ "should_terminate": true/false,
+ "response": "Your response to the agent"
+}}"""
+
+ def __init__(
+ self,
+ problem: str,
+ user_persona: str,
+ user_preferences: str,
+ vllm_client: VLLMClient,
+ ):
+ self.problem = problem
+ self.user_persona = user_persona
+ self.user_preferences = user_preferences
+ self.client = vllm_client
+
+ self.system_prompt = self.SYSTEM_PROMPT.format(
+ problem=problem,
+ user_persona=user_persona,
+ user_preferences=user_preferences,
+ )
+
+ def generate_user_response(
+ self,
+ conversation: List[Dict[str, str]]
+ ) -> Optional[Dict[str, Any]]:
+ """Generate user response given conversation history."""
+ # Build messages with reversed roles (from user simulator's perspective)
+ messages = [{"role": "system", "content": self.system_prompt}]
+
+ for msg in conversation:
+ # Reverse roles: agent's messages become "user", user's become "assistant"
+ role = "user" if msg["role"] == "assistant" else "assistant"
+ messages.append({"role": role, "content": msg["content"]})
+
+ try:
+ response = self.client.chat(
+ messages=messages,
+ max_tokens=512,
+ temperature=0.7,
+ )
+
+ content = response["content"]
+
+ # Parse JSON response
+ try:
+ from json_repair import repair_json
+ parsed = repair_json(content, return_objects=True)
+
+ if isinstance(parsed, dict) and all(k in parsed for k in ["reasoning", "draft_answer", "should_terminate", "response"]):
+ return parsed
+ except:
+ pass
+
+ # Fallback
+ if self.TERMINATION_SIGNAL in content:
+ return {
+ "reasoning": "Ending conversation",
+ "draft_answer": "",
+ "should_terminate": True,
+ "response": self.TERMINATION_SIGNAL
+ }
+
+ return {
+ "reasoning": "",
+ "draft_answer": "",
+ "should_terminate": False,
+ "response": content
+ }
+
+ except Exception as e:
+ print(f"[VLLMUserSimulator] Error: {e}")
+ return None
+
+
+class VLLMAgentAdapter:
+ """
+ Base agent adapter using vLLM for fast inference.
+ Can be extended for different methods (vanilla, rag, etc.)
+ """
+
+ def __init__(self, vllm_client: VLLMClient, system_prompt: str = None):
+ self.client = vllm_client
+ self.system_prompt = system_prompt or "You are a helpful assistant."
+ self.conversation_history: List[Dict[str, str]] = []
+
+ def reset(self):
+ """Reset conversation history."""
+ self.conversation_history = []
+
+ def generate_response(
+ self,
+ user_message: str,
+ additional_context: str = None,
+ ) -> Dict[str, Any]:
+ """Generate agent response."""
+ self.conversation_history.append({"role": "user", "content": user_message})
+
+ system = self.system_prompt
+ if additional_context:
+ system = f"{system}\n\n{additional_context}"
+
+ messages = [{"role": "system", "content": system}]
+ messages.extend(self.conversation_history)
+
+ response = self.client.chat(
+ messages=messages,
+ max_tokens=1024,
+ temperature=0.7,
+ )
+
+ assistant_content = response["content"]
+ self.conversation_history.append({"role": "assistant", "content": assistant_content})
+
+ return {
+ "response": assistant_content,
+ "usage": response["usage"],
+ "latency_ms": response["latency_ms"],
+ }
+
+
+def benchmark_vllm(
+ client: VLLMClient,
+ n_requests: int = 10,
+ concurrent: bool = False,
+ n_workers: int = 4,
+) -> Dict[str, Any]:
+ """
+ Benchmark vLLM server throughput.
+
+ Args:
+ client: VLLMClient instance
+ n_requests: Number of requests to send
+ concurrent: Whether to send requests concurrently
+ n_workers: Number of concurrent workers
+
+ Returns:
+ Dict with benchmark results
+ """
+ test_messages = [
+ {"role": "user", "content": "What is the capital of France? Answer briefly."}
+ ]
+
+ latencies = []
+ errors = 0
+
+ start_time = time.time()
+
+ if concurrent:
+ with ThreadPoolExecutor(max_workers=n_workers) as executor:
+ futures = [
+ executor.submit(client.chat, test_messages, 64, 0.1)
+ for _ in range(n_requests)
+ ]
+ for future in as_completed(futures):
+ try:
+ result = future.result()
+ latencies.append(result["latency_ms"])
+ except Exception as e:
+ errors += 1
+ print(f"Error: {e}")
+ else:
+ for _ in range(n_requests):
+ try:
+ result = client.chat(test_messages, 64, 0.1)
+ latencies.append(result["latency_ms"])
+ except Exception as e:
+ errors += 1
+ print(f"Error: {e}")
+
+ total_time = time.time() - start_time
+
+ if latencies:
+ return {
+ "n_requests": n_requests,
+ "concurrent": concurrent,
+ "n_workers": n_workers if concurrent else 1,
+ "total_time_s": total_time,
+ "throughput_req_per_s": len(latencies) / total_time,
+ "avg_latency_ms": sum(latencies) / len(latencies),
+ "min_latency_ms": min(latencies),
+ "max_latency_ms": max(latencies),
+ "errors": errors,
+ }
+ else:
+ return {"error": "All requests failed", "errors": errors}
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(description="Test vLLM client")
+ parser.add_argument("--url", default="http://localhost:8004/v1", help="vLLM server URL")
+ parser.add_argument("--benchmark", action="store_true", help="Run benchmark")
+ parser.add_argument("-n", type=int, default=10, help="Number of requests")
+ parser.add_argument("--concurrent", action="store_true", help="Run concurrent benchmark")
+
+ args = parser.parse_args()
+
+ client = VLLMClient(base_url=args.url)
+
+ # Health check
+ print(f"Checking vLLM server at {args.url}...")
+ if client.health_check():
+ print("✓ Server is healthy")
+ print(f"Model info: {client.get_model_info()}")
+ else:
+ print("✗ Server is not responding")
+ exit(1)
+
+ if args.benchmark:
+ print(f"\nRunning benchmark with {args.n} requests (concurrent={args.concurrent})...")
+ results = benchmark_vllm(client, args.n, args.concurrent)
+ print(json.dumps(results, indent=2))
+ else:
+ # Simple test
+ print("\nTesting chat completion...")
+ response = client.chat([{"role": "user", "content": "Hello, who are you?"}])
+ print(f"Response: {response['content'][:200]}...")
+ print(f"Latency: {response['latency_ms']:.1f}ms")