diff options
Diffstat (limited to 'src/personalization/feedback/armo_reward.py')
| -rw-r--r-- | src/personalization/feedback/armo_reward.py | 373 |
1 files changed, 373 insertions, 0 deletions
diff --git a/src/personalization/feedback/armo_reward.py b/src/personalization/feedback/armo_reward.py new file mode 100644 index 0000000..20e9474 --- /dev/null +++ b/src/personalization/feedback/armo_reward.py @@ -0,0 +1,373 @@ +""" +ArmoRM-Llama3-8B-v0.1 local reward model. + +Replaces OpenAI-based LLM judge with local ArmoRM for faster inference. +ArmoRM outputs a preference score (0-1) indicating response quality. + +Score interpretation: +- > 0.7: Good response (positive reward) +- 0.4-0.7: Neutral response +- < 0.4: Poor response (negative reward) + +For preference compliance checking, we compare scores between: +1. Agent response following preferences +2. What the user's follow-up suggests about satisfaction +""" +from __future__ import annotations + +import hashlib +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple + +import torch +from transformers import AutoModelForSequenceClassification, AutoTokenizer + + +@dataclass +class ArmoRewardConfig: + model_id: str = "RLHFlow/ArmoRM-Llama3-8B-v0.1" + device: str = "cuda" + torch_dtype: str = "bfloat16" + max_length: int = 4096 + truncation: bool = True + # Score thresholds for reward mapping + positive_threshold: float = 0.7 # Score >= this → positive reward + negative_threshold: float = 0.4 # Score <= this → negative reward + # Reward values + positive_reward: float = 0.8 + neutral_reward: float = 0.0 + negative_reward: float = -0.8 + # Gating + confidence_threshold: float = 0.3 # Skip update if score variance is too high + enable_cache: bool = True + + +@dataclass +class ArmoRewardResult: + score: float # Raw ArmoRM score (0-1) + reward: float # Mapped reward value + should_update: bool + rationale: str = "" + + +class ArmoRMRewardModel: + """ + Local reward model using ArmoRM-Llama3-8B-v0.1. + + ArmoRM is trained on preference data and outputs a score indicating + how good a response is. We use this to estimate implicit user feedback. + """ + + def __init__(self, config: Optional[ArmoRewardConfig] = None): + self.config = config or ArmoRewardConfig() + self._model = None + self._tokenizer = None + self._cache: Dict[str, ArmoRewardResult] = {} + self._loaded = False + + def load(self): + """Load model and tokenizer (lazy loading).""" + if self._loaded: + return + + dtype_map = { + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "float32": torch.float32, + } + torch_dtype = dtype_map.get(self.config.torch_dtype, torch.bfloat16) + + print(f"[ArmoRM] Loading model {self.config.model_id} on {self.config.device}...") + self._model = AutoModelForSequenceClassification.from_pretrained( + self.config.model_id, + device_map=self.config.device, + trust_remote_code=True, + torch_dtype=torch_dtype, + ) + self._tokenizer = AutoTokenizer.from_pretrained( + self.config.model_id, + use_fast=True, + ) + self._loaded = True + print(f"[ArmoRM] Model loaded successfully.") + + def _cache_key(self, messages: List[Dict[str, str]]) -> str: + """Deterministic hash of messages.""" + content = str(messages) + return hashlib.sha256(content.encode("utf-8")).hexdigest() + + def _score_to_reward(self, score: float) -> Tuple[float, bool]: + """Convert ArmoRM score to reward value with gating.""" + if score >= self.config.positive_threshold: + reward = self.config.positive_reward + should_update = True + elif score <= self.config.negative_threshold: + reward = self.config.negative_reward + should_update = True + else: + reward = self.config.neutral_reward + should_update = False # Ambiguous signal, skip update + + return reward, should_update + + def score_response( + self, + messages: List[Dict[str, str]], + ) -> ArmoRewardResult: + """ + Score a conversation using ArmoRM. + + Args: + messages: List of {"role": "user"/"assistant", "content": "..."} + + Returns: + ArmoRewardResult with score, reward, and should_update + """ + if not self._loaded: + self.load() + + # Cache lookup + if self.config.enable_cache: + key = self._cache_key(messages) + if key in self._cache: + return self._cache[key] + + # Tokenize and score + input_ids = self._tokenizer.apply_chat_template( + messages, + return_tensors="pt", + padding=True, + truncation=self.config.truncation, + max_length=self.config.max_length, + ).to(self._model.device) + + with torch.no_grad(): + output = self._model(input_ids) + score = output.score.float().item() + + # Convert to reward + reward, should_update = self._score_to_reward(score) + + result = ArmoRewardResult( + score=score, + reward=reward, + should_update=should_update, + rationale=f"ArmoRM score: {score:.3f}", + ) + + # Cache store + if self.config.enable_cache: + self._cache[key] = result + + return result + + def score_batch( + self, + messages_batch: List[List[Dict[str, str]]], + ) -> List[ArmoRewardResult]: + """Score a batch of conversations.""" + return [self.score_response(msgs) for msgs in messages_batch] + + def estimate_preference_compliance( + self, + query: str, + response: str, + user_followup: str, + preferences: Optional[List[str]] = None, + ) -> ArmoRewardResult: + """ + Estimate if the response followed user preferences based on follow-up. + + Strategy: Score the conversation quality. A satisfied user (whose + preferences were followed) will have a more positive follow-up, + leading to higher scores. + + Args: + query: User's original query (q_t) + response: Agent's response (a_t) + user_followup: User's next message (q_{t+1}) + preferences: Optional list of user preferences (for context) + + Returns: + ArmoRewardResult indicating preference compliance + """ + # Build conversation for scoring + # Include the follow-up to capture user satisfaction signal + messages = [ + {"role": "user", "content": query}, + {"role": "assistant", "content": response}, + {"role": "user", "content": user_followup}, + ] + + return self.score_response(messages) + + def compare_responses( + self, + query: str, + response_a: str, + response_b: str, + ) -> Tuple[float, float, str]: + """ + Compare two responses and return which is better. + + Returns: + (score_a, score_b, winner) where winner is 'a', 'b', or 'tie' + """ + messages_a = [ + {"role": "user", "content": query}, + {"role": "assistant", "content": response_a}, + ] + messages_b = [ + {"role": "user", "content": query}, + {"role": "assistant", "content": response_b}, + ] + + result_a = self.score_response(messages_a) + result_b = self.score_response(messages_b) + + if abs(result_a.score - result_b.score) < 0.05: + winner = "tie" + elif result_a.score > result_b.score: + winner = "a" + else: + winner = "b" + + return result_a.score, result_b.score, winner + + def cleanup(self): + """Free GPU memory.""" + if self._model is not None: + del self._model + self._model = None + if self._tokenizer is not None: + del self._tokenizer + self._tokenizer = None + self._loaded = False + torch.cuda.empty_cache() + + +# --- Convenience Functions --- + +def create_armo_reward_model( + device: str = "cuda", + model_id: str = "RLHFlow/ArmoRM-Llama3-8B-v0.1", +) -> ArmoRMRewardModel: + """Create and load ArmoRM reward model.""" + config = ArmoRewardConfig( + model_id=model_id, + device=device, + ) + model = ArmoRMRewardModel(config) + model.load() + return model + + +# --- Integration with existing eval_step interface --- + +async def eval_step_armo( + q_t: str, + answer_t: str, + q_t1: str, + armo_model: ArmoRMRewardModel, + memories_t: Optional[List[str]] = None, +) -> Tuple[float, float]: + """ + Drop-in replacement for eval_step_llm using ArmoRM. + + Args: + q_t: User query at turn t + answer_t: Agent response at turn t + q_t1: User follow-up at turn t+1 + armo_model: Loaded ArmoRMRewardModel instance + memories_t: Retrieved memories (not used by ArmoRM, kept for API compat) + + Returns: + (reward, gating) tuple compatible with existing interface + """ + result = armo_model.estimate_preference_compliance( + query=q_t, + response=answer_t, + user_followup=q_t1, + ) + + # Gating: 1.0 if should_update, 0.0 otherwise + gating = 1.0 if result.should_update else 0.0 + + return result.reward, gating + + +# --- Test Script --- + +if __name__ == "__main__": + print("=" * 60) + print("ArmoRM Reward Model Test") + print("=" * 60) + + # Create model + model = create_armo_reward_model(device="cuda") + + # Test 1: Basic response scoring + print("\n--- Test 1: Basic Response Scoring ---") + messages = [ + {"role": "user", "content": "What is the capital of France?"}, + {"role": "assistant", "content": "The capital of France is Paris."}, + ] + result = model.score_response(messages) + print(f"Query: What is the capital of France?") + print(f"Response: The capital of France is Paris.") + print(f"Score: {result.score:.3f}, Reward: {result.reward:.2f}, Update: {result.should_update}") + + # Test 2: Good response with satisfied user + print("\n--- Test 2: Good Response (User Satisfied) ---") + result = model.estimate_preference_compliance( + query="Can you explain how photosynthesis works?", + response="Photosynthesis is the process by which plants convert sunlight, water, and carbon dioxide into glucose and oxygen. It occurs in the chloroplasts, primarily in the leaves. The light-dependent reactions capture solar energy, while the Calvin cycle uses that energy to fix carbon dioxide into sugars.", + user_followup="Great explanation! Can you tell me more about the Calvin cycle?", + ) + print(f"Score: {result.score:.3f}, Reward: {result.reward:.2f}, Update: {result.should_update}") + + # Test 3: Bad response with dissatisfied user + print("\n--- Test 3: Bad Response (User Dissatisfied) ---") + result = model.estimate_preference_compliance( + query="Can you explain how photosynthesis works?", + response="Plants make food.", + user_followup="That's not helpful at all. I asked for an explanation of how photosynthesis works, not a one-liner.", + ) + print(f"Score: {result.score:.3f}, Reward: {result.reward:.2f}, Update: {result.should_update}") + + # Test 4: Preference enforcement scenario + print("\n--- Test 4: Preference Enforcement Scenario ---") + result = model.estimate_preference_compliance( + query="Solve x^2 - 5x + 6 = 0", + response="x = 2 or x = 3", + user_followup="I asked you to show step-by-step work. Please solve it again showing each step.", + ) + print(f"Score: {result.score:.3f}, Reward: {result.reward:.2f}, Update: {result.should_update}") + + # Test 5: Compare two responses + print("\n--- Test 5: Response Comparison ---") + score_a, score_b, winner = model.compare_responses( + query="What are the benefits of exercise?", + response_a="Exercise is good for you.", + response_b="Exercise offers numerous benefits including improved cardiovascular health, stronger muscles and bones, better mental health through endorphin release, weight management, increased energy levels, and better sleep quality. Regular physical activity also reduces the risk of chronic diseases like diabetes and heart disease.", + ) + print(f"Response A (short): {score_a:.3f}") + print(f"Response B (detailed): {score_b:.3f}") + print(f"Winner: {winner}") + + # Test 6: Batch scoring + print("\n--- Test 6: Batch Scoring ---") + batch = [ + [{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi there! How can I help you today?"}], + [{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "k"}], + ] + results = model.score_batch(batch) + for i, r in enumerate(results): + print(f" Conversation {i+1}: Score={r.score:.3f}, Reward={r.reward:.2f}") + + print("\n" + "=" * 60) + print("Tests complete!") + print("=" * 60) + + # Cleanup + model.cleanup() |
