summaryrefslogtreecommitdiff
path: root/collaborativeagents/scripts/benchmark_inference.py
diff options
context:
space:
mode:
Diffstat (limited to 'collaborativeagents/scripts/benchmark_inference.py')
-rwxr-xr-xcollaborativeagents/scripts/benchmark_inference.py429
1 files changed, 429 insertions, 0 deletions
diff --git a/collaborativeagents/scripts/benchmark_inference.py b/collaborativeagents/scripts/benchmark_inference.py
new file mode 100755
index 0000000..6a2ee13
--- /dev/null
+++ b/collaborativeagents/scripts/benchmark_inference.py
@@ -0,0 +1,429 @@
+#!/usr/bin/env python3
+"""
+Benchmark inference speed: Transformers vs vLLM.
+
+This script helps diagnose the 100x slowdown issue by comparing:
+1. Raw transformers inference (current implementation)
+2. vLLM server inference (target implementation)
+
+Usage:
+ # First, start vLLM server:
+ # CUDA_VISIBLE_DEVICES=0 vllm serve /path/to/model --port 8003
+
+ # Then run benchmark:
+ python benchmark_inference.py --mode both --n 20
+ python benchmark_inference.py --mode vllm --url http://localhost:8003/v1 --n 50
+ python benchmark_inference.py --mode transformers --model /path/to/model --n 10
+"""
+
+import argparse
+import json
+import time
+import sys
+from pathlib import Path
+from typing import List, Dict, Any
+from dataclasses import dataclass
+
+# Add paths
+sys.path.insert(0, str(Path(__file__).parent.parent))
+sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src"))
+
+
+@dataclass
+class BenchmarkResult:
+ mode: str
+ n_requests: int
+ total_time_s: float
+ avg_latency_ms: float
+ min_latency_ms: float
+ max_latency_ms: float
+ throughput_req_per_s: float
+ throughput_conv_per_hr: float # Estimated conversations per hour
+ errors: int
+
+
+def benchmark_transformers(
+ model_path: str,
+ n_requests: int = 10,
+ device: str = "cuda:0",
+) -> BenchmarkResult:
+ """Benchmark raw transformers inference."""
+ import torch
+ from transformers import AutoModelForCausalLM, AutoTokenizer
+
+ print(f"Loading model from {model_path}...")
+ load_start = time.time()
+
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ torch_dtype=torch.bfloat16,
+ device_map=device,
+ )
+ if tokenizer.pad_token_id is None:
+ tokenizer.pad_token = tokenizer.eos_token
+
+ load_time = time.time() - load_start
+ print(f"Model loaded in {load_time:.1f}s")
+
+ # Test prompt (simulating a typical user simulator turn)
+ test_messages = [
+ {"role": "system", "content": "You are a user simulator. Output JSON with reasoning, draft_answer, should_terminate, and response fields."},
+ {"role": "user", "content": "The agent said: 'Hello, how can I help you today?' Respond as the user."},
+ ]
+
+ prompt = tokenizer.apply_chat_template(test_messages, tokenize=False, add_generation_prompt=True)
+
+ latencies = []
+ errors = 0
+
+ print(f"Running {n_requests} inference requests...")
+ start_time = time.time()
+
+ for i in range(n_requests):
+ try:
+ req_start = time.time()
+
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096)
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
+
+ with torch.no_grad():
+ outputs = model.generate(
+ **inputs,
+ max_new_tokens=256,
+ do_sample=True,
+ temperature=0.7,
+ top_p=0.9,
+ eos_token_id=tokenizer.eos_token_id,
+ pad_token_id=tokenizer.pad_token_id,
+ )
+
+ # Decode output
+ input_len = inputs["input_ids"].shape[1]
+ gen_ids = outputs[0][input_len:]
+ response = tokenizer.decode(gen_ids, skip_special_tokens=True)
+
+ latency_ms = (time.time() - req_start) * 1000
+ latencies.append(latency_ms)
+
+ if (i + 1) % 5 == 0:
+ print(f" Completed {i + 1}/{n_requests}, last latency: {latency_ms:.0f}ms")
+
+ except Exception as e:
+ errors += 1
+ print(f" Error on request {i + 1}: {e}")
+
+ total_time = time.time() - start_time
+
+ if not latencies:
+ return BenchmarkResult(
+ mode="transformers",
+ n_requests=n_requests,
+ total_time_s=total_time,
+ avg_latency_ms=0,
+ min_latency_ms=0,
+ max_latency_ms=0,
+ throughput_req_per_s=0,
+ throughput_conv_per_hr=0,
+ errors=errors,
+ )
+
+ avg_latency = sum(latencies) / len(latencies)
+ # Estimate: ~10 turns per conversation, so conv/hr = (req/s) * 3600 / 10
+ throughput = len(latencies) / total_time
+ conv_per_hr = throughput * 3600 / 10
+
+ return BenchmarkResult(
+ mode="transformers",
+ n_requests=n_requests,
+ total_time_s=total_time,
+ avg_latency_ms=avg_latency,
+ min_latency_ms=min(latencies),
+ max_latency_ms=max(latencies),
+ throughput_req_per_s=throughput,
+ throughput_conv_per_hr=conv_per_hr,
+ errors=errors,
+ )
+
+
+def benchmark_vllm(
+ base_url: str = "http://localhost:8003/v1",
+ n_requests: int = 10,
+ concurrent: bool = False,
+ n_workers: int = 4,
+) -> BenchmarkResult:
+ """Benchmark vLLM server inference."""
+ from utils.vllm_client import VLLMClient
+
+ client = VLLMClient(base_url=base_url)
+
+ # Check health
+ if not client.health_check():
+ print(f"ERROR: vLLM server at {base_url} is not responding")
+ return BenchmarkResult(
+ mode="vllm",
+ n_requests=n_requests,
+ total_time_s=0,
+ avg_latency_ms=0,
+ min_latency_ms=0,
+ max_latency_ms=0,
+ throughput_req_per_s=0,
+ throughput_conv_per_hr=0,
+ errors=n_requests,
+ )
+
+ print(f"vLLM server healthy: {client.get_model_info()}")
+
+ # Test messages
+ test_messages = [
+ {"role": "system", "content": "You are a user simulator. Output JSON with reasoning, draft_answer, should_terminate, and response fields."},
+ {"role": "user", "content": "The agent said: 'Hello, how can I help you today?' Respond as the user."},
+ ]
+
+ latencies = []
+ errors = 0
+
+ print(f"Running {n_requests} inference requests (concurrent={concurrent})...")
+ start_time = time.time()
+
+ if concurrent:
+ from concurrent.futures import ThreadPoolExecutor, as_completed
+
+ with ThreadPoolExecutor(max_workers=n_workers) as executor:
+ futures = [
+ executor.submit(client.chat, test_messages, 256, 0.7)
+ for _ in range(n_requests)
+ ]
+ for i, future in enumerate(as_completed(futures)):
+ try:
+ result = future.result()
+ latencies.append(result["latency_ms"])
+ if (i + 1) % 10 == 0:
+ print(f" Completed {i + 1}/{n_requests}")
+ except Exception as e:
+ errors += 1
+ print(f" Error: {e}")
+ else:
+ for i in range(n_requests):
+ try:
+ result = client.chat(test_messages, 256, 0.7)
+ latencies.append(result["latency_ms"])
+
+ if (i + 1) % 5 == 0:
+ print(f" Completed {i + 1}/{n_requests}, last latency: {result['latency_ms']:.0f}ms")
+
+ except Exception as e:
+ errors += 1
+ print(f" Error on request {i + 1}: {e}")
+
+ total_time = time.time() - start_time
+
+ if not latencies:
+ return BenchmarkResult(
+ mode="vllm" + ("_concurrent" if concurrent else ""),
+ n_requests=n_requests,
+ total_time_s=total_time,
+ avg_latency_ms=0,
+ min_latency_ms=0,
+ max_latency_ms=0,
+ throughput_req_per_s=0,
+ throughput_conv_per_hr=0,
+ errors=errors,
+ )
+
+ avg_latency = sum(latencies) / len(latencies)
+ throughput = len(latencies) / total_time
+ conv_per_hr = throughput * 3600 / 10
+
+ return BenchmarkResult(
+ mode="vllm" + ("_concurrent" if concurrent else ""),
+ n_requests=n_requests,
+ total_time_s=total_time,
+ avg_latency_ms=avg_latency,
+ min_latency_ms=min(latencies),
+ max_latency_ms=max(latencies),
+ throughput_req_per_s=throughput,
+ throughput_conv_per_hr=conv_per_hr,
+ errors=errors,
+ )
+
+
+def benchmark_full_conversation(
+ vllm_url_70b: str,
+ vllm_url_8b: str,
+ n_conversations: int = 5,
+ max_turns: int = 10,
+) -> Dict[str, Any]:
+ """
+ Benchmark a full multi-turn conversation with user simulator and agent.
+ This simulates the actual experiment loop.
+ """
+ from utils.vllm_client import VLLMClient, VLLMUserSimulator, VLLMAgentAdapter
+
+ user_client = VLLMClient(base_url=vllm_url_70b)
+ agent_client = VLLMClient(base_url=vllm_url_8b)
+
+ if not user_client.health_check():
+ print(f"ERROR: 70B server at {vllm_url_70b} not responding")
+ return {"error": "70B server not available"}
+
+ if not agent_client.health_check():
+ print(f"ERROR: 8B server at {vllm_url_8b} not responding")
+ return {"error": "8B server not available"}
+
+ print(f"Running {n_conversations} full conversations (max {max_turns} turns each)...")
+
+ conversation_times = []
+ total_turns = 0
+
+ start_time = time.time()
+
+ for conv_idx in range(n_conversations):
+ conv_start = time.time()
+
+ # Create user simulator
+ user_sim = VLLMUserSimulator(
+ problem="What is 2 + 2? Explain your reasoning step by step.",
+ user_persona="A student learning math",
+ user_preferences="- I prefer step-by-step explanations\n- Always show your work",
+ vllm_client=user_client,
+ )
+
+ # Create agent
+ agent = VLLMAgentAdapter(
+ vllm_client=agent_client,
+ system_prompt="You are a helpful math tutor. Explain concepts clearly."
+ )
+
+ # Run conversation
+ conversation = [{"role": "assistant", "content": "How can I help you today?"}]
+
+ for turn in range(max_turns):
+ # User turn
+ user_response = user_sim.generate_user_response(conversation)
+ if user_response is None:
+ break
+
+ conversation.append({"role": "user", "content": user_response["response"]})
+
+ if user_response.get("should_terminate", False):
+ break
+
+ # Agent turn
+ agent_response = agent.generate_response(user_response["response"])
+ conversation.append({"role": "assistant", "content": agent_response["response"]})
+
+ total_turns += 1
+
+ conv_time = time.time() - conv_start
+ conversation_times.append(conv_time)
+ print(f" Conversation {conv_idx + 1}/{n_conversations}: {len(conversation)} messages, {conv_time:.1f}s")
+
+ total_time = time.time() - start_time
+
+ return {
+ "n_conversations": n_conversations,
+ "total_turns": total_turns,
+ "total_time_s": total_time,
+ "avg_conv_time_s": sum(conversation_times) / len(conversation_times) if conversation_times else 0,
+ "throughput_conv_per_hr": n_conversations / total_time * 3600,
+ "throughput_turns_per_hr": total_turns / total_time * 3600,
+ }
+
+
+def print_results(results: List[BenchmarkResult]):
+ """Print benchmark results in a nice table."""
+ print("\n" + "=" * 80)
+ print("BENCHMARK RESULTS")
+ print("=" * 80)
+
+ print(f"\n{'Mode':<20} {'Requests':<10} {'Avg Latency':<12} {'Throughput':<15} {'Conv/hr':<12} {'Errors':<8}")
+ print("-" * 80)
+
+ for r in results:
+ print(f"{r.mode:<20} {r.n_requests:<10} {r.avg_latency_ms:>8.0f}ms {r.throughput_req_per_s:>10.2f}/s {r.throughput_conv_per_hr:>8.0f} {r.errors:<8}")
+
+ print("-" * 80)
+
+ # Compare speedup
+ if len(results) >= 2:
+ transformers_result = next((r for r in results if r.mode == "transformers"), None)
+ vllm_result = next((r for r in results if "vllm" in r.mode and r.throughput_req_per_s > 0), None)
+
+ if transformers_result and vllm_result and transformers_result.throughput_req_per_s > 0:
+ speedup = vllm_result.throughput_req_per_s / transformers_result.throughput_req_per_s
+ print(f"\nvLLM speedup over transformers: {speedup:.1f}x")
+
+ # Target comparison
+ target_conv_per_hr = 2000
+ for r in results:
+ if r.throughput_conv_per_hr > 0:
+ ratio = r.throughput_conv_per_hr / target_conv_per_hr
+ status = "✓" if ratio >= 0.5 else "✗"
+ print(f"{status} {r.mode}: {r.throughput_conv_per_hr:.0f} conv/hr ({ratio:.1%} of paper's 2000 conv/hr)")
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Benchmark inference speed")
+ parser.add_argument("--mode", choices=["transformers", "vllm", "both", "conversation"], default="vllm")
+ parser.add_argument("--model", type=str, default="/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/models/llama-3.1-8b-instruct",
+ help="Model path for transformers benchmark")
+ parser.add_argument("--url", type=str, default="http://localhost:8003/v1",
+ help="vLLM server URL")
+ parser.add_argument("--url-70b", type=str, default="http://localhost:8004/v1",
+ help="vLLM server URL for 70B model (user simulator)")
+ parser.add_argument("--url-8b", type=str, default="http://localhost:8003/v1",
+ help="vLLM server URL for 8B model (agent)")
+ parser.add_argument("-n", type=int, default=20, help="Number of requests")
+ parser.add_argument("--concurrent", action="store_true", help="Run vLLM benchmark with concurrent requests")
+ parser.add_argument("--device", type=str, default="cuda:0", help="Device for transformers")
+
+ args = parser.parse_args()
+
+ results = []
+
+ if args.mode == "conversation":
+ # Full conversation benchmark
+ conv_results = benchmark_full_conversation(
+ args.url_70b,
+ args.url_8b,
+ n_conversations=args.n,
+ )
+ print("\n" + "=" * 80)
+ print("FULL CONVERSATION BENCHMARK")
+ print("=" * 80)
+ print(json.dumps(conv_results, indent=2))
+
+ if "throughput_conv_per_hr" in conv_results:
+ target = 2000
+ actual = conv_results["throughput_conv_per_hr"]
+ print(f"\nTarget: {target} conv/hr (paper)")
+ print(f"Actual: {actual:.0f} conv/hr ({actual/target:.1%} of target)")
+
+ else:
+ if args.mode in ["transformers", "both"]:
+ print("\n" + "=" * 40)
+ print("TRANSFORMERS BENCHMARK")
+ print("=" * 40)
+ result = benchmark_transformers(args.model, args.n, args.device)
+ results.append(result)
+
+ if args.mode in ["vllm", "both"]:
+ print("\n" + "=" * 40)
+ print("vLLM BENCHMARK (sequential)")
+ print("=" * 40)
+ result = benchmark_vllm(args.url, args.n, concurrent=False)
+ results.append(result)
+
+ if args.concurrent:
+ print("\n" + "=" * 40)
+ print("vLLM BENCHMARK (concurrent)")
+ print("=" * 40)
+ result = benchmark_vllm(args.url, args.n, concurrent=True, n_workers=4)
+ results.append(result)
+
+ print_results(results)
+
+
+if __name__ == "__main__":
+ main()