#!/usr/bin/env python """ Test batch processing with 50 conversations (matching paper's setup). """ import sys import time sys.path.insert(0, '/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/collaborativeagents') from agents.batch_vllm_agent import BatchConversationGenerator def main(): user_url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:8004/v1" agent_url = sys.argv[2] if len(sys.argv) > 2 else "http://localhost:8003/v1" batch_size = int(sys.argv[3]) if len(sys.argv) > 3 else 50 max_turns = int(sys.argv[4]) if len(sys.argv) > 4 else 10 print(f"\n{'='*60}") print(f"Batch Processing Test (Paper Configuration)") print(f"{'='*60}") print(f"Batch size: {batch_size}") print(f"Max turns: {max_turns}") print(f"User URL: {user_url}") print(f"Agent URL: {agent_url}") print() # Create samples (simulating MMLU-style questions) samples = [ { "problem": f"Question {i+1}: What is the capital of country number {i+1}? " f"A) City A B) City B C) City C D) City D. " f"Please explain your reasoning step by step.", "solution": "City A" } for i in range(batch_size) ] generator = BatchConversationGenerator( user_vllm_url=user_url, agent_vllm_url=agent_url, max_turns=max_turns, user_max_tokens=512, agent_max_tokens=1024, temperature=0.7, ) print(f"Starting batch generation of {batch_size} conversations...") print(f"Expected: ~{batch_size * max_turns * 2} total LLM calls batched into ~{max_turns * 2} batch requests") print() start = time.time() results = generator.generate_batch( samples=samples, user_persona="A curious student seeking help with exam questions.", user_preferences="1. Explain your reasoning step by step\n2. Be concise but thorough\n3. Highlight the key concept", agent_system_prompt="You are a helpful tutor. Answer questions clearly and explain your reasoning.", ) elapsed = time.time() - start successes = sum(1 for r in results if r is not None) total_turns = sum( len(r['conversation']) // 2 if r else 0 for r in results ) print(f"\n{'='*60}") print(f"RESULTS") print(f"{'='*60}") print(f"Batch size: {batch_size}") print(f"Max turns: {max_turns}") print(f"Successes: {successes}/{batch_size}") print(f"Total conversation turns: {total_turns}") print(f"Time: {elapsed:.1f}s") print() print(f"Throughput: {successes * 3600 / elapsed:.0f} conversations/hr") print(f"Sessions/hr (3 sessions/profile): {successes * 3 * 3600 / elapsed:.0f}") print() # Compare with paper's claimed performance paper_sessions = 2000 # sessions per hour claimed our_sessions = successes * 3 * 3600 / elapsed print(f"Paper's claimed throughput: ~{paper_sessions} sessions/hr") print(f"Our throughput: {our_sessions:.0f} sessions/hr") print(f"Ratio: {our_sessions / paper_sessions * 100:.1f}% of paper's performance") print() # Show sample conversation if results[0]: print(f"Sample conversation (first 4 messages):") for msg in results[0]['conversation'][:4]: role = msg['role'].upper() content = msg['content'][:100] + "..." if len(msg['content']) > 100 else msg['content'] print(f" [{role}]: {content}") return results if __name__ == "__main__": main()