import json import os import asyncio from openai import OpenAI, AsyncOpenAI from typing import Dict, Any, Set, List # --- Configuration --- BATCH_IDS_FILE = "data/raw_datasets/submitted_retry_batch_ids.json" # The input file for *this specific retry batch* run RETRY_INPUT_SOURCE = "data/raw_datasets/retry_requests.jsonl" # Where to append the final results OUTPUT_LABEL_FILE = "data/raw_datasets/labeled_full_dataset_batch.jsonl" MODEL_NAME = "gpt-5.1" def load_retry_queries() -> Dict[str, Dict[str, Any]]: """ Load the requests that were submitted in the retry batch. These are essentially JSON Request objects. """ print("Loading retry source requests...") mapping = {} with open(RETRY_INPUT_SOURCE, "r", encoding="utf-8") as f: for line in f: if line.strip(): req = json.loads(line) # Structure: {"custom_id": "...", "body": {"messages": [..., {"role": "user", "content": "..."}]}} custom_id = req["custom_id"] # Extract user query back from the request body user_content = "" for m in req["body"]["messages"]: if m["role"] == "user": user_content = m["content"] break mapping[custom_id] = { "query": user_content, # We might have lost source info in the retry conversion if not careful, # but for now let's assume we just need the query. # (Ideally we should have propagated source in metadata) } return mapping async def process_and_finish(): api_key = os.getenv("OPENAI_API_KEY") if not api_key: print("Error: OPENAI_API_KEY not set.") return sync_client = OpenAI(api_key=api_key) async_client = AsyncOpenAI(api_key=api_key) if not os.path.exists(BATCH_IDS_FILE): print(f"Error: {BATCH_IDS_FILE} not found.") return with open(BATCH_IDS_FILE, "r") as f: batch_ids = json.load(f) query_map = load_retry_queries() processed_ids: Set[str] = set() print(f"Total requests in retry batch: {len(query_map)}") success_count = 0 # 1. Download results from Batch API (even if expired) print("Downloading batch results...") with open(OUTPUT_LABEL_FILE, "a", encoding="utf-8") as f_out: for b_id in batch_ids: try: batch = sync_client.batches.retrieve(b_id) if batch.output_file_id: content = sync_client.files.content(batch.output_file_id).text for line in content.splitlines(): if not line.strip(): continue res = json.loads(line) custom_id = res["custom_id"] if res["response"]["status_code"] == 200: try: body = res["response"]["body"] llm_content = body["choices"][0]["message"]["content"] parsed_json = json.loads(llm_content) original = query_map.get(custom_id) if original: record = { "custom_id": custom_id, "original_query": original["query"], "source": "retry_recovery", # Lost original source, marking as recovery "extracted_json": parsed_json, "has_preference": len(parsed_json.get("preferences", [])) > 0 } f_out.write(json.dumps(record, ensure_ascii=False) + "\n") processed_ids.add(custom_id) success_count += 1 except: pass except Exception as e: print(f"Error checking batch {b_id}: {e}") # 2. Identify Missing missing_ids = [cid for cid in query_map.keys() if cid not in processed_ids] print(f"\nMissing/Failed items: {len(missing_ids)}") # 3. Finish with Direct API if missing_ids: print("Processing missing items via Direct API...") # Load System Prompt with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f: sys_prompt = f.read() with open(OUTPUT_LABEL_FILE, "a", encoding="utf-8") as f_out: for cid in missing_ids: item = query_map[cid] query = item["query"] print(f" Fixing {cid}...") try: resp = await async_client.chat.completions.create( model=MODEL_NAME, messages=[ {"role": "system", "content": sys_prompt}, {"role": "user", "content": query} ], response_format={"type": "json_object"} ) content = resp.choices[0].message.content parsed_json = json.loads(content) record = { "custom_id": cid, "original_query": query, "source": "retry_direct_fix", "extracted_json": parsed_json, "has_preference": len(parsed_json.get("preferences", [])) > 0 } f_out.write(json.dumps(record, ensure_ascii=False) + "\n") success_count += 1 except Exception as e: print(f" Failed to fix {cid}: {e}") print("\n" + "="*50) print("ALL RETRY BATCHES RECOVERED.") print(f"Total processed in this run: {success_count}") print(f"Full dataset updated at: {OUTPUT_LABEL_FILE}") print("="*50) if __name__ == "__main__": asyncio.run(process_and_finish())