diff options
Diffstat (limited to 'putnamsup/run_putnam_gap.py')
| -rw-r--r-- | putnamsup/run_putnam_gap.py | 167 |
1 files changed, 167 insertions, 0 deletions
diff --git a/putnamsup/run_putnam_gap.py b/putnamsup/run_putnam_gap.py new file mode 100644 index 0000000..73f0ef6 --- /dev/null +++ b/putnamsup/run_putnam_gap.py @@ -0,0 +1,167 @@ +import os +import argparse +import torch +import time +from transformers import AutoModelForCausalLM, AutoTokenizer +from tqdm import tqdm +from putnam_utils import load_dataset, SUPPORTED_VARIANTS +import json + +def run_inference_batch(model, tokenizer, questions: list, device: str) -> list: + """ + Runs generation for a batch of questions. + """ + prompts = [f"Problem:\n{q}\n\nPlease solve the problem above step by step and provide the final answer.\n\nSolution:\n" for q in questions] + + # Determine target device for inputs + if device == "auto": + target_device = model.device + else: + target_device = device + + input_texts = [] + if tokenizer.chat_template: + for p in prompts: + messages = [{"role": "user", "content": p}] + try: + formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + input_texts.append(formatted) + except Exception: + input_texts.append(p) + else: + input_texts = prompts + + # Tokenize with padding + inputs = tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True).to(target_device) + + with torch.no_grad(): + output_ids = model.generate( + **inputs, + max_new_tokens=1024, + do_sample=False, + pad_token_id=tokenizer.pad_token_id + ) + + # Decode only new tokens + # output_ids contains input_ids + new_tokens. We need to slice. + # However, input lengths might vary due to padding. + # batch_decode usually decodes everything. + # A common trick is to decode everything and then strip the prompt, but prompts are different. + # Better: tokenizer.batch_decode(output_ids[:, inputs.input_ids.shape[1]:]) works if left-padded and consistent length? + # No, with left padding, the new tokens are at the end. + + generated_texts = tokenizer.batch_decode(output_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True) + return [t.strip() for t in generated_texts] + +def main(): + parser = argparse.ArgumentParser(description="Run inference on PutnamGAP dataset") + parser.add_argument("--data_dir", type=str, default="PutnamGAP", help="Path to PutnamGAP JSON files") + parser.add_argument("--model_name_or_path", type=str, required=True, help="Hugging Face model name or path") + parser.add_argument("--output_file", type=str, default="putnam_gap_results.jsonl", help="Output file path") + parser.add_argument("--limit", type=int, default=None, help="Limit total number of problems to run") + parser.add_argument("--limit_per_variant", type=int, default=None, help="Limit number of problems per variant") + parser.add_argument("--batch_size", type=int, default=1, help="Batch size for inference") + parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to run on (use 'auto' for multi-GPU)") + parser.add_argument("--dry_run", action="store_true", help="Only load data and print first few examples, do not load model") + parser.add_argument("--variants", type=str, default=None, help=f"Comma-separated list of variants to include. Choices: {','.join(SUPPORTED_VARIANTS)}") + args = parser.parse_args() + + # Parse variants argument + selected_variants = None + + # Diagnostic check for CUDA availability + if torch.cuda.device_count() > 0 and not torch.cuda.is_available(): + print("\n" + "!"*60) + print(f"WARNING: PyTorch detects {torch.cuda.device_count()} CUDA devices but cannot use them.") + print(f"torch.cuda.is_available() == False") + print(f"Current PyTorch version: {torch.__version__}") + print(f"Your driver probably supports an older CUDA version than this PyTorch build.") + print("!"*60 + "\n") + + if args.variants: + selected_variants = [v.strip() for v in args.variants.split(",")] + print(f"Filtering for variants: {selected_variants}") + + print(f"Scanning data from {args.data_dir}...") + dataset = list(load_dataset(args.data_dir, selected_variants=selected_variants)) + print(f"Found {len(dataset)} problem variants.") + + if args.limit_per_variant: + from collections import defaultdict + counts = defaultdict(int) + filtered_dataset = [] + for item in dataset: + v = item['variant'] + if counts[v] < args.limit_per_variant: + filtered_dataset.append(item) + counts[v] += 1 + dataset = filtered_dataset + print(f"Filtered to {len(dataset)} examples (max {args.limit_per_variant} per variant).") + + if args.dry_run: + if dataset: + print("\n--- Example 1 ---") + print(f"Index: {dataset[0]['file_index']}") + print(f"Variant: {dataset[0]['variant']}") + print(f"Question: {dataset[0]['question'][:200]}...") + print(f"Solution: {dataset[0]['solution'][:200]}...") + return + + print(f"Loading model: {args.model_name_or_path} on {args.device}") + + try: + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True, padding_side='left') + if tokenizer.pad_token_id is None: + if tokenizer.eos_token_id is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id + else: + tokenizer.pad_token_id = 0 + + # Determine dtype + torch_dtype = torch.float16 + if args.device == "cpu": + torch_dtype = torch.float32 + + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + device_map=args.device, + trust_remote_code=True, + torch_dtype=torch_dtype + ) + except Exception as e: + print(f"Failed to load model: {e}") + return + + if args.limit: + dataset = dataset[:args.limit] + print(f"Limiting to first {args.limit} examples.") + + with open(args.output_file, "w", encoding="utf-8") as f_out: + batch_size = args.batch_size + for i in tqdm(range(0, len(dataset), batch_size), desc="Running Inference"): + batch = dataset[i : i + batch_size] + questions = [item["question"] for item in batch] + + try: + generated_answers = run_inference_batch(model, tokenizer, questions, args.device) + except Exception as e: + print(f"Error generating for batch starting at index {i}: {e}") + generated_answers = [f"<ERROR: {str(e)}>" for _ in batch] + + for item, ans in zip(batch, generated_answers): + result_entry = { + "file_index": item["file_index"], + "problem_type": item["problem_type"], + "variant": item["variant"], + "question": item["question"], + "solution": item["solution"], + "generated_solution": ans + } + + f_out.write(json.dumps(result_entry, ensure_ascii=False) + "\n") + f_out.flush() + + print(f"Done. Results saved to {args.output_file}") + +if __name__ == "__main__": + main() |
