summaryrefslogtreecommitdiff
path: root/scripts/test_local_reward_batch.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-27 12:15:45 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-27 12:15:45 -0600
commit680513b7771a29f27cbbb3ffb009a69a913de6f9 (patch)
treea0d60aef9ade1b2953b915f535b990c0de95e493 /scripts/test_local_reward_batch.py
parentc06ec2f3b80f8968f09eb801b69237495b055ec1 (diff)
local reward model
Diffstat (limited to 'scripts/test_local_reward_batch.py')
-rw-r--r--scripts/test_local_reward_batch.py206
1 files changed, 206 insertions, 0 deletions
diff --git a/scripts/test_local_reward_batch.py b/scripts/test_local_reward_batch.py
new file mode 100644
index 0000000..7afb834
--- /dev/null
+++ b/scripts/test_local_reward_batch.py
@@ -0,0 +1,206 @@
+#!/usr/bin/env python3
+"""
+Test batch reward estimation with local LLM via vLLM.
+
+Verifies that:
+1. LocalLLMRewardClient works with vLLM server
+2. Batch processing is efficient (concurrent, not sequential)
+3. Results match expected labels
+"""
+import argparse
+import asyncio
+import sys
+import os
+import time
+
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
+
+from personalization.feedback.schemas import TurnSample
+from personalization.feedback.local_llm_reward import (
+ LocalLLMRewardClient,
+ LocalLLMRewardConfig,
+)
+
+# Test cases with expected labels
+TEST_CASES = [
+ {
+ "name": "neg_constraint_restate - format preference",
+ "query_t": "Explain how sorting works in Python. Please use bullet points.",
+ "answer_t": "Sorting in Python can be done using the sorted() function or the list.sort() method. The sorted() function returns a new sorted list, while sort() modifies the list in place. Both accept a key parameter for custom sorting and a reverse parameter for descending order.",
+ "query_t1": "I asked for bullet points. Can you reformat that with bullet points please?",
+ "expected": "neg_constraint_restate",
+ },
+ {
+ "name": "neg_constraint_restate - step by step",
+ "query_t": "Solve x^2 - 5x + 6 = 0. Show step by step.",
+ "answer_t": "The solutions are x = 2 and x = 3.",
+ "query_t1": "As I said, I need to see the step-by-step solution, not just the answer.",
+ "expected": "neg_constraint_restate",
+ },
+ {
+ "name": "neg_correction - wrong answer",
+ "query_t": "What is the capital of Australia?",
+ "answer_t": "The capital of Australia is Sydney.",
+ "query_t1": "That's incorrect. Sydney is not the capital of Australia.",
+ "expected": "neg_correction",
+ },
+ {
+ "name": "neg_confusion - unclear explanation",
+ "query_t": "What is recursion in programming?",
+ "answer_t": "Recursion is when a function calls itself in a self-similar way to solve problems.",
+ "query_t1": "I'm confused. What do you mean by 'self-similar way'? Can you explain more clearly?",
+ "expected": "neg_confusion",
+ },
+ {
+ "name": "pos_praise - explicit thanks",
+ "query_t": "How do I center a div in CSS?",
+ "answer_t": "You can center a div using flexbox: set the parent to `display: flex; justify-content: center; align-items: center;`. Alternatively, use `margin: 0 auto;` for horizontal centering with a defined width.",
+ "query_t1": "Perfect, thank you! That's exactly what I needed.",
+ "expected": "pos_praise",
+ },
+ {
+ "name": "pos_praise - great explanation",
+ "query_t": "Explain how photosynthesis works.",
+ "answer_t": "Photosynthesis is the process by which plants convert sunlight, water, and CO2 into glucose and oxygen. It occurs in chloroplasts, with light-dependent reactions in the thylakoid membrane and the Calvin cycle in the stroma.",
+ "query_t1": "Great explanation! This really helped me understand the concept.",
+ "expected": "pos_praise",
+ },
+ {
+ "name": "pos_progress - follow-up question",
+ "query_t": "What is a binary search tree?",
+ "answer_t": "A binary search tree (BST) is a data structure where each node has at most two children. The left subtree contains only nodes with values less than the parent, and the right subtree only nodes with values greater than the parent.",
+ "query_t1": "Interesting! How would I implement insertion into a BST?",
+ "expected": "pos_progress",
+ },
+ {
+ "name": "pos_progress - extension",
+ "query_t": "How do I read a file in Python?",
+ "answer_t": "Use `with open('file.txt', 'r') as f: content = f.read()`. The 'with' statement ensures the file is properly closed.",
+ "query_t1": "Got it. What if I want to read it line by line instead?",
+ "expected": "pos_progress",
+ },
+ {
+ "name": "neutral - minimal response",
+ "query_t": "What's 2 + 2?",
+ "answer_t": "2 + 2 = 4",
+ "query_t1": "Ok.",
+ "expected": "neutral",
+ },
+ {
+ "name": "topic_shift - new topic",
+ "query_t": "What is the Pythagorean theorem?",
+ "answer_t": "The Pythagorean theorem states that in a right triangle, a² + b² = c², where c is the hypotenuse.",
+ "query_t1": "By the way, can you help me write a poem about nature?",
+ "expected": "topic_shift",
+ },
+ {
+ "name": "neg_constraint_restate - language preference",
+ "query_t": "Explain machine learning in simple terms.",
+ "answer_t": "Machine learning is a subset of artificial intelligence that uses statistical techniques to enable computers to learn from data. It involves training models on datasets to make predictions or decisions without being explicitly programmed for specific tasks.",
+ "query_t1": "Remember I asked for simple terms? That's too technical. Can you explain like I'm 5?",
+ "expected": "neg_constraint_restate",
+ },
+ {
+ "name": "neg_correction - incomplete answer",
+ "query_t": "List all the planets in our solar system.",
+ "answer_t": "The planets are Mercury, Venus, Earth, Mars, Jupiter, and Saturn.",
+ "query_t1": "You're missing Uranus and Neptune. There are 8 planets, not 6.",
+ "expected": "neg_correction",
+ },
+]
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--vllm-url",
+ type=str,
+ default="http://localhost:8005/v1",
+ help="vLLM server URL for reward model",
+ )
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ default=12,
+ help="Batch size (default: all 12 test cases)",
+ )
+ args = parser.parse_args()
+
+ print("=" * 70)
+ print("Local LLM Reward Model Batch Test")
+ print("=" * 70)
+ print(f"vLLM URL: {args.vllm_url}")
+ print()
+
+ # Create client
+ config = LocalLLMRewardConfig(
+ vllm_url=args.vllm_url,
+ max_tokens=256,
+ temperature=0.1,
+ max_concurrent=50,
+ )
+ client = LocalLLMRewardClient(config)
+ print(f"Model: {client._model_name}")
+ print()
+
+ # Convert test cases to TurnSamples
+ samples = [
+ TurnSample(
+ user_id="test_user",
+ session_id="test_session",
+ turn_id=i,
+ query_t=tc["query_t"],
+ answer_t=tc["answer_t"],
+ query_t1=tc["query_t1"],
+ memories=[], # Not needed for reward classification
+ )
+ for i, tc in enumerate(TEST_CASES[:args.batch_size])
+ ]
+
+ # Run batch inference
+ print(f"Running batch inference on {len(samples)} samples...")
+ t0 = time.time()
+ results = client.judge_batch(samples)
+ elapsed = time.time() - t0
+
+ print(f"Completed in {elapsed:.2f}s ({len(samples)/elapsed:.1f} samples/sec)")
+ print()
+
+ # Analyze results
+ correct = 0
+ for i, (tc, result) in enumerate(zip(TEST_CASES[:args.batch_size], results)):
+ is_correct = result.label == tc["expected"]
+ if is_correct:
+ correct += 1
+ status = "OK" if is_correct else "WRONG"
+
+ print(f"[{i+1:2d}] {tc['name'][:45]:45s}")
+ print(f" Expected: {tc['expected']:25s} Got: {result.label:25s} [{status}]")
+ print(f" Confidence: {result.confidence:.2f}, Reward: {result.reward:+.1f}, Update: {result.should_update}")
+ print()
+
+ # Summary
+ print("=" * 70)
+ print("SUMMARY")
+ print("=" * 70)
+ accuracy = correct / len(samples) * 100
+ print(f"Accuracy: {accuracy:.1f}% ({correct}/{len(samples)})")
+ print(f"Time: {elapsed:.2f}s")
+ print(f"Throughput: {len(samples)/elapsed:.1f} samples/sec")
+ print(f"Avg latency: {elapsed/len(samples)*1000:.0f}ms per sample (batched)")
+ print()
+
+ # Errors
+ errors = [
+ (tc, result)
+ for tc, result in zip(TEST_CASES[:args.batch_size], results)
+ if result.label != tc["expected"]
+ ]
+ if errors:
+ print(f"Errors ({len(errors)}):")
+ for tc, result in errors:
+ print(f" - {tc['name']}: Got {result.label}, Expected {tc['expected']}")
+
+
+if __name__ == "__main__":
+ main()