summaryrefslogtreecommitdiff
path: root/collaborativeagents/scripts/test_batch_50.py
diff options
context:
space:
mode:
Diffstat (limited to 'collaborativeagents/scripts/test_batch_50.py')
-rw-r--r--collaborativeagents/scripts/test_batch_50.py98
1 files changed, 98 insertions, 0 deletions
diff --git a/collaborativeagents/scripts/test_batch_50.py b/collaborativeagents/scripts/test_batch_50.py
new file mode 100644
index 0000000..b3f1c37
--- /dev/null
+++ b/collaborativeagents/scripts/test_batch_50.py
@@ -0,0 +1,98 @@
+#!/usr/bin/env python
+"""
+Test batch processing with 50 conversations (matching paper's setup).
+"""
+
+import sys
+import time
+sys.path.insert(0, '/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/collaborativeagents')
+
+from agents.batch_vllm_agent import BatchConversationGenerator
+
+def main():
+ user_url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:8004/v1"
+ agent_url = sys.argv[2] if len(sys.argv) > 2 else "http://localhost:8003/v1"
+ batch_size = int(sys.argv[3]) if len(sys.argv) > 3 else 50
+ max_turns = int(sys.argv[4]) if len(sys.argv) > 4 else 10
+
+ print(f"\n{'='*60}")
+ print(f"Batch Processing Test (Paper Configuration)")
+ print(f"{'='*60}")
+ print(f"Batch size: {batch_size}")
+ print(f"Max turns: {max_turns}")
+ print(f"User URL: {user_url}")
+ print(f"Agent URL: {agent_url}")
+ print()
+
+ # Create samples (simulating MMLU-style questions)
+ samples = [
+ {
+ "problem": f"Question {i+1}: What is the capital of country number {i+1}? "
+ f"A) City A B) City B C) City C D) City D. "
+ f"Please explain your reasoning step by step.",
+ "solution": "City A"
+ }
+ for i in range(batch_size)
+ ]
+
+ generator = BatchConversationGenerator(
+ user_vllm_url=user_url,
+ agent_vllm_url=agent_url,
+ max_turns=max_turns,
+ user_max_tokens=512,
+ agent_max_tokens=1024,
+ temperature=0.7,
+ )
+
+ print(f"Starting batch generation of {batch_size} conversations...")
+ print(f"Expected: ~{batch_size * max_turns * 2} total LLM calls batched into ~{max_turns * 2} batch requests")
+ print()
+
+ start = time.time()
+ results = generator.generate_batch(
+ samples=samples,
+ user_persona="A curious student seeking help with exam questions.",
+ user_preferences="1. Explain your reasoning step by step\n2. Be concise but thorough\n3. Highlight the key concept",
+ agent_system_prompt="You are a helpful tutor. Answer questions clearly and explain your reasoning.",
+ )
+ elapsed = time.time() - start
+
+ successes = sum(1 for r in results if r is not None)
+ total_turns = sum(
+ len(r['conversation']) // 2 if r else 0
+ for r in results
+ )
+
+ print(f"\n{'='*60}")
+ print(f"RESULTS")
+ print(f"{'='*60}")
+ print(f"Batch size: {batch_size}")
+ print(f"Max turns: {max_turns}")
+ print(f"Successes: {successes}/{batch_size}")
+ print(f"Total conversation turns: {total_turns}")
+ print(f"Time: {elapsed:.1f}s")
+ print()
+ print(f"Throughput: {successes * 3600 / elapsed:.0f} conversations/hr")
+ print(f"Sessions/hr (3 sessions/profile): {successes * 3 * 3600 / elapsed:.0f}")
+ print()
+
+ # Compare with paper's claimed performance
+ paper_sessions = 2000 # sessions per hour claimed
+ our_sessions = successes * 3 * 3600 / elapsed
+ print(f"Paper's claimed throughput: ~{paper_sessions} sessions/hr")
+ print(f"Our throughput: {our_sessions:.0f} sessions/hr")
+ print(f"Ratio: {our_sessions / paper_sessions * 100:.1f}% of paper's performance")
+ print()
+
+ # Show sample conversation
+ if results[0]:
+ print(f"Sample conversation (first 4 messages):")
+ for msg in results[0]['conversation'][:4]:
+ role = msg['role'].upper()
+ content = msg['content'][:100] + "..." if len(msg['content']) > 100 else msg['content']
+ print(f" [{role}]: {content}")
+
+ return results
+
+if __name__ == "__main__":
+ main()