summaryrefslogtreecommitdiff
path: root/scripts/submit_batch.py
blob: e848dc5b10020b6837f9fe2f1044d70b99f95bd8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import json
import os
import time
from typing import List
from openai import OpenAI

# --- Configuration ---
INPUT_FILE = "data/raw_datasets/combined_raw_queries.jsonl"
BATCH_DIR = "data/raw_datasets/batch_files"
MODEL_NAME = "gpt-5.1"  # Or "gpt-4o"
BATCH_SIZE_LIMIT = 49000 # Safe under 50k limit

# --- Load System Prompt ---
with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f:
    SYSTEM_PROMPT = f.read()

def prepare_and_submit_batches():
    api_key = os.getenv("OPENAI_API_KEY")
    if not api_key:
        print("Error: OPENAI_API_KEY not set.")
        return
    client = OpenAI(api_key=api_key)

    os.makedirs(BATCH_DIR, exist_ok=True)
    
    print(f"Reading from {INPUT_FILE}...")
    
    # Read all lines first
    all_lines = []
    with open(INPUT_FILE, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                all_lines.append(json.loads(line))
    
    total_items = len(all_lines)
    print(f"Total items: {total_items}")
    
    batch_ids = []
    
    # Split and Process
    for batch_idx, i in enumerate(range(0, total_items, BATCH_SIZE_LIMIT)):
        chunk = all_lines[i : i + BATCH_SIZE_LIMIT]
        chunk_filename = os.path.join(BATCH_DIR, f"batch_input_part_{batch_idx}.jsonl")
        
        print(f"\n--- Processing Batch {batch_idx} ({len(chunk)} items) ---")
        
        # 1. Create File
        with open(chunk_filename, "w", encoding="utf-8") as f_out:
            for item_idx, item in enumerate(chunk):
                # Global index to track back later if needed
                global_idx = i + item_idx
                query = item["query"]
                
                # Custom ID: "req_{global_index}"
                custom_id = f"req_{global_idx}"
                
                request_obj = {
                    "custom_id": custom_id,
                    "method": "POST",
                    "url": "/v1/chat/completions",
                    "body": {
                        "model": MODEL_NAME,
                        "messages": [
                            {"role": "system", "content": SYSTEM_PROMPT},
                            {"role": "user", "content": query}
                        ],
                        "temperature": 0.0,
                        "response_format": {"type": "json_object"}
                    }
                }
                f_out.write(json.dumps(request_obj) + "\n")
        
        print(f"File created: {chunk_filename}")
        
        # 2. Upload File
        print("Uploading to OpenAI...")
        batch_file_obj = client.files.create(
            file=open(chunk_filename, "rb"),
            purpose="batch"
        )
        file_id = batch_file_obj.id
        print(f"Uploaded. File ID: {file_id}")
        
        # 3. Submit Batch
        print("Submitting Batch Job...")
        batch_job = client.batches.create(
            input_file_id=file_id,
            endpoint="/v1/chat/completions",
            completion_window="24h",
            metadata={
                "description": f"Pers. Extractor Part {batch_idx}",
                "part_index": str(batch_idx)
            }
        )
        print(f"Submitted. Batch ID: {batch_job.id}")
        batch_ids.append(batch_job.id)
        
        # Sleep briefly to be nice to API
        time.sleep(1)

    # Save all Batch IDs
    id_file = "data/raw_datasets/submitted_batch_ids.json"
    with open(id_file, "w") as f:
        json.dump(batch_ids, f, indent=2)
    
    print(f"\nALL DONE! Submitted {len(batch_ids)} batches.")
    print(f"Batch IDs saved to {id_file}")
    print("Run scripts/check_batch_status.py (you need to write it) to monitor.")

if __name__ == "__main__":
    prepare_and_submit_batches()