diff options
Diffstat (limited to 'scripts/retrieve_batch_results.py')
| -rw-r--r-- | scripts/retrieve_batch_results.py | 151 |
1 files changed, 151 insertions, 0 deletions
diff --git a/scripts/retrieve_batch_results.py b/scripts/retrieve_batch_results.py new file mode 100644 index 0000000..aa26e28 --- /dev/null +++ b/scripts/retrieve_batch_results.py @@ -0,0 +1,151 @@ +import json +import os +import time +from typing import Dict, Any, List, Set +from openai import OpenAI + +# --- Configuration --- +BATCH_IDS_FILE = "data/raw_datasets/submitted_batch_ids.json" +ORIGINAL_INPUT_FILE = "data/raw_datasets/combined_raw_queries.jsonl" +OUTPUT_LABEL_FILE = "data/raw_datasets/labeled_full_dataset_batch.jsonl" +RETRY_INPUT_FILE = "data/raw_datasets/retry_requests.jsonl" +MODEL_NAME = "gpt-5.1" # Need this for reconstruction + +# Load System Prompt locally to avoid import errors +with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f: + SYSTEM_PROMPT = f.read() + +def load_original_queries() -> Dict[str, Dict[str, Any]]: + print("Loading original queries map...") + mapping = {} + with open(ORIGINAL_INPUT_FILE, "r", encoding="utf-8") as f: + for idx, line in enumerate(f): + if line.strip(): + mapping[f"req_{idx}"] = json.loads(line) + return mapping + +def process_batch_results(): + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("Error: OPENAI_API_KEY not set.") + return + client = OpenAI(api_key=api_key) + + if not os.path.exists(BATCH_IDS_FILE): + print(f"Error: {BATCH_IDS_FILE} not found.") + return + + with open(BATCH_IDS_FILE, "r") as f: + batch_ids = json.load(f) + + query_map = load_original_queries() + processed_ids: Set[str] = set() + + # We append to existing output file if it exists, or overwrite? + # To be safe and avoid duplicates if re-run, let's load existing processed IDs if file exists. + if os.path.exists(OUTPUT_LABEL_FILE): + print("Scanning existing output file to avoid duplicates...") + with open(OUTPUT_LABEL_FILE, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + try: + # We don't store custom_id in output, but we can infer or we should have stored it. + # Wait, the output format in previous run didn't store custom_id. + # But we can't easily dedup without it unless we match content. + # BETTER STRATEGY: Just overwrite OUTPUT_LABEL_FILE for this recovery run to be clean. + # Or, since we crashed mid-way, maybe overwrite is safer. + pass + except: + pass + + print("Starting fresh download/processing (Overwriting output)...") + + success_count = 0 + fail_count = 0 + + with open(OUTPUT_LABEL_FILE, "w", encoding="utf-8") as f_success: + for b_id in batch_ids: + print(f"\nProcessing Batch {b_id}...") + try: + batch = client.batches.retrieve(b_id) + + # 1. Output File (Success) + if batch.output_file_id: + print(f" Downloading output {batch.output_file_id}...") + content = client.files.content(batch.output_file_id).text + + for line in content.splitlines(): + if not line.strip(): continue + res = json.loads(line) + custom_id = res["custom_id"] + + if res["response"]["status_code"] == 200: + try: + body = res["response"]["body"] + llm_content = body["choices"][0]["message"]["content"] + parsed_json = json.loads(llm_content) + + original_item = query_map.get(custom_id) + if original_item: + record = { + "custom_id": custom_id, # Add this to help debug later + "original_query": original_item["query"], + "source": original_item.get("source"), + "extracted_json": parsed_json, + "has_preference": len(parsed_json.get("preferences", [])) > 0 + } + f_success.write(json.dumps(record, ensure_ascii=False) + "\n") + processed_ids.add(custom_id) + success_count += 1 + except Exception as e: + print(f" Parse Error {custom_id}: {e}") + # Parse error -> Fail + # If not 200, it's a fail, handled by logic below (since it won't be in processed_ids) + + # 2. Error File (Explicit Failures) + # We don't need to explicitly read error file to write retries, + # because we will do a global "Missing Check" at the end. + # But reading it helps debugging. + if batch.error_file_id: + print(f" Downloading ERROR {batch.error_file_id}...") + # Just print count + # content = client.files.content(batch.error_file_id).text + # print(f" Found {len(content.splitlines())} errors in error file.") + + except Exception as e: + print(f" CRITICAL ERROR processing batch {b_id}: {e}") + + # --- Missing Check & Retry Generation --- + print(f"\nVerifying completeness... (Total Queries: {len(query_map)})") + print(f"Successful processed: {len(processed_ids)}") + + with open(RETRY_INPUT_FILE, "w", encoding="utf-8") as f_retry: + for custom_id, original_item in query_map.items(): + if custom_id not in processed_ids: + fail_count += 1 + + # Reconstruct Request + request_obj = { + "custom_id": custom_id, + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": MODEL_NAME, + "messages": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": original_item["query"]} + ], + "temperature": 0.0, + "response_format": {"type": "json_object"} + } + } + f_retry.write(json.dumps(request_obj) + "\n") + + print("\n" + "="*50) + print(f"Processing Complete.") + print(f"Successful: {success_count} (Saved to {OUTPUT_LABEL_FILE})") + print(f"To Retry: {fail_count} (Saved to {RETRY_INPUT_FILE})") + print("="*50) + +if __name__ == "__main__": + process_batch_results() |
