diff options
Diffstat (limited to 'scripts/finish_retry_batches.py')
| -rw-r--r-- | scripts/finish_retry_batches.py | 154 |
1 files changed, 154 insertions, 0 deletions
diff --git a/scripts/finish_retry_batches.py b/scripts/finish_retry_batches.py new file mode 100644 index 0000000..f266327 --- /dev/null +++ b/scripts/finish_retry_batches.py @@ -0,0 +1,154 @@ +import json +import os +import asyncio +from openai import OpenAI, AsyncOpenAI +from typing import Dict, Any, Set, List + +# --- Configuration --- +BATCH_IDS_FILE = "data/raw_datasets/submitted_retry_batch_ids.json" +# The input file for *this specific retry batch* run +RETRY_INPUT_SOURCE = "data/raw_datasets/retry_requests.jsonl" +# Where to append the final results +OUTPUT_LABEL_FILE = "data/raw_datasets/labeled_full_dataset_batch.jsonl" +MODEL_NAME = "gpt-5.1" + +def load_retry_queries() -> Dict[str, Dict[str, Any]]: + """ + Load the requests that were submitted in the retry batch. + These are essentially JSON Request objects. + """ + print("Loading retry source requests...") + mapping = {} + with open(RETRY_INPUT_SOURCE, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + req = json.loads(line) + # Structure: {"custom_id": "...", "body": {"messages": [..., {"role": "user", "content": "..."}]}} + custom_id = req["custom_id"] + # Extract user query back from the request body + user_content = "" + for m in req["body"]["messages"]: + if m["role"] == "user": + user_content = m["content"] + break + + mapping[custom_id] = { + "query": user_content, + # We might have lost source info in the retry conversion if not careful, + # but for now let's assume we just need the query. + # (Ideally we should have propagated source in metadata) + } + return mapping + +async def process_and_finish(): + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("Error: OPENAI_API_KEY not set.") + return + + sync_client = OpenAI(api_key=api_key) + async_client = AsyncOpenAI(api_key=api_key) + + if not os.path.exists(BATCH_IDS_FILE): + print(f"Error: {BATCH_IDS_FILE} not found.") + return + + with open(BATCH_IDS_FILE, "r") as f: + batch_ids = json.load(f) + + query_map = load_retry_queries() + processed_ids: Set[str] = set() + + print(f"Total requests in retry batch: {len(query_map)}") + + success_count = 0 + + # 1. Download results from Batch API (even if expired) + print("Downloading batch results...") + with open(OUTPUT_LABEL_FILE, "a", encoding="utf-8") as f_out: + for b_id in batch_ids: + try: + batch = sync_client.batches.retrieve(b_id) + if batch.output_file_id: + content = sync_client.files.content(batch.output_file_id).text + for line in content.splitlines(): + if not line.strip(): continue + res = json.loads(line) + custom_id = res["custom_id"] + + if res["response"]["status_code"] == 200: + try: + body = res["response"]["body"] + llm_content = body["choices"][0]["message"]["content"] + parsed_json = json.loads(llm_content) + + original = query_map.get(custom_id) + if original: + record = { + "custom_id": custom_id, + "original_query": original["query"], + "source": "retry_recovery", # Lost original source, marking as recovery + "extracted_json": parsed_json, + "has_preference": len(parsed_json.get("preferences", [])) > 0 + } + f_out.write(json.dumps(record, ensure_ascii=False) + "\n") + processed_ids.add(custom_id) + success_count += 1 + except: + pass + except Exception as e: + print(f"Error checking batch {b_id}: {e}") + + # 2. Identify Missing + missing_ids = [cid for cid in query_map.keys() if cid not in processed_ids] + print(f"\nMissing/Failed items: {len(missing_ids)}") + + # 3. Finish with Direct API + if missing_ids: + print("Processing missing items via Direct API...") + + # Load System Prompt + with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f: + sys_prompt = f.read() + + with open(OUTPUT_LABEL_FILE, "a", encoding="utf-8") as f_out: + for cid in missing_ids: + item = query_map[cid] + query = item["query"] + print(f" Fixing {cid}...") + + try: + resp = await async_client.chat.completions.create( + model=MODEL_NAME, + messages=[ + {"role": "system", "content": sys_prompt}, + {"role": "user", "content": query} + ], + response_format={"type": "json_object"} + ) + + content = resp.choices[0].message.content + parsed_json = json.loads(content) + + record = { + "custom_id": cid, + "original_query": query, + "source": "retry_direct_fix", + "extracted_json": parsed_json, + "has_preference": len(parsed_json.get("preferences", [])) > 0 + } + f_out.write(json.dumps(record, ensure_ascii=False) + "\n") + success_count += 1 + + except Exception as e: + print(f" Failed to fix {cid}: {e}") + + print("\n" + "="*50) + print("ALL RETRY BATCHES RECOVERED.") + print(f"Total processed in this run: {success_count}") + print(f"Full dataset updated at: {OUTPUT_LABEL_FILE}") + print("="*50) + +if __name__ == "__main__": + asyncio.run(process_and_finish()) + |
