summaryrefslogtreecommitdiff
path: root/scripts/test_armo_reward.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-27 12:15:45 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-27 12:15:45 -0600
commit680513b7771a29f27cbbb3ffb009a69a913de6f9 (patch)
treea0d60aef9ade1b2953b915f535b990c0de95e493 /scripts/test_armo_reward.py
parentc06ec2f3b80f8968f09eb801b69237495b055ec1 (diff)
local reward model
Diffstat (limited to 'scripts/test_armo_reward.py')
-rw-r--r--scripts/test_armo_reward.py216
1 files changed, 216 insertions, 0 deletions
diff --git a/scripts/test_armo_reward.py b/scripts/test_armo_reward.py
new file mode 100644
index 0000000..534ab84
--- /dev/null
+++ b/scripts/test_armo_reward.py
@@ -0,0 +1,216 @@
+#!/usr/bin/env python3
+"""
+Test script for ArmoRM reward model.
+
+Usage:
+ python scripts/test_armo_reward.py [--device cuda:0]
+"""
+import argparse
+import sys
+import os
+
+# Add src to path
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
+
+from personalization.feedback.armo_reward import (
+ ArmoRMRewardModel,
+ ArmoRewardConfig,
+ create_armo_reward_model,
+)
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Test ArmoRM reward model")
+ parser.add_argument("--device", type=str, default="cuda", help="Device to use")
+ parser.add_argument("--model-id", type=str, default="RLHFlow/ArmoRM-Llama3-8B-v0.1")
+ args = parser.parse_args()
+
+ print("=" * 70)
+ print("ArmoRM Reward Model Test")
+ print("=" * 70)
+ print(f"Device: {args.device}")
+ print(f"Model: {args.model_id}")
+ print()
+
+ # Create model
+ print("Loading model...")
+ config = ArmoRewardConfig(
+ model_id=args.model_id,
+ device=args.device,
+ )
+ model = ArmoRMRewardModel(config)
+ model.load()
+ print("Model loaded!\n")
+
+ # Test cases
+ test_cases = [
+ {
+ "name": "Good factual answer",
+ "messages": [
+ {"role": "user", "content": "What is the capital of France?"},
+ {"role": "assistant", "content": "The capital of France is Paris. Paris is not only the capital but also the largest city in France, located in the north-central part of the country along the Seine River."},
+ ],
+ "expected": "high score (>0.7)",
+ },
+ {
+ "name": "Minimal answer",
+ "messages": [
+ {"role": "user", "content": "What is the capital of France?"},
+ {"role": "assistant", "content": "Paris."},
+ ],
+ "expected": "medium score (0.4-0.7)",
+ },
+ {
+ "name": "Wrong answer",
+ "messages": [
+ {"role": "user", "content": "What is the capital of France?"},
+ {"role": "assistant", "content": "The capital of France is London."},
+ ],
+ "expected": "low score (<0.4)",
+ },
+ {
+ "name": "Helpful detailed explanation",
+ "messages": [
+ {"role": "user", "content": "Explain how to sort a list in Python"},
+ {"role": "assistant", "content": """There are several ways to sort a list in Python:
+
+1. **Using sorted()** - Returns a new sorted list:
+ ```python
+ my_list = [3, 1, 4, 1, 5]
+ sorted_list = sorted(my_list) # [1, 1, 3, 4, 5]
+ ```
+
+2. **Using list.sort()** - Sorts in place:
+ ```python
+ my_list = [3, 1, 4, 1, 5]
+ my_list.sort() # my_list is now [1, 1, 3, 4, 5]
+ ```
+
+3. **Reverse sorting**:
+ ```python
+ sorted_list = sorted(my_list, reverse=True)
+ ```
+
+4. **Custom key function**:
+ ```python
+ words = ['apple', 'Banana', 'cherry']
+ sorted_words = sorted(words, key=str.lower)
+ ```"""},
+ ],
+ "expected": "high score (>0.7)",
+ },
+ {
+ "name": "Unhelpful response",
+ "messages": [
+ {"role": "user", "content": "Explain how to sort a list in Python"},
+ {"role": "assistant", "content": "Just use sort."},
+ ],
+ "expected": "low score (<0.4)",
+ },
+ ]
+
+ print("Running test cases...\n")
+ for i, tc in enumerate(test_cases, 1):
+ print(f"--- Test {i}: {tc['name']} ---")
+ print(f"Expected: {tc['expected']}")
+
+ result = model.score_response(tc["messages"])
+
+ print(f"Score: {result.score:.4f}")
+ print(f"Reward: {result.reward:.2f}")
+ print(f"Should Update: {result.should_update}")
+ print()
+
+ # Test preference compliance
+ print("=" * 70)
+ print("Testing Preference Compliance Scenarios")
+ print("=" * 70)
+ print()
+
+ compliance_tests = [
+ {
+ "name": "User satisfied (preference followed)",
+ "query": "Can you explain recursion? I prefer examples with code.",
+ "response": """Recursion is when a function calls itself. Here's a classic example - calculating factorial:
+
+```python
+def factorial(n):
+ if n <= 1:
+ return 1
+ return n * factorial(n - 1)
+
+print(factorial(5)) # Output: 120
+```
+
+The function calls itself with a smaller value until it reaches the base case (n <= 1).""",
+ "followup": "Perfect! That's exactly what I needed. Can you show me another example with Fibonacci?",
+ },
+ {
+ "name": "User dissatisfied (preference NOT followed)",
+ "query": "Can you explain recursion? I prefer examples with code.",
+ "response": "Recursion is a programming concept where a function calls itself to solve smaller instances of the same problem.",
+ "followup": "I specifically asked for code examples. Please show me some actual code demonstrating recursion.",
+ },
+ {
+ "name": "User correcting format preference",
+ "query": "List 5 benefits of meditation. Use bullet points please.",
+ "response": "Meditation has many benefits. First, it reduces stress. Second, it improves focus. Third, it promotes emotional health. Fourth, it enhances self-awareness. Fifth, it can reduce anxiety.",
+ "followup": "I asked for bullet points, not numbered sentences. Can you reformat that?",
+ },
+ ]
+
+ for i, tc in enumerate(compliance_tests, 1):
+ print(f"--- Compliance Test {i}: {tc['name']} ---")
+ print(f"Query: {tc['query'][:60]}...")
+ print(f"Followup: {tc['followup'][:60]}...")
+
+ result = model.estimate_preference_compliance(
+ query=tc["query"],
+ response=tc["response"],
+ user_followup=tc["followup"],
+ )
+
+ print(f"Score: {result.score:.4f}")
+ print(f"Reward: {result.reward:.2f}")
+ print(f"Should Update: {result.should_update}")
+ print()
+
+ # Test response comparison
+ print("=" * 70)
+ print("Testing Response Comparison")
+ print("=" * 70)
+ print()
+
+ query = "What are the health benefits of drinking water?"
+ response_a = "Water is good for health."
+ response_b = """Drinking adequate water provides numerous health benefits:
+
+1. **Hydration**: Maintains fluid balance for bodily functions
+2. **Digestion**: Aids in breaking down food and nutrient absorption
+3. **Skin Health**: Keeps skin moisturized and may reduce wrinkles
+4. **Kidney Function**: Helps flush out toxins and prevents kidney stones
+5. **Energy**: Prevents fatigue caused by dehydration
+6. **Weight Management**: Can reduce appetite when consumed before meals
+7. **Joint Health**: Lubricates and cushions joints
+
+The general recommendation is 8 glasses (64 oz) per day, though needs vary by individual."""
+
+ print(f"Query: {query}")
+ print(f"Response A: {response_a}")
+ print(f"Response B: {response_b[:100]}...")
+
+ score_a, score_b, winner = model.compare_responses(query, response_a, response_b)
+ print(f"\nScore A: {score_a:.4f}")
+ print(f"Score B: {score_b:.4f}")
+ print(f"Winner: {winner.upper()}")
+
+ print("\n" + "=" * 70)
+ print("All tests complete!")
+ print("=" * 70)
+
+ # Cleanup
+ model.cleanup()
+
+
+if __name__ == "__main__":
+ main()