diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-27 12:15:45 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-27 12:15:45 -0600 |
| commit | 680513b7771a29f27cbbb3ffb009a69a913de6f9 (patch) | |
| tree | a0d60aef9ade1b2953b915f535b990c0de95e493 /scripts/test_armo_reward.py | |
| parent | c06ec2f3b80f8968f09eb801b69237495b055ec1 (diff) | |
local reward model
Diffstat (limited to 'scripts/test_armo_reward.py')
| -rw-r--r-- | scripts/test_armo_reward.py | 216 |
1 files changed, 216 insertions, 0 deletions
diff --git a/scripts/test_armo_reward.py b/scripts/test_armo_reward.py new file mode 100644 index 0000000..534ab84 --- /dev/null +++ b/scripts/test_armo_reward.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python3 +""" +Test script for ArmoRM reward model. + +Usage: + python scripts/test_armo_reward.py [--device cuda:0] +""" +import argparse +import sys +import os + +# Add src to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) + +from personalization.feedback.armo_reward import ( + ArmoRMRewardModel, + ArmoRewardConfig, + create_armo_reward_model, +) + + +def main(): + parser = argparse.ArgumentParser(description="Test ArmoRM reward model") + parser.add_argument("--device", type=str, default="cuda", help="Device to use") + parser.add_argument("--model-id", type=str, default="RLHFlow/ArmoRM-Llama3-8B-v0.1") + args = parser.parse_args() + + print("=" * 70) + print("ArmoRM Reward Model Test") + print("=" * 70) + print(f"Device: {args.device}") + print(f"Model: {args.model_id}") + print() + + # Create model + print("Loading model...") + config = ArmoRewardConfig( + model_id=args.model_id, + device=args.device, + ) + model = ArmoRMRewardModel(config) + model.load() + print("Model loaded!\n") + + # Test cases + test_cases = [ + { + "name": "Good factual answer", + "messages": [ + {"role": "user", "content": "What is the capital of France?"}, + {"role": "assistant", "content": "The capital of France is Paris. Paris is not only the capital but also the largest city in France, located in the north-central part of the country along the Seine River."}, + ], + "expected": "high score (>0.7)", + }, + { + "name": "Minimal answer", + "messages": [ + {"role": "user", "content": "What is the capital of France?"}, + {"role": "assistant", "content": "Paris."}, + ], + "expected": "medium score (0.4-0.7)", + }, + { + "name": "Wrong answer", + "messages": [ + {"role": "user", "content": "What is the capital of France?"}, + {"role": "assistant", "content": "The capital of France is London."}, + ], + "expected": "low score (<0.4)", + }, + { + "name": "Helpful detailed explanation", + "messages": [ + {"role": "user", "content": "Explain how to sort a list in Python"}, + {"role": "assistant", "content": """There are several ways to sort a list in Python: + +1. **Using sorted()** - Returns a new sorted list: + ```python + my_list = [3, 1, 4, 1, 5] + sorted_list = sorted(my_list) # [1, 1, 3, 4, 5] + ``` + +2. **Using list.sort()** - Sorts in place: + ```python + my_list = [3, 1, 4, 1, 5] + my_list.sort() # my_list is now [1, 1, 3, 4, 5] + ``` + +3. **Reverse sorting**: + ```python + sorted_list = sorted(my_list, reverse=True) + ``` + +4. **Custom key function**: + ```python + words = ['apple', 'Banana', 'cherry'] + sorted_words = sorted(words, key=str.lower) + ```"""}, + ], + "expected": "high score (>0.7)", + }, + { + "name": "Unhelpful response", + "messages": [ + {"role": "user", "content": "Explain how to sort a list in Python"}, + {"role": "assistant", "content": "Just use sort."}, + ], + "expected": "low score (<0.4)", + }, + ] + + print("Running test cases...\n") + for i, tc in enumerate(test_cases, 1): + print(f"--- Test {i}: {tc['name']} ---") + print(f"Expected: {tc['expected']}") + + result = model.score_response(tc["messages"]) + + print(f"Score: {result.score:.4f}") + print(f"Reward: {result.reward:.2f}") + print(f"Should Update: {result.should_update}") + print() + + # Test preference compliance + print("=" * 70) + print("Testing Preference Compliance Scenarios") + print("=" * 70) + print() + + compliance_tests = [ + { + "name": "User satisfied (preference followed)", + "query": "Can you explain recursion? I prefer examples with code.", + "response": """Recursion is when a function calls itself. Here's a classic example - calculating factorial: + +```python +def factorial(n): + if n <= 1: + return 1 + return n * factorial(n - 1) + +print(factorial(5)) # Output: 120 +``` + +The function calls itself with a smaller value until it reaches the base case (n <= 1).""", + "followup": "Perfect! That's exactly what I needed. Can you show me another example with Fibonacci?", + }, + { + "name": "User dissatisfied (preference NOT followed)", + "query": "Can you explain recursion? I prefer examples with code.", + "response": "Recursion is a programming concept where a function calls itself to solve smaller instances of the same problem.", + "followup": "I specifically asked for code examples. Please show me some actual code demonstrating recursion.", + }, + { + "name": "User correcting format preference", + "query": "List 5 benefits of meditation. Use bullet points please.", + "response": "Meditation has many benefits. First, it reduces stress. Second, it improves focus. Third, it promotes emotional health. Fourth, it enhances self-awareness. Fifth, it can reduce anxiety.", + "followup": "I asked for bullet points, not numbered sentences. Can you reformat that?", + }, + ] + + for i, tc in enumerate(compliance_tests, 1): + print(f"--- Compliance Test {i}: {tc['name']} ---") + print(f"Query: {tc['query'][:60]}...") + print(f"Followup: {tc['followup'][:60]}...") + + result = model.estimate_preference_compliance( + query=tc["query"], + response=tc["response"], + user_followup=tc["followup"], + ) + + print(f"Score: {result.score:.4f}") + print(f"Reward: {result.reward:.2f}") + print(f"Should Update: {result.should_update}") + print() + + # Test response comparison + print("=" * 70) + print("Testing Response Comparison") + print("=" * 70) + print() + + query = "What are the health benefits of drinking water?" + response_a = "Water is good for health." + response_b = """Drinking adequate water provides numerous health benefits: + +1. **Hydration**: Maintains fluid balance for bodily functions +2. **Digestion**: Aids in breaking down food and nutrient absorption +3. **Skin Health**: Keeps skin moisturized and may reduce wrinkles +4. **Kidney Function**: Helps flush out toxins and prevents kidney stones +5. **Energy**: Prevents fatigue caused by dehydration +6. **Weight Management**: Can reduce appetite when consumed before meals +7. **Joint Health**: Lubricates and cushions joints + +The general recommendation is 8 glasses (64 oz) per day, though needs vary by individual.""" + + print(f"Query: {query}") + print(f"Response A: {response_a}") + print(f"Response B: {response_b[:100]}...") + + score_a, score_b, winner = model.compare_responses(query, response_a, response_b) + print(f"\nScore A: {score_a:.4f}") + print(f"Score B: {score_b:.4f}") + print(f"Winner: {winner.upper()}") + + print("\n" + "=" * 70) + print("All tests complete!") + print("=" * 70) + + # Cleanup + model.cleanup() + + +if __name__ == "__main__": + main() |
