1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
|
"""
ArmoRM-Llama3-8B-v0.1 local reward model.
Replaces OpenAI-based LLM judge with local ArmoRM for faster inference.
ArmoRM outputs a preference score (0-1) indicating response quality.
Score interpretation:
- > 0.7: Good response (positive reward)
- 0.4-0.7: Neutral response
- < 0.4: Poor response (negative reward)
For preference compliance checking, we compare scores between:
1. Agent response following preferences
2. What the user's follow-up suggests about satisfaction
"""
from __future__ import annotations
import hashlib
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
@dataclass
class ArmoRewardConfig:
model_id: str = "RLHFlow/ArmoRM-Llama3-8B-v0.1"
device: str = "cuda"
torch_dtype: str = "bfloat16"
max_length: int = 4096
truncation: bool = True
# Score thresholds for reward mapping
positive_threshold: float = 0.7 # Score >= this → positive reward
negative_threshold: float = 0.4 # Score <= this → negative reward
# Reward values
positive_reward: float = 0.8
neutral_reward: float = 0.0
negative_reward: float = -0.8
# Gating
confidence_threshold: float = 0.3 # Skip update if score variance is too high
enable_cache: bool = True
@dataclass
class ArmoRewardResult:
score: float # Raw ArmoRM score (0-1)
reward: float # Mapped reward value
should_update: bool
rationale: str = ""
class ArmoRMRewardModel:
"""
Local reward model using ArmoRM-Llama3-8B-v0.1.
ArmoRM is trained on preference data and outputs a score indicating
how good a response is. We use this to estimate implicit user feedback.
"""
def __init__(self, config: Optional[ArmoRewardConfig] = None):
self.config = config or ArmoRewardConfig()
self._model = None
self._tokenizer = None
self._cache: Dict[str, ArmoRewardResult] = {}
self._loaded = False
def load(self):
"""Load model and tokenizer (lazy loading)."""
if self._loaded:
return
dtype_map = {
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32,
}
torch_dtype = dtype_map.get(self.config.torch_dtype, torch.bfloat16)
print(f"[ArmoRM] Loading model {self.config.model_id} on {self.config.device}...")
self._model = AutoModelForSequenceClassification.from_pretrained(
self.config.model_id,
device_map=self.config.device,
trust_remote_code=True,
torch_dtype=torch_dtype,
)
self._tokenizer = AutoTokenizer.from_pretrained(
self.config.model_id,
use_fast=True,
)
self._loaded = True
print(f"[ArmoRM] Model loaded successfully.")
def _cache_key(self, messages: List[Dict[str, str]]) -> str:
"""Deterministic hash of messages."""
content = str(messages)
return hashlib.sha256(content.encode("utf-8")).hexdigest()
def _score_to_reward(self, score: float) -> Tuple[float, bool]:
"""Convert ArmoRM score to reward value with gating."""
if score >= self.config.positive_threshold:
reward = self.config.positive_reward
should_update = True
elif score <= self.config.negative_threshold:
reward = self.config.negative_reward
should_update = True
else:
reward = self.config.neutral_reward
should_update = False # Ambiguous signal, skip update
return reward, should_update
def score_response(
self,
messages: List[Dict[str, str]],
) -> ArmoRewardResult:
"""
Score a conversation using ArmoRM.
Args:
messages: List of {"role": "user"/"assistant", "content": "..."}
Returns:
ArmoRewardResult with score, reward, and should_update
"""
if not self._loaded:
self.load()
# Cache lookup
if self.config.enable_cache:
key = self._cache_key(messages)
if key in self._cache:
return self._cache[key]
# Tokenize and score
input_ids = self._tokenizer.apply_chat_template(
messages,
return_tensors="pt",
padding=True,
truncation=self.config.truncation,
max_length=self.config.max_length,
).to(self._model.device)
with torch.no_grad():
output = self._model(input_ids)
score = output.score.float().item()
# Convert to reward
reward, should_update = self._score_to_reward(score)
result = ArmoRewardResult(
score=score,
reward=reward,
should_update=should_update,
rationale=f"ArmoRM score: {score:.3f}",
)
# Cache store
if self.config.enable_cache:
self._cache[key] = result
return result
def score_batch(
self,
messages_batch: List[List[Dict[str, str]]],
) -> List[ArmoRewardResult]:
"""Score a batch of conversations."""
return [self.score_response(msgs) for msgs in messages_batch]
def estimate_preference_compliance(
self,
query: str,
response: str,
user_followup: str,
preferences: Optional[List[str]] = None,
) -> ArmoRewardResult:
"""
Estimate if the response followed user preferences based on follow-up.
Strategy: Score the conversation quality. A satisfied user (whose
preferences were followed) will have a more positive follow-up,
leading to higher scores.
Args:
query: User's original query (q_t)
response: Agent's response (a_t)
user_followup: User's next message (q_{t+1})
preferences: Optional list of user preferences (for context)
Returns:
ArmoRewardResult indicating preference compliance
"""
# Build conversation for scoring
# Include the follow-up to capture user satisfaction signal
messages = [
{"role": "user", "content": query},
{"role": "assistant", "content": response},
{"role": "user", "content": user_followup},
]
return self.score_response(messages)
def compare_responses(
self,
query: str,
response_a: str,
response_b: str,
) -> Tuple[float, float, str]:
"""
Compare two responses and return which is better.
Returns:
(score_a, score_b, winner) where winner is 'a', 'b', or 'tie'
"""
messages_a = [
{"role": "user", "content": query},
{"role": "assistant", "content": response_a},
]
messages_b = [
{"role": "user", "content": query},
{"role": "assistant", "content": response_b},
]
result_a = self.score_response(messages_a)
result_b = self.score_response(messages_b)
if abs(result_a.score - result_b.score) < 0.05:
winner = "tie"
elif result_a.score > result_b.score:
winner = "a"
else:
winner = "b"
return result_a.score, result_b.score, winner
def cleanup(self):
"""Free GPU memory."""
if self._model is not None:
del self._model
self._model = None
if self._tokenizer is not None:
del self._tokenizer
self._tokenizer = None
self._loaded = False
torch.cuda.empty_cache()
# --- Convenience Functions ---
def create_armo_reward_model(
device: str = "cuda",
model_id: str = "RLHFlow/ArmoRM-Llama3-8B-v0.1",
) -> ArmoRMRewardModel:
"""Create and load ArmoRM reward model."""
config = ArmoRewardConfig(
model_id=model_id,
device=device,
)
model = ArmoRMRewardModel(config)
model.load()
return model
# --- Integration with existing eval_step interface ---
async def eval_step_armo(
q_t: str,
answer_t: str,
q_t1: str,
armo_model: ArmoRMRewardModel,
memories_t: Optional[List[str]] = None,
) -> Tuple[float, float]:
"""
Drop-in replacement for eval_step_llm using ArmoRM.
Args:
q_t: User query at turn t
answer_t: Agent response at turn t
q_t1: User follow-up at turn t+1
armo_model: Loaded ArmoRMRewardModel instance
memories_t: Retrieved memories (not used by ArmoRM, kept for API compat)
Returns:
(reward, gating) tuple compatible with existing interface
"""
result = armo_model.estimate_preference_compliance(
query=q_t,
response=answer_t,
user_followup=q_t1,
)
# Gating: 1.0 if should_update, 0.0 otherwise
gating = 1.0 if result.should_update else 0.0
return result.reward, gating
# --- Test Script ---
if __name__ == "__main__":
print("=" * 60)
print("ArmoRM Reward Model Test")
print("=" * 60)
# Create model
model = create_armo_reward_model(device="cuda")
# Test 1: Basic response scoring
print("\n--- Test 1: Basic Response Scoring ---")
messages = [
{"role": "user", "content": "What is the capital of France?"},
{"role": "assistant", "content": "The capital of France is Paris."},
]
result = model.score_response(messages)
print(f"Query: What is the capital of France?")
print(f"Response: The capital of France is Paris.")
print(f"Score: {result.score:.3f}, Reward: {result.reward:.2f}, Update: {result.should_update}")
# Test 2: Good response with satisfied user
print("\n--- Test 2: Good Response (User Satisfied) ---")
result = model.estimate_preference_compliance(
query="Can you explain how photosynthesis works?",
response="Photosynthesis is the process by which plants convert sunlight, water, and carbon dioxide into glucose and oxygen. It occurs in the chloroplasts, primarily in the leaves. The light-dependent reactions capture solar energy, while the Calvin cycle uses that energy to fix carbon dioxide into sugars.",
user_followup="Great explanation! Can you tell me more about the Calvin cycle?",
)
print(f"Score: {result.score:.3f}, Reward: {result.reward:.2f}, Update: {result.should_update}")
# Test 3: Bad response with dissatisfied user
print("\n--- Test 3: Bad Response (User Dissatisfied) ---")
result = model.estimate_preference_compliance(
query="Can you explain how photosynthesis works?",
response="Plants make food.",
user_followup="That's not helpful at all. I asked for an explanation of how photosynthesis works, not a one-liner.",
)
print(f"Score: {result.score:.3f}, Reward: {result.reward:.2f}, Update: {result.should_update}")
# Test 4: Preference enforcement scenario
print("\n--- Test 4: Preference Enforcement Scenario ---")
result = model.estimate_preference_compliance(
query="Solve x^2 - 5x + 6 = 0",
response="x = 2 or x = 3",
user_followup="I asked you to show step-by-step work. Please solve it again showing each step.",
)
print(f"Score: {result.score:.3f}, Reward: {result.reward:.2f}, Update: {result.should_update}")
# Test 5: Compare two responses
print("\n--- Test 5: Response Comparison ---")
score_a, score_b, winner = model.compare_responses(
query="What are the benefits of exercise?",
response_a="Exercise is good for you.",
response_b="Exercise offers numerous benefits including improved cardiovascular health, stronger muscles and bones, better mental health through endorphin release, weight management, increased energy levels, and better sleep quality. Regular physical activity also reduces the risk of chronic diseases like diabetes and heart disease.",
)
print(f"Response A (short): {score_a:.3f}")
print(f"Response B (detailed): {score_b:.3f}")
print(f"Winner: {winner}")
# Test 6: Batch scoring
print("\n--- Test 6: Batch Scoring ---")
batch = [
[{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi there! How can I help you today?"}],
[{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "k"}],
]
results = model.score_batch(batch)
for i, r in enumerate(results):
print(f" Conversation {i+1}: Score={r.score:.3f}, Reward={r.reward:.2f}")
print("\n" + "=" * 60)
print("Tests complete!")
print("=" * 60)
# Cleanup
model.cleanup()
|