summaryrefslogtreecommitdiff
path: root/collaborativeagents/training/grpo/generate_grpo_data.py
blob: 7f67b0ed70ebb3657c9f28e0a7c8c514e461843f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import json

from collaborativeagents.utils import get_conversation_string
from collaborativeagents.prompts import update_agent_notes_prompt

logiqa_llama70b_training_data = "/shared/storage-01/users/mehri2/mem/collaborativeagents/training/sft/training_data/unprocessed_conversations/logiqa_llama70b_user_llama70b_agent_training_data_with_reflection_eval_size_20.jsonl"
math_500_llama70b_training_data = "/shared/storage-01/users/mehri2/mem/collaborativeagents/training/sft/training_data/unprocessed_conversations/math_500_llama70b_user_llama70b_agent_training_data_with_reflection_eval_size_20.jsonl"
math_hard_llama70b_training_data = "/shared/storage-01/users/mehri2/mem/collaborativeagents/training/sft/training_data/unprocessed_conversations/math_hard_llama70b_user_llama70b_agent_training_data_with_reflection_eval_size_20.jsonl"
medqa_llama70b_training_data = "/shared/storage-01/users/mehri2/mem/collaborativeagents/training/sft/training_data/unprocessed_conversations/medqa_llama70b_user_llama70b_agent_training_data_with_reflection_eval_size_20.jsonl"
mmlu_llama70b_training_data = "/shared/storage-01/users/mehri2/mem/collaborativeagents/training/sft/training_data/unprocessed_conversations/mmlu_llama70b_user_llama70b_agent_training_data_with_reflection_eval_size_20.jsonl"

processed_data = []
for file_path in [logiqa_llama70b_training_data, math_500_llama70b_training_data, math_hard_llama70b_training_data, medqa_llama70b_training_data, mmlu_llama70b_training_data]:
    unprocessed_data = []
    with open(file_path, "r") as f:
        for line in f:
            unprocessed_data.append(json.loads(line))

    for user_elem in unprocessed_data:
        for conversation_elem in user_elem["generated_conversations"]:
            conversation_str = get_conversation_string(conversation_elem["conversation"])
            
            formatted_update_agent_notes_prompt = update_agent_notes_prompt.format(agent_notes="", conversation_str=conversation_str)
            agent_notes_response = json.dumps(conversation_elem["agent_notes"], indent=2)

            training_conversation = [
                {"role": "user", "content": formatted_update_agent_notes_prompt},
                {"role": "assistant", "content": agent_notes_response}
            ]

            responses_that_enforce_preferences = [
                elem["response"] for elem in conversation_elem["full_conversation_log"] if "enforce_preferences" in elem and elem["enforce_preferences"]
            ]

            user_profile = {
                "i": user_elem["i"],
                "persona": user_elem["persona"],
                "preferences": user_elem["preferences"]
            }

            processed_data.append({
                "messages": training_conversation,
                "responses_that_enforce_preferences": responses_that_enforce_preferences,
                "user_profile": user_profile
            })


with open("/shared/storage-01/users/mehri2/mem/collaborativeagents/training/grpo/training_data/session_level_reflection_grpo_data.jsonl", "w") as f:
    json.dump(processed_data, f, indent=2)