summaryrefslogtreecommitdiff
path: root/scripts/submit_oasst1_batch.py
blob: 1a96dd0cd31622c11a226d25b81e82a93d7d0078 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import json
import os
import time
from openai import OpenAI

# --- Configuration ---
INPUT_FILE = "data/raw_datasets/oasst1_queries.jsonl"
BATCH_DIR = "data/raw_datasets/batch_files_oasst1"
METADATA_FILE = "data/raw_datasets/oasst1_metadata_map.jsonl"
MODEL_NAME = "gpt-5.1"
BATCH_SIZE_LIMIT = 49000

# --- Load System Prompt ---
with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f:
    SYSTEM_PROMPT = f.read()

def submit_oasst1_batch():
    api_key = os.getenv("OPENAI_API_KEY")
    if not api_key:
        print("Error: OPENAI_API_KEY not set.")
        return
    client = OpenAI(api_key=api_key)

    os.makedirs(BATCH_DIR, exist_ok=True)
    
    print(f"Reading from {INPUT_FILE}...")
    
    all_lines = []
    with open(INPUT_FILE, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                all_lines.append(json.loads(line))
    
    total_items = len(all_lines)
    print(f"Total OASST1 items: {total_items}")
    
    # 1. Generate Metadata Map first
    # This ensures we have the mapping even if batch submission fails mid-way
    print(f"Generating metadata map to {METADATA_FILE}...")
    with open(METADATA_FILE, "w", encoding="utf-8") as f_meta:
        for idx, item in enumerate(all_lines):
            custom_id = f"oasst1_req_{idx}"
            meta_record = {
                "custom_id": custom_id,
                "user_id": item.get("user_id"),
                "session_id": item.get("session_id"),
                "turn_id": item.get("turn_id"),
                "original_query": item.get("original_query") or item.get("query")
            }
            f_meta.write(json.dumps(meta_record, ensure_ascii=False) + "\n")
            
            # Store custom_id back to item list for batch generation
            item["_temp_custom_id"] = custom_id

    # 2. Split and Submit
    batch_ids = []
    
    for batch_idx, i in enumerate(range(0, total_items, BATCH_SIZE_LIMIT)):
        chunk = all_lines[i : i + BATCH_SIZE_LIMIT]
        chunk_filename = os.path.join(BATCH_DIR, f"oasst1_batch_part_{batch_idx}.jsonl")
        
        print(f"\n--- Processing OASST1 Batch {batch_idx} ({len(chunk)} items) ---")
        
        with open(chunk_filename, "w", encoding="utf-8") as f_out:
            for item in chunk:
                custom_id = item["_temp_custom_id"]
                query = item.get("original_query") or item.get("query")
                
                request_obj = {
                    "custom_id": custom_id,
                    "method": "POST",
                    "url": "/v1/chat/completions",
                    "body": {
                        "model": MODEL_NAME,
                        "messages": [
                            {"role": "system", "content": SYSTEM_PROMPT},
                            {"role": "user", "content": query}
                        ],
                        "temperature": 0.0,
                        "response_format": {"type": "json_object"}
                    }
                }
                f_out.write(json.dumps(request_obj) + "\n")
        
        print(f"File created: {chunk_filename}")
        
        print("Uploading to OpenAI...")
        batch_file_obj = client.files.create(
            file=open(chunk_filename, "rb"),
            purpose="batch"
        )
        file_id = batch_file_obj.id
        print(f"Uploaded. File ID: {file_id}")
        
        print("Submitting Batch Job...")
        batch_job = client.batches.create(
            input_file_id=file_id,
            endpoint="/v1/chat/completions",
            completion_window="24h",
            metadata={
                "description": f"Pers. Extractor OASST1 Part {batch_idx}",
                "dataset": "oasst1"
            }
        )
        print(f"Submitted. Batch ID: {batch_job.id}")
        batch_ids.append(batch_job.id)
        
        time.sleep(1)

    id_file = "data/raw_datasets/submitted_oasst1_batch_ids.json"
    with open(id_file, "w") as f:
        json.dump(batch_ids, f, indent=2)
    
    print(f"\nALL DONE! Submitted {len(batch_ids)} OASST1 batches.")
    print(f"Metadata saved to {METADATA_FILE}")
    print(f"Batch IDs saved to {id_file}")

if __name__ == "__main__":
    submit_oasst1_batch()