summaryrefslogtreecommitdiff
path: root/collaborativeagents/training/grpo/llama_grpo.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-27 09:57:37 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-27 09:57:37 -0600
commitdc801c07cf38b0c495686463e6ca6f871a64440e (patch)
tree599f03114775921dbc472403c701f4a3a8ea188a /collaborativeagents/training/grpo/llama_grpo.py
parente43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (diff)
Add collaborativeagents module and update gitignore
- Add collaborativeagents subproject with adapters, agents, and evaluation modules - Update .gitignore to exclude large binary files (.whl, .tar), wandb logs, and results Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'collaborativeagents/training/grpo/llama_grpo.py')
-rw-r--r--collaborativeagents/training/grpo/llama_grpo.py255
1 files changed, 255 insertions, 0 deletions
diff --git a/collaborativeagents/training/grpo/llama_grpo.py b/collaborativeagents/training/grpo/llama_grpo.py
new file mode 100644
index 0000000..10876ab
--- /dev/null
+++ b/collaborativeagents/training/grpo/llama_grpo.py
@@ -0,0 +1,255 @@
+from json_repair import repair_json
+import openai
+from tenacity import retry
+from together import Together
+from json_repair import repair_json
+from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
+
+import concurrent.futures
+
+import os
+os.environ["VLLM_USE_V1"] = '0'
+os.environ["WANDB_PROJECT"] = "collaborative-agent-reflection-grpo"
+
+# Global tracker for reflection scores
+reflection_scores_tracker = {
+ "scores": [],
+ "batch_count": 0
+}
+
+
+## Step 1: Load up Llama 3.1 8B Instruct and set parameters
+print("Step 1: Load up Llama 3.1 8B Instruct and set parameters")
+max_seq_length = 3072 # Can increase for longer reasoning traces
+max_prompt_length = 2048 #
+max_completion_length = max_seq_length - max_prompt_length
+
+model_id = "/shared/storage-01/users/mehri2/LLaMA-Factory/saves/llama-3.1-8b-instruct/full/sft_session_level_reflection/checkpoint-628"
+tokenizer = AutoTokenizer.from_pretrained(model_id)
+
+# Create generation config first
+generation_config = GenerationConfig.from_pretrained(model_id)
+generation_config.max_length = max_seq_length
+generation_config.do_sample = True
+generation_config.top_p = 0.9
+
+# First create the model with the generation config
+model = AutoModelForCausalLM.from_pretrained(
+ model_id,
+ torch_dtype="auto",
+ device_map="auto",
+ generation_config=generation_config
+)
+
+
+
+# Step 2: Data preparation
+import re
+import json
+
+# Load and prep dataset
+def extract_json_answer(text: str) -> str:
+ try:
+ answer = repair_json(text, return_objects=True)
+ answer = answer["agent_notes"]
+ except Exception as e:
+ print(f"Error extracting JSON answer: {e}")
+ print("Text: ", text)
+ return ""
+ return answer
+
+
+def get_conversation_data():
+ with open("/shared/storage-01/users/mehri2/mem/collaborativeagents/training/grpo/training_data/session_level_reflection_grpo_data.jsonl", "r") as f:
+ session_level_reflection_data = json.load(f)
+
+ grpo_data = []
+ for elem in session_level_reflection_data:
+ prompt = elem['messages'][0]['content']
+ gold_response = extract_json_answer(elem['messages'][1]['content'])
+ responses_that_enforce_preferences = elem['responses_that_enforce_preferences']
+
+ tokens = tokenizer.tokenize(prompt)
+ if len(tokens) > max_prompt_length: continue
+
+ grpo_data.append({
+ 'prompt': prompt,
+ 'gold_response': gold_response,
+ 'responses_that_enforce_preferences': responses_that_enforce_preferences
+ })
+ return grpo_data
+
+dataset = get_conversation_data()
+
+
+avg_len = sum([len(elem["prompt"]) for elem in dataset]) / len(dataset)
+print("Avg len: ", avg_len)
+print("Max len: ", max([len(elem["prompt"]) for elem in dataset]))
+
+print("MESM! dataset size:", len(dataset))
+
+
+# Reward functions
+print("Step 3: Reward Functions")
+client = openai.OpenAI(base_url="http://localhost:8004/v1", api_key="EMPTY")
+@retry
+def ask_judge(prompt, system_prompt=None):
+ if system_prompt:
+ messages = [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": prompt}]
+ else:
+ messages = [
+ {"role": "user", "content": prompt}]
+
+ chat_completion = client.chat.completions.create(
+ model="Llama-3.3-70B-Instruct",
+ messages=messages,
+ max_tokens=2048
+ )
+ return chat_completion.choices[0].message.content.strip()
+
+def reflection_reward_func(prompts, completions, gold_response, responses_that_enforce_preferences, **kwargs) -> list[float]:
+ # Prepare arguments for parallel processing
+ args_list = []
+ for i, completion in enumerate(completions):
+ completion_copy = extract_json_answer(completion)
+ args_list.append((prompts[i], completion_copy, gold_response[i], responses_that_enforce_preferences[i]))
+
+ # Use ProcessPoolExecutor for parallel processing
+ with concurrent.futures.ProcessPoolExecutor(max_workers=20) as executor:
+ rewards = list(executor.map(process_single_completion, args_list))
+
+ # Update tracker with batch scores
+ reflection_scores_tracker["scores"].extend(rewards)
+ reflection_scores_tracker["batch_count"] += 1
+ print("Rollout reflection scores: ", rewards)
+ print("Overall reflection score average (across all rollouts): ", sum(reflection_scores_tracker["scores"]) / len(reflection_scores_tracker["scores"]))
+ return rewards
+
+@retry
+def process_single_completion(args):
+ prompt, completion, gold_response, responses_that_enforce_preferences = args
+
+ # completion = extract_json_answer(completion)
+ if completion == "":
+ print(f"Poorly formatted completion: {completion}")
+ print(f"Reflection Score: 0")
+
+ return 0
+
+ user_messages_where_they_enforce_preferences = ""
+ for i, response in enumerate(responses_that_enforce_preferences):
+ user_messages_where_they_enforce_preferences += f"User message #{i+1}: {response}\n"
+
+ reflection_evaluation_prompt = f"""You are an expert evaluator analyzing a conversational agent's reflection of a conversation, where they analyze the conversation to identify the user's preferences and create actionable notes to help them satisfy these preferences in future conversations.
+
+Throughout the conversation, the user explicitly enforces their preferences whenever necessary. The agent analyzes the conversation to identify the user's preferences and create actionable notes to help them satisfy these preferences in future conversations.
+
+# Your Task:
+Evaluate whether the agent's reflection succesfully captures the user's preferences and provides actionable notes to help them satisfy these preferences in future conversations.
+
+# Agent's Reflection:
+{completion}
+
+# User Messages Where They Enforce Their Preferences:
+{user_messages_where_they_enforce_preferences}
+
+# Gold Reflection:
+Here is a gold reflection for the same conversation. Use this as a reference to evaluate the agent's reflection.
+{gold_response}
+
+# Evaluation Criteria:
+Assess the reflection on four dimensions:
+- **Coverage (Completeness):** Does the agent's reflection capture all of the user's preferences?
+- **Actionability (Quality):** Does the agent's reflection provide actionable notes and details that help the agent satisfy these preferences in future conversations?
+- **Accuracy (No Hallucination):** Are all points grounded in actual user statements? Does the reflection avoid inventing preferences or misrepresenting user statements?
+- **Clarity:** Is the reflection well-organized and clearly formatted? Does the reflection avoid redundancy, with each preference stated once without repetitive or overlapping notes?
+
+You will output a score from 0-3, where:
+- 0: Does not effectively capture user preferences: gaps in converage, or significant hallucinations
+- 1: Captures some preferences with limited actionable notes, may hallucinate some preferences
+- 2: Captures most preferences with actionable notes, may have some slight hallucinations
+- 3: Comprehensively captures all preferences with highly actionable notes and no hallucinations
+
+# Output Format:
+{{
+ "reasoning": # Brief explanation of your decision
+ "reflection_score": # 0-3
+}}
+
+Output a properly formatted JSON response, as specified by the Output Format."""
+# TODO: maybe add something about hallucinating preferences!!
+
+
+ reflection_score = ask_judge(reflection_evaluation_prompt)
+ reflection_score = repair_json(reflection_score, return_objects=True)["reflection_score"]
+
+ print(f"Reflection Evaluation Prompt: {reflection_evaluation_prompt}")
+ print(f"Reflection Score: {reflection_score}")
+
+ return reflection_score
+
+def soft_format_reward_func(prompts, completions, **kwargs) -> list[float]:
+ """Reward function that checks if the completion has JSON format with agent_notes and user_preferences_reasoning fields."""
+ responses = [completion for completion in completions]
+ rewards = []
+
+ for response in responses:
+ reward = 0.0
+ try:
+ parsed_json = repair_json(response, return_objects=True)
+
+ if "agent_notes" in parsed_json and "user_preferences_reasoning" in parsed_json:
+ reward = 0.5
+ except Exception as e:
+ pass
+
+ rewards.append(reward)
+
+ for i, response in enumerate(responses):
+ print("Soft Format Reward: ", rewards[i])
+
+ return rewards
+
+
+
+# Train the model
+from trl import GRPOConfig, GRPOTrainer
+training_args = GRPOConfig(
+ learning_rate = 5e-6,
+ # adam_beta1 = 0.9,
+ # adam_beta2 = 0.99,
+ # weight_decay = 0.1,
+ # warmup_ratio = 0.1,
+ gradient_accumulation_steps = 4,
+ num_train_epochs = 1,
+ bf16 = True,
+ max_prompt_length = max_prompt_length,
+ max_completion_length = max_completion_length,
+ report_to = "wandb",
+ logging_steps = 1,
+ num_generations = 4,
+ max_steps = 2000,
+ save_steps = 50,
+ output_dir = "./outputs_lamma3_reflection_grpo_v2",
+)
+
+
+
+
+trainer = GRPOTrainer(
+ model = model,
+ processing_class = tokenizer,
+ reward_funcs = [
+ soft_format_reward_func,
+ reflection_reward_func
+ ],
+ args = training_args,
+ train_dataset = dataset,
+)
+trainer.train()
+
+
+
+# python3 llama_grpo.py >> llama_grpo_with_reflection_v1.out 2>&1 \ No newline at end of file