1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
|
import json
import os
import asyncio
from openai import OpenAI, AsyncOpenAI
from typing import Dict, Any, Set, List
# --- Configuration ---
BATCH_IDS_FILE = "data/raw_datasets/submitted_retry_batch_ids.json"
# The input file for *this specific retry batch* run
RETRY_INPUT_SOURCE = "data/raw_datasets/retry_requests.jsonl"
# Where to append the final results
OUTPUT_LABEL_FILE = "data/raw_datasets/labeled_full_dataset_batch.jsonl"
MODEL_NAME = "gpt-5.1"
def load_retry_queries() -> Dict[str, Dict[str, Any]]:
"""
Load the requests that were submitted in the retry batch.
These are essentially JSON Request objects.
"""
print("Loading retry source requests...")
mapping = {}
with open(RETRY_INPUT_SOURCE, "r", encoding="utf-8") as f:
for line in f:
if line.strip():
req = json.loads(line)
# Structure: {"custom_id": "...", "body": {"messages": [..., {"role": "user", "content": "..."}]}}
custom_id = req["custom_id"]
# Extract user query back from the request body
user_content = ""
for m in req["body"]["messages"]:
if m["role"] == "user":
user_content = m["content"]
break
mapping[custom_id] = {
"query": user_content,
# We might have lost source info in the retry conversion if not careful,
# but for now let's assume we just need the query.
# (Ideally we should have propagated source in metadata)
}
return mapping
async def process_and_finish():
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
print("Error: OPENAI_API_KEY not set.")
return
sync_client = OpenAI(api_key=api_key)
async_client = AsyncOpenAI(api_key=api_key)
if not os.path.exists(BATCH_IDS_FILE):
print(f"Error: {BATCH_IDS_FILE} not found.")
return
with open(BATCH_IDS_FILE, "r") as f:
batch_ids = json.load(f)
query_map = load_retry_queries()
processed_ids: Set[str] = set()
print(f"Total requests in retry batch: {len(query_map)}")
success_count = 0
# 1. Download results from Batch API (even if expired)
print("Downloading batch results...")
with open(OUTPUT_LABEL_FILE, "a", encoding="utf-8") as f_out:
for b_id in batch_ids:
try:
batch = sync_client.batches.retrieve(b_id)
if batch.output_file_id:
content = sync_client.files.content(batch.output_file_id).text
for line in content.splitlines():
if not line.strip(): continue
res = json.loads(line)
custom_id = res["custom_id"]
if res["response"]["status_code"] == 200:
try:
body = res["response"]["body"]
llm_content = body["choices"][0]["message"]["content"]
parsed_json = json.loads(llm_content)
original = query_map.get(custom_id)
if original:
record = {
"custom_id": custom_id,
"original_query": original["query"],
"source": "retry_recovery", # Lost original source, marking as recovery
"extracted_json": parsed_json,
"has_preference": len(parsed_json.get("preferences", [])) > 0
}
f_out.write(json.dumps(record, ensure_ascii=False) + "\n")
processed_ids.add(custom_id)
success_count += 1
except:
pass
except Exception as e:
print(f"Error checking batch {b_id}: {e}")
# 2. Identify Missing
missing_ids = [cid for cid in query_map.keys() if cid not in processed_ids]
print(f"\nMissing/Failed items: {len(missing_ids)}")
# 3. Finish with Direct API
if missing_ids:
print("Processing missing items via Direct API...")
# Load System Prompt
with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f:
sys_prompt = f.read()
with open(OUTPUT_LABEL_FILE, "a", encoding="utf-8") as f_out:
for cid in missing_ids:
item = query_map[cid]
query = item["query"]
print(f" Fixing {cid}...")
try:
resp = await async_client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": sys_prompt},
{"role": "user", "content": query}
],
response_format={"type": "json_object"}
)
content = resp.choices[0].message.content
parsed_json = json.loads(content)
record = {
"custom_id": cid,
"original_query": query,
"source": "retry_direct_fix",
"extracted_json": parsed_json,
"has_preference": len(parsed_json.get("preferences", [])) > 0
}
f_out.write(json.dumps(record, ensure_ascii=False) + "\n")
success_count += 1
except Exception as e:
print(f" Failed to fix {cid}: {e}")
print("\n" + "="*50)
print("ALL RETRY BATCHES RECOVERED.")
print(f"Total processed in this run: {success_count}")
print(f"Full dataset updated at: {OUTPUT_LABEL_FILE}")
print("="*50)
if __name__ == "__main__":
asyncio.run(process_and_finish())
|