from json_repair import repair_json import openai from tenacity import retry from together import Together from json_repair import repair_json from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig import concurrent.futures import os os.environ["VLLM_USE_V1"] = '0' os.environ["WANDB_PROJECT"] = "collaborative-agent-reflection-grpo" # Global tracker for reflection scores reflection_scores_tracker = { "scores": [], "batch_count": 0 } ## Step 1: Load up Llama 3.1 8B Instruct and set parameters print("Step 1: Load up Llama 3.1 8B Instruct and set parameters") max_seq_length = 3072 # Can increase for longer reasoning traces max_prompt_length = 2048 # max_completion_length = max_seq_length - max_prompt_length model_id = "/shared/storage-01/users/mehri2/LLaMA-Factory/saves/llama-3.1-8b-instruct/full/sft_session_level_reflection/checkpoint-628" tokenizer = AutoTokenizer.from_pretrained(model_id) # Create generation config first generation_config = GenerationConfig.from_pretrained(model_id) generation_config.max_length = max_seq_length generation_config.do_sample = True generation_config.top_p = 0.9 # First create the model with the generation config model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype="auto", device_map="auto", generation_config=generation_config ) # Step 2: Data preparation import re import json # Load and prep dataset def extract_json_answer(text: str) -> str: try: answer = repair_json(text, return_objects=True) answer = answer["agent_notes"] except Exception as e: print(f"Error extracting JSON answer: {e}") print("Text: ", text) return "" return answer def get_conversation_data(): with open("/shared/storage-01/users/mehri2/mem/collaborativeagents/training/grpo/training_data/session_level_reflection_grpo_data.jsonl", "r") as f: session_level_reflection_data = json.load(f) grpo_data = [] for elem in session_level_reflection_data: prompt = elem['messages'][0]['content'] gold_response = extract_json_answer(elem['messages'][1]['content']) responses_that_enforce_preferences = elem['responses_that_enforce_preferences'] tokens = tokenizer.tokenize(prompt) if len(tokens) > max_prompt_length: continue grpo_data.append({ 'prompt': prompt, 'gold_response': gold_response, 'responses_that_enforce_preferences': responses_that_enforce_preferences }) return grpo_data dataset = get_conversation_data() avg_len = sum([len(elem["prompt"]) for elem in dataset]) / len(dataset) print("Avg len: ", avg_len) print("Max len: ", max([len(elem["prompt"]) for elem in dataset])) print("MESM! dataset size:", len(dataset)) # Reward functions print("Step 3: Reward Functions") client = openai.OpenAI(base_url="http://localhost:8004/v1", api_key="EMPTY") @retry def ask_judge(prompt, system_prompt=None): if system_prompt: messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}] else: messages = [ {"role": "user", "content": prompt}] chat_completion = client.chat.completions.create( model="Llama-3.3-70B-Instruct", messages=messages, max_tokens=2048 ) return chat_completion.choices[0].message.content.strip() def reflection_reward_func(prompts, completions, gold_response, responses_that_enforce_preferences, **kwargs) -> list[float]: # Prepare arguments for parallel processing args_list = [] for i, completion in enumerate(completions): completion_copy = extract_json_answer(completion) args_list.append((prompts[i], completion_copy, gold_response[i], responses_that_enforce_preferences[i])) # Use ProcessPoolExecutor for parallel processing with concurrent.futures.ProcessPoolExecutor(max_workers=20) as executor: rewards = list(executor.map(process_single_completion, args_list)) # Update tracker with batch scores reflection_scores_tracker["scores"].extend(rewards) reflection_scores_tracker["batch_count"] += 1 print("Rollout reflection scores: ", rewards) print("Overall reflection score average (across all rollouts): ", sum(reflection_scores_tracker["scores"]) / len(reflection_scores_tracker["scores"])) return rewards @retry def process_single_completion(args): prompt, completion, gold_response, responses_that_enforce_preferences = args # completion = extract_json_answer(completion) if completion == "": print(f"Poorly formatted completion: {completion}") print(f"Reflection Score: 0") return 0 user_messages_where_they_enforce_preferences = "" for i, response in enumerate(responses_that_enforce_preferences): user_messages_where_they_enforce_preferences += f"User message #{i+1}: {response}\n" reflection_evaluation_prompt = f"""You are an expert evaluator analyzing a conversational agent's reflection of a conversation, where they analyze the conversation to identify the user's preferences and create actionable notes to help them satisfy these preferences in future conversations. Throughout the conversation, the user explicitly enforces their preferences whenever necessary. The agent analyzes the conversation to identify the user's preferences and create actionable notes to help them satisfy these preferences in future conversations. # Your Task: Evaluate whether the agent's reflection succesfully captures the user's preferences and provides actionable notes to help them satisfy these preferences in future conversations. # Agent's Reflection: {completion} # User Messages Where They Enforce Their Preferences: {user_messages_where_they_enforce_preferences} # Gold Reflection: Here is a gold reflection for the same conversation. Use this as a reference to evaluate the agent's reflection. {gold_response} # Evaluation Criteria: Assess the reflection on four dimensions: - **Coverage (Completeness):** Does the agent's reflection capture all of the user's preferences? - **Actionability (Quality):** Does the agent's reflection provide actionable notes and details that help the agent satisfy these preferences in future conversations? - **Accuracy (No Hallucination):** Are all points grounded in actual user statements? Does the reflection avoid inventing preferences or misrepresenting user statements? - **Clarity:** Is the reflection well-organized and clearly formatted? Does the reflection avoid redundancy, with each preference stated once without repetitive or overlapping notes? You will output a score from 0-3, where: - 0: Does not effectively capture user preferences: gaps in converage, or significant hallucinations - 1: Captures some preferences with limited actionable notes, may hallucinate some preferences - 2: Captures most preferences with actionable notes, may have some slight hallucinations - 3: Comprehensively captures all preferences with highly actionable notes and no hallucinations # Output Format: {{ "reasoning": # Brief explanation of your decision "reflection_score": # 0-3 }} Output a properly formatted JSON response, as specified by the Output Format.""" # TODO: maybe add something about hallucinating preferences!! reflection_score = ask_judge(reflection_evaluation_prompt) reflection_score = repair_json(reflection_score, return_objects=True)["reflection_score"] print(f"Reflection Evaluation Prompt: {reflection_evaluation_prompt}") print(f"Reflection Score: {reflection_score}") return reflection_score def soft_format_reward_func(prompts, completions, **kwargs) -> list[float]: """Reward function that checks if the completion has JSON format with agent_notes and user_preferences_reasoning fields.""" responses = [completion for completion in completions] rewards = [] for response in responses: reward = 0.0 try: parsed_json = repair_json(response, return_objects=True) if "agent_notes" in parsed_json and "user_preferences_reasoning" in parsed_json: reward = 0.5 except Exception as e: pass rewards.append(reward) for i, response in enumerate(responses): print("Soft Format Reward: ", rewards[i]) return rewards # Train the model from trl import GRPOConfig, GRPOTrainer training_args = GRPOConfig( learning_rate = 5e-6, # adam_beta1 = 0.9, # adam_beta2 = 0.99, # weight_decay = 0.1, # warmup_ratio = 0.1, gradient_accumulation_steps = 4, num_train_epochs = 1, bf16 = True, max_prompt_length = max_prompt_length, max_completion_length = max_completion_length, report_to = "wandb", logging_steps = 1, num_generations = 4, max_steps = 2000, save_steps = 50, output_dir = "./outputs_lamma3_reflection_grpo_v2", ) trainer = GRPOTrainer( model = model, processing_class = tokenizer, reward_funcs = [ soft_format_reward_func, reflection_reward_func ], args = training_args, train_dataset = dataset, ) trainer.train() # python3 llama_grpo.py >> llama_grpo_with_reflection_v1.out 2>&1