summaryrefslogtreecommitdiff
path: root/scripts/finish_retry_batches.py
blob: f2663272aeeae814b4c623480b59c30002c35dbe (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import json
import os
import asyncio
from openai import OpenAI, AsyncOpenAI
from typing import Dict, Any, Set, List

# --- Configuration ---
BATCH_IDS_FILE = "data/raw_datasets/submitted_retry_batch_ids.json"
# The input file for *this specific retry batch* run
RETRY_INPUT_SOURCE = "data/raw_datasets/retry_requests.jsonl" 
# Where to append the final results
OUTPUT_LABEL_FILE = "data/raw_datasets/labeled_full_dataset_batch.jsonl"
MODEL_NAME = "gpt-5.1"

def load_retry_queries() -> Dict[str, Dict[str, Any]]:
    """
    Load the requests that were submitted in the retry batch.
    These are essentially JSON Request objects.
    """
    print("Loading retry source requests...")
    mapping = {}
    with open(RETRY_INPUT_SOURCE, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                req = json.loads(line)
                # Structure: {"custom_id": "...", "body": {"messages": [..., {"role": "user", "content": "..."}]}}
                custom_id = req["custom_id"]
                # Extract user query back from the request body
                user_content = ""
                for m in req["body"]["messages"]:
                    if m["role"] == "user":
                        user_content = m["content"]
                        break
                
                mapping[custom_id] = {
                    "query": user_content,
                    # We might have lost source info in the retry conversion if not careful,
                    # but for now let's assume we just need the query.
                    # (Ideally we should have propagated source in metadata)
                }
    return mapping

async def process_and_finish():
    api_key = os.getenv("OPENAI_API_KEY")
    if not api_key:
        print("Error: OPENAI_API_KEY not set.")
        return
    
    sync_client = OpenAI(api_key=api_key)
    async_client = AsyncOpenAI(api_key=api_key)

    if not os.path.exists(BATCH_IDS_FILE):
        print(f"Error: {BATCH_IDS_FILE} not found.")
        return

    with open(BATCH_IDS_FILE, "r") as f:
        batch_ids = json.load(f)

    query_map = load_retry_queries()
    processed_ids: Set[str] = set()
    
    print(f"Total requests in retry batch: {len(query_map)}")
    
    success_count = 0
    
    # 1. Download results from Batch API (even if expired)
    print("Downloading batch results...")
    with open(OUTPUT_LABEL_FILE, "a", encoding="utf-8") as f_out:
        for b_id in batch_ids:
            try:
                batch = sync_client.batches.retrieve(b_id)
                if batch.output_file_id:
                    content = sync_client.files.content(batch.output_file_id).text
                    for line in content.splitlines():
                        if not line.strip(): continue
                        res = json.loads(line)
                        custom_id = res["custom_id"]
                        
                        if res["response"]["status_code"] == 200:
                            try:
                                body = res["response"]["body"]
                                llm_content = body["choices"][0]["message"]["content"]
                                parsed_json = json.loads(llm_content)
                                
                                original = query_map.get(custom_id)
                                if original:
                                    record = {
                                        "custom_id": custom_id,
                                        "original_query": original["query"],
                                        "source": "retry_recovery", # Lost original source, marking as recovery
                                        "extracted_json": parsed_json,
                                        "has_preference": len(parsed_json.get("preferences", [])) > 0
                                    }
                                    f_out.write(json.dumps(record, ensure_ascii=False) + "\n")
                                    processed_ids.add(custom_id)
                                    success_count += 1
                            except:
                                pass
            except Exception as e:
                print(f"Error checking batch {b_id}: {e}")

    # 2. Identify Missing
    missing_ids = [cid for cid in query_map.keys() if cid not in processed_ids]
    print(f"\nMissing/Failed items: {len(missing_ids)}")
    
    # 3. Finish with Direct API
    if missing_ids:
        print("Processing missing items via Direct API...")
        
        # Load System Prompt
        with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f:
            sys_prompt = f.read()
            
        with open(OUTPUT_LABEL_FILE, "a", encoding="utf-8") as f_out:
            for cid in missing_ids:
                item = query_map[cid]
                query = item["query"]
                print(f"  Fixing {cid}...")
                
                try:
                    resp = await async_client.chat.completions.create(
                        model=MODEL_NAME,
                        messages=[
                            {"role": "system", "content": sys_prompt},
                            {"role": "user", "content": query}
                        ],
                        response_format={"type": "json_object"}
                    )
                    
                    content = resp.choices[0].message.content
                    parsed_json = json.loads(content)
                    
                    record = {
                        "custom_id": cid,
                        "original_query": query,
                        "source": "retry_direct_fix",
                        "extracted_json": parsed_json,
                        "has_preference": len(parsed_json.get("preferences", [])) > 0
                    }
                    f_out.write(json.dumps(record, ensure_ascii=False) + "\n")
                    success_count += 1
                    
                except Exception as e:
                    print(f"  Failed to fix {cid}: {e}")

    print("\n" + "="*50)
    print("ALL RETRY BATCHES RECOVERED.")
    print(f"Total processed in this run: {success_count}")
    print(f"Full dataset updated at: {OUTPUT_LABEL_FILE}")
    print("="*50)

if __name__ == "__main__":
    asyncio.run(process_and_finish())