1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
|
"""
Custom reward functions for VERL GRPO training
Compatible with VERL's reward function signature
"""
import openai
from json_repair import repair_json
import concurrent.futures
import time
# Initialize judge client
client = openai.OpenAI(base_url="http://localhost:8004/v1", api_key="EMPTY")
# Global tracker for reflection scores
reflection_scores_tracker = {
"scores": [],
"batch_count": 0
}
def extract_json_answer(text: str) -> str:
"""Extract agent_notes from JSON response"""
try:
answer = repair_json(text, return_objects=True)
answer = answer["agent_notes"]
except Exception as e:
print(f"Error extracting JSON answer: {e}")
return ""
return answer
def ask_judge(prompt, system_prompt=None, max_retries=3):
"""Ask the judge model for evaluation"""
if system_prompt:
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}
]
else:
messages = [{"role": "user", "content": prompt}]
for attempt in range(max_retries):
try:
chat_completion = client.chat.completions.create(
model="meta-llama/Llama-3.1-70B-Instruct",
messages=messages,
max_tokens=2048
)
return chat_completion.choices[0].message.content.strip()
except Exception as e:
if attempt < max_retries - 1:
time.sleep(1)
else:
raise e
def evaluate_reflection(completion, gold_response, responses_that_enforce_preferences):
"""Evaluate a single reflection completion"""
completion_text = extract_json_answer(completion)
if completion_text == "":
print(f"Poorly formatted completion: {completion}")
return 0
user_messages_where_they_enforce_preferences = ""
for i, response in enumerate(responses_that_enforce_preferences):
user_messages_where_they_enforce_preferences += f"User message #{i+1}: {response}\n"
reflection_evaluation_prompt = f"""You are an expert evaluator analyzing a conversational agent's reflection of a conversation, where they analyze the conversation to identify the user's preferences and create actionable notes to help them satisfy these preferences in future conversations.
Throughout the conversation, the user explicitly enforces their preferences whenever necessary. The agent analyzes the conversation to identify the user's preferences and create actionable notes to help them satisfy these preferences in future conversations.
# Your Task:
Evaluate whether the agent's reflection succesfully captures the user's preferences and provides actionable notes to help them satisfy these preferences in future conversations.
# Agent's Reflection:
{completion_text}
# User Messages Where They Enforce Their Preferences:
{user_messages_where_they_enforce_preferences}
# Gold Reflection:
Here is a gold reflection for the same conversation. Use this as a reference to evaluate the agent's reflection.
{gold_response}
# Evaluation Criteria:
Assess the reflection on four dimensions:
- **Coverage (Completeness):** Does the agent's reflection capture all of the user's preferences?
- **Actionability (Quality):** Does the agent's reflection provide actionable notes and details that help the agent satisfy these preferences in future conversations?
- **Accuracy (No Hallucination):** Are all points grounded in actual user statements? Does the reflection avoid inventing preferences or misrepresenting user statements?
- **Clarity:** Is the reflection well-organized and clearly formatted? Does the reflection avoid redundancy, with each preference stated once without repetitive or overlapping notes?
You will output a score from 0-3, where:
- 0: Does not effectively capture user preferences: gaps in converage, or significant hallucinations
- 1: Captures some preferences with limited actionable notes, may hallucinate some preferences
- 2: Captures most preferences with actionable notes, may have some slight hallucinations
- 3: Comprehensively captures all preferences with highly actionable notes and no hallucinations
# Output Format:
{{
"reasoning": # Brief explanation of your decision
"reflection_score": # 0-3
}}
Output a properly formatted JSON response, as specified by the Output Format."""
reflection_score_response = ask_judge(reflection_evaluation_prompt)
reflection_score = repair_json(reflection_score_response, return_objects=True)["reflection_score"]
print(f"Reflection Score: {reflection_score}")
return reflection_score
def soft_format_reward(solution_str):
"""Check if the completion has JSON format with required fields"""
reward = 0.0
try:
parsed_json = repair_json(solution_str, return_objects=True)
if "agent_notes" in parsed_json and "user_preferences_reasoning" in parsed_json:
reward = 0.5
except Exception:
pass
print(f"Soft Format Reward: {reward}")
return reward
# VERL reward function signature: (data_source, solution_str, ground_truth, extra_info)
def compute_score(data_source, solution_str, ground_truth, extra_info=None):
"""
Main reward function for VERL (named 'compute_score' for default VERL compatibility).
This matches the signature expected by VERL's reward managers.
Args:
data_source: Source identifier for the data
solution_str: The model's generated completion
ground_truth: Not used directly, passed in extra_info
extra_info: Dictionary containing 'gold_response' and 'responses_that_enforce_preferences'
Returns:
float: Combined reward score
"""
if extra_info is None:
print("Warning: extra_info is None")
return 0.0
gold_response = extra_info.get('gold_response', '')
responses_that_enforce_preferences = extra_info.get('responses_that_enforce_preferences', [])
# Soft format reward
format_reward = soft_format_reward(solution_str)
# Reflection quality reward
reflection_score = evaluate_reflection(solution_str, gold_response, responses_that_enforce_preferences)
total_reward = format_reward + reflection_score
reflection_scores_tracker["scores"].append(reflection_score)
print(f"Total Reward: {total_reward} (Format: {format_reward}, Reflection: {reflection_score})")
return total_reward
|