summaryrefslogtreecommitdiff
path: root/scripts/process_putnam_batch.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/process_putnam_batch.py')
-rw-r--r--scripts/process_putnam_batch.py239
1 files changed, 239 insertions, 0 deletions
diff --git a/scripts/process_putnam_batch.py b/scripts/process_putnam_batch.py
new file mode 100644
index 0000000..27e0465
--- /dev/null
+++ b/scripts/process_putnam_batch.py
@@ -0,0 +1,239 @@
+import json
+import os
+import time
+from typing import List, Dict, Any
+from openai import OpenAI
+from collections import Counter
+
+# Configuration
+API_KEY = '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'
+BATCH_ID_FILE = "data/putnam_eval/submitted_batch_ids.json"
+INPUT_FILE = "data/putnam_eval/putnam_eval_batch.jsonl"
+OUTPUT_FILE = "data/putnam_eval/final_results.json"
+MODEL_NAME = "gpt-5"
+
+def load_input_requests(filepath: str) -> Dict[str, Any]:
+ """Load the original requests to allow retrying."""
+ requests_map = {}
+ print(f"Loading input requests from {filepath}...")
+ with open(filepath, "r", encoding="utf-8") as f:
+ for line in f:
+ if not line.strip():
+ continue
+ item = json.loads(line)
+ requests_map[item["custom_id"]] = item
+ return requests_map
+
+def retrieve_batch_results(client: OpenAI, batch_id: str) -> List[Dict[str, Any]]:
+ """Retrieve and parse batch results."""
+ print(f"Checking status for batch {batch_id}...")
+ batch = client.batches.retrieve(batch_id)
+
+ print(f"Batch Status: {batch.status}")
+ print(f"Output File ID: {batch.output_file_id}")
+ print(f"Error File ID: {batch.error_file_id}")
+
+ results = []
+
+ if batch.output_file_id:
+ print("Downloading output file...")
+ file_response = client.files.content(batch.output_file_id)
+ file_content = file_response.read().decode("utf-8")
+
+ for line in file_content.splitlines():
+ if line.strip():
+ results.append(json.loads(line))
+
+ if batch.error_file_id:
+ print("Downloading error file (if any)...")
+ # Usually contains request-level errors that didn't generate a response object in output
+ try:
+ err_response = client.files.content(batch.error_file_id)
+ err_content = err_response.read().decode("utf-8")
+ for line in err_content.splitlines():
+ if line.strip():
+ results.append(json.loads(line))
+ except Exception as e:
+ print(f"Note: Could not download/parse error file: {e}")
+
+ return results
+
+def process_results_and_find_failures(results: List[Dict[str, Any]], all_request_ids: set) -> tuple[List[Dict[str, Any]], List[str]]:
+ """Separate successful parsable results from failures."""
+ valid_results = []
+ failed_ids = []
+ seen_ids = set()
+
+ for res in results:
+ custom_id = res.get("custom_id")
+ seen_ids.add(custom_id)
+
+ # Check for API level errors
+ if res.get("error"):
+ print(f"Request {custom_id} failed with error: {res['error']}")
+ failed_ids.append(custom_id)
+ continue
+
+ response = res.get("response", {})
+ if response.get("status_code") != 200:
+ print(f"Request {custom_id} failed with status {response.get('status_code')}")
+ failed_ids.append(custom_id)
+ continue
+
+ # Try to parse the content as JSON
+ try:
+ body = response.get("body", {})
+ choices = body.get("choices", [])
+ if not choices:
+ print(f"Request {custom_id} has no choices.")
+ failed_ids.append(custom_id)
+ continue
+
+ content_str = choices[0].get("message", {}).get("content", "")
+ content_json = json.loads(content_str)
+
+ valid_results.append({
+ "custom_id": custom_id,
+ "analysis": content_json
+ })
+ except json.JSONDecodeError:
+ print(f"Request {custom_id} returned invalid JSON content.")
+ failed_ids.append(custom_id)
+ except Exception as e:
+ print(f"Request {custom_id} unexpected processing error: {e}")
+ failed_ids.append(custom_id)
+
+ # Check for completely missing requests
+ missing_ids = all_request_ids - seen_ids
+ if missing_ids:
+ print(f"Found {len(missing_ids)} missing requests that were not in the batch output.")
+ failed_ids.extend(list(missing_ids))
+
+ return valid_results, failed_ids
+
+def retry_failed_requests(client: OpenAI, failed_ids: List[str], input_map: Dict[str, Any]) -> List[Dict[str, Any]]:
+ """Retry specific requests synchronously."""
+ retried_results = []
+ print(f"\nRetrying {len(failed_ids)} failed requests synchronously...")
+
+ for i, custom_id in enumerate(failed_ids):
+ if custom_id not in input_map:
+ print(f"Warning: Original request for {custom_id} not found.")
+ continue
+
+ print(f"Retrying {i+1}/{len(failed_ids)}: {custom_id}")
+ original_req = input_map[custom_id]
+ body = original_req["body"]
+
+ try:
+ response = client.chat.completions.create(
+ model=MODEL_NAME, # Use the model from the script constant, not necessarily the batch one if we want to enforce gpt-5
+ messages=body["messages"],
+ response_format=body.get("response_format"),
+ temperature=body.get("temperature", 1.0) # Default if not set, usually 0 in our templates?
+ )
+
+ content_str = response.choices[0].message.content
+ content_json = json.loads(content_str)
+
+ retried_results.append({
+ "custom_id": custom_id,
+ "analysis": content_json
+ })
+ except Exception as e:
+ print(f"Retry failed for {custom_id}: {e}")
+
+ return retried_results
+
+def print_stats(final_results: List[Dict[str, Any]]):
+ """Calculate and print statistics."""
+ total = len(final_results)
+ if total == 0:
+ print("No results to analyze.")
+ return
+
+ # Categories
+ valid_variant_count = 0
+ correct_solution_count = 0
+ equivalent_count = 0
+ strongly_related_count = 0
+
+ # Validation Consistency
+ both_valid_and_equiv = 0
+
+ print(f"\n--- Statistics (N={total}) ---")
+
+ for item in final_results:
+ analysis = item["analysis"]
+ validity = analysis.get("variant_validity", {})
+ relation = analysis.get("relation_to_original", {})
+
+ is_valid = validity.get("is_problem_valid", False)
+ is_correct = validity.get("is_solution_correct", False)
+ is_equiv = relation.get("is_equivalent", False)
+ is_related = relation.get("is_strongly_related", False)
+
+ if is_valid: valid_variant_count += 1
+ if is_correct: correct_solution_count += 1
+ if is_equiv: equivalent_count += 1
+ if is_related: strongly_related_count += 1
+
+ if is_valid and is_correct and (is_equiv or is_related):
+ both_valid_and_equiv += 1
+
+ print(f"Variant Valid: {valid_variant_count} ({valid_variant_count/total:.1%})")
+ print(f"Solution Correct: {correct_solution_count} ({correct_solution_count/total:.1%})")
+ print(f"Equivalent: {equivalent_count} ({equivalent_count/total:.1%})")
+ print(f"Strongly Related: {strongly_related_count} ({strongly_related_count/total:.1%})")
+ print(f"Valid & Rel/Equiv: {both_valid_and_equiv} ({both_valid_and_equiv/total:.1%})")
+
+def main():
+ if not API_KEY:
+ print("Error: API_KEY not set.")
+ return
+
+ client = OpenAI(api_key=API_KEY)
+
+ # 1. Get Batch ID
+ if not os.path.exists(BATCH_ID_FILE):
+ print(f"Batch ID file not found at {BATCH_ID_FILE}")
+ return
+
+ with open(BATCH_ID_FILE, "r") as f:
+ batch_ids = json.load(f)
+ if not batch_ids:
+ print("No batch IDs found.")
+ return
+ batch_id = batch_ids[-1] # Take the latest one
+ print(f"Processing Batch ID: {batch_id}")
+
+ # 2. Retrieve Results
+ raw_results = retrieve_batch_results(client, batch_id)
+
+ # 3. Load Inputs (to identify missing/failed IDs)
+ input_map = load_input_requests(INPUT_FILE)
+ all_request_ids = set(input_map.keys())
+
+ # 4. Parse and Find Failures
+ valid_results, failed_ids = process_results_and_find_failures(raw_results, all_request_ids)
+ print(f"Successfully parsed: {len(valid_results)}")
+ print(f"Failed/Missing: {len(failed_ids)}")
+
+ # 5. Retry Failures
+ if failed_ids:
+ retry_results = retry_failed_requests(client, failed_ids, input_map)
+ valid_results.extend(retry_results)
+
+ # 6. Save Final Results
+ print(f"Saving {len(valid_results)} results to {OUTPUT_FILE}...")
+ with open(OUTPUT_FILE, "w", encoding="utf-8") as f:
+ json.dump(valid_results, f, indent=2)
+
+ # 7. Stats
+ print_stats(valid_results)
+
+if __name__ == "__main__":
+ main()
+
+
+