summaryrefslogtreecommitdiff
path: root/scripts/test_local_reward_batch.py
blob: 7afb834875116ca0b879ca28a30064012567cfe7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
#!/usr/bin/env python3
"""
Test batch reward estimation with local LLM via vLLM.

Verifies that:
1. LocalLLMRewardClient works with vLLM server
2. Batch processing is efficient (concurrent, not sequential)
3. Results match expected labels
"""
import argparse
import asyncio
import sys
import os
import time

sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))

from personalization.feedback.schemas import TurnSample
from personalization.feedback.local_llm_reward import (
    LocalLLMRewardClient,
    LocalLLMRewardConfig,
)

# Test cases with expected labels
TEST_CASES = [
    {
        "name": "neg_constraint_restate - format preference",
        "query_t": "Explain how sorting works in Python. Please use bullet points.",
        "answer_t": "Sorting in Python can be done using the sorted() function or the list.sort() method. The sorted() function returns a new sorted list, while sort() modifies the list in place. Both accept a key parameter for custom sorting and a reverse parameter for descending order.",
        "query_t1": "I asked for bullet points. Can you reformat that with bullet points please?",
        "expected": "neg_constraint_restate",
    },
    {
        "name": "neg_constraint_restate - step by step",
        "query_t": "Solve x^2 - 5x + 6 = 0. Show step by step.",
        "answer_t": "The solutions are x = 2 and x = 3.",
        "query_t1": "As I said, I need to see the step-by-step solution, not just the answer.",
        "expected": "neg_constraint_restate",
    },
    {
        "name": "neg_correction - wrong answer",
        "query_t": "What is the capital of Australia?",
        "answer_t": "The capital of Australia is Sydney.",
        "query_t1": "That's incorrect. Sydney is not the capital of Australia.",
        "expected": "neg_correction",
    },
    {
        "name": "neg_confusion - unclear explanation",
        "query_t": "What is recursion in programming?",
        "answer_t": "Recursion is when a function calls itself in a self-similar way to solve problems.",
        "query_t1": "I'm confused. What do you mean by 'self-similar way'? Can you explain more clearly?",
        "expected": "neg_confusion",
    },
    {
        "name": "pos_praise - explicit thanks",
        "query_t": "How do I center a div in CSS?",
        "answer_t": "You can center a div using flexbox: set the parent to `display: flex; justify-content: center; align-items: center;`. Alternatively, use `margin: 0 auto;` for horizontal centering with a defined width.",
        "query_t1": "Perfect, thank you! That's exactly what I needed.",
        "expected": "pos_praise",
    },
    {
        "name": "pos_praise - great explanation",
        "query_t": "Explain how photosynthesis works.",
        "answer_t": "Photosynthesis is the process by which plants convert sunlight, water, and CO2 into glucose and oxygen. It occurs in chloroplasts, with light-dependent reactions in the thylakoid membrane and the Calvin cycle in the stroma.",
        "query_t1": "Great explanation! This really helped me understand the concept.",
        "expected": "pos_praise",
    },
    {
        "name": "pos_progress - follow-up question",
        "query_t": "What is a binary search tree?",
        "answer_t": "A binary search tree (BST) is a data structure where each node has at most two children. The left subtree contains only nodes with values less than the parent, and the right subtree only nodes with values greater than the parent.",
        "query_t1": "Interesting! How would I implement insertion into a BST?",
        "expected": "pos_progress",
    },
    {
        "name": "pos_progress - extension",
        "query_t": "How do I read a file in Python?",
        "answer_t": "Use `with open('file.txt', 'r') as f: content = f.read()`. The 'with' statement ensures the file is properly closed.",
        "query_t1": "Got it. What if I want to read it line by line instead?",
        "expected": "pos_progress",
    },
    {
        "name": "neutral - minimal response",
        "query_t": "What's 2 + 2?",
        "answer_t": "2 + 2 = 4",
        "query_t1": "Ok.",
        "expected": "neutral",
    },
    {
        "name": "topic_shift - new topic",
        "query_t": "What is the Pythagorean theorem?",
        "answer_t": "The Pythagorean theorem states that in a right triangle, a² + b² = c², where c is the hypotenuse.",
        "query_t1": "By the way, can you help me write a poem about nature?",
        "expected": "topic_shift",
    },
    {
        "name": "neg_constraint_restate - language preference",
        "query_t": "Explain machine learning in simple terms.",
        "answer_t": "Machine learning is a subset of artificial intelligence that uses statistical techniques to enable computers to learn from data. It involves training models on datasets to make predictions or decisions without being explicitly programmed for specific tasks.",
        "query_t1": "Remember I asked for simple terms? That's too technical. Can you explain like I'm 5?",
        "expected": "neg_constraint_restate",
    },
    {
        "name": "neg_correction - incomplete answer",
        "query_t": "List all the planets in our solar system.",
        "answer_t": "The planets are Mercury, Venus, Earth, Mars, Jupiter, and Saturn.",
        "query_t1": "You're missing Uranus and Neptune. There are 8 planets, not 6.",
        "expected": "neg_correction",
    },
]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--vllm-url",
        type=str,
        default="http://localhost:8005/v1",
        help="vLLM server URL for reward model",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=12,
        help="Batch size (default: all 12 test cases)",
    )
    args = parser.parse_args()

    print("=" * 70)
    print("Local LLM Reward Model Batch Test")
    print("=" * 70)
    print(f"vLLM URL: {args.vllm_url}")
    print()

    # Create client
    config = LocalLLMRewardConfig(
        vllm_url=args.vllm_url,
        max_tokens=256,
        temperature=0.1,
        max_concurrent=50,
    )
    client = LocalLLMRewardClient(config)
    print(f"Model: {client._model_name}")
    print()

    # Convert test cases to TurnSamples
    samples = [
        TurnSample(
            user_id="test_user",
            session_id="test_session",
            turn_id=i,
            query_t=tc["query_t"],
            answer_t=tc["answer_t"],
            query_t1=tc["query_t1"],
            memories=[],  # Not needed for reward classification
        )
        for i, tc in enumerate(TEST_CASES[:args.batch_size])
    ]

    # Run batch inference
    print(f"Running batch inference on {len(samples)} samples...")
    t0 = time.time()
    results = client.judge_batch(samples)
    elapsed = time.time() - t0

    print(f"Completed in {elapsed:.2f}s ({len(samples)/elapsed:.1f} samples/sec)")
    print()

    # Analyze results
    correct = 0
    for i, (tc, result) in enumerate(zip(TEST_CASES[:args.batch_size], results)):
        is_correct = result.label == tc["expected"]
        if is_correct:
            correct += 1
        status = "OK" if is_correct else "WRONG"

        print(f"[{i+1:2d}] {tc['name'][:45]:45s}")
        print(f"     Expected: {tc['expected']:25s} Got: {result.label:25s} [{status}]")
        print(f"     Confidence: {result.confidence:.2f}, Reward: {result.reward:+.1f}, Update: {result.should_update}")
        print()

    # Summary
    print("=" * 70)
    print("SUMMARY")
    print("=" * 70)
    accuracy = correct / len(samples) * 100
    print(f"Accuracy: {accuracy:.1f}% ({correct}/{len(samples)})")
    print(f"Time: {elapsed:.2f}s")
    print(f"Throughput: {len(samples)/elapsed:.1f} samples/sec")
    print(f"Avg latency: {elapsed/len(samples)*1000:.0f}ms per sample (batched)")
    print()

    # Errors
    errors = [
        (tc, result)
        for tc, result in zip(TEST_CASES[:args.batch_size], results)
        if result.label != tc["expected"]
    ]
    if errors:
        print(f"Errors ({len(errors)}):")
        for tc, result in errors:
            print(f"  - {tc['name']}: Got {result.label}, Expected {tc['expected']}")


if __name__ == "__main__":
    main()