diff options
Diffstat (limited to 'collaborativeagents/scripts/test_batch_50.py')
| -rw-r--r-- | collaborativeagents/scripts/test_batch_50.py | 98 |
1 files changed, 98 insertions, 0 deletions
diff --git a/collaborativeagents/scripts/test_batch_50.py b/collaborativeagents/scripts/test_batch_50.py new file mode 100644 index 0000000..b3f1c37 --- /dev/null +++ b/collaborativeagents/scripts/test_batch_50.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python +""" +Test batch processing with 50 conversations (matching paper's setup). +""" + +import sys +import time +sys.path.insert(0, '/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/collaborativeagents') + +from agents.batch_vllm_agent import BatchConversationGenerator + +def main(): + user_url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:8004/v1" + agent_url = sys.argv[2] if len(sys.argv) > 2 else "http://localhost:8003/v1" + batch_size = int(sys.argv[3]) if len(sys.argv) > 3 else 50 + max_turns = int(sys.argv[4]) if len(sys.argv) > 4 else 10 + + print(f"\n{'='*60}") + print(f"Batch Processing Test (Paper Configuration)") + print(f"{'='*60}") + print(f"Batch size: {batch_size}") + print(f"Max turns: {max_turns}") + print(f"User URL: {user_url}") + print(f"Agent URL: {agent_url}") + print() + + # Create samples (simulating MMLU-style questions) + samples = [ + { + "problem": f"Question {i+1}: What is the capital of country number {i+1}? " + f"A) City A B) City B C) City C D) City D. " + f"Please explain your reasoning step by step.", + "solution": "City A" + } + for i in range(batch_size) + ] + + generator = BatchConversationGenerator( + user_vllm_url=user_url, + agent_vllm_url=agent_url, + max_turns=max_turns, + user_max_tokens=512, + agent_max_tokens=1024, + temperature=0.7, + ) + + print(f"Starting batch generation of {batch_size} conversations...") + print(f"Expected: ~{batch_size * max_turns * 2} total LLM calls batched into ~{max_turns * 2} batch requests") + print() + + start = time.time() + results = generator.generate_batch( + samples=samples, + user_persona="A curious student seeking help with exam questions.", + user_preferences="1. Explain your reasoning step by step\n2. Be concise but thorough\n3. Highlight the key concept", + agent_system_prompt="You are a helpful tutor. Answer questions clearly and explain your reasoning.", + ) + elapsed = time.time() - start + + successes = sum(1 for r in results if r is not None) + total_turns = sum( + len(r['conversation']) // 2 if r else 0 + for r in results + ) + + print(f"\n{'='*60}") + print(f"RESULTS") + print(f"{'='*60}") + print(f"Batch size: {batch_size}") + print(f"Max turns: {max_turns}") + print(f"Successes: {successes}/{batch_size}") + print(f"Total conversation turns: {total_turns}") + print(f"Time: {elapsed:.1f}s") + print() + print(f"Throughput: {successes * 3600 / elapsed:.0f} conversations/hr") + print(f"Sessions/hr (3 sessions/profile): {successes * 3 * 3600 / elapsed:.0f}") + print() + + # Compare with paper's claimed performance + paper_sessions = 2000 # sessions per hour claimed + our_sessions = successes * 3 * 3600 / elapsed + print(f"Paper's claimed throughput: ~{paper_sessions} sessions/hr") + print(f"Our throughput: {our_sessions:.0f} sessions/hr") + print(f"Ratio: {our_sessions / paper_sessions * 100:.1f}% of paper's performance") + print() + + # Show sample conversation + if results[0]: + print(f"Sample conversation (first 4 messages):") + for msg in results[0]['conversation'][:4]: + role = msg['role'].upper() + content = msg['content'][:100] + "..." if len(msg['content']) > 100 else msg['content'] + print(f" [{role}]: {content}") + + return results + +if __name__ == "__main__": + main() |
