summaryrefslogtreecommitdiff
path: root/scripts/submit_oasst1_batch.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/submit_oasst1_batch.py')
-rw-r--r--scripts/submit_oasst1_batch.py120
1 files changed, 120 insertions, 0 deletions
diff --git a/scripts/submit_oasst1_batch.py b/scripts/submit_oasst1_batch.py
new file mode 100644
index 0000000..1a96dd0
--- /dev/null
+++ b/scripts/submit_oasst1_batch.py
@@ -0,0 +1,120 @@
+import json
+import os
+import time
+from openai import OpenAI
+
+# --- Configuration ---
+INPUT_FILE = "data/raw_datasets/oasst1_queries.jsonl"
+BATCH_DIR = "data/raw_datasets/batch_files_oasst1"
+METADATA_FILE = "data/raw_datasets/oasst1_metadata_map.jsonl"
+MODEL_NAME = "gpt-5.1"
+BATCH_SIZE_LIMIT = 49000
+
+# --- Load System Prompt ---
+with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f:
+ SYSTEM_PROMPT = f.read()
+
+def submit_oasst1_batch():
+ api_key = os.getenv("OPENAI_API_KEY")
+ if not api_key:
+ print("Error: OPENAI_API_KEY not set.")
+ return
+ client = OpenAI(api_key=api_key)
+
+ os.makedirs(BATCH_DIR, exist_ok=True)
+
+ print(f"Reading from {INPUT_FILE}...")
+
+ all_lines = []
+ with open(INPUT_FILE, "r", encoding="utf-8") as f:
+ for line in f:
+ if line.strip():
+ all_lines.append(json.loads(line))
+
+ total_items = len(all_lines)
+ print(f"Total OASST1 items: {total_items}")
+
+ # 1. Generate Metadata Map first
+ # This ensures we have the mapping even if batch submission fails mid-way
+ print(f"Generating metadata map to {METADATA_FILE}...")
+ with open(METADATA_FILE, "w", encoding="utf-8") as f_meta:
+ for idx, item in enumerate(all_lines):
+ custom_id = f"oasst1_req_{idx}"
+ meta_record = {
+ "custom_id": custom_id,
+ "user_id": item.get("user_id"),
+ "session_id": item.get("session_id"),
+ "turn_id": item.get("turn_id"),
+ "original_query": item.get("original_query") or item.get("query")
+ }
+ f_meta.write(json.dumps(meta_record, ensure_ascii=False) + "\n")
+
+ # Store custom_id back to item list for batch generation
+ item["_temp_custom_id"] = custom_id
+
+ # 2. Split and Submit
+ batch_ids = []
+
+ for batch_idx, i in enumerate(range(0, total_items, BATCH_SIZE_LIMIT)):
+ chunk = all_lines[i : i + BATCH_SIZE_LIMIT]
+ chunk_filename = os.path.join(BATCH_DIR, f"oasst1_batch_part_{batch_idx}.jsonl")
+
+ print(f"\n--- Processing OASST1 Batch {batch_idx} ({len(chunk)} items) ---")
+
+ with open(chunk_filename, "w", encoding="utf-8") as f_out:
+ for item in chunk:
+ custom_id = item["_temp_custom_id"]
+ query = item.get("original_query") or item.get("query")
+
+ request_obj = {
+ "custom_id": custom_id,
+ "method": "POST",
+ "url": "/v1/chat/completions",
+ "body": {
+ "model": MODEL_NAME,
+ "messages": [
+ {"role": "system", "content": SYSTEM_PROMPT},
+ {"role": "user", "content": query}
+ ],
+ "temperature": 0.0,
+ "response_format": {"type": "json_object"}
+ }
+ }
+ f_out.write(json.dumps(request_obj) + "\n")
+
+ print(f"File created: {chunk_filename}")
+
+ print("Uploading to OpenAI...")
+ batch_file_obj = client.files.create(
+ file=open(chunk_filename, "rb"),
+ purpose="batch"
+ )
+ file_id = batch_file_obj.id
+ print(f"Uploaded. File ID: {file_id}")
+
+ print("Submitting Batch Job...")
+ batch_job = client.batches.create(
+ input_file_id=file_id,
+ endpoint="/v1/chat/completions",
+ completion_window="24h",
+ metadata={
+ "description": f"Pers. Extractor OASST1 Part {batch_idx}",
+ "dataset": "oasst1"
+ }
+ )
+ print(f"Submitted. Batch ID: {batch_job.id}")
+ batch_ids.append(batch_job.id)
+
+ time.sleep(1)
+
+ id_file = "data/raw_datasets/submitted_oasst1_batch_ids.json"
+ with open(id_file, "w") as f:
+ json.dump(batch_ids, f, indent=2)
+
+ print(f"\nALL DONE! Submitted {len(batch_ids)} OASST1 batches.")
+ print(f"Metadata saved to {METADATA_FILE}")
+ print(f"Batch IDs saved to {id_file}")
+
+if __name__ == "__main__":
+ submit_oasst1_batch()
+