diff options
Diffstat (limited to 'collaborativeagents/utils/vllm_client.py')
| -rw-r--r-- | collaborativeagents/utils/vllm_client.py | 467 |
1 files changed, 467 insertions, 0 deletions
diff --git a/collaborativeagents/utils/vllm_client.py b/collaborativeagents/utils/vllm_client.py new file mode 100644 index 0000000..0403364 --- /dev/null +++ b/collaborativeagents/utils/vllm_client.py @@ -0,0 +1,467 @@ +""" +vLLM Client wrapper for high-performance inference. + +This module provides a unified interface to vLLM servers, replacing the slow +transformers-based inference with vLLM's optimized serving. + +Usage: + # Start vLLM servers first: + # CUDA_VISIBLE_DEVICES=0,1,2,3 vllm serve model --port 8004 --tensor-parallel-size 4 + + client = VLLMClient(base_url="http://localhost:8004/v1") + response = client.chat(messages=[{"role": "user", "content": "Hello"}]) +""" + +import os +import time +import json +from typing import List, Dict, Any, Optional +from dataclasses import dataclass +import requests +from concurrent.futures import ThreadPoolExecutor, as_completed + + +@dataclass +class VLLMConfig: + """Configuration for vLLM client.""" + base_url: str = "http://localhost:8004/v1" + model: str = None # Auto-discover from server if None + api_key: str = "EMPTY" + timeout: int = 120 + max_retries: int = 3 + + +class VLLMClient: + """ + Client for vLLM OpenAI-compatible API. + + Much faster than raw transformers due to: + - Continuous batching + - PagedAttention + - Optimized CUDA kernels + """ + + def __init__(self, config: VLLMConfig = None, base_url: str = None): + if config: + self.config = config + else: + self.config = VLLMConfig(base_url=base_url or "http://localhost:8004/v1") + + self._session = requests.Session() + self._session.headers.update({ + "Authorization": f"Bearer {self.config.api_key}", + "Content-Type": "application/json" + }) + + # Auto-discover model name if not provided + if self.config.model is None: + self._discover_model() + + def _discover_model(self): + """Auto-discover the model name from the vLLM server.""" + try: + response = self._session.get( + f"{self.config.base_url}/models", + timeout=10 + ) + response.raise_for_status() + models = response.json() + if models.get("data") and len(models["data"]) > 0: + self.config.model = models["data"][0]["id"] + print(f"[VLLMClient] Auto-discovered model: {self.config.model}") + else: + self.config.model = "default" + print("[VLLMClient] Warning: No models found, using 'default'") + except Exception as e: + self.config.model = "default" + print(f"[VLLMClient] Warning: Could not discover model ({e}), using 'default'") + + def chat( + self, + messages: List[Dict[str, str]], + max_tokens: int = 512, + temperature: float = 0.7, + top_p: float = 0.9, + stop: Optional[List[str]] = None, + ) -> Dict[str, Any]: + """ + Send a chat completion request to vLLM server. + + Args: + messages: List of {"role": str, "content": str} + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + top_p: Top-p sampling + stop: Stop sequences + + Returns: + Dict with 'content', 'usage', 'latency_ms' + """ + url = f"{self.config.base_url}/chat/completions" + + payload = { + "model": self.config.model, + "messages": messages, + "max_tokens": max_tokens, + "temperature": temperature, + "top_p": top_p, + } + if stop: + payload["stop"] = stop + + start_time = time.time() + + for attempt in range(self.config.max_retries): + try: + response = self._session.post( + url, + json=payload, + timeout=self.config.timeout + ) + response.raise_for_status() + + result = response.json() + latency_ms = (time.time() - start_time) * 1000 + + return { + "content": result["choices"][0]["message"]["content"], + "usage": result.get("usage", {}), + "latency_ms": latency_ms, + "finish_reason": result["choices"][0].get("finish_reason"), + } + + except requests.exceptions.RequestException as e: + if attempt < self.config.max_retries - 1: + time.sleep(1 * (attempt + 1)) # Exponential backoff + continue + raise RuntimeError(f"vLLM request failed after {self.config.max_retries} attempts: {e}") + + def generate( + self, + prompt: str, + max_tokens: int = 512, + temperature: float = 0.7, + top_p: float = 0.9, + stop: Optional[List[str]] = None, + ) -> Dict[str, Any]: + """ + Send a completion request (non-chat) to vLLM server. + """ + url = f"{self.config.base_url}/completions" + + payload = { + "model": self.config.model, + "prompt": prompt, + "max_tokens": max_tokens, + "temperature": temperature, + "top_p": top_p, + } + if stop: + payload["stop"] = stop + + start_time = time.time() + + for attempt in range(self.config.max_retries): + try: + response = self._session.post( + url, + json=payload, + timeout=self.config.timeout + ) + response.raise_for_status() + + result = response.json() + latency_ms = (time.time() - start_time) * 1000 + + return { + "content": result["choices"][0]["text"], + "usage": result.get("usage", {}), + "latency_ms": latency_ms, + } + + except requests.exceptions.RequestException as e: + if attempt < self.config.max_retries - 1: + time.sleep(1 * (attempt + 1)) + continue + raise RuntimeError(f"vLLM request failed: {e}") + + def health_check(self) -> bool: + """Check if vLLM server is healthy.""" + try: + # Try the models endpoint + response = self._session.get( + f"{self.config.base_url}/models", + timeout=5 + ) + return response.status_code == 200 + except: + return False + + def get_model_info(self) -> Dict[str, Any]: + """Get information about loaded model.""" + try: + response = self._session.get( + f"{self.config.base_url}/models", + timeout=5 + ) + response.raise_for_status() + return response.json() + except Exception as e: + return {"error": str(e)} + + +class VLLMUserSimulator: + """ + User simulator using vLLM for fast inference. + Drop-in replacement for LocalUserAgent. + """ + + TERMINATION_SIGNAL = "TERMINATE" + + SYSTEM_PROMPT = """You are a user simulator collaborating with an agent to solve a problem. You will be provided with a problem description, and you must get the agent to help you solve it. You will also be provided with user preferences, which you must actively enforce throughout the conversation. + +# Problem Description +{problem} +Note: the agent cannot see this problem description. + +# User Persona +{user_persona} + +# User Preferences +{user_preferences} +These preferences are NON-NEGOTIABLE that define how you prefer the agent to behave. They must be strictly enforced: +- **Enforce immediately**: Every agent response must satisfy your preferences before you can proceed. +- **Never proceed without compliance**: Do NOT move forward until the agent follows your preferences. + +# Draft Answer Management +- Maintain a draft answer throughout the conversation. Start with "I don't know". +- Update your draft based on agent responses that follow your preferences. + +# Conversation Termination +When ready to terminate (draft answer is good OR agent cannot help), respond with "TERMINATE". + +# Output Format: +Respond with a JSON object: +{{ + "reasoning": "Brief reasoning about the agent's response and your preferences", + "draft_answer": "Your current working draft answer", + "should_terminate": true/false, + "response": "Your response to the agent" +}}""" + + def __init__( + self, + problem: str, + user_persona: str, + user_preferences: str, + vllm_client: VLLMClient, + ): + self.problem = problem + self.user_persona = user_persona + self.user_preferences = user_preferences + self.client = vllm_client + + self.system_prompt = self.SYSTEM_PROMPT.format( + problem=problem, + user_persona=user_persona, + user_preferences=user_preferences, + ) + + def generate_user_response( + self, + conversation: List[Dict[str, str]] + ) -> Optional[Dict[str, Any]]: + """Generate user response given conversation history.""" + # Build messages with reversed roles (from user simulator's perspective) + messages = [{"role": "system", "content": self.system_prompt}] + + for msg in conversation: + # Reverse roles: agent's messages become "user", user's become "assistant" + role = "user" if msg["role"] == "assistant" else "assistant" + messages.append({"role": role, "content": msg["content"]}) + + try: + response = self.client.chat( + messages=messages, + max_tokens=512, + temperature=0.7, + ) + + content = response["content"] + + # Parse JSON response + try: + from json_repair import repair_json + parsed = repair_json(content, return_objects=True) + + if isinstance(parsed, dict) and all(k in parsed for k in ["reasoning", "draft_answer", "should_terminate", "response"]): + return parsed + except: + pass + + # Fallback + if self.TERMINATION_SIGNAL in content: + return { + "reasoning": "Ending conversation", + "draft_answer": "", + "should_terminate": True, + "response": self.TERMINATION_SIGNAL + } + + return { + "reasoning": "", + "draft_answer": "", + "should_terminate": False, + "response": content + } + + except Exception as e: + print(f"[VLLMUserSimulator] Error: {e}") + return None + + +class VLLMAgentAdapter: + """ + Base agent adapter using vLLM for fast inference. + Can be extended for different methods (vanilla, rag, etc.) + """ + + def __init__(self, vllm_client: VLLMClient, system_prompt: str = None): + self.client = vllm_client + self.system_prompt = system_prompt or "You are a helpful assistant." + self.conversation_history: List[Dict[str, str]] = [] + + def reset(self): + """Reset conversation history.""" + self.conversation_history = [] + + def generate_response( + self, + user_message: str, + additional_context: str = None, + ) -> Dict[str, Any]: + """Generate agent response.""" + self.conversation_history.append({"role": "user", "content": user_message}) + + system = self.system_prompt + if additional_context: + system = f"{system}\n\n{additional_context}" + + messages = [{"role": "system", "content": system}] + messages.extend(self.conversation_history) + + response = self.client.chat( + messages=messages, + max_tokens=1024, + temperature=0.7, + ) + + assistant_content = response["content"] + self.conversation_history.append({"role": "assistant", "content": assistant_content}) + + return { + "response": assistant_content, + "usage": response["usage"], + "latency_ms": response["latency_ms"], + } + + +def benchmark_vllm( + client: VLLMClient, + n_requests: int = 10, + concurrent: bool = False, + n_workers: int = 4, +) -> Dict[str, Any]: + """ + Benchmark vLLM server throughput. + + Args: + client: VLLMClient instance + n_requests: Number of requests to send + concurrent: Whether to send requests concurrently + n_workers: Number of concurrent workers + + Returns: + Dict with benchmark results + """ + test_messages = [ + {"role": "user", "content": "What is the capital of France? Answer briefly."} + ] + + latencies = [] + errors = 0 + + start_time = time.time() + + if concurrent: + with ThreadPoolExecutor(max_workers=n_workers) as executor: + futures = [ + executor.submit(client.chat, test_messages, 64, 0.1) + for _ in range(n_requests) + ] + for future in as_completed(futures): + try: + result = future.result() + latencies.append(result["latency_ms"]) + except Exception as e: + errors += 1 + print(f"Error: {e}") + else: + for _ in range(n_requests): + try: + result = client.chat(test_messages, 64, 0.1) + latencies.append(result["latency_ms"]) + except Exception as e: + errors += 1 + print(f"Error: {e}") + + total_time = time.time() - start_time + + if latencies: + return { + "n_requests": n_requests, + "concurrent": concurrent, + "n_workers": n_workers if concurrent else 1, + "total_time_s": total_time, + "throughput_req_per_s": len(latencies) / total_time, + "avg_latency_ms": sum(latencies) / len(latencies), + "min_latency_ms": min(latencies), + "max_latency_ms": max(latencies), + "errors": errors, + } + else: + return {"error": "All requests failed", "errors": errors} + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Test vLLM client") + parser.add_argument("--url", default="http://localhost:8004/v1", help="vLLM server URL") + parser.add_argument("--benchmark", action="store_true", help="Run benchmark") + parser.add_argument("-n", type=int, default=10, help="Number of requests") + parser.add_argument("--concurrent", action="store_true", help="Run concurrent benchmark") + + args = parser.parse_args() + + client = VLLMClient(base_url=args.url) + + # Health check + print(f"Checking vLLM server at {args.url}...") + if client.health_check(): + print("✓ Server is healthy") + print(f"Model info: {client.get_model_info()}") + else: + print("✗ Server is not responding") + exit(1) + + if args.benchmark: + print(f"\nRunning benchmark with {args.n} requests (concurrent={args.concurrent})...") + results = benchmark_vllm(client, args.n, args.concurrent) + print(json.dumps(results, indent=2)) + else: + # Simple test + print("\nTesting chat completion...") + response = client.chat([{"role": "user", "content": "Hello, who are you?"}]) + print(f"Response: {response['content'][:200]}...") + print(f"Latency: {response['latency_ms']:.1f}ms") |
