1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
|
#!/usr/bin/env python
"""
Test batch processing with 50 conversations (matching paper's setup).
"""
import sys
import time
sys.path.insert(0, '/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/collaborativeagents')
from agents.batch_vllm_agent import BatchConversationGenerator
def main():
user_url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:8004/v1"
agent_url = sys.argv[2] if len(sys.argv) > 2 else "http://localhost:8003/v1"
batch_size = int(sys.argv[3]) if len(sys.argv) > 3 else 50
max_turns = int(sys.argv[4]) if len(sys.argv) > 4 else 10
print(f"\n{'='*60}")
print(f"Batch Processing Test (Paper Configuration)")
print(f"{'='*60}")
print(f"Batch size: {batch_size}")
print(f"Max turns: {max_turns}")
print(f"User URL: {user_url}")
print(f"Agent URL: {agent_url}")
print()
# Create samples (simulating MMLU-style questions)
samples = [
{
"problem": f"Question {i+1}: What is the capital of country number {i+1}? "
f"A) City A B) City B C) City C D) City D. "
f"Please explain your reasoning step by step.",
"solution": "City A"
}
for i in range(batch_size)
]
generator = BatchConversationGenerator(
user_vllm_url=user_url,
agent_vllm_url=agent_url,
max_turns=max_turns,
user_max_tokens=512,
agent_max_tokens=1024,
temperature=0.7,
)
print(f"Starting batch generation of {batch_size} conversations...")
print(f"Expected: ~{batch_size * max_turns * 2} total LLM calls batched into ~{max_turns * 2} batch requests")
print()
start = time.time()
results = generator.generate_batch(
samples=samples,
user_persona="A curious student seeking help with exam questions.",
user_preferences="1. Explain your reasoning step by step\n2. Be concise but thorough\n3. Highlight the key concept",
agent_system_prompt="You are a helpful tutor. Answer questions clearly and explain your reasoning.",
)
elapsed = time.time() - start
successes = sum(1 for r in results if r is not None)
total_turns = sum(
len(r['conversation']) // 2 if r else 0
for r in results
)
print(f"\n{'='*60}")
print(f"RESULTS")
print(f"{'='*60}")
print(f"Batch size: {batch_size}")
print(f"Max turns: {max_turns}")
print(f"Successes: {successes}/{batch_size}")
print(f"Total conversation turns: {total_turns}")
print(f"Time: {elapsed:.1f}s")
print()
print(f"Throughput: {successes * 3600 / elapsed:.0f} conversations/hr")
print(f"Sessions/hr (3 sessions/profile): {successes * 3 * 3600 / elapsed:.0f}")
print()
# Compare with paper's claimed performance
paper_sessions = 2000 # sessions per hour claimed
our_sessions = successes * 3 * 3600 / elapsed
print(f"Paper's claimed throughput: ~{paper_sessions} sessions/hr")
print(f"Our throughput: {our_sessions:.0f} sessions/hr")
print(f"Ratio: {our_sessions / paper_sessions * 100:.1f}% of paper's performance")
print()
# Show sample conversation
if results[0]:
print(f"Sample conversation (first 4 messages):")
for msg in results[0]['conversation'][:4]:
role = msg['role'].upper()
content = msg['content'][:100] + "..." if len(msg['content']) > 100 else msg['content']
print(f" [{role}]: {content}")
return results
if __name__ == "__main__":
main()
|