summaryrefslogtreecommitdiff
path: root/scripts/submit_synthesis_batch.py
blob: 025782d217518853ea25a4958536a0e1bc9df3ac (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import json
import os
from openai import OpenAI
import time

# --- Configuration ---
INPUT_SEEDS = "data/raw_datasets/positive_seeds.jsonl"
BATCH_DIR = "data/raw_datasets/batch_files_synthesis"
MODEL_NAME = "gpt-5.1"  # Or gpt-4o
BATCH_SIZE_LIMIT = 30000 # 31k total, splitting into 2 files is safe

SYNTHESIS_SYSTEM_PROMPT = """You are a data augmentation assistant. 
Your task is to rewrite a User Query that contains specific preferences into 5 different variations.
The goal is to train a model to recognize these preferences in various contexts.

Variations required:
1. Formal/Polite: Use sophisticated language and polite markers.
2. Casual/Direct: Use slang, abbreviations, or very direct commands.
3. Implicit/Contextual: Embed the preference naturally within a larger context or story, making it harder to spot.
4. Distractor-Heavy: Mix the preference with irrelevant information or another task.
5. Imperative/Short: Extremely concise, almost robotic.

Output strictly a JSON object with a single key "rewrites" containing a list of 5 strings.
Example: {"rewrites": ["string1", "string2", "string3", "string4", "string5"]}
"""

def submit_synthesis_batch():
    api_key = os.getenv("OPENAI_API_KEY")
    if not api_key:
        print("Error: OPENAI_API_KEY not set.")
        return
    client = OpenAI(api_key=api_key)

    os.makedirs(BATCH_DIR, exist_ok=True)
    
    if not os.path.exists(INPUT_SEEDS):
        print(f"Error: {INPUT_SEEDS} not found.")
        return

    print(f"Reading seeds from {INPUT_SEEDS}...")
    
    seeds = []
    with open(INPUT_SEEDS, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                seeds.append(json.loads(line))
    
    total_items = len(seeds)
    print(f"Total seeds: {total_items}")
    
    batch_ids = []
    
    # Split and Submit
    for batch_idx, i in enumerate(range(0, total_items, BATCH_SIZE_LIMIT)):
        chunk = seeds[i : i + BATCH_SIZE_LIMIT]
        chunk_filename = os.path.join(BATCH_DIR, f"synthesis_batch_part_{batch_idx}.jsonl")
        
        print(f"\n--- Processing Synthesis Batch {batch_idx} ({len(chunk)} items) ---")
        
        # 1. Create File
        with open(chunk_filename, "w", encoding="utf-8") as f_out:
            for item in chunk:
                # We need to pass both the query and the extracted preference to help the model
                # understand WHAT to preserve.
                original_query = item["original_query"]
                # extracted_json = item["extracted_json"] # Optional, but maybe helpful?
                # Actually, showing the extracted preference ensures the rewrite keeps the core intent.
                
                # Use original custom_id or create new one?
                # Let's create new one: "syn_{original_custom_id}" if available, else "syn_{index}"
                # Wait, positive_seeds might not have custom_id if it came from the recovered batch.
                # Let's check keys. The recovered file usually has custom_id.
                base_id = item.get("custom_id", f"seed_{i}")
                custom_id = f"syn_{base_id}" # Prefix to distinguish
                
                user_content = f"Original Query: {original_query}"
                # Optionally add: f"\nCore Preference: {json.dumps(extracted_json)}"
                
                request_obj = {
                    "custom_id": custom_id,
                    "method": "POST",
                    "url": "/v1/chat/completions",
                    "body": {
                        "model": MODEL_NAME,
                        "messages": [
                            {"role": "system", "content": SYNTHESIS_SYSTEM_PROMPT},
                            {"role": "user", "content": user_content}
                        ],
                        "temperature": 0.7, # Higher temp for diversity
                        "response_format": {"type": "json_object"}
                    }
                }
                f_out.write(json.dumps(request_obj) + "\n")
        
        print(f"File created: {chunk_filename}")
        
        # 2. Upload
        print("Uploading to OpenAI...")
        batch_file_obj = client.files.create(
            file=open(chunk_filename, "rb"),
            purpose="batch"
        )
        file_id = batch_file_obj.id
        print(f"Uploaded. File ID: {file_id}")
        
        # 3. Submit
        print("Submitting Batch Job...")
        batch_job = client.batches.create(
            input_file_id=file_id,
            endpoint="/v1/chat/completions",
            completion_window="24h",
            metadata={
                "description": f"Pers. Extractor Synthesis Part {batch_idx}",
                "type": "synthesis"
            }
        )
        print(f"Submitted. Batch ID: {batch_job.id}")
        batch_ids.append(batch_job.id)
        
        time.sleep(1)

    id_file = "data/raw_datasets/submitted_synthesis_batch_ids.json"
    with open(id_file, "w") as f:
        json.dump(batch_ids, f, indent=2)
    
    print(f"\nALL DONE! Submitted {len(batch_ids)} synthesis batches.")
    print(f"Batch IDs saved to {id_file}")

if __name__ == "__main__":
    submit_synthesis_batch()