#!/usr/bin/env python3 """ Test script for ArmoRM reward model. Usage: python scripts/test_armo_reward.py [--device cuda:0] """ import argparse import sys import os # Add src to path sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) from personalization.feedback.armo_reward import ( ArmoRMRewardModel, ArmoRewardConfig, create_armo_reward_model, ) def main(): parser = argparse.ArgumentParser(description="Test ArmoRM reward model") parser.add_argument("--device", type=str, default="cuda", help="Device to use") parser.add_argument("--model-id", type=str, default="RLHFlow/ArmoRM-Llama3-8B-v0.1") args = parser.parse_args() print("=" * 70) print("ArmoRM Reward Model Test") print("=" * 70) print(f"Device: {args.device}") print(f"Model: {args.model_id}") print() # Create model print("Loading model...") config = ArmoRewardConfig( model_id=args.model_id, device=args.device, ) model = ArmoRMRewardModel(config) model.load() print("Model loaded!\n") # Test cases test_cases = [ { "name": "Good factual answer", "messages": [ {"role": "user", "content": "What is the capital of France?"}, {"role": "assistant", "content": "The capital of France is Paris. Paris is not only the capital but also the largest city in France, located in the north-central part of the country along the Seine River."}, ], "expected": "high score (>0.7)", }, { "name": "Minimal answer", "messages": [ {"role": "user", "content": "What is the capital of France?"}, {"role": "assistant", "content": "Paris."}, ], "expected": "medium score (0.4-0.7)", }, { "name": "Wrong answer", "messages": [ {"role": "user", "content": "What is the capital of France?"}, {"role": "assistant", "content": "The capital of France is London."}, ], "expected": "low score (<0.4)", }, { "name": "Helpful detailed explanation", "messages": [ {"role": "user", "content": "Explain how to sort a list in Python"}, {"role": "assistant", "content": """There are several ways to sort a list in Python: 1. **Using sorted()** - Returns a new sorted list: ```python my_list = [3, 1, 4, 1, 5] sorted_list = sorted(my_list) # [1, 1, 3, 4, 5] ``` 2. **Using list.sort()** - Sorts in place: ```python my_list = [3, 1, 4, 1, 5] my_list.sort() # my_list is now [1, 1, 3, 4, 5] ``` 3. **Reverse sorting**: ```python sorted_list = sorted(my_list, reverse=True) ``` 4. **Custom key function**: ```python words = ['apple', 'Banana', 'cherry'] sorted_words = sorted(words, key=str.lower) ```"""}, ], "expected": "high score (>0.7)", }, { "name": "Unhelpful response", "messages": [ {"role": "user", "content": "Explain how to sort a list in Python"}, {"role": "assistant", "content": "Just use sort."}, ], "expected": "low score (<0.4)", }, ] print("Running test cases...\n") for i, tc in enumerate(test_cases, 1): print(f"--- Test {i}: {tc['name']} ---") print(f"Expected: {tc['expected']}") result = model.score_response(tc["messages"]) print(f"Score: {result.score:.4f}") print(f"Reward: {result.reward:.2f}") print(f"Should Update: {result.should_update}") print() # Test preference compliance print("=" * 70) print("Testing Preference Compliance Scenarios") print("=" * 70) print() compliance_tests = [ { "name": "User satisfied (preference followed)", "query": "Can you explain recursion? I prefer examples with code.", "response": """Recursion is when a function calls itself. Here's a classic example - calculating factorial: ```python def factorial(n): if n <= 1: return 1 return n * factorial(n - 1) print(factorial(5)) # Output: 120 ``` The function calls itself with a smaller value until it reaches the base case (n <= 1).""", "followup": "Perfect! That's exactly what I needed. Can you show me another example with Fibonacci?", }, { "name": "User dissatisfied (preference NOT followed)", "query": "Can you explain recursion? I prefer examples with code.", "response": "Recursion is a programming concept where a function calls itself to solve smaller instances of the same problem.", "followup": "I specifically asked for code examples. Please show me some actual code demonstrating recursion.", }, { "name": "User correcting format preference", "query": "List 5 benefits of meditation. Use bullet points please.", "response": "Meditation has many benefits. First, it reduces stress. Second, it improves focus. Third, it promotes emotional health. Fourth, it enhances self-awareness. Fifth, it can reduce anxiety.", "followup": "I asked for bullet points, not numbered sentences. Can you reformat that?", }, ] for i, tc in enumerate(compliance_tests, 1): print(f"--- Compliance Test {i}: {tc['name']} ---") print(f"Query: {tc['query'][:60]}...") print(f"Followup: {tc['followup'][:60]}...") result = model.estimate_preference_compliance( query=tc["query"], response=tc["response"], user_followup=tc["followup"], ) print(f"Score: {result.score:.4f}") print(f"Reward: {result.reward:.2f}") print(f"Should Update: {result.should_update}") print() # Test response comparison print("=" * 70) print("Testing Response Comparison") print("=" * 70) print() query = "What are the health benefits of drinking water?" response_a = "Water is good for health." response_b = """Drinking adequate water provides numerous health benefits: 1. **Hydration**: Maintains fluid balance for bodily functions 2. **Digestion**: Aids in breaking down food and nutrient absorption 3. **Skin Health**: Keeps skin moisturized and may reduce wrinkles 4. **Kidney Function**: Helps flush out toxins and prevents kidney stones 5. **Energy**: Prevents fatigue caused by dehydration 6. **Weight Management**: Can reduce appetite when consumed before meals 7. **Joint Health**: Lubricates and cushions joints The general recommendation is 8 glasses (64 oz) per day, though needs vary by individual.""" print(f"Query: {query}") print(f"Response A: {response_a}") print(f"Response B: {response_b[:100]}...") score_a, score_b, winner = model.compare_responses(query, response_a, response_b) print(f"\nScore A: {score_a:.4f}") print(f"Score B: {score_b:.4f}") print(f"Winner: {winner.upper()}") print("\n" + "=" * 70) print("All tests complete!") print("=" * 70) # Cleanup model.cleanup() if __name__ == "__main__": main()