diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2025-12-17 04:29:37 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2025-12-17 04:29:37 -0600 |
| commit | e43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (patch) | |
| tree | 6ce8a00d2f8b9ebd83c894a27ea01ac50cfb2ff5 /scripts/submit_oasst1_batch.py | |
Diffstat (limited to 'scripts/submit_oasst1_batch.py')
| -rw-r--r-- | scripts/submit_oasst1_batch.py | 120 |
1 files changed, 120 insertions, 0 deletions
diff --git a/scripts/submit_oasst1_batch.py b/scripts/submit_oasst1_batch.py new file mode 100644 index 0000000..1a96dd0 --- /dev/null +++ b/scripts/submit_oasst1_batch.py @@ -0,0 +1,120 @@ +import json +import os +import time +from openai import OpenAI + +# --- Configuration --- +INPUT_FILE = "data/raw_datasets/oasst1_queries.jsonl" +BATCH_DIR = "data/raw_datasets/batch_files_oasst1" +METADATA_FILE = "data/raw_datasets/oasst1_metadata_map.jsonl" +MODEL_NAME = "gpt-5.1" +BATCH_SIZE_LIMIT = 49000 + +# --- Load System Prompt --- +with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f: + SYSTEM_PROMPT = f.read() + +def submit_oasst1_batch(): + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("Error: OPENAI_API_KEY not set.") + return + client = OpenAI(api_key=api_key) + + os.makedirs(BATCH_DIR, exist_ok=True) + + print(f"Reading from {INPUT_FILE}...") + + all_lines = [] + with open(INPUT_FILE, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + all_lines.append(json.loads(line)) + + total_items = len(all_lines) + print(f"Total OASST1 items: {total_items}") + + # 1. Generate Metadata Map first + # This ensures we have the mapping even if batch submission fails mid-way + print(f"Generating metadata map to {METADATA_FILE}...") + with open(METADATA_FILE, "w", encoding="utf-8") as f_meta: + for idx, item in enumerate(all_lines): + custom_id = f"oasst1_req_{idx}" + meta_record = { + "custom_id": custom_id, + "user_id": item.get("user_id"), + "session_id": item.get("session_id"), + "turn_id": item.get("turn_id"), + "original_query": item.get("original_query") or item.get("query") + } + f_meta.write(json.dumps(meta_record, ensure_ascii=False) + "\n") + + # Store custom_id back to item list for batch generation + item["_temp_custom_id"] = custom_id + + # 2. Split and Submit + batch_ids = [] + + for batch_idx, i in enumerate(range(0, total_items, BATCH_SIZE_LIMIT)): + chunk = all_lines[i : i + BATCH_SIZE_LIMIT] + chunk_filename = os.path.join(BATCH_DIR, f"oasst1_batch_part_{batch_idx}.jsonl") + + print(f"\n--- Processing OASST1 Batch {batch_idx} ({len(chunk)} items) ---") + + with open(chunk_filename, "w", encoding="utf-8") as f_out: + for item in chunk: + custom_id = item["_temp_custom_id"] + query = item.get("original_query") or item.get("query") + + request_obj = { + "custom_id": custom_id, + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": MODEL_NAME, + "messages": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": query} + ], + "temperature": 0.0, + "response_format": {"type": "json_object"} + } + } + f_out.write(json.dumps(request_obj) + "\n") + + print(f"File created: {chunk_filename}") + + print("Uploading to OpenAI...") + batch_file_obj = client.files.create( + file=open(chunk_filename, "rb"), + purpose="batch" + ) + file_id = batch_file_obj.id + print(f"Uploaded. File ID: {file_id}") + + print("Submitting Batch Job...") + batch_job = client.batches.create( + input_file_id=file_id, + endpoint="/v1/chat/completions", + completion_window="24h", + metadata={ + "description": f"Pers. Extractor OASST1 Part {batch_idx}", + "dataset": "oasst1" + } + ) + print(f"Submitted. Batch ID: {batch_job.id}") + batch_ids.append(batch_job.id) + + time.sleep(1) + + id_file = "data/raw_datasets/submitted_oasst1_batch_ids.json" + with open(id_file, "w") as f: + json.dump(batch_ids, f, indent=2) + + print(f"\nALL DONE! Submitted {len(batch_ids)} OASST1 batches.") + print(f"Metadata saved to {METADATA_FILE}") + print(f"Batch IDs saved to {id_file}") + +if __name__ == "__main__": + submit_oasst1_batch() + |
