diff options
Diffstat (limited to 'collaborativeagents/training/sft/generate_sft_data.py')
| -rw-r--r-- | collaborativeagents/training/sft/generate_sft_data.py | 49 |
1 files changed, 49 insertions, 0 deletions
diff --git a/collaborativeagents/training/sft/generate_sft_data.py b/collaborativeagents/training/sft/generate_sft_data.py new file mode 100644 index 0000000..4f34f03 --- /dev/null +++ b/collaborativeagents/training/sft/generate_sft_data.py @@ -0,0 +1,49 @@ +import json + +from collaborativeagents.utils import get_conversation_string +from collaborativeagents.prompts import update_agent_notes_prompt + +logiqa_llama70b_training_data = "/shared/storage-01/users/mehri2/mem/collaborativeagents/training/sft/training_data/unprocessed_conversations/logiqa_llama70b_user_llama70b_agent_training_data_with_reflection_eval_size_20.jsonl" +math_500_llama70b_training_data = "/shared/storage-01/users/mehri2/mem/collaborativeagents/training/sft/training_data/unprocessed_conversations/math_500_llama70b_user_llama70b_agent_training_data_with_reflection_eval_size_20.jsonl" +math_hard_llama70b_training_data = "/shared/storage-01/users/mehri2/mem/collaborativeagents/training/sft/training_data/unprocessed_conversations/math_hard_llama70b_user_llama70b_agent_training_data_with_reflection_eval_size_20.jsonl" +medqa_llama70b_training_data = "/shared/storage-01/users/mehri2/mem/collaborativeagents/training/sft/training_data/unprocessed_conversations/medqa_llama70b_user_llama70b_agent_training_data_with_reflection_eval_size_20.jsonl" +mmlu_llama70b_training_data = "/shared/storage-01/users/mehri2/mem/collaborativeagents/training/sft/training_data/unprocessed_conversations/mmlu_llama70b_user_llama70b_agent_training_data_with_reflection_eval_size_20.jsonl" + +processed_data = [] +for file_path in [logiqa_llama70b_training_data, math_500_llama70b_training_data, math_hard_llama70b_training_data, medqa_llama70b_training_data, mmlu_llama70b_training_data]: + unprocessed_data = [] + with open(file_path, "r") as f: + for line in f: + unprocessed_data.append(json.loads(line)) + + for user_elem in unprocessed_data: + for conversation_elem in user_elem["generated_conversations"]: + conversation_str = get_conversation_string(conversation_elem["conversation"]) + + formatted_update_agent_notes_prompt = update_agent_notes_prompt.format(agent_notes="", conversation_str=conversation_str) + agent_notes_response = json.dumps(conversation_elem["agent_notes"], indent=2) + + training_conversation = [ + {"role": "user", "content": formatted_update_agent_notes_prompt}, + {"role": "assistant", "content": agent_notes_response} + ] + + responses_that_enforce_preferences = [ + elem["response"] for elem in conversation_elem["full_conversation_log"] if "enforce_preferences" in elem and elem["enforce_preferences"] + ] + + user_profile = { + "i": user_elem["i"], + "persona": user_elem["persona"], + "preferences": user_elem["preferences"] + } + + processed_data.append({ + "messages": training_conversation, + # "responses_that_enforce_preferences": responses_that_enforce_preferences, + # "user_profile": user_profile + }) + + +with open("/shared/storage-01/users/mehri2/mem/collaborativeagents/training/sft/training_data/session_level_reflection_sft_data.jsonl", "w") as f: + json.dump(processed_data, f, indent=2) |
