summaryrefslogtreecommitdiff
path: root/putnamsup/run_putnam_gap.py
diff options
context:
space:
mode:
Diffstat (limited to 'putnamsup/run_putnam_gap.py')
-rw-r--r--putnamsup/run_putnam_gap.py167
1 files changed, 167 insertions, 0 deletions
diff --git a/putnamsup/run_putnam_gap.py b/putnamsup/run_putnam_gap.py
new file mode 100644
index 0000000..73f0ef6
--- /dev/null
+++ b/putnamsup/run_putnam_gap.py
@@ -0,0 +1,167 @@
+import os
+import argparse
+import torch
+import time
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from tqdm import tqdm
+from putnam_utils import load_dataset, SUPPORTED_VARIANTS
+import json
+
+def run_inference_batch(model, tokenizer, questions: list, device: str) -> list:
+ """
+ Runs generation for a batch of questions.
+ """
+ prompts = [f"Problem:\n{q}\n\nPlease solve the problem above step by step and provide the final answer.\n\nSolution:\n" for q in questions]
+
+ # Determine target device for inputs
+ if device == "auto":
+ target_device = model.device
+ else:
+ target_device = device
+
+ input_texts = []
+ if tokenizer.chat_template:
+ for p in prompts:
+ messages = [{"role": "user", "content": p}]
+ try:
+ formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+ input_texts.append(formatted)
+ except Exception:
+ input_texts.append(p)
+ else:
+ input_texts = prompts
+
+ # Tokenize with padding
+ inputs = tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True).to(target_device)
+
+ with torch.no_grad():
+ output_ids = model.generate(
+ **inputs,
+ max_new_tokens=1024,
+ do_sample=False,
+ pad_token_id=tokenizer.pad_token_id
+ )
+
+ # Decode only new tokens
+ # output_ids contains input_ids + new_tokens. We need to slice.
+ # However, input lengths might vary due to padding.
+ # batch_decode usually decodes everything.
+ # A common trick is to decode everything and then strip the prompt, but prompts are different.
+ # Better: tokenizer.batch_decode(output_ids[:, inputs.input_ids.shape[1]:]) works if left-padded and consistent length?
+ # No, with left padding, the new tokens are at the end.
+
+ generated_texts = tokenizer.batch_decode(output_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
+ return [t.strip() for t in generated_texts]
+
+def main():
+ parser = argparse.ArgumentParser(description="Run inference on PutnamGAP dataset")
+ parser.add_argument("--data_dir", type=str, default="PutnamGAP", help="Path to PutnamGAP JSON files")
+ parser.add_argument("--model_name_or_path", type=str, required=True, help="Hugging Face model name or path")
+ parser.add_argument("--output_file", type=str, default="putnam_gap_results.jsonl", help="Output file path")
+ parser.add_argument("--limit", type=int, default=None, help="Limit total number of problems to run")
+ parser.add_argument("--limit_per_variant", type=int, default=None, help="Limit number of problems per variant")
+ parser.add_argument("--batch_size", type=int, default=1, help="Batch size for inference")
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to run on (use 'auto' for multi-GPU)")
+ parser.add_argument("--dry_run", action="store_true", help="Only load data and print first few examples, do not load model")
+ parser.add_argument("--variants", type=str, default=None, help=f"Comma-separated list of variants to include. Choices: {','.join(SUPPORTED_VARIANTS)}")
+ args = parser.parse_args()
+
+ # Parse variants argument
+ selected_variants = None
+
+ # Diagnostic check for CUDA availability
+ if torch.cuda.device_count() > 0 and not torch.cuda.is_available():
+ print("\n" + "!"*60)
+ print(f"WARNING: PyTorch detects {torch.cuda.device_count()} CUDA devices but cannot use them.")
+ print(f"torch.cuda.is_available() == False")
+ print(f"Current PyTorch version: {torch.__version__}")
+ print(f"Your driver probably supports an older CUDA version than this PyTorch build.")
+ print("!"*60 + "\n")
+
+ if args.variants:
+ selected_variants = [v.strip() for v in args.variants.split(",")]
+ print(f"Filtering for variants: {selected_variants}")
+
+ print(f"Scanning data from {args.data_dir}...")
+ dataset = list(load_dataset(args.data_dir, selected_variants=selected_variants))
+ print(f"Found {len(dataset)} problem variants.")
+
+ if args.limit_per_variant:
+ from collections import defaultdict
+ counts = defaultdict(int)
+ filtered_dataset = []
+ for item in dataset:
+ v = item['variant']
+ if counts[v] < args.limit_per_variant:
+ filtered_dataset.append(item)
+ counts[v] += 1
+ dataset = filtered_dataset
+ print(f"Filtered to {len(dataset)} examples (max {args.limit_per_variant} per variant).")
+
+ if args.dry_run:
+ if dataset:
+ print("\n--- Example 1 ---")
+ print(f"Index: {dataset[0]['file_index']}")
+ print(f"Variant: {dataset[0]['variant']}")
+ print(f"Question: {dataset[0]['question'][:200]}...")
+ print(f"Solution: {dataset[0]['solution'][:200]}...")
+ return
+
+ print(f"Loading model: {args.model_name_or_path} on {args.device}")
+
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True, padding_side='left')
+ if tokenizer.pad_token_id is None:
+ if tokenizer.eos_token_id is not None:
+ tokenizer.pad_token_id = tokenizer.eos_token_id
+ else:
+ tokenizer.pad_token_id = 0
+
+ # Determine dtype
+ torch_dtype = torch.float16
+ if args.device == "cpu":
+ torch_dtype = torch.float32
+
+ model = AutoModelForCausalLM.from_pretrained(
+ args.model_name_or_path,
+ device_map=args.device,
+ trust_remote_code=True,
+ torch_dtype=torch_dtype
+ )
+ except Exception as e:
+ print(f"Failed to load model: {e}")
+ return
+
+ if args.limit:
+ dataset = dataset[:args.limit]
+ print(f"Limiting to first {args.limit} examples.")
+
+ with open(args.output_file, "w", encoding="utf-8") as f_out:
+ batch_size = args.batch_size
+ for i in tqdm(range(0, len(dataset), batch_size), desc="Running Inference"):
+ batch = dataset[i : i + batch_size]
+ questions = [item["question"] for item in batch]
+
+ try:
+ generated_answers = run_inference_batch(model, tokenizer, questions, args.device)
+ except Exception as e:
+ print(f"Error generating for batch starting at index {i}: {e}")
+ generated_answers = [f"<ERROR: {str(e)}>" for _ in batch]
+
+ for item, ans in zip(batch, generated_answers):
+ result_entry = {
+ "file_index": item["file_index"],
+ "problem_type": item["problem_type"],
+ "variant": item["variant"],
+ "question": item["question"],
+ "solution": item["solution"],
+ "generated_solution": ans
+ }
+
+ f_out.write(json.dumps(result_entry, ensure_ascii=False) + "\n")
+ f_out.flush()
+
+ print(f"Done. Results saved to {args.output_file}")
+
+if __name__ == "__main__":
+ main()