1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
|
from json_repair import repair_json
import openai
from tenacity import retry
from together import Together
from json_repair import repair_json
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import concurrent.futures
import os
os.environ["VLLM_USE_V1"] = '0'
os.environ["WANDB_PROJECT"] = "collaborative-agent-reflection-grpo"
# Global tracker for reflection scores
reflection_scores_tracker = {
"scores": [],
"batch_count": 0
}
## Step 1: Load up Llama 3.1 8B Instruct and set parameters
print("Step 1: Load up Llama 3.1 8B Instruct and set parameters")
max_seq_length = 3072 # Can increase for longer reasoning traces
max_prompt_length = 2048 #
max_completion_length = max_seq_length - max_prompt_length
model_id = "/shared/storage-01/users/mehri2/LLaMA-Factory/saves/llama-3.1-8b-instruct/full/sft_session_level_reflection/checkpoint-628"
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Create generation config first
generation_config = GenerationConfig.from_pretrained(model_id)
generation_config.max_length = max_seq_length
generation_config.do_sample = True
generation_config.top_p = 0.9
# First create the model with the generation config
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype="auto",
device_map="auto",
generation_config=generation_config
)
# Step 2: Data preparation
import re
import json
# Load and prep dataset
def extract_json_answer(text: str) -> str:
try:
answer = repair_json(text, return_objects=True)
answer = answer["agent_notes"]
except Exception as e:
print(f"Error extracting JSON answer: {e}")
print("Text: ", text)
return ""
return answer
def get_conversation_data():
with open("/shared/storage-01/users/mehri2/mem/collaborativeagents/training/grpo/training_data/session_level_reflection_grpo_data.jsonl", "r") as f:
session_level_reflection_data = json.load(f)
grpo_data = []
for elem in session_level_reflection_data:
prompt = elem['messages'][0]['content']
gold_response = extract_json_answer(elem['messages'][1]['content'])
responses_that_enforce_preferences = elem['responses_that_enforce_preferences']
tokens = tokenizer.tokenize(prompt)
if len(tokens) > max_prompt_length: continue
grpo_data.append({
'prompt': prompt,
'gold_response': gold_response,
'responses_that_enforce_preferences': responses_that_enforce_preferences
})
return grpo_data
dataset = get_conversation_data()
avg_len = sum([len(elem["prompt"]) for elem in dataset]) / len(dataset)
print("Avg len: ", avg_len)
print("Max len: ", max([len(elem["prompt"]) for elem in dataset]))
print("MESM! dataset size:", len(dataset))
# Reward functions
print("Step 3: Reward Functions")
client = openai.OpenAI(base_url="http://localhost:8004/v1", api_key="EMPTY")
@retry
def ask_judge(prompt, system_prompt=None):
if system_prompt:
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}]
else:
messages = [
{"role": "user", "content": prompt}]
chat_completion = client.chat.completions.create(
model="Llama-3.3-70B-Instruct",
messages=messages,
max_tokens=2048
)
return chat_completion.choices[0].message.content.strip()
def reflection_reward_func(prompts, completions, gold_response, responses_that_enforce_preferences, **kwargs) -> list[float]:
# Prepare arguments for parallel processing
args_list = []
for i, completion in enumerate(completions):
completion_copy = extract_json_answer(completion)
args_list.append((prompts[i], completion_copy, gold_response[i], responses_that_enforce_preferences[i]))
# Use ProcessPoolExecutor for parallel processing
with concurrent.futures.ProcessPoolExecutor(max_workers=20) as executor:
rewards = list(executor.map(process_single_completion, args_list))
# Update tracker with batch scores
reflection_scores_tracker["scores"].extend(rewards)
reflection_scores_tracker["batch_count"] += 1
print("Rollout reflection scores: ", rewards)
print("Overall reflection score average (across all rollouts): ", sum(reflection_scores_tracker["scores"]) / len(reflection_scores_tracker["scores"]))
return rewards
@retry
def process_single_completion(args):
prompt, completion, gold_response, responses_that_enforce_preferences = args
# completion = extract_json_answer(completion)
if completion == "":
print(f"Poorly formatted completion: {completion}")
print(f"Reflection Score: 0")
return 0
user_messages_where_they_enforce_preferences = ""
for i, response in enumerate(responses_that_enforce_preferences):
user_messages_where_they_enforce_preferences += f"User message #{i+1}: {response}\n"
reflection_evaluation_prompt = f"""You are an expert evaluator analyzing a conversational agent's reflection of a conversation, where they analyze the conversation to identify the user's preferences and create actionable notes to help them satisfy these preferences in future conversations.
Throughout the conversation, the user explicitly enforces their preferences whenever necessary. The agent analyzes the conversation to identify the user's preferences and create actionable notes to help them satisfy these preferences in future conversations.
# Your Task:
Evaluate whether the agent's reflection succesfully captures the user's preferences and provides actionable notes to help them satisfy these preferences in future conversations.
# Agent's Reflection:
{completion}
# User Messages Where They Enforce Their Preferences:
{user_messages_where_they_enforce_preferences}
# Gold Reflection:
Here is a gold reflection for the same conversation. Use this as a reference to evaluate the agent's reflection.
{gold_response}
# Evaluation Criteria:
Assess the reflection on four dimensions:
- **Coverage (Completeness):** Does the agent's reflection capture all of the user's preferences?
- **Actionability (Quality):** Does the agent's reflection provide actionable notes and details that help the agent satisfy these preferences in future conversations?
- **Accuracy (No Hallucination):** Are all points grounded in actual user statements? Does the reflection avoid inventing preferences or misrepresenting user statements?
- **Clarity:** Is the reflection well-organized and clearly formatted? Does the reflection avoid redundancy, with each preference stated once without repetitive or overlapping notes?
You will output a score from 0-3, where:
- 0: Does not effectively capture user preferences: gaps in converage, or significant hallucinations
- 1: Captures some preferences with limited actionable notes, may hallucinate some preferences
- 2: Captures most preferences with actionable notes, may have some slight hallucinations
- 3: Comprehensively captures all preferences with highly actionable notes and no hallucinations
# Output Format:
{{
"reasoning": # Brief explanation of your decision
"reflection_score": # 0-3
}}
Output a properly formatted JSON response, as specified by the Output Format."""
# TODO: maybe add something about hallucinating preferences!!
reflection_score = ask_judge(reflection_evaluation_prompt)
reflection_score = repair_json(reflection_score, return_objects=True)["reflection_score"]
print(f"Reflection Evaluation Prompt: {reflection_evaluation_prompt}")
print(f"Reflection Score: {reflection_score}")
return reflection_score
def soft_format_reward_func(prompts, completions, **kwargs) -> list[float]:
"""Reward function that checks if the completion has JSON format with agent_notes and user_preferences_reasoning fields."""
responses = [completion for completion in completions]
rewards = []
for response in responses:
reward = 0.0
try:
parsed_json = repair_json(response, return_objects=True)
if "agent_notes" in parsed_json and "user_preferences_reasoning" in parsed_json:
reward = 0.5
except Exception as e:
pass
rewards.append(reward)
for i, response in enumerate(responses):
print("Soft Format Reward: ", rewards[i])
return rewards
# Train the model
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
learning_rate = 5e-6,
# adam_beta1 = 0.9,
# adam_beta2 = 0.99,
# weight_decay = 0.1,
# warmup_ratio = 0.1,
gradient_accumulation_steps = 4,
num_train_epochs = 1,
bf16 = True,
max_prompt_length = max_prompt_length,
max_completion_length = max_completion_length,
report_to = "wandb",
logging_steps = 1,
num_generations = 4,
max_steps = 2000,
save_steps = 50,
output_dir = "./outputs_lamma3_reflection_grpo_v2",
)
trainer = GRPOTrainer(
model = model,
processing_class = tokenizer,
reward_funcs = [
soft_format_reward_func,
reflection_reward_func
],
args = training_args,
train_dataset = dataset,
)
trainer.train()
# python3 llama_grpo.py >> llama_grpo_with_reflection_v1.out 2>&1
|