summaryrefslogtreecommitdiff
path: root/collaborativeagents/training/grpo_verl/verl_reward_functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'collaborativeagents/training/grpo_verl/verl_reward_functions.py')
-rw-r--r--collaborativeagents/training/grpo_verl/verl_reward_functions.py160
1 files changed, 160 insertions, 0 deletions
diff --git a/collaborativeagents/training/grpo_verl/verl_reward_functions.py b/collaborativeagents/training/grpo_verl/verl_reward_functions.py
new file mode 100644
index 0000000..fa38b8f
--- /dev/null
+++ b/collaborativeagents/training/grpo_verl/verl_reward_functions.py
@@ -0,0 +1,160 @@
+"""
+Custom reward functions for VERL GRPO training
+Compatible with VERL's reward function signature
+"""
+
+import openai
+from json_repair import repair_json
+import concurrent.futures
+import time
+
+# Initialize judge client
+client = openai.OpenAI(base_url="http://localhost:8004/v1", api_key="EMPTY")
+
+# Global tracker for reflection scores
+reflection_scores_tracker = {
+ "scores": [],
+ "batch_count": 0
+}
+
+
+def extract_json_answer(text: str) -> str:
+ """Extract agent_notes from JSON response"""
+ try:
+ answer = repair_json(text, return_objects=True)
+ answer = answer["agent_notes"]
+ except Exception as e:
+ print(f"Error extracting JSON answer: {e}")
+ return ""
+ return answer
+
+
+def ask_judge(prompt, system_prompt=None, max_retries=3):
+ """Ask the judge model for evaluation"""
+ if system_prompt:
+ messages = [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": prompt}
+ ]
+ else:
+ messages = [{"role": "user", "content": prompt}]
+
+ for attempt in range(max_retries):
+ try:
+ chat_completion = client.chat.completions.create(
+ model="meta-llama/Llama-3.1-70B-Instruct",
+ messages=messages,
+ max_tokens=2048
+ )
+ return chat_completion.choices[0].message.content.strip()
+ except Exception as e:
+ if attempt < max_retries - 1:
+ time.sleep(1)
+ else:
+ raise e
+
+
+def evaluate_reflection(completion, gold_response, responses_that_enforce_preferences):
+ """Evaluate a single reflection completion"""
+
+ completion_text = extract_json_answer(completion)
+ if completion_text == "":
+ print(f"Poorly formatted completion: {completion}")
+ return 0
+
+ user_messages_where_they_enforce_preferences = ""
+ for i, response in enumerate(responses_that_enforce_preferences):
+ user_messages_where_they_enforce_preferences += f"User message #{i+1}: {response}\n"
+
+ reflection_evaluation_prompt = f"""You are an expert evaluator analyzing a conversational agent's reflection of a conversation, where they analyze the conversation to identify the user's preferences and create actionable notes to help them satisfy these preferences in future conversations.
+
+Throughout the conversation, the user explicitly enforces their preferences whenever necessary. The agent analyzes the conversation to identify the user's preferences and create actionable notes to help them satisfy these preferences in future conversations.
+
+# Your Task:
+Evaluate whether the agent's reflection succesfully captures the user's preferences and provides actionable notes to help them satisfy these preferences in future conversations.
+
+# Agent's Reflection:
+{completion_text}
+
+# User Messages Where They Enforce Their Preferences:
+{user_messages_where_they_enforce_preferences}
+
+# Gold Reflection:
+Here is a gold reflection for the same conversation. Use this as a reference to evaluate the agent's reflection.
+{gold_response}
+
+# Evaluation Criteria:
+Assess the reflection on four dimensions:
+- **Coverage (Completeness):** Does the agent's reflection capture all of the user's preferences?
+- **Actionability (Quality):** Does the agent's reflection provide actionable notes and details that help the agent satisfy these preferences in future conversations?
+- **Accuracy (No Hallucination):** Are all points grounded in actual user statements? Does the reflection avoid inventing preferences or misrepresenting user statements?
+- **Clarity:** Is the reflection well-organized and clearly formatted? Does the reflection avoid redundancy, with each preference stated once without repetitive or overlapping notes?
+
+You will output a score from 0-3, where:
+- 0: Does not effectively capture user preferences: gaps in converage, or significant hallucinations
+- 1: Captures some preferences with limited actionable notes, may hallucinate some preferences
+- 2: Captures most preferences with actionable notes, may have some slight hallucinations
+- 3: Comprehensively captures all preferences with highly actionable notes and no hallucinations
+
+# Output Format:
+{{
+ "reasoning": # Brief explanation of your decision
+ "reflection_score": # 0-3
+}}
+
+Output a properly formatted JSON response, as specified by the Output Format."""
+
+ reflection_score_response = ask_judge(reflection_evaluation_prompt)
+ reflection_score = repair_json(reflection_score_response, return_objects=True)["reflection_score"]
+
+ print(f"Reflection Score: {reflection_score}")
+ return reflection_score
+
+
+def soft_format_reward(solution_str):
+ """Check if the completion has JSON format with required fields"""
+ reward = 0.0
+ try:
+ parsed_json = repair_json(solution_str, return_objects=True)
+ if "agent_notes" in parsed_json and "user_preferences_reasoning" in parsed_json:
+ reward = 0.5
+ except Exception:
+ pass
+
+ print(f"Soft Format Reward: {reward}")
+ return reward
+
+
+# VERL reward function signature: (data_source, solution_str, ground_truth, extra_info)
+def compute_score(data_source, solution_str, ground_truth, extra_info=None):
+ """
+ Main reward function for VERL (named 'compute_score' for default VERL compatibility).
+ This matches the signature expected by VERL's reward managers.
+
+ Args:
+ data_source: Source identifier for the data
+ solution_str: The model's generated completion
+ ground_truth: Not used directly, passed in extra_info
+ extra_info: Dictionary containing 'gold_response' and 'responses_that_enforce_preferences'
+
+ Returns:
+ float: Combined reward score
+ """
+ if extra_info is None:
+ print("Warning: extra_info is None")
+ return 0.0
+
+ gold_response = extra_info.get('gold_response', '')
+ responses_that_enforce_preferences = extra_info.get('responses_that_enforce_preferences', [])
+
+ # Soft format reward
+ format_reward = soft_format_reward(solution_str)
+ # Reflection quality reward
+ reflection_score = evaluate_reflection(solution_str, gold_response, responses_that_enforce_preferences)
+
+ total_reward = format_reward + reflection_score
+
+ reflection_scores_tracker["scores"].append(reflection_score)
+ print(f"Total Reward: {total_reward} (Format: {format_reward}, Reflection: {reflection_score})")
+
+ return total_reward \ No newline at end of file