summaryrefslogtreecommitdiff
path: root/collaborativeagents/training/sft/generate_sft_data.py
diff options
context:
space:
mode:
Diffstat (limited to 'collaborativeagents/training/sft/generate_sft_data.py')
-rw-r--r--collaborativeagents/training/sft/generate_sft_data.py49
1 files changed, 49 insertions, 0 deletions
diff --git a/collaborativeagents/training/sft/generate_sft_data.py b/collaborativeagents/training/sft/generate_sft_data.py
new file mode 100644
index 0000000..4f34f03
--- /dev/null
+++ b/collaborativeagents/training/sft/generate_sft_data.py
@@ -0,0 +1,49 @@
+import json
+
+from collaborativeagents.utils import get_conversation_string
+from collaborativeagents.prompts import update_agent_notes_prompt
+
+logiqa_llama70b_training_data = "/shared/storage-01/users/mehri2/mem/collaborativeagents/training/sft/training_data/unprocessed_conversations/logiqa_llama70b_user_llama70b_agent_training_data_with_reflection_eval_size_20.jsonl"
+math_500_llama70b_training_data = "/shared/storage-01/users/mehri2/mem/collaborativeagents/training/sft/training_data/unprocessed_conversations/math_500_llama70b_user_llama70b_agent_training_data_with_reflection_eval_size_20.jsonl"
+math_hard_llama70b_training_data = "/shared/storage-01/users/mehri2/mem/collaborativeagents/training/sft/training_data/unprocessed_conversations/math_hard_llama70b_user_llama70b_agent_training_data_with_reflection_eval_size_20.jsonl"
+medqa_llama70b_training_data = "/shared/storage-01/users/mehri2/mem/collaborativeagents/training/sft/training_data/unprocessed_conversations/medqa_llama70b_user_llama70b_agent_training_data_with_reflection_eval_size_20.jsonl"
+mmlu_llama70b_training_data = "/shared/storage-01/users/mehri2/mem/collaborativeagents/training/sft/training_data/unprocessed_conversations/mmlu_llama70b_user_llama70b_agent_training_data_with_reflection_eval_size_20.jsonl"
+
+processed_data = []
+for file_path in [logiqa_llama70b_training_data, math_500_llama70b_training_data, math_hard_llama70b_training_data, medqa_llama70b_training_data, mmlu_llama70b_training_data]:
+ unprocessed_data = []
+ with open(file_path, "r") as f:
+ for line in f:
+ unprocessed_data.append(json.loads(line))
+
+ for user_elem in unprocessed_data:
+ for conversation_elem in user_elem["generated_conversations"]:
+ conversation_str = get_conversation_string(conversation_elem["conversation"])
+
+ formatted_update_agent_notes_prompt = update_agent_notes_prompt.format(agent_notes="", conversation_str=conversation_str)
+ agent_notes_response = json.dumps(conversation_elem["agent_notes"], indent=2)
+
+ training_conversation = [
+ {"role": "user", "content": formatted_update_agent_notes_prompt},
+ {"role": "assistant", "content": agent_notes_response}
+ ]
+
+ responses_that_enforce_preferences = [
+ elem["response"] for elem in conversation_elem["full_conversation_log"] if "enforce_preferences" in elem and elem["enforce_preferences"]
+ ]
+
+ user_profile = {
+ "i": user_elem["i"],
+ "persona": user_elem["persona"],
+ "preferences": user_elem["preferences"]
+ }
+
+ processed_data.append({
+ "messages": training_conversation,
+ # "responses_that_enforce_preferences": responses_that_enforce_preferences,
+ # "user_profile": user_profile
+ })
+
+
+with open("/shared/storage-01/users/mehri2/mem/collaborativeagents/training/sft/training_data/session_level_reflection_sft_data.jsonl", "w") as f:
+ json.dump(processed_data, f, indent=2)