summaryrefslogtreecommitdiff
path: root/putnamsup/run_putnam_gap.py
blob: 73f0ef6b6fa1daebb80c0b01ee3c4857be40f1cc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import os
import argparse
import torch
import time
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
from putnam_utils import load_dataset, SUPPORTED_VARIANTS
import json

def run_inference_batch(model, tokenizer, questions: list, device: str) -> list:
    """
    Runs generation for a batch of questions.
    """
    prompts = [f"Problem:\n{q}\n\nPlease solve the problem above step by step and provide the final answer.\n\nSolution:\n" for q in questions]
    
    # Determine target device for inputs
    if device == "auto":
        target_device = model.device
    else:
        target_device = device

    input_texts = []
    if tokenizer.chat_template:
         for p in prompts:
             messages = [{"role": "user", "content": p}]
             try:
                 formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
                 input_texts.append(formatted)
             except Exception:
                 input_texts.append(p)
    else:
        input_texts = prompts
    
    # Tokenize with padding
    inputs = tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True).to(target_device)

    with torch.no_grad():
        output_ids = model.generate(
            **inputs, 
            max_new_tokens=1024, 
            do_sample=False, 
            pad_token_id=tokenizer.pad_token_id
        )

    # Decode only new tokens
    # output_ids contains input_ids + new_tokens. We need to slice.
    # However, input lengths might vary due to padding.
    # batch_decode usually decodes everything.
    # A common trick is to decode everything and then strip the prompt, but prompts are different.
    # Better: tokenizer.batch_decode(output_ids[:, inputs.input_ids.shape[1]:]) works if left-padded and consistent length?
    # No, with left padding, the new tokens are at the end.
    
    generated_texts = tokenizer.batch_decode(output_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
    return [t.strip() for t in generated_texts]

def main():
    parser = argparse.ArgumentParser(description="Run inference on PutnamGAP dataset")
    parser.add_argument("--data_dir", type=str, default="PutnamGAP", help="Path to PutnamGAP JSON files")
    parser.add_argument("--model_name_or_path", type=str, required=True, help="Hugging Face model name or path")
    parser.add_argument("--output_file", type=str, default="putnam_gap_results.jsonl", help="Output file path")
    parser.add_argument("--limit", type=int, default=None, help="Limit total number of problems to run")
    parser.add_argument("--limit_per_variant", type=int, default=None, help="Limit number of problems per variant")
    parser.add_argument("--batch_size", type=int, default=1, help="Batch size for inference")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to run on (use 'auto' for multi-GPU)")
    parser.add_argument("--dry_run", action="store_true", help="Only load data and print first few examples, do not load model")
    parser.add_argument("--variants", type=str, default=None, help=f"Comma-separated list of variants to include. Choices: {','.join(SUPPORTED_VARIANTS)}")
    args = parser.parse_args()

    # Parse variants argument
    selected_variants = None
    
    # Diagnostic check for CUDA availability
    if torch.cuda.device_count() > 0 and not torch.cuda.is_available():
         print("\n" + "!"*60)
         print(f"WARNING: PyTorch detects {torch.cuda.device_count()} CUDA devices but cannot use them.")
         print(f"torch.cuda.is_available() == False")
         print(f"Current PyTorch version: {torch.__version__}")
         print(f"Your driver probably supports an older CUDA version than this PyTorch build.")
         print("!"*60 + "\n")
    
    if args.variants:
        selected_variants = [v.strip() for v in args.variants.split(",")]
        print(f"Filtering for variants: {selected_variants}")

    print(f"Scanning data from {args.data_dir}...")
    dataset = list(load_dataset(args.data_dir, selected_variants=selected_variants))
    print(f"Found {len(dataset)} problem variants.")

    if args.limit_per_variant:
        from collections import defaultdict
        counts = defaultdict(int)
        filtered_dataset = []
        for item in dataset:
            v = item['variant']
            if counts[v] < args.limit_per_variant:
                filtered_dataset.append(item)
                counts[v] += 1
        dataset = filtered_dataset
        print(f"Filtered to {len(dataset)} examples (max {args.limit_per_variant} per variant).")
    
    if args.dry_run:
        if dataset:
            print("\n--- Example 1 ---")
            print(f"Index: {dataset[0]['file_index']}")
            print(f"Variant: {dataset[0]['variant']}")
            print(f"Question: {dataset[0]['question'][:200]}...")
            print(f"Solution: {dataset[0]['solution'][:200]}...")
        return

    print(f"Loading model: {args.model_name_or_path} on {args.device}")
    
    try:
        tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True, padding_side='left')
        if tokenizer.pad_token_id is None:
            if tokenizer.eos_token_id is not None:
                tokenizer.pad_token_id = tokenizer.eos_token_id
            else:
                tokenizer.pad_token_id = 0
        
        # Determine dtype
        torch_dtype = torch.float16
        if args.device == "cpu":
            torch_dtype = torch.float32
            
        model = AutoModelForCausalLM.from_pretrained(
            args.model_name_or_path, 
            device_map=args.device, 
            trust_remote_code=True,
            torch_dtype=torch_dtype
        )
    except Exception as e:
        print(f"Failed to load model: {e}")
        return

    if args.limit:
        dataset = dataset[:args.limit]
        print(f"Limiting to first {args.limit} examples.")

    with open(args.output_file, "w", encoding="utf-8") as f_out:
        batch_size = args.batch_size
        for i in tqdm(range(0, len(dataset), batch_size), desc="Running Inference"):
            batch = dataset[i : i + batch_size]
            questions = [item["question"] for item in batch]
            
            try:
                generated_answers = run_inference_batch(model, tokenizer, questions, args.device)
            except Exception as e:
                print(f"Error generating for batch starting at index {i}: {e}")
                generated_answers = [f"<ERROR: {str(e)}>" for _ in batch]

            for item, ans in zip(batch, generated_answers):
                result_entry = {
                    "file_index": item["file_index"],
                    "problem_type": item["problem_type"],
                    "variant": item["variant"],
                    "question": item["question"],
                    "solution": item["solution"],
                    "generated_solution": ans
                }
                
                f_out.write(json.dumps(result_entry, ensure_ascii=False) + "\n")
            f_out.flush()

    print(f"Done. Results saved to {args.output_file}")

if __name__ == "__main__":
    main()