summaryrefslogtreecommitdiff
path: root/collaborativeagents/scripts/test_batch_50.py
blob: b3f1c37907b5052c8eeaae96cf6d041bd87fd082 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#!/usr/bin/env python
"""
Test batch processing with 50 conversations (matching paper's setup).
"""

import sys
import time
sys.path.insert(0, '/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/collaborativeagents')

from agents.batch_vllm_agent import BatchConversationGenerator

def main():
    user_url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:8004/v1"
    agent_url = sys.argv[2] if len(sys.argv) > 2 else "http://localhost:8003/v1"
    batch_size = int(sys.argv[3]) if len(sys.argv) > 3 else 50
    max_turns = int(sys.argv[4]) if len(sys.argv) > 4 else 10

    print(f"\n{'='*60}")
    print(f"Batch Processing Test (Paper Configuration)")
    print(f"{'='*60}")
    print(f"Batch size: {batch_size}")
    print(f"Max turns: {max_turns}")
    print(f"User URL: {user_url}")
    print(f"Agent URL: {agent_url}")
    print()

    # Create samples (simulating MMLU-style questions)
    samples = [
        {
            "problem": f"Question {i+1}: What is the capital of country number {i+1}? "
                      f"A) City A  B) City B  C) City C  D) City D. "
                      f"Please explain your reasoning step by step.",
            "solution": "City A"
        }
        for i in range(batch_size)
    ]

    generator = BatchConversationGenerator(
        user_vllm_url=user_url,
        agent_vllm_url=agent_url,
        max_turns=max_turns,
        user_max_tokens=512,
        agent_max_tokens=1024,
        temperature=0.7,
    )

    print(f"Starting batch generation of {batch_size} conversations...")
    print(f"Expected: ~{batch_size * max_turns * 2} total LLM calls batched into ~{max_turns * 2} batch requests")
    print()

    start = time.time()
    results = generator.generate_batch(
        samples=samples,
        user_persona="A curious student seeking help with exam questions.",
        user_preferences="1. Explain your reasoning step by step\n2. Be concise but thorough\n3. Highlight the key concept",
        agent_system_prompt="You are a helpful tutor. Answer questions clearly and explain your reasoning.",
    )
    elapsed = time.time() - start

    successes = sum(1 for r in results if r is not None)
    total_turns = sum(
        len(r['conversation']) // 2 if r else 0
        for r in results
    )

    print(f"\n{'='*60}")
    print(f"RESULTS")
    print(f"{'='*60}")
    print(f"Batch size: {batch_size}")
    print(f"Max turns: {max_turns}")
    print(f"Successes: {successes}/{batch_size}")
    print(f"Total conversation turns: {total_turns}")
    print(f"Time: {elapsed:.1f}s")
    print()
    print(f"Throughput: {successes * 3600 / elapsed:.0f} conversations/hr")
    print(f"Sessions/hr (3 sessions/profile): {successes * 3 * 3600 / elapsed:.0f}")
    print()

    # Compare with paper's claimed performance
    paper_sessions = 2000  # sessions per hour claimed
    our_sessions = successes * 3 * 3600 / elapsed
    print(f"Paper's claimed throughput: ~{paper_sessions} sessions/hr")
    print(f"Our throughput: {our_sessions:.0f} sessions/hr")
    print(f"Ratio: {our_sessions / paper_sessions * 100:.1f}% of paper's performance")
    print()

    # Show sample conversation
    if results[0]:
        print(f"Sample conversation (first 4 messages):")
        for msg in results[0]['conversation'][:4]:
            role = msg['role'].upper()
            content = msg['content'][:100] + "..." if len(msg['content']) > 100 else msg['content']
            print(f"  [{role}]: {content}")

    return results

if __name__ == "__main__":
    main()