summaryrefslogtreecommitdiff
path: root/putnamsup/run_putnam_gap_openrouter.py
blob: 8a23141c959bc7a79e0f191174753c438de618d1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import json
import argparse
import asyncio
import time
from tqdm.asyncio import tqdm
from putnam_utils import load_dataset, SUPPORTED_VARIANTS

try:
    from openai import AsyncOpenAI
except ImportError:
    AsyncOpenAI = None

async def process_item(sem, client, model_name, item):
    """
    Process a single item with semaphore for concurrency control.
    """
    async with sem:
        question = item["question"]
        prompt = f"Problem:\n{question}\n\nPlease solve the problem above step by step and provide the final answer.\n\nSolution:\n"
        messages = [{"role": "user", "content": prompt}]
        
        try:
            # Call API asynchronously
            completion = await client.chat.completions.create(
                model=model_name,
                messages=messages,
                temperature=0.0,
                max_tokens=2048,
                extra_headers={
                   "HTTP-Referer": "https://github.com/PutnamGAP",
                   "X-Title": "PutnamGAP Eval",
                }
            )
            generated_answer = completion.choices[0].message.content
        except Exception as e:
            generated_answer = f"<API ERROR: {str(e)}>"

        # Construct result entry
        result_entry = {
            "file_index": item["file_index"],
            "problem_type": item["problem_type"],
            "variant": item["variant"],
            "question": question,
            "solution": item["solution"],
            "generated_solution": generated_answer,
            "model": model_name
        }
        return result_entry

async def run_async_inference(args, dataset):
    if AsyncOpenAI is None:
        print("Error: 'openai' library not found. Please install it via: pip install openai")
        return

    if not args.api_key:
        print("Error: API key not provided. Use --api_key or set OPENROUTER_API_KEY env var.")
        return

    print(f"Initializing AsyncOpenAI client with base_url={args.base_url}")
    client = AsyncOpenAI(
        base_url=args.base_url,
        api_key=args.api_key,
    )

    concurrency = args.concurrency
    print(f"Running with concurrency: {concurrency}")
    sem = asyncio.Semaphore(concurrency)
    
    tasks = []
    for item in dataset:
        task = process_item(sem, client, args.model_name, item)
        tasks.append(task)

    print(f"Starting {len(tasks)} tasks using model: {args.model_name}")
    
    with open(args.output_file, "w", encoding="utf-8") as f_out:
        for future in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Async Inference"):
            result = await future
            f_out.write(json.dumps(result, ensure_ascii=False) + "\n")
            f_out.flush()

    print(f"Done. Results saved to {args.output_file}")

def main():
    parser = argparse.ArgumentParser(description="Run inference on PutnamGAP dataset via OpenRouter (Async)")
    parser.add_argument("--data_dir", type=str, default="PutnamGAP", help="Path to PutnamGAP JSON files")
    parser.add_argument("--model_name", type=str, required=True, help="OpenRouter model name")
    parser.add_argument("--output_file", type=str, default="putnam_gap_openrouter_results.jsonl", help="Output file path")
    parser.add_argument("--limit", type=int, default=None, help="Limit number of problems to run (for testing)")
    parser.add_argument("--concurrency", type=int, default=10, help="Number of concurrent requests")
    parser.add_argument("--api_key", type=str, default=os.getenv("OPENROUTER_API_KEY"), help="OpenRouter API Key")
    parser.add_argument("--base_url", type=str, default="https://openrouter.ai/api/v1", help="API Base URL")
    parser.add_argument("--dry_run", action="store_true", help="Only load data and print info")
    parser.add_argument("--variants", type=str, default=None, help=f"Comma-separated list of variants to include. Choices: {','.join(SUPPORTED_VARIANTS)}")
    
    args = parser.parse_args()

    # Parse variants argument
    selected_variants = None
    if args.variants:
        selected_variants = [v.strip() for v in args.variants.split(",")]
        print(f"Filtering for variants: {selected_variants}")
    
    print(f"Scanning data from {args.data_dir}...")
    dataset = list(load_dataset(args.data_dir, selected_variants=selected_variants))
    print(f"Found {len(dataset)} problem variants.")

    if args.dry_run:
        if dataset:
            print("\n--- Example 1 ---")
            print(f"Index: {dataset[0]['file_index']}")
            print(f"Variant: {dataset[0]['variant']}")
            print(f"Question: {dataset[0]['question'][:200]}...")
            return

    if args.limit:
        dataset = dataset[:args.limit]
        print(f"Limiting to first {args.limit} examples.")

    asyncio.run(run_async_inference(args, dataset))

if __name__ == "__main__":
    main()