#!/usr/bin/env python3 """ Benchmark inference speed: Transformers vs vLLM. This script helps diagnose the 100x slowdown issue by comparing: 1. Raw transformers inference (current implementation) 2. vLLM server inference (target implementation) Usage: # First, start vLLM server: # CUDA_VISIBLE_DEVICES=0 vllm serve /path/to/model --port 8003 # Then run benchmark: python benchmark_inference.py --mode both --n 20 python benchmark_inference.py --mode vllm --url http://localhost:8003/v1 --n 50 python benchmark_inference.py --mode transformers --model /path/to/model --n 10 """ import argparse import json import time import sys from pathlib import Path from typing import List, Dict, Any from dataclasses import dataclass # Add paths sys.path.insert(0, str(Path(__file__).parent.parent)) sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) @dataclass class BenchmarkResult: mode: str n_requests: int total_time_s: float avg_latency_ms: float min_latency_ms: float max_latency_ms: float throughput_req_per_s: float throughput_conv_per_hr: float # Estimated conversations per hour errors: int def benchmark_transformers( model_path: str, n_requests: int = 10, device: str = "cuda:0", ) -> BenchmarkResult: """Benchmark raw transformers inference.""" import torch from transformers import AutoModelForCausalLM, AutoTokenizer print(f"Loading model from {model_path}...") load_start = time.time() tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.bfloat16, device_map=device, ) if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token load_time = time.time() - load_start print(f"Model loaded in {load_time:.1f}s") # Test prompt (simulating a typical user simulator turn) test_messages = [ {"role": "system", "content": "You are a user simulator. Output JSON with reasoning, draft_answer, should_terminate, and response fields."}, {"role": "user", "content": "The agent said: 'Hello, how can I help you today?' Respond as the user."}, ] prompt = tokenizer.apply_chat_template(test_messages, tokenize=False, add_generation_prompt=True) latencies = [] errors = 0 print(f"Running {n_requests} inference requests...") start_time = time.time() for i in range(n_requests): try: req_start = time.time() inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096) inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=256, do_sample=True, temperature=0.7, top_p=0.9, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, ) # Decode output input_len = inputs["input_ids"].shape[1] gen_ids = outputs[0][input_len:] response = tokenizer.decode(gen_ids, skip_special_tokens=True) latency_ms = (time.time() - req_start) * 1000 latencies.append(latency_ms) if (i + 1) % 5 == 0: print(f" Completed {i + 1}/{n_requests}, last latency: {latency_ms:.0f}ms") except Exception as e: errors += 1 print(f" Error on request {i + 1}: {e}") total_time = time.time() - start_time if not latencies: return BenchmarkResult( mode="transformers", n_requests=n_requests, total_time_s=total_time, avg_latency_ms=0, min_latency_ms=0, max_latency_ms=0, throughput_req_per_s=0, throughput_conv_per_hr=0, errors=errors, ) avg_latency = sum(latencies) / len(latencies) # Estimate: ~10 turns per conversation, so conv/hr = (req/s) * 3600 / 10 throughput = len(latencies) / total_time conv_per_hr = throughput * 3600 / 10 return BenchmarkResult( mode="transformers", n_requests=n_requests, total_time_s=total_time, avg_latency_ms=avg_latency, min_latency_ms=min(latencies), max_latency_ms=max(latencies), throughput_req_per_s=throughput, throughput_conv_per_hr=conv_per_hr, errors=errors, ) def benchmark_vllm( base_url: str = "http://localhost:8003/v1", n_requests: int = 10, concurrent: bool = False, n_workers: int = 4, ) -> BenchmarkResult: """Benchmark vLLM server inference.""" from utils.vllm_client import VLLMClient client = VLLMClient(base_url=base_url) # Check health if not client.health_check(): print(f"ERROR: vLLM server at {base_url} is not responding") return BenchmarkResult( mode="vllm", n_requests=n_requests, total_time_s=0, avg_latency_ms=0, min_latency_ms=0, max_latency_ms=0, throughput_req_per_s=0, throughput_conv_per_hr=0, errors=n_requests, ) print(f"vLLM server healthy: {client.get_model_info()}") # Test messages test_messages = [ {"role": "system", "content": "You are a user simulator. Output JSON with reasoning, draft_answer, should_terminate, and response fields."}, {"role": "user", "content": "The agent said: 'Hello, how can I help you today?' Respond as the user."}, ] latencies = [] errors = 0 print(f"Running {n_requests} inference requests (concurrent={concurrent})...") start_time = time.time() if concurrent: from concurrent.futures import ThreadPoolExecutor, as_completed with ThreadPoolExecutor(max_workers=n_workers) as executor: futures = [ executor.submit(client.chat, test_messages, 256, 0.7) for _ in range(n_requests) ] for i, future in enumerate(as_completed(futures)): try: result = future.result() latencies.append(result["latency_ms"]) if (i + 1) % 10 == 0: print(f" Completed {i + 1}/{n_requests}") except Exception as e: errors += 1 print(f" Error: {e}") else: for i in range(n_requests): try: result = client.chat(test_messages, 256, 0.7) latencies.append(result["latency_ms"]) if (i + 1) % 5 == 0: print(f" Completed {i + 1}/{n_requests}, last latency: {result['latency_ms']:.0f}ms") except Exception as e: errors += 1 print(f" Error on request {i + 1}: {e}") total_time = time.time() - start_time if not latencies: return BenchmarkResult( mode="vllm" + ("_concurrent" if concurrent else ""), n_requests=n_requests, total_time_s=total_time, avg_latency_ms=0, min_latency_ms=0, max_latency_ms=0, throughput_req_per_s=0, throughput_conv_per_hr=0, errors=errors, ) avg_latency = sum(latencies) / len(latencies) throughput = len(latencies) / total_time conv_per_hr = throughput * 3600 / 10 return BenchmarkResult( mode="vllm" + ("_concurrent" if concurrent else ""), n_requests=n_requests, total_time_s=total_time, avg_latency_ms=avg_latency, min_latency_ms=min(latencies), max_latency_ms=max(latencies), throughput_req_per_s=throughput, throughput_conv_per_hr=conv_per_hr, errors=errors, ) def benchmark_full_conversation( vllm_url_70b: str, vllm_url_8b: str, n_conversations: int = 5, max_turns: int = 10, ) -> Dict[str, Any]: """ Benchmark a full multi-turn conversation with user simulator and agent. This simulates the actual experiment loop. """ from utils.vllm_client import VLLMClient, VLLMUserSimulator, VLLMAgentAdapter user_client = VLLMClient(base_url=vllm_url_70b) agent_client = VLLMClient(base_url=vllm_url_8b) if not user_client.health_check(): print(f"ERROR: 70B server at {vllm_url_70b} not responding") return {"error": "70B server not available"} if not agent_client.health_check(): print(f"ERROR: 8B server at {vllm_url_8b} not responding") return {"error": "8B server not available"} print(f"Running {n_conversations} full conversations (max {max_turns} turns each)...") conversation_times = [] total_turns = 0 start_time = time.time() for conv_idx in range(n_conversations): conv_start = time.time() # Create user simulator user_sim = VLLMUserSimulator( problem="What is 2 + 2? Explain your reasoning step by step.", user_persona="A student learning math", user_preferences="- I prefer step-by-step explanations\n- Always show your work", vllm_client=user_client, ) # Create agent agent = VLLMAgentAdapter( vllm_client=agent_client, system_prompt="You are a helpful math tutor. Explain concepts clearly." ) # Run conversation conversation = [{"role": "assistant", "content": "How can I help you today?"}] for turn in range(max_turns): # User turn user_response = user_sim.generate_user_response(conversation) if user_response is None: break conversation.append({"role": "user", "content": user_response["response"]}) if user_response.get("should_terminate", False): break # Agent turn agent_response = agent.generate_response(user_response["response"]) conversation.append({"role": "assistant", "content": agent_response["response"]}) total_turns += 1 conv_time = time.time() - conv_start conversation_times.append(conv_time) print(f" Conversation {conv_idx + 1}/{n_conversations}: {len(conversation)} messages, {conv_time:.1f}s") total_time = time.time() - start_time return { "n_conversations": n_conversations, "total_turns": total_turns, "total_time_s": total_time, "avg_conv_time_s": sum(conversation_times) / len(conversation_times) if conversation_times else 0, "throughput_conv_per_hr": n_conversations / total_time * 3600, "throughput_turns_per_hr": total_turns / total_time * 3600, } def print_results(results: List[BenchmarkResult]): """Print benchmark results in a nice table.""" print("\n" + "=" * 80) print("BENCHMARK RESULTS") print("=" * 80) print(f"\n{'Mode':<20} {'Requests':<10} {'Avg Latency':<12} {'Throughput':<15} {'Conv/hr':<12} {'Errors':<8}") print("-" * 80) for r in results: print(f"{r.mode:<20} {r.n_requests:<10} {r.avg_latency_ms:>8.0f}ms {r.throughput_req_per_s:>10.2f}/s {r.throughput_conv_per_hr:>8.0f} {r.errors:<8}") print("-" * 80) # Compare speedup if len(results) >= 2: transformers_result = next((r for r in results if r.mode == "transformers"), None) vllm_result = next((r for r in results if "vllm" in r.mode and r.throughput_req_per_s > 0), None) if transformers_result and vllm_result and transformers_result.throughput_req_per_s > 0: speedup = vllm_result.throughput_req_per_s / transformers_result.throughput_req_per_s print(f"\nvLLM speedup over transformers: {speedup:.1f}x") # Target comparison target_conv_per_hr = 2000 for r in results: if r.throughput_conv_per_hr > 0: ratio = r.throughput_conv_per_hr / target_conv_per_hr status = "✓" if ratio >= 0.5 else "✗" print(f"{status} {r.mode}: {r.throughput_conv_per_hr:.0f} conv/hr ({ratio:.1%} of paper's 2000 conv/hr)") def main(): parser = argparse.ArgumentParser(description="Benchmark inference speed") parser.add_argument("--mode", choices=["transformers", "vllm", "both", "conversation"], default="vllm") parser.add_argument("--model", type=str, default="/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/models/llama-3.1-8b-instruct", help="Model path for transformers benchmark") parser.add_argument("--url", type=str, default="http://localhost:8003/v1", help="vLLM server URL") parser.add_argument("--url-70b", type=str, default="http://localhost:8004/v1", help="vLLM server URL for 70B model (user simulator)") parser.add_argument("--url-8b", type=str, default="http://localhost:8003/v1", help="vLLM server URL for 8B model (agent)") parser.add_argument("-n", type=int, default=20, help="Number of requests") parser.add_argument("--concurrent", action="store_true", help="Run vLLM benchmark with concurrent requests") parser.add_argument("--device", type=str, default="cuda:0", help="Device for transformers") args = parser.parse_args() results = [] if args.mode == "conversation": # Full conversation benchmark conv_results = benchmark_full_conversation( args.url_70b, args.url_8b, n_conversations=args.n, ) print("\n" + "=" * 80) print("FULL CONVERSATION BENCHMARK") print("=" * 80) print(json.dumps(conv_results, indent=2)) if "throughput_conv_per_hr" in conv_results: target = 2000 actual = conv_results["throughput_conv_per_hr"] print(f"\nTarget: {target} conv/hr (paper)") print(f"Actual: {actual:.0f} conv/hr ({actual/target:.1%} of target)") else: if args.mode in ["transformers", "both"]: print("\n" + "=" * 40) print("TRANSFORMERS BENCHMARK") print("=" * 40) result = benchmark_transformers(args.model, args.n, args.device) results.append(result) if args.mode in ["vllm", "both"]: print("\n" + "=" * 40) print("vLLM BENCHMARK (sequential)") print("=" * 40) result = benchmark_vllm(args.url, args.n, concurrent=False) results.append(result) if args.concurrent: print("\n" + "=" * 40) print("vLLM BENCHMARK (concurrent)") print("=" * 40) result = benchmark_vllm(args.url, args.n, concurrent=True, n_workers=4) results.append(result) print_results(results) if __name__ == "__main__": main()