""" vLLM Client wrapper for high-performance inference. This module provides a unified interface to vLLM servers, replacing the slow transformers-based inference with vLLM's optimized serving. Usage: # Start vLLM servers first: # CUDA_VISIBLE_DEVICES=0,1,2,3 vllm serve model --port 8004 --tensor-parallel-size 4 client = VLLMClient(base_url="http://localhost:8004/v1") response = client.chat(messages=[{"role": "user", "content": "Hello"}]) """ import os import time import json from typing import List, Dict, Any, Optional from dataclasses import dataclass import requests from concurrent.futures import ThreadPoolExecutor, as_completed @dataclass class VLLMConfig: """Configuration for vLLM client.""" base_url: str = "http://localhost:8004/v1" model: str = None # Auto-discover from server if None api_key: str = "EMPTY" timeout: int = 120 max_retries: int = 3 class VLLMClient: """ Client for vLLM OpenAI-compatible API. Much faster than raw transformers due to: - Continuous batching - PagedAttention - Optimized CUDA kernels """ def __init__(self, config: VLLMConfig = None, base_url: str = None): if config: self.config = config else: self.config = VLLMConfig(base_url=base_url or "http://localhost:8004/v1") self._session = requests.Session() self._session.headers.update({ "Authorization": f"Bearer {self.config.api_key}", "Content-Type": "application/json" }) # Auto-discover model name if not provided if self.config.model is None: self._discover_model() def _discover_model(self): """Auto-discover the model name from the vLLM server.""" try: response = self._session.get( f"{self.config.base_url}/models", timeout=10 ) response.raise_for_status() models = response.json() if models.get("data") and len(models["data"]) > 0: self.config.model = models["data"][0]["id"] print(f"[VLLMClient] Auto-discovered model: {self.config.model}") else: self.config.model = "default" print("[VLLMClient] Warning: No models found, using 'default'") except Exception as e: self.config.model = "default" print(f"[VLLMClient] Warning: Could not discover model ({e}), using 'default'") def chat( self, messages: List[Dict[str, str]], max_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.9, stop: Optional[List[str]] = None, ) -> Dict[str, Any]: """ Send a chat completion request to vLLM server. Args: messages: List of {"role": str, "content": str} max_tokens: Maximum tokens to generate temperature: Sampling temperature top_p: Top-p sampling stop: Stop sequences Returns: Dict with 'content', 'usage', 'latency_ms' """ url = f"{self.config.base_url}/chat/completions" payload = { "model": self.config.model, "messages": messages, "max_tokens": max_tokens, "temperature": temperature, "top_p": top_p, } if stop: payload["stop"] = stop start_time = time.time() for attempt in range(self.config.max_retries): try: response = self._session.post( url, json=payload, timeout=self.config.timeout ) response.raise_for_status() result = response.json() latency_ms = (time.time() - start_time) * 1000 return { "content": result["choices"][0]["message"]["content"], "usage": result.get("usage", {}), "latency_ms": latency_ms, "finish_reason": result["choices"][0].get("finish_reason"), } except requests.exceptions.RequestException as e: if attempt < self.config.max_retries - 1: time.sleep(1 * (attempt + 1)) # Exponential backoff continue raise RuntimeError(f"vLLM request failed after {self.config.max_retries} attempts: {e}") def generate( self, prompt: str, max_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.9, stop: Optional[List[str]] = None, ) -> Dict[str, Any]: """ Send a completion request (non-chat) to vLLM server. """ url = f"{self.config.base_url}/completions" payload = { "model": self.config.model, "prompt": prompt, "max_tokens": max_tokens, "temperature": temperature, "top_p": top_p, } if stop: payload["stop"] = stop start_time = time.time() for attempt in range(self.config.max_retries): try: response = self._session.post( url, json=payload, timeout=self.config.timeout ) response.raise_for_status() result = response.json() latency_ms = (time.time() - start_time) * 1000 return { "content": result["choices"][0]["text"], "usage": result.get("usage", {}), "latency_ms": latency_ms, } except requests.exceptions.RequestException as e: if attempt < self.config.max_retries - 1: time.sleep(1 * (attempt + 1)) continue raise RuntimeError(f"vLLM request failed: {e}") def health_check(self) -> bool: """Check if vLLM server is healthy.""" try: # Try the models endpoint response = self._session.get( f"{self.config.base_url}/models", timeout=5 ) return response.status_code == 200 except: return False def get_model_info(self) -> Dict[str, Any]: """Get information about loaded model.""" try: response = self._session.get( f"{self.config.base_url}/models", timeout=5 ) response.raise_for_status() return response.json() except Exception as e: return {"error": str(e)} class VLLMUserSimulator: """ User simulator using vLLM for fast inference. Drop-in replacement for LocalUserAgent. """ TERMINATION_SIGNAL = "TERMINATE" SYSTEM_PROMPT = """You are a user simulator collaborating with an agent to solve a problem. You will be provided with a problem description, and you must get the agent to help you solve it. You will also be provided with user preferences, which you must actively enforce throughout the conversation. # Problem Description {problem} Note: the agent cannot see this problem description. # User Persona {user_persona} # User Preferences {user_preferences} These preferences are NON-NEGOTIABLE that define how you prefer the agent to behave. They must be strictly enforced: - **Enforce immediately**: Every agent response must satisfy your preferences before you can proceed. - **Never proceed without compliance**: Do NOT move forward until the agent follows your preferences. # Draft Answer Management - Maintain a draft answer throughout the conversation. Start with "I don't know". - Update your draft based on agent responses that follow your preferences. # Conversation Termination When ready to terminate (draft answer is good OR agent cannot help), respond with "TERMINATE". # Output Format: Respond with a JSON object: {{ "reasoning": "Brief reasoning about the agent's response and your preferences", "draft_answer": "Your current working draft answer", "should_terminate": true/false, "response": "Your response to the agent" }}""" def __init__( self, problem: str, user_persona: str, user_preferences: str, vllm_client: VLLMClient, ): self.problem = problem self.user_persona = user_persona self.user_preferences = user_preferences self.client = vllm_client self.system_prompt = self.SYSTEM_PROMPT.format( problem=problem, user_persona=user_persona, user_preferences=user_preferences, ) def generate_user_response( self, conversation: List[Dict[str, str]] ) -> Optional[Dict[str, Any]]: """Generate user response given conversation history.""" # Build messages with reversed roles (from user simulator's perspective) messages = [{"role": "system", "content": self.system_prompt}] for msg in conversation: # Reverse roles: agent's messages become "user", user's become "assistant" role = "user" if msg["role"] == "assistant" else "assistant" messages.append({"role": role, "content": msg["content"]}) try: response = self.client.chat( messages=messages, max_tokens=512, temperature=0.7, ) content = response["content"] # Parse JSON response try: from json_repair import repair_json parsed = repair_json(content, return_objects=True) if isinstance(parsed, dict) and all(k in parsed for k in ["reasoning", "draft_answer", "should_terminate", "response"]): return parsed except: pass # Fallback if self.TERMINATION_SIGNAL in content: return { "reasoning": "Ending conversation", "draft_answer": "", "should_terminate": True, "response": self.TERMINATION_SIGNAL } return { "reasoning": "", "draft_answer": "", "should_terminate": False, "response": content } except Exception as e: print(f"[VLLMUserSimulator] Error: {e}") return None class VLLMAgentAdapter: """ Base agent adapter using vLLM for fast inference. Can be extended for different methods (vanilla, rag, etc.) """ def __init__(self, vllm_client: VLLMClient, system_prompt: str = None): self.client = vllm_client self.system_prompt = system_prompt or "You are a helpful assistant." self.conversation_history: List[Dict[str, str]] = [] def reset(self): """Reset conversation history.""" self.conversation_history = [] def generate_response( self, user_message: str, additional_context: str = None, ) -> Dict[str, Any]: """Generate agent response.""" self.conversation_history.append({"role": "user", "content": user_message}) system = self.system_prompt if additional_context: system = f"{system}\n\n{additional_context}" messages = [{"role": "system", "content": system}] messages.extend(self.conversation_history) response = self.client.chat( messages=messages, max_tokens=1024, temperature=0.7, ) assistant_content = response["content"] self.conversation_history.append({"role": "assistant", "content": assistant_content}) return { "response": assistant_content, "usage": response["usage"], "latency_ms": response["latency_ms"], } def benchmark_vllm( client: VLLMClient, n_requests: int = 10, concurrent: bool = False, n_workers: int = 4, ) -> Dict[str, Any]: """ Benchmark vLLM server throughput. Args: client: VLLMClient instance n_requests: Number of requests to send concurrent: Whether to send requests concurrently n_workers: Number of concurrent workers Returns: Dict with benchmark results """ test_messages = [ {"role": "user", "content": "What is the capital of France? Answer briefly."} ] latencies = [] errors = 0 start_time = time.time() if concurrent: with ThreadPoolExecutor(max_workers=n_workers) as executor: futures = [ executor.submit(client.chat, test_messages, 64, 0.1) for _ in range(n_requests) ] for future in as_completed(futures): try: result = future.result() latencies.append(result["latency_ms"]) except Exception as e: errors += 1 print(f"Error: {e}") else: for _ in range(n_requests): try: result = client.chat(test_messages, 64, 0.1) latencies.append(result["latency_ms"]) except Exception as e: errors += 1 print(f"Error: {e}") total_time = time.time() - start_time if latencies: return { "n_requests": n_requests, "concurrent": concurrent, "n_workers": n_workers if concurrent else 1, "total_time_s": total_time, "throughput_req_per_s": len(latencies) / total_time, "avg_latency_ms": sum(latencies) / len(latencies), "min_latency_ms": min(latencies), "max_latency_ms": max(latencies), "errors": errors, } else: return {"error": "All requests failed", "errors": errors} if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Test vLLM client") parser.add_argument("--url", default="http://localhost:8004/v1", help="vLLM server URL") parser.add_argument("--benchmark", action="store_true", help="Run benchmark") parser.add_argument("-n", type=int, default=10, help="Number of requests") parser.add_argument("--concurrent", action="store_true", help="Run concurrent benchmark") args = parser.parse_args() client = VLLMClient(base_url=args.url) # Health check print(f"Checking vLLM server at {args.url}...") if client.health_check(): print("✓ Server is healthy") print(f"Model info: {client.get_model_info()}") else: print("✗ Server is not responding") exit(1) if args.benchmark: print(f"\nRunning benchmark with {args.n} requests (concurrent={args.concurrent})...") results = benchmark_vllm(client, args.n, args.concurrent) print(json.dumps(results, indent=2)) else: # Simple test print("\nTesting chat completion...") response = client.chat([{"role": "user", "content": "Hello, who are you?"}]) print(f"Response: {response['content'][:200]}...") print(f"Latency: {response['latency_ms']:.1f}ms")