diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-27 12:15:45 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-27 12:15:45 -0600 |
| commit | 680513b7771a29f27cbbb3ffb009a69a913de6f9 (patch) | |
| tree | a0d60aef9ade1b2953b915f535b990c0de95e493 /src/personalization | |
| parent | c06ec2f3b80f8968f09eb801b69237495b055ec1 (diff) | |
local reward model
Diffstat (limited to 'src/personalization')
| -rw-r--r-- | src/personalization/feedback/armo_reward.py | 373 | ||||
| -rw-r--r-- | src/personalization/feedback/local_llm_reward.py | 342 | ||||
| -rw-r--r-- | src/personalization/serving/personalized_llm.py | 20 |
3 files changed, 730 insertions, 5 deletions
diff --git a/src/personalization/feedback/armo_reward.py b/src/personalization/feedback/armo_reward.py new file mode 100644 index 0000000..20e9474 --- /dev/null +++ b/src/personalization/feedback/armo_reward.py @@ -0,0 +1,373 @@ +""" +ArmoRM-Llama3-8B-v0.1 local reward model. + +Replaces OpenAI-based LLM judge with local ArmoRM for faster inference. +ArmoRM outputs a preference score (0-1) indicating response quality. + +Score interpretation: +- > 0.7: Good response (positive reward) +- 0.4-0.7: Neutral response +- < 0.4: Poor response (negative reward) + +For preference compliance checking, we compare scores between: +1. Agent response following preferences +2. What the user's follow-up suggests about satisfaction +""" +from __future__ import annotations + +import hashlib +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple + +import torch +from transformers import AutoModelForSequenceClassification, AutoTokenizer + + +@dataclass +class ArmoRewardConfig: + model_id: str = "RLHFlow/ArmoRM-Llama3-8B-v0.1" + device: str = "cuda" + torch_dtype: str = "bfloat16" + max_length: int = 4096 + truncation: bool = True + # Score thresholds for reward mapping + positive_threshold: float = 0.7 # Score >= this → positive reward + negative_threshold: float = 0.4 # Score <= this → negative reward + # Reward values + positive_reward: float = 0.8 + neutral_reward: float = 0.0 + negative_reward: float = -0.8 + # Gating + confidence_threshold: float = 0.3 # Skip update if score variance is too high + enable_cache: bool = True + + +@dataclass +class ArmoRewardResult: + score: float # Raw ArmoRM score (0-1) + reward: float # Mapped reward value + should_update: bool + rationale: str = "" + + +class ArmoRMRewardModel: + """ + Local reward model using ArmoRM-Llama3-8B-v0.1. + + ArmoRM is trained on preference data and outputs a score indicating + how good a response is. We use this to estimate implicit user feedback. + """ + + def __init__(self, config: Optional[ArmoRewardConfig] = None): + self.config = config or ArmoRewardConfig() + self._model = None + self._tokenizer = None + self._cache: Dict[str, ArmoRewardResult] = {} + self._loaded = False + + def load(self): + """Load model and tokenizer (lazy loading).""" + if self._loaded: + return + + dtype_map = { + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "float32": torch.float32, + } + torch_dtype = dtype_map.get(self.config.torch_dtype, torch.bfloat16) + + print(f"[ArmoRM] Loading model {self.config.model_id} on {self.config.device}...") + self._model = AutoModelForSequenceClassification.from_pretrained( + self.config.model_id, + device_map=self.config.device, + trust_remote_code=True, + torch_dtype=torch_dtype, + ) + self._tokenizer = AutoTokenizer.from_pretrained( + self.config.model_id, + use_fast=True, + ) + self._loaded = True + print(f"[ArmoRM] Model loaded successfully.") + + def _cache_key(self, messages: List[Dict[str, str]]) -> str: + """Deterministic hash of messages.""" + content = str(messages) + return hashlib.sha256(content.encode("utf-8")).hexdigest() + + def _score_to_reward(self, score: float) -> Tuple[float, bool]: + """Convert ArmoRM score to reward value with gating.""" + if score >= self.config.positive_threshold: + reward = self.config.positive_reward + should_update = True + elif score <= self.config.negative_threshold: + reward = self.config.negative_reward + should_update = True + else: + reward = self.config.neutral_reward + should_update = False # Ambiguous signal, skip update + + return reward, should_update + + def score_response( + self, + messages: List[Dict[str, str]], + ) -> ArmoRewardResult: + """ + Score a conversation using ArmoRM. + + Args: + messages: List of {"role": "user"/"assistant", "content": "..."} + + Returns: + ArmoRewardResult with score, reward, and should_update + """ + if not self._loaded: + self.load() + + # Cache lookup + if self.config.enable_cache: + key = self._cache_key(messages) + if key in self._cache: + return self._cache[key] + + # Tokenize and score + input_ids = self._tokenizer.apply_chat_template( + messages, + return_tensors="pt", + padding=True, + truncation=self.config.truncation, + max_length=self.config.max_length, + ).to(self._model.device) + + with torch.no_grad(): + output = self._model(input_ids) + score = output.score.float().item() + + # Convert to reward + reward, should_update = self._score_to_reward(score) + + result = ArmoRewardResult( + score=score, + reward=reward, + should_update=should_update, + rationale=f"ArmoRM score: {score:.3f}", + ) + + # Cache store + if self.config.enable_cache: + self._cache[key] = result + + return result + + def score_batch( + self, + messages_batch: List[List[Dict[str, str]]], + ) -> List[ArmoRewardResult]: + """Score a batch of conversations.""" + return [self.score_response(msgs) for msgs in messages_batch] + + def estimate_preference_compliance( + self, + query: str, + response: str, + user_followup: str, + preferences: Optional[List[str]] = None, + ) -> ArmoRewardResult: + """ + Estimate if the response followed user preferences based on follow-up. + + Strategy: Score the conversation quality. A satisfied user (whose + preferences were followed) will have a more positive follow-up, + leading to higher scores. + + Args: + query: User's original query (q_t) + response: Agent's response (a_t) + user_followup: User's next message (q_{t+1}) + preferences: Optional list of user preferences (for context) + + Returns: + ArmoRewardResult indicating preference compliance + """ + # Build conversation for scoring + # Include the follow-up to capture user satisfaction signal + messages = [ + {"role": "user", "content": query}, + {"role": "assistant", "content": response}, + {"role": "user", "content": user_followup}, + ] + + return self.score_response(messages) + + def compare_responses( + self, + query: str, + response_a: str, + response_b: str, + ) -> Tuple[float, float, str]: + """ + Compare two responses and return which is better. + + Returns: + (score_a, score_b, winner) where winner is 'a', 'b', or 'tie' + """ + messages_a = [ + {"role": "user", "content": query}, + {"role": "assistant", "content": response_a}, + ] + messages_b = [ + {"role": "user", "content": query}, + {"role": "assistant", "content": response_b}, + ] + + result_a = self.score_response(messages_a) + result_b = self.score_response(messages_b) + + if abs(result_a.score - result_b.score) < 0.05: + winner = "tie" + elif result_a.score > result_b.score: + winner = "a" + else: + winner = "b" + + return result_a.score, result_b.score, winner + + def cleanup(self): + """Free GPU memory.""" + if self._model is not None: + del self._model + self._model = None + if self._tokenizer is not None: + del self._tokenizer + self._tokenizer = None + self._loaded = False + torch.cuda.empty_cache() + + +# --- Convenience Functions --- + +def create_armo_reward_model( + device: str = "cuda", + model_id: str = "RLHFlow/ArmoRM-Llama3-8B-v0.1", +) -> ArmoRMRewardModel: + """Create and load ArmoRM reward model.""" + config = ArmoRewardConfig( + model_id=model_id, + device=device, + ) + model = ArmoRMRewardModel(config) + model.load() + return model + + +# --- Integration with existing eval_step interface --- + +async def eval_step_armo( + q_t: str, + answer_t: str, + q_t1: str, + armo_model: ArmoRMRewardModel, + memories_t: Optional[List[str]] = None, +) -> Tuple[float, float]: + """ + Drop-in replacement for eval_step_llm using ArmoRM. + + Args: + q_t: User query at turn t + answer_t: Agent response at turn t + q_t1: User follow-up at turn t+1 + armo_model: Loaded ArmoRMRewardModel instance + memories_t: Retrieved memories (not used by ArmoRM, kept for API compat) + + Returns: + (reward, gating) tuple compatible with existing interface + """ + result = armo_model.estimate_preference_compliance( + query=q_t, + response=answer_t, + user_followup=q_t1, + ) + + # Gating: 1.0 if should_update, 0.0 otherwise + gating = 1.0 if result.should_update else 0.0 + + return result.reward, gating + + +# --- Test Script --- + +if __name__ == "__main__": + print("=" * 60) + print("ArmoRM Reward Model Test") + print("=" * 60) + + # Create model + model = create_armo_reward_model(device="cuda") + + # Test 1: Basic response scoring + print("\n--- Test 1: Basic Response Scoring ---") + messages = [ + {"role": "user", "content": "What is the capital of France?"}, + {"role": "assistant", "content": "The capital of France is Paris."}, + ] + result = model.score_response(messages) + print(f"Query: What is the capital of France?") + print(f"Response: The capital of France is Paris.") + print(f"Score: {result.score:.3f}, Reward: {result.reward:.2f}, Update: {result.should_update}") + + # Test 2: Good response with satisfied user + print("\n--- Test 2: Good Response (User Satisfied) ---") + result = model.estimate_preference_compliance( + query="Can you explain how photosynthesis works?", + response="Photosynthesis is the process by which plants convert sunlight, water, and carbon dioxide into glucose and oxygen. It occurs in the chloroplasts, primarily in the leaves. The light-dependent reactions capture solar energy, while the Calvin cycle uses that energy to fix carbon dioxide into sugars.", + user_followup="Great explanation! Can you tell me more about the Calvin cycle?", + ) + print(f"Score: {result.score:.3f}, Reward: {result.reward:.2f}, Update: {result.should_update}") + + # Test 3: Bad response with dissatisfied user + print("\n--- Test 3: Bad Response (User Dissatisfied) ---") + result = model.estimate_preference_compliance( + query="Can you explain how photosynthesis works?", + response="Plants make food.", + user_followup="That's not helpful at all. I asked for an explanation of how photosynthesis works, not a one-liner.", + ) + print(f"Score: {result.score:.3f}, Reward: {result.reward:.2f}, Update: {result.should_update}") + + # Test 4: Preference enforcement scenario + print("\n--- Test 4: Preference Enforcement Scenario ---") + result = model.estimate_preference_compliance( + query="Solve x^2 - 5x + 6 = 0", + response="x = 2 or x = 3", + user_followup="I asked you to show step-by-step work. Please solve it again showing each step.", + ) + print(f"Score: {result.score:.3f}, Reward: {result.reward:.2f}, Update: {result.should_update}") + + # Test 5: Compare two responses + print("\n--- Test 5: Response Comparison ---") + score_a, score_b, winner = model.compare_responses( + query="What are the benefits of exercise?", + response_a="Exercise is good for you.", + response_b="Exercise offers numerous benefits including improved cardiovascular health, stronger muscles and bones, better mental health through endorphin release, weight management, increased energy levels, and better sleep quality. Regular physical activity also reduces the risk of chronic diseases like diabetes and heart disease.", + ) + print(f"Response A (short): {score_a:.3f}") + print(f"Response B (detailed): {score_b:.3f}") + print(f"Winner: {winner}") + + # Test 6: Batch scoring + print("\n--- Test 6: Batch Scoring ---") + batch = [ + [{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi there! How can I help you today?"}], + [{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "k"}], + ] + results = model.score_batch(batch) + for i, r in enumerate(results): + print(f" Conversation {i+1}: Score={r.score:.3f}, Reward={r.reward:.2f}") + + print("\n" + "=" * 60) + print("Tests complete!") + print("=" * 60) + + # Cleanup + model.cleanup() diff --git a/src/personalization/feedback/local_llm_reward.py b/src/personalization/feedback/local_llm_reward.py new file mode 100644 index 0000000..9837ff0 --- /dev/null +++ b/src/personalization/feedback/local_llm_reward.py @@ -0,0 +1,342 @@ +""" +Local LLM reward model using vLLM server for batch inference. + +Drop-in replacement for LLMRewardClient when you want to use a local model +(e.g., Llama-3.1-8B-Instruct) instead of OpenAI API. + +Uses BatchVLLMClient for efficient concurrent requests - vLLM's continuous +batching will process them together for high throughput. +""" +from __future__ import annotations + +import asyncio +import hashlib +import json +from dataclasses import dataclass +from typing import Dict, List, Optional + +import aiohttp + +from personalization.feedback.schemas import TurnSample +from personalization.feedback.llm_reward import ( + REWARD_MAP, + VALID_LABELS, + JUDGE_SYSTEM_PROMPT, + JUDGE_USER_TEMPLATE, + RewardResult, +) + + +@dataclass +class LocalLLMRewardConfig: + """Configuration for local LLM reward model.""" + vllm_url: str = "http://localhost:8005/v1" # vLLM server URL + model_name: Optional[str] = None # Auto-discovered if None + max_tokens: int = 256 + temperature: float = 0.1 + max_concurrent: int = 100 # High concurrency for vLLM batching + timeout: Optional[int] = 60 # Per-request timeout in seconds + confidence_threshold: float = 0.6 # tau_c: skip update if confidence < this + enable_cache: bool = True # Cache by hash of (q_t, a_t, q_{t+1}) + + +class LocalLLMRewardClient: + """ + Local LLM reward client using vLLM server. + + Designed for batch processing - uses async HTTP requests that vLLM + batches together via continuous batching for high throughput. + """ + + def __init__(self, config: Optional[LocalLLMRewardConfig] = None): + self.config = config or LocalLLMRewardConfig() + self._model_name = self.config.model_name + self._cache: Dict[str, RewardResult] = {} + + # Discover model name if not provided + if self._model_name is None: + self._discover_model_sync() + + def _discover_model_sync(self): + """Synchronously discover model name from vLLM server.""" + import requests + try: + response = requests.get( + f"{self.config.vllm_url}/models", + timeout=10 + ) + response.raise_for_status() + models = response.json() + if models.get("data") and len(models["data"]) > 0: + self._model_name = models["data"][0]["id"] + else: + self._model_name = "default" + except Exception as e: + print(f"[LocalLLMReward] Warning: Could not discover model ({e})") + self._model_name = "default" + + def _cache_key(self, query_t: str, answer_t: str, query_t1: str) -> str: + """Deterministic hash of the judge input triple.""" + content = f"{query_t}\x00{answer_t}\x00{query_t1}" + return hashlib.sha256(content.encode("utf-8")).hexdigest() + + def _build_messages(self, sample: TurnSample) -> List[dict]: + """Construct the judge prompt from (q_t, a_t, q_{t+1}).""" + user_content = JUDGE_USER_TEMPLATE.format( + query_t=sample.query_t, + answer_t=sample.answer_t, + query_t1=sample.query_t1, + ) + return [ + {"role": "system", "content": JUDGE_SYSTEM_PROMPT}, + {"role": "user", "content": user_content}, + ] + + def _parse_result(self, raw: Optional[str]) -> RewardResult: + """Parse structured JSON output into RewardResult.""" + if raw is None: + return RewardResult( + label="neutral", + confidence=0.0, + rationale="request_failed", + reward=0.0, + should_update=False, + ) + + try: + # Handle markdown code blocks + text = raw.strip() + if text.startswith("```"): + lines = text.split("\n") + text = "\n".join( + lines[1:-1] if lines[-1].strip() == "```" else lines[1:] + ) + + parsed = json.loads(text) + label = parsed.get("label", "neutral") + confidence = float(parsed.get("confidence", 0.0)) + rationale = parsed.get("rationale", "") + + if label not in VALID_LABELS: + label = "neutral" + confidence = 0.0 + + reward = REWARD_MAP[label] + + # Confidence gating and topic_shift skip + should_update = ( + confidence >= self.config.confidence_threshold + and label != "topic_shift" + ) + if not should_update: + reward = 0.0 + + return RewardResult( + label=label, + confidence=confidence, + rationale=rationale, + reward=reward, + should_update=should_update, + ) + except (json.JSONDecodeError, KeyError, TypeError, ValueError): + # Try to extract JSON from text + import re + match = re.search(r'\{[^}]+\}', raw, re.DOTALL) + if match: + try: + parsed = json.loads(match.group()) + label = parsed.get("label", "neutral") + confidence = float(parsed.get("confidence", 0.0)) + rationale = parsed.get("rationale", "") + + if label not in VALID_LABELS: + label = "neutral" + confidence = 0.0 + + reward = REWARD_MAP[label] + should_update = ( + confidence >= self.config.confidence_threshold + and label != "topic_shift" + ) + if not should_update: + reward = 0.0 + + return RewardResult( + label=label, + confidence=confidence, + rationale=rationale, + reward=reward, + should_update=should_update, + ) + except: + pass + + return RewardResult( + label="neutral", + confidence=0.0, + rationale="parse_failure", + reward=0.0, + should_update=False, + ) + + async def _single_request( + self, + session: aiohttp.ClientSession, + messages: List[dict], + idx: int, + ) -> tuple: + """Make a single async request to vLLM server.""" + payload = { + "model": self._model_name, + "messages": messages, + "max_tokens": self.config.max_tokens, + "temperature": self.config.temperature, + "response_format": {"type": "json_object"}, + } + + for attempt in range(3): + try: + timeout_config = ( + aiohttp.ClientTimeout(total=self.config.timeout) + if self.config.timeout else None + ) + async with session.post( + f"{self.config.vllm_url}/chat/completions", + json=payload, + timeout=timeout_config, + ) as response: + if response.status == 200: + result = await response.json() + content = result["choices"][0]["message"]["content"] + return (idx, content, None) + elif response.status == 429: + # Rate limit - wait and retry + await asyncio.sleep(2 ** attempt) + continue + else: + error_text = await response.text() + return (idx, None, f"HTTP {response.status}: {error_text[:200]}") + except asyncio.TimeoutError: + if attempt < 2: + continue + return (idx, None, "Timeout") + except Exception as e: + return (idx, None, str(e)) + + return (idx, None, "Max retries exceeded") + + async def judge_batch_async( + self, + samples: List[TurnSample], + show_progress: bool = False, + ) -> List[RewardResult]: + """ + Judge a batch of turns using concurrent vLLM requests. + + vLLM's continuous batching will process these together for + high throughput. + """ + n_samples = len(samples) + results = [None] * n_samples + + # Check cache and build request list + to_request = [] # (original_idx, messages) + for i, sample in enumerate(samples): + if self.config.enable_cache: + key = self._cache_key(sample.query_t, sample.answer_t, sample.query_t1) + if key in self._cache: + results[i] = self._cache[key] + continue + + messages = self._build_messages(sample) + to_request.append((i, messages)) + + if not to_request: + return results + + # Make concurrent requests + semaphore = asyncio.Semaphore(self.config.max_concurrent) + + async def limited_request(session, messages, idx): + async with semaphore: + return await self._single_request(session, messages, idx) + + connector = aiohttp.TCPConnector(limit=self.config.max_concurrent) + headers = {"Content-Type": "application/json"} + + async with aiohttp.ClientSession( + connector=connector, headers=headers + ) as session: + tasks = [ + limited_request(session, messages, orig_idx) + for orig_idx, messages in to_request + ] + + completed = 0 + for coro in asyncio.as_completed(tasks): + orig_idx, content, error = await coro + completed += 1 + + if error: + print(f"[LocalLLMReward] Request {orig_idx} failed: {error}") + + result = self._parse_result(content) + results[orig_idx] = result + + # Cache the result + if self.config.enable_cache: + sample = samples[orig_idx] + key = self._cache_key( + sample.query_t, sample.answer_t, sample.query_t1 + ) + self._cache[key] = result + + if show_progress and completed % 10 == 0: + print(f" [LocalLLMReward {completed}/{len(to_request)}] completed") + + return results + + async def judge_async(self, sample: TurnSample) -> RewardResult: + """Judge a single turn (async).""" + results = await self.judge_batch_async([sample]) + return results[0] + + def judge_batch(self, samples: List[TurnSample]) -> List[RewardResult]: + """ + Judge a batch of turns (sync wrapper). + + This is the main entry point for batch reward estimation. + """ + return asyncio.run(self.judge_batch_async(samples)) + + def judge(self, sample: TurnSample) -> RewardResult: + """Judge a single turn (sync wrapper).""" + return asyncio.run(self.judge_async(sample)) + + +# --- Convenience Functions --- + +def estimate_reward_local( + sample: TurnSample, + config: Optional[LocalLLMRewardConfig] = None, +) -> tuple: + """ + Synchronous single-sample reward estimation using local LLM. + Returns (reward, should_update). + """ + client = LocalLLMRewardClient(config) + result = client.judge(sample) + return result.reward, result.should_update + + +def estimate_rewards_batch_local( + samples: List[TurnSample], + config: Optional[LocalLLMRewardConfig] = None, +) -> List[tuple]: + """ + Synchronous batch reward estimation using local LLM. + Returns list of (reward, should_update) tuples. + """ + client = LocalLLMRewardClient(config) + results = client.judge_batch(samples) + return [(r.reward, r.should_update) for r in results] diff --git a/src/personalization/serving/personalized_llm.py b/src/personalization/serving/personalized_llm.py index 733ff87..45d002b 100644 --- a/src/personalization/serving/personalized_llm.py +++ b/src/personalization/serving/personalized_llm.py @@ -282,8 +282,9 @@ class PersonalizedLLM: use_shared_models: bool = False, # Use shared singleton models for multi-threaded efficiency reranker_type: str = "qwen3", # "qwen3" (8B) or "bge" (278M) best_of_n: int = 1, # Generate N responses and pick best (for RAG methods) - reward_mode: str = "keyword", # "keyword" (legacy heuristic) or "llm" (GPT-5-nano judge) + reward_mode: str = "keyword", # "keyword", "llm" (GPT-4o-mini), or "llm_local" (local vLLM) llm_reward_config: Optional["LLMRewardConfig"] = None, # Config for LLM judge + reward_vllm_url: Optional[str] = None, # vLLM URL for local reward model (when reward_mode="llm_local") ): """ Initialize the PersonalizedLLM. @@ -317,12 +318,21 @@ class PersonalizedLLM: self.eval_mode = eval_mode # True = greedy, False = sample self.reranker_type = reranker_type # "qwen3" or "bge" self.best_of_n = best_of_n # Generate N responses and pick best - self.reward_mode = reward_mode # "keyword" or "llm" + self.reward_mode = reward_mode # "keyword", "llm", or "llm_local" # Initialize LLM reward client if using LLM judge - self._llm_reward_client: Optional[LLMRewardClient] = None + self._llm_reward_client = None # Can be LLMRewardClient or LocalLLMRewardClient if reward_mode == "llm": self._llm_reward_client = LLMRewardClient(llm_reward_config or LLMRewardConfig()) + elif reward_mode == "llm_local": + from personalization.feedback.local_llm_reward import ( + LocalLLMRewardClient, + LocalLLMRewardConfig, + ) + local_config = LocalLLMRewardConfig( + vllm_url=reward_vllm_url or "http://localhost:8005/v1", + ) + self._llm_reward_client = LocalLLMRewardClient(local_config) # Multi-GPU device assignment self._device_assignment = device_assignment or { @@ -743,7 +753,7 @@ class PersonalizedLLM: } # Auto-compute reward via LLM judge if enabled - if self.reward_mode == "llm" and self._llm_reward_client is not None: + if self._llm_reward_client is not None: import asyncio try: reward, gating = asyncio.run(eval_step_llm( @@ -974,7 +984,7 @@ class PersonalizedLLM: } # Auto-compute reward via LLM judge if enabled - if self.reward_mode == "llm" and self._llm_reward_client is not None: + if self._llm_reward_client is not None: import asyncio try: reward, gating = asyncio.run(eval_step_llm( |
