From dc801c07cf38b0c495686463e6ca6f871a64440e Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Tue, 27 Jan 2026 09:57:37 -0600 Subject: Add collaborativeagents module and update gitignore - Add collaborativeagents subproject with adapters, agents, and evaluation modules - Update .gitignore to exclude large binary files (.whl, .tar), wandb logs, and results Co-Authored-By: Claude Opus 4.5 --- collaborativeagents/scripts/benchmark_inference.py | 429 +++++++++++++++++++++ 1 file changed, 429 insertions(+) create mode 100755 collaborativeagents/scripts/benchmark_inference.py (limited to 'collaborativeagents/scripts/benchmark_inference.py') diff --git a/collaborativeagents/scripts/benchmark_inference.py b/collaborativeagents/scripts/benchmark_inference.py new file mode 100755 index 0000000..6a2ee13 --- /dev/null +++ b/collaborativeagents/scripts/benchmark_inference.py @@ -0,0 +1,429 @@ +#!/usr/bin/env python3 +""" +Benchmark inference speed: Transformers vs vLLM. + +This script helps diagnose the 100x slowdown issue by comparing: +1. Raw transformers inference (current implementation) +2. vLLM server inference (target implementation) + +Usage: + # First, start vLLM server: + # CUDA_VISIBLE_DEVICES=0 vllm serve /path/to/model --port 8003 + + # Then run benchmark: + python benchmark_inference.py --mode both --n 20 + python benchmark_inference.py --mode vllm --url http://localhost:8003/v1 --n 50 + python benchmark_inference.py --mode transformers --model /path/to/model --n 10 +""" + +import argparse +import json +import time +import sys +from pathlib import Path +from typing import List, Dict, Any +from dataclasses import dataclass + +# Add paths +sys.path.insert(0, str(Path(__file__).parent.parent)) +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) + + +@dataclass +class BenchmarkResult: + mode: str + n_requests: int + total_time_s: float + avg_latency_ms: float + min_latency_ms: float + max_latency_ms: float + throughput_req_per_s: float + throughput_conv_per_hr: float # Estimated conversations per hour + errors: int + + +def benchmark_transformers( + model_path: str, + n_requests: int = 10, + device: str = "cuda:0", +) -> BenchmarkResult: + """Benchmark raw transformers inference.""" + import torch + from transformers import AutoModelForCausalLM, AutoTokenizer + + print(f"Loading model from {model_path}...") + load_start = time.time() + + tokenizer = AutoTokenizer.from_pretrained(model_path) + model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + device_map=device, + ) + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + + load_time = time.time() - load_start + print(f"Model loaded in {load_time:.1f}s") + + # Test prompt (simulating a typical user simulator turn) + test_messages = [ + {"role": "system", "content": "You are a user simulator. Output JSON with reasoning, draft_answer, should_terminate, and response fields."}, + {"role": "user", "content": "The agent said: 'Hello, how can I help you today?' Respond as the user."}, + ] + + prompt = tokenizer.apply_chat_template(test_messages, tokenize=False, add_generation_prompt=True) + + latencies = [] + errors = 0 + + print(f"Running {n_requests} inference requests...") + start_time = time.time() + + for i in range(n_requests): + try: + req_start = time.time() + + inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096) + inputs = {k: v.to(model.device) for k, v in inputs.items()} + + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=256, + do_sample=True, + temperature=0.7, + top_p=0.9, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + ) + + # Decode output + input_len = inputs["input_ids"].shape[1] + gen_ids = outputs[0][input_len:] + response = tokenizer.decode(gen_ids, skip_special_tokens=True) + + latency_ms = (time.time() - req_start) * 1000 + latencies.append(latency_ms) + + if (i + 1) % 5 == 0: + print(f" Completed {i + 1}/{n_requests}, last latency: {latency_ms:.0f}ms") + + except Exception as e: + errors += 1 + print(f" Error on request {i + 1}: {e}") + + total_time = time.time() - start_time + + if not latencies: + return BenchmarkResult( + mode="transformers", + n_requests=n_requests, + total_time_s=total_time, + avg_latency_ms=0, + min_latency_ms=0, + max_latency_ms=0, + throughput_req_per_s=0, + throughput_conv_per_hr=0, + errors=errors, + ) + + avg_latency = sum(latencies) / len(latencies) + # Estimate: ~10 turns per conversation, so conv/hr = (req/s) * 3600 / 10 + throughput = len(latencies) / total_time + conv_per_hr = throughput * 3600 / 10 + + return BenchmarkResult( + mode="transformers", + n_requests=n_requests, + total_time_s=total_time, + avg_latency_ms=avg_latency, + min_latency_ms=min(latencies), + max_latency_ms=max(latencies), + throughput_req_per_s=throughput, + throughput_conv_per_hr=conv_per_hr, + errors=errors, + ) + + +def benchmark_vllm( + base_url: str = "http://localhost:8003/v1", + n_requests: int = 10, + concurrent: bool = False, + n_workers: int = 4, +) -> BenchmarkResult: + """Benchmark vLLM server inference.""" + from utils.vllm_client import VLLMClient + + client = VLLMClient(base_url=base_url) + + # Check health + if not client.health_check(): + print(f"ERROR: vLLM server at {base_url} is not responding") + return BenchmarkResult( + mode="vllm", + n_requests=n_requests, + total_time_s=0, + avg_latency_ms=0, + min_latency_ms=0, + max_latency_ms=0, + throughput_req_per_s=0, + throughput_conv_per_hr=0, + errors=n_requests, + ) + + print(f"vLLM server healthy: {client.get_model_info()}") + + # Test messages + test_messages = [ + {"role": "system", "content": "You are a user simulator. Output JSON with reasoning, draft_answer, should_terminate, and response fields."}, + {"role": "user", "content": "The agent said: 'Hello, how can I help you today?' Respond as the user."}, + ] + + latencies = [] + errors = 0 + + print(f"Running {n_requests} inference requests (concurrent={concurrent})...") + start_time = time.time() + + if concurrent: + from concurrent.futures import ThreadPoolExecutor, as_completed + + with ThreadPoolExecutor(max_workers=n_workers) as executor: + futures = [ + executor.submit(client.chat, test_messages, 256, 0.7) + for _ in range(n_requests) + ] + for i, future in enumerate(as_completed(futures)): + try: + result = future.result() + latencies.append(result["latency_ms"]) + if (i + 1) % 10 == 0: + print(f" Completed {i + 1}/{n_requests}") + except Exception as e: + errors += 1 + print(f" Error: {e}") + else: + for i in range(n_requests): + try: + result = client.chat(test_messages, 256, 0.7) + latencies.append(result["latency_ms"]) + + if (i + 1) % 5 == 0: + print(f" Completed {i + 1}/{n_requests}, last latency: {result['latency_ms']:.0f}ms") + + except Exception as e: + errors += 1 + print(f" Error on request {i + 1}: {e}") + + total_time = time.time() - start_time + + if not latencies: + return BenchmarkResult( + mode="vllm" + ("_concurrent" if concurrent else ""), + n_requests=n_requests, + total_time_s=total_time, + avg_latency_ms=0, + min_latency_ms=0, + max_latency_ms=0, + throughput_req_per_s=0, + throughput_conv_per_hr=0, + errors=errors, + ) + + avg_latency = sum(latencies) / len(latencies) + throughput = len(latencies) / total_time + conv_per_hr = throughput * 3600 / 10 + + return BenchmarkResult( + mode="vllm" + ("_concurrent" if concurrent else ""), + n_requests=n_requests, + total_time_s=total_time, + avg_latency_ms=avg_latency, + min_latency_ms=min(latencies), + max_latency_ms=max(latencies), + throughput_req_per_s=throughput, + throughput_conv_per_hr=conv_per_hr, + errors=errors, + ) + + +def benchmark_full_conversation( + vllm_url_70b: str, + vllm_url_8b: str, + n_conversations: int = 5, + max_turns: int = 10, +) -> Dict[str, Any]: + """ + Benchmark a full multi-turn conversation with user simulator and agent. + This simulates the actual experiment loop. + """ + from utils.vllm_client import VLLMClient, VLLMUserSimulator, VLLMAgentAdapter + + user_client = VLLMClient(base_url=vllm_url_70b) + agent_client = VLLMClient(base_url=vllm_url_8b) + + if not user_client.health_check(): + print(f"ERROR: 70B server at {vllm_url_70b} not responding") + return {"error": "70B server not available"} + + if not agent_client.health_check(): + print(f"ERROR: 8B server at {vllm_url_8b} not responding") + return {"error": "8B server not available"} + + print(f"Running {n_conversations} full conversations (max {max_turns} turns each)...") + + conversation_times = [] + total_turns = 0 + + start_time = time.time() + + for conv_idx in range(n_conversations): + conv_start = time.time() + + # Create user simulator + user_sim = VLLMUserSimulator( + problem="What is 2 + 2? Explain your reasoning step by step.", + user_persona="A student learning math", + user_preferences="- I prefer step-by-step explanations\n- Always show your work", + vllm_client=user_client, + ) + + # Create agent + agent = VLLMAgentAdapter( + vllm_client=agent_client, + system_prompt="You are a helpful math tutor. Explain concepts clearly." + ) + + # Run conversation + conversation = [{"role": "assistant", "content": "How can I help you today?"}] + + for turn in range(max_turns): + # User turn + user_response = user_sim.generate_user_response(conversation) + if user_response is None: + break + + conversation.append({"role": "user", "content": user_response["response"]}) + + if user_response.get("should_terminate", False): + break + + # Agent turn + agent_response = agent.generate_response(user_response["response"]) + conversation.append({"role": "assistant", "content": agent_response["response"]}) + + total_turns += 1 + + conv_time = time.time() - conv_start + conversation_times.append(conv_time) + print(f" Conversation {conv_idx + 1}/{n_conversations}: {len(conversation)} messages, {conv_time:.1f}s") + + total_time = time.time() - start_time + + return { + "n_conversations": n_conversations, + "total_turns": total_turns, + "total_time_s": total_time, + "avg_conv_time_s": sum(conversation_times) / len(conversation_times) if conversation_times else 0, + "throughput_conv_per_hr": n_conversations / total_time * 3600, + "throughput_turns_per_hr": total_turns / total_time * 3600, + } + + +def print_results(results: List[BenchmarkResult]): + """Print benchmark results in a nice table.""" + print("\n" + "=" * 80) + print("BENCHMARK RESULTS") + print("=" * 80) + + print(f"\n{'Mode':<20} {'Requests':<10} {'Avg Latency':<12} {'Throughput':<15} {'Conv/hr':<12} {'Errors':<8}") + print("-" * 80) + + for r in results: + print(f"{r.mode:<20} {r.n_requests:<10} {r.avg_latency_ms:>8.0f}ms {r.throughput_req_per_s:>10.2f}/s {r.throughput_conv_per_hr:>8.0f} {r.errors:<8}") + + print("-" * 80) + + # Compare speedup + if len(results) >= 2: + transformers_result = next((r for r in results if r.mode == "transformers"), None) + vllm_result = next((r for r in results if "vllm" in r.mode and r.throughput_req_per_s > 0), None) + + if transformers_result and vllm_result and transformers_result.throughput_req_per_s > 0: + speedup = vllm_result.throughput_req_per_s / transformers_result.throughput_req_per_s + print(f"\nvLLM speedup over transformers: {speedup:.1f}x") + + # Target comparison + target_conv_per_hr = 2000 + for r in results: + if r.throughput_conv_per_hr > 0: + ratio = r.throughput_conv_per_hr / target_conv_per_hr + status = "✓" if ratio >= 0.5 else "✗" + print(f"{status} {r.mode}: {r.throughput_conv_per_hr:.0f} conv/hr ({ratio:.1%} of paper's 2000 conv/hr)") + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark inference speed") + parser.add_argument("--mode", choices=["transformers", "vllm", "both", "conversation"], default="vllm") + parser.add_argument("--model", type=str, default="/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/models/llama-3.1-8b-instruct", + help="Model path for transformers benchmark") + parser.add_argument("--url", type=str, default="http://localhost:8003/v1", + help="vLLM server URL") + parser.add_argument("--url-70b", type=str, default="http://localhost:8004/v1", + help="vLLM server URL for 70B model (user simulator)") + parser.add_argument("--url-8b", type=str, default="http://localhost:8003/v1", + help="vLLM server URL for 8B model (agent)") + parser.add_argument("-n", type=int, default=20, help="Number of requests") + parser.add_argument("--concurrent", action="store_true", help="Run vLLM benchmark with concurrent requests") + parser.add_argument("--device", type=str, default="cuda:0", help="Device for transformers") + + args = parser.parse_args() + + results = [] + + if args.mode == "conversation": + # Full conversation benchmark + conv_results = benchmark_full_conversation( + args.url_70b, + args.url_8b, + n_conversations=args.n, + ) + print("\n" + "=" * 80) + print("FULL CONVERSATION BENCHMARK") + print("=" * 80) + print(json.dumps(conv_results, indent=2)) + + if "throughput_conv_per_hr" in conv_results: + target = 2000 + actual = conv_results["throughput_conv_per_hr"] + print(f"\nTarget: {target} conv/hr (paper)") + print(f"Actual: {actual:.0f} conv/hr ({actual/target:.1%} of target)") + + else: + if args.mode in ["transformers", "both"]: + print("\n" + "=" * 40) + print("TRANSFORMERS BENCHMARK") + print("=" * 40) + result = benchmark_transformers(args.model, args.n, args.device) + results.append(result) + + if args.mode in ["vllm", "both"]: + print("\n" + "=" * 40) + print("vLLM BENCHMARK (sequential)") + print("=" * 40) + result = benchmark_vllm(args.url, args.n, concurrent=False) + results.append(result) + + if args.concurrent: + print("\n" + "=" * 40) + print("vLLM BENCHMARK (concurrent)") + print("=" * 40) + result = benchmark_vllm(args.url, args.n, concurrent=True, n_workers=4) + results.append(result) + + print_results(results) + + +if __name__ == "__main__": + main() -- cgit v1.2.3