diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2025-12-17 04:29:37 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2025-12-17 04:29:37 -0600 |
| commit | e43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (patch) | |
| tree | 6ce8a00d2f8b9ebd83c894a27ea01ac50cfb2ff5 /scripts/process_putnam_batch.py | |
Diffstat (limited to 'scripts/process_putnam_batch.py')
| -rw-r--r-- | scripts/process_putnam_batch.py | 239 |
1 files changed, 239 insertions, 0 deletions
diff --git a/scripts/process_putnam_batch.py b/scripts/process_putnam_batch.py new file mode 100644 index 0000000..27e0465 --- /dev/null +++ b/scripts/process_putnam_batch.py @@ -0,0 +1,239 @@ +import json +import os +import time +from typing import List, Dict, Any +from openai import OpenAI +from collections import Counter + +# Configuration +API_KEY = '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' +BATCH_ID_FILE = "data/putnam_eval/submitted_batch_ids.json" +INPUT_FILE = "data/putnam_eval/putnam_eval_batch.jsonl" +OUTPUT_FILE = "data/putnam_eval/final_results.json" +MODEL_NAME = "gpt-5" + +def load_input_requests(filepath: str) -> Dict[str, Any]: + """Load the original requests to allow retrying.""" + requests_map = {} + print(f"Loading input requests from {filepath}...") + with open(filepath, "r", encoding="utf-8") as f: + for line in f: + if not line.strip(): + continue + item = json.loads(line) + requests_map[item["custom_id"]] = item + return requests_map + +def retrieve_batch_results(client: OpenAI, batch_id: str) -> List[Dict[str, Any]]: + """Retrieve and parse batch results.""" + print(f"Checking status for batch {batch_id}...") + batch = client.batches.retrieve(batch_id) + + print(f"Batch Status: {batch.status}") + print(f"Output File ID: {batch.output_file_id}") + print(f"Error File ID: {batch.error_file_id}") + + results = [] + + if batch.output_file_id: + print("Downloading output file...") + file_response = client.files.content(batch.output_file_id) + file_content = file_response.read().decode("utf-8") + + for line in file_content.splitlines(): + if line.strip(): + results.append(json.loads(line)) + + if batch.error_file_id: + print("Downloading error file (if any)...") + # Usually contains request-level errors that didn't generate a response object in output + try: + err_response = client.files.content(batch.error_file_id) + err_content = err_response.read().decode("utf-8") + for line in err_content.splitlines(): + if line.strip(): + results.append(json.loads(line)) + except Exception as e: + print(f"Note: Could not download/parse error file: {e}") + + return results + +def process_results_and_find_failures(results: List[Dict[str, Any]], all_request_ids: set) -> tuple[List[Dict[str, Any]], List[str]]: + """Separate successful parsable results from failures.""" + valid_results = [] + failed_ids = [] + seen_ids = set() + + for res in results: + custom_id = res.get("custom_id") + seen_ids.add(custom_id) + + # Check for API level errors + if res.get("error"): + print(f"Request {custom_id} failed with error: {res['error']}") + failed_ids.append(custom_id) + continue + + response = res.get("response", {}) + if response.get("status_code") != 200: + print(f"Request {custom_id} failed with status {response.get('status_code')}") + failed_ids.append(custom_id) + continue + + # Try to parse the content as JSON + try: + body = response.get("body", {}) + choices = body.get("choices", []) + if not choices: + print(f"Request {custom_id} has no choices.") + failed_ids.append(custom_id) + continue + + content_str = choices[0].get("message", {}).get("content", "") + content_json = json.loads(content_str) + + valid_results.append({ + "custom_id": custom_id, + "analysis": content_json + }) + except json.JSONDecodeError: + print(f"Request {custom_id} returned invalid JSON content.") + failed_ids.append(custom_id) + except Exception as e: + print(f"Request {custom_id} unexpected processing error: {e}") + failed_ids.append(custom_id) + + # Check for completely missing requests + missing_ids = all_request_ids - seen_ids + if missing_ids: + print(f"Found {len(missing_ids)} missing requests that were not in the batch output.") + failed_ids.extend(list(missing_ids)) + + return valid_results, failed_ids + +def retry_failed_requests(client: OpenAI, failed_ids: List[str], input_map: Dict[str, Any]) -> List[Dict[str, Any]]: + """Retry specific requests synchronously.""" + retried_results = [] + print(f"\nRetrying {len(failed_ids)} failed requests synchronously...") + + for i, custom_id in enumerate(failed_ids): + if custom_id not in input_map: + print(f"Warning: Original request for {custom_id} not found.") + continue + + print(f"Retrying {i+1}/{len(failed_ids)}: {custom_id}") + original_req = input_map[custom_id] + body = original_req["body"] + + try: + response = client.chat.completions.create( + model=MODEL_NAME, # Use the model from the script constant, not necessarily the batch one if we want to enforce gpt-5 + messages=body["messages"], + response_format=body.get("response_format"), + temperature=body.get("temperature", 1.0) # Default if not set, usually 0 in our templates? + ) + + content_str = response.choices[0].message.content + content_json = json.loads(content_str) + + retried_results.append({ + "custom_id": custom_id, + "analysis": content_json + }) + except Exception as e: + print(f"Retry failed for {custom_id}: {e}") + + return retried_results + +def print_stats(final_results: List[Dict[str, Any]]): + """Calculate and print statistics.""" + total = len(final_results) + if total == 0: + print("No results to analyze.") + return + + # Categories + valid_variant_count = 0 + correct_solution_count = 0 + equivalent_count = 0 + strongly_related_count = 0 + + # Validation Consistency + both_valid_and_equiv = 0 + + print(f"\n--- Statistics (N={total}) ---") + + for item in final_results: + analysis = item["analysis"] + validity = analysis.get("variant_validity", {}) + relation = analysis.get("relation_to_original", {}) + + is_valid = validity.get("is_problem_valid", False) + is_correct = validity.get("is_solution_correct", False) + is_equiv = relation.get("is_equivalent", False) + is_related = relation.get("is_strongly_related", False) + + if is_valid: valid_variant_count += 1 + if is_correct: correct_solution_count += 1 + if is_equiv: equivalent_count += 1 + if is_related: strongly_related_count += 1 + + if is_valid and is_correct and (is_equiv or is_related): + both_valid_and_equiv += 1 + + print(f"Variant Valid: {valid_variant_count} ({valid_variant_count/total:.1%})") + print(f"Solution Correct: {correct_solution_count} ({correct_solution_count/total:.1%})") + print(f"Equivalent: {equivalent_count} ({equivalent_count/total:.1%})") + print(f"Strongly Related: {strongly_related_count} ({strongly_related_count/total:.1%})") + print(f"Valid & Rel/Equiv: {both_valid_and_equiv} ({both_valid_and_equiv/total:.1%})") + +def main(): + if not API_KEY: + print("Error: API_KEY not set.") + return + + client = OpenAI(api_key=API_KEY) + + # 1. Get Batch ID + if not os.path.exists(BATCH_ID_FILE): + print(f"Batch ID file not found at {BATCH_ID_FILE}") + return + + with open(BATCH_ID_FILE, "r") as f: + batch_ids = json.load(f) + if not batch_ids: + print("No batch IDs found.") + return + batch_id = batch_ids[-1] # Take the latest one + print(f"Processing Batch ID: {batch_id}") + + # 2. Retrieve Results + raw_results = retrieve_batch_results(client, batch_id) + + # 3. Load Inputs (to identify missing/failed IDs) + input_map = load_input_requests(INPUT_FILE) + all_request_ids = set(input_map.keys()) + + # 4. Parse and Find Failures + valid_results, failed_ids = process_results_and_find_failures(raw_results, all_request_ids) + print(f"Successfully parsed: {len(valid_results)}") + print(f"Failed/Missing: {len(failed_ids)}") + + # 5. Retry Failures + if failed_ids: + retry_results = retry_failed_requests(client, failed_ids, input_map) + valid_results.extend(retry_results) + + # 6. Save Final Results + print(f"Saving {len(valid_results)} results to {OUTPUT_FILE}...") + with open(OUTPUT_FILE, "w", encoding="utf-8") as f: + json.dump(valid_results, f, indent=2) + + # 7. Stats + print_stats(valid_results) + +if __name__ == "__main__": + main() + + + |
