import json import os import time from typing import List, Dict, Any from openai import OpenAI from collections import Counter # Configuration API_KEY = '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' BATCH_ID_FILE = "data/putnam_eval/submitted_batch_ids.json" INPUT_FILE = "data/putnam_eval/putnam_eval_batch.jsonl" OUTPUT_FILE = "data/putnam_eval/final_results.json" MODEL_NAME = "gpt-5" def load_input_requests(filepath: str) -> Dict[str, Any]: """Load the original requests to allow retrying.""" requests_map = {} print(f"Loading input requests from {filepath}...") with open(filepath, "r", encoding="utf-8") as f: for line in f: if not line.strip(): continue item = json.loads(line) requests_map[item["custom_id"]] = item return requests_map def retrieve_batch_results(client: OpenAI, batch_id: str) -> List[Dict[str, Any]]: """Retrieve and parse batch results.""" print(f"Checking status for batch {batch_id}...") batch = client.batches.retrieve(batch_id) print(f"Batch Status: {batch.status}") print(f"Output File ID: {batch.output_file_id}") print(f"Error File ID: {batch.error_file_id}") results = [] if batch.output_file_id: print("Downloading output file...") file_response = client.files.content(batch.output_file_id) file_content = file_response.read().decode("utf-8") for line in file_content.splitlines(): if line.strip(): results.append(json.loads(line)) if batch.error_file_id: print("Downloading error file (if any)...") # Usually contains request-level errors that didn't generate a response object in output try: err_response = client.files.content(batch.error_file_id) err_content = err_response.read().decode("utf-8") for line in err_content.splitlines(): if line.strip(): results.append(json.loads(line)) except Exception as e: print(f"Note: Could not download/parse error file: {e}") return results def process_results_and_find_failures(results: List[Dict[str, Any]], all_request_ids: set) -> tuple[List[Dict[str, Any]], List[str]]: """Separate successful parsable results from failures.""" valid_results = [] failed_ids = [] seen_ids = set() for res in results: custom_id = res.get("custom_id") seen_ids.add(custom_id) # Check for API level errors if res.get("error"): print(f"Request {custom_id} failed with error: {res['error']}") failed_ids.append(custom_id) continue response = res.get("response", {}) if response.get("status_code") != 200: print(f"Request {custom_id} failed with status {response.get('status_code')}") failed_ids.append(custom_id) continue # Try to parse the content as JSON try: body = response.get("body", {}) choices = body.get("choices", []) if not choices: print(f"Request {custom_id} has no choices.") failed_ids.append(custom_id) continue content_str = choices[0].get("message", {}).get("content", "") content_json = json.loads(content_str) valid_results.append({ "custom_id": custom_id, "analysis": content_json }) except json.JSONDecodeError: print(f"Request {custom_id} returned invalid JSON content.") failed_ids.append(custom_id) except Exception as e: print(f"Request {custom_id} unexpected processing error: {e}") failed_ids.append(custom_id) # Check for completely missing requests missing_ids = all_request_ids - seen_ids if missing_ids: print(f"Found {len(missing_ids)} missing requests that were not in the batch output.") failed_ids.extend(list(missing_ids)) return valid_results, failed_ids def retry_failed_requests(client: OpenAI, failed_ids: List[str], input_map: Dict[str, Any]) -> List[Dict[str, Any]]: """Retry specific requests synchronously.""" retried_results = [] print(f"\nRetrying {len(failed_ids)} failed requests synchronously...") for i, custom_id in enumerate(failed_ids): if custom_id not in input_map: print(f"Warning: Original request for {custom_id} not found.") continue print(f"Retrying {i+1}/{len(failed_ids)}: {custom_id}") original_req = input_map[custom_id] body = original_req["body"] try: response = client.chat.completions.create( model=MODEL_NAME, # Use the model from the script constant, not necessarily the batch one if we want to enforce gpt-5 messages=body["messages"], response_format=body.get("response_format"), temperature=body.get("temperature", 1.0) # Default if not set, usually 0 in our templates? ) content_str = response.choices[0].message.content content_json = json.loads(content_str) retried_results.append({ "custom_id": custom_id, "analysis": content_json }) except Exception as e: print(f"Retry failed for {custom_id}: {e}") return retried_results def print_stats(final_results: List[Dict[str, Any]]): """Calculate and print statistics.""" total = len(final_results) if total == 0: print("No results to analyze.") return # Categories valid_variant_count = 0 correct_solution_count = 0 equivalent_count = 0 strongly_related_count = 0 # Validation Consistency both_valid_and_equiv = 0 print(f"\n--- Statistics (N={total}) ---") for item in final_results: analysis = item["analysis"] validity = analysis.get("variant_validity", {}) relation = analysis.get("relation_to_original", {}) is_valid = validity.get("is_problem_valid", False) is_correct = validity.get("is_solution_correct", False) is_equiv = relation.get("is_equivalent", False) is_related = relation.get("is_strongly_related", False) if is_valid: valid_variant_count += 1 if is_correct: correct_solution_count += 1 if is_equiv: equivalent_count += 1 if is_related: strongly_related_count += 1 if is_valid and is_correct and (is_equiv or is_related): both_valid_and_equiv += 1 print(f"Variant Valid: {valid_variant_count} ({valid_variant_count/total:.1%})") print(f"Solution Correct: {correct_solution_count} ({correct_solution_count/total:.1%})") print(f"Equivalent: {equivalent_count} ({equivalent_count/total:.1%})") print(f"Strongly Related: {strongly_related_count} ({strongly_related_count/total:.1%})") print(f"Valid & Rel/Equiv: {both_valid_and_equiv} ({both_valid_and_equiv/total:.1%})") def main(): if not API_KEY: print("Error: API_KEY not set.") return client = OpenAI(api_key=API_KEY) # 1. Get Batch ID if not os.path.exists(BATCH_ID_FILE): print(f"Batch ID file not found at {BATCH_ID_FILE}") return with open(BATCH_ID_FILE, "r") as f: batch_ids = json.load(f) if not batch_ids: print("No batch IDs found.") return batch_id = batch_ids[-1] # Take the latest one print(f"Processing Batch ID: {batch_id}") # 2. Retrieve Results raw_results = retrieve_batch_results(client, batch_id) # 3. Load Inputs (to identify missing/failed IDs) input_map = load_input_requests(INPUT_FILE) all_request_ids = set(input_map.keys()) # 4. Parse and Find Failures valid_results, failed_ids = process_results_and_find_failures(raw_results, all_request_ids) print(f"Successfully parsed: {len(valid_results)}") print(f"Failed/Missing: {len(failed_ids)}") # 5. Retry Failures if failed_ids: retry_results = retry_failed_requests(client, failed_ids, input_map) valid_results.extend(retry_results) # 6. Save Final Results print(f"Saving {len(valid_results)} results to {OUTPUT_FILE}...") with open(OUTPUT_FILE, "w", encoding="utf-8") as f: json.dump(valid_results, f, indent=2) # 7. Stats print_stats(valid_results) if __name__ == "__main__": main()