summaryrefslogtreecommitdiff
path: root/scripts/finish_retry_batches.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/finish_retry_batches.py')
-rw-r--r--scripts/finish_retry_batches.py154
1 files changed, 154 insertions, 0 deletions
diff --git a/scripts/finish_retry_batches.py b/scripts/finish_retry_batches.py
new file mode 100644
index 0000000..f266327
--- /dev/null
+++ b/scripts/finish_retry_batches.py
@@ -0,0 +1,154 @@
+import json
+import os
+import asyncio
+from openai import OpenAI, AsyncOpenAI
+from typing import Dict, Any, Set, List
+
+# --- Configuration ---
+BATCH_IDS_FILE = "data/raw_datasets/submitted_retry_batch_ids.json"
+# The input file for *this specific retry batch* run
+RETRY_INPUT_SOURCE = "data/raw_datasets/retry_requests.jsonl"
+# Where to append the final results
+OUTPUT_LABEL_FILE = "data/raw_datasets/labeled_full_dataset_batch.jsonl"
+MODEL_NAME = "gpt-5.1"
+
+def load_retry_queries() -> Dict[str, Dict[str, Any]]:
+ """
+ Load the requests that were submitted in the retry batch.
+ These are essentially JSON Request objects.
+ """
+ print("Loading retry source requests...")
+ mapping = {}
+ with open(RETRY_INPUT_SOURCE, "r", encoding="utf-8") as f:
+ for line in f:
+ if line.strip():
+ req = json.loads(line)
+ # Structure: {"custom_id": "...", "body": {"messages": [..., {"role": "user", "content": "..."}]}}
+ custom_id = req["custom_id"]
+ # Extract user query back from the request body
+ user_content = ""
+ for m in req["body"]["messages"]:
+ if m["role"] == "user":
+ user_content = m["content"]
+ break
+
+ mapping[custom_id] = {
+ "query": user_content,
+ # We might have lost source info in the retry conversion if not careful,
+ # but for now let's assume we just need the query.
+ # (Ideally we should have propagated source in metadata)
+ }
+ return mapping
+
+async def process_and_finish():
+ api_key = os.getenv("OPENAI_API_KEY")
+ if not api_key:
+ print("Error: OPENAI_API_KEY not set.")
+ return
+
+ sync_client = OpenAI(api_key=api_key)
+ async_client = AsyncOpenAI(api_key=api_key)
+
+ if not os.path.exists(BATCH_IDS_FILE):
+ print(f"Error: {BATCH_IDS_FILE} not found.")
+ return
+
+ with open(BATCH_IDS_FILE, "r") as f:
+ batch_ids = json.load(f)
+
+ query_map = load_retry_queries()
+ processed_ids: Set[str] = set()
+
+ print(f"Total requests in retry batch: {len(query_map)}")
+
+ success_count = 0
+
+ # 1. Download results from Batch API (even if expired)
+ print("Downloading batch results...")
+ with open(OUTPUT_LABEL_FILE, "a", encoding="utf-8") as f_out:
+ for b_id in batch_ids:
+ try:
+ batch = sync_client.batches.retrieve(b_id)
+ if batch.output_file_id:
+ content = sync_client.files.content(batch.output_file_id).text
+ for line in content.splitlines():
+ if not line.strip(): continue
+ res = json.loads(line)
+ custom_id = res["custom_id"]
+
+ if res["response"]["status_code"] == 200:
+ try:
+ body = res["response"]["body"]
+ llm_content = body["choices"][0]["message"]["content"]
+ parsed_json = json.loads(llm_content)
+
+ original = query_map.get(custom_id)
+ if original:
+ record = {
+ "custom_id": custom_id,
+ "original_query": original["query"],
+ "source": "retry_recovery", # Lost original source, marking as recovery
+ "extracted_json": parsed_json,
+ "has_preference": len(parsed_json.get("preferences", [])) > 0
+ }
+ f_out.write(json.dumps(record, ensure_ascii=False) + "\n")
+ processed_ids.add(custom_id)
+ success_count += 1
+ except:
+ pass
+ except Exception as e:
+ print(f"Error checking batch {b_id}: {e}")
+
+ # 2. Identify Missing
+ missing_ids = [cid for cid in query_map.keys() if cid not in processed_ids]
+ print(f"\nMissing/Failed items: {len(missing_ids)}")
+
+ # 3. Finish with Direct API
+ if missing_ids:
+ print("Processing missing items via Direct API...")
+
+ # Load System Prompt
+ with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f:
+ sys_prompt = f.read()
+
+ with open(OUTPUT_LABEL_FILE, "a", encoding="utf-8") as f_out:
+ for cid in missing_ids:
+ item = query_map[cid]
+ query = item["query"]
+ print(f" Fixing {cid}...")
+
+ try:
+ resp = await async_client.chat.completions.create(
+ model=MODEL_NAME,
+ messages=[
+ {"role": "system", "content": sys_prompt},
+ {"role": "user", "content": query}
+ ],
+ response_format={"type": "json_object"}
+ )
+
+ content = resp.choices[0].message.content
+ parsed_json = json.loads(content)
+
+ record = {
+ "custom_id": cid,
+ "original_query": query,
+ "source": "retry_direct_fix",
+ "extracted_json": parsed_json,
+ "has_preference": len(parsed_json.get("preferences", [])) > 0
+ }
+ f_out.write(json.dumps(record, ensure_ascii=False) + "\n")
+ success_count += 1
+
+ except Exception as e:
+ print(f" Failed to fix {cid}: {e}")
+
+ print("\n" + "="*50)
+ print("ALL RETRY BATCHES RECOVERED.")
+ print(f"Total processed in this run: {success_count}")
+ print(f"Full dataset updated at: {OUTPUT_LABEL_FILE}")
+ print("="*50)
+
+if __name__ == "__main__":
+ asyncio.run(process_and_finish())
+