1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
|
import os
import json
import argparse
import asyncio
import time
from tqdm.asyncio import tqdm
from putnam_utils import load_dataset, SUPPORTED_VARIANTS
try:
from openai import AsyncOpenAI
except ImportError:
AsyncOpenAI = None
async def process_item(sem, client, model_name, item):
"""
Process a single item with semaphore for concurrency control.
"""
async with sem:
question = item["question"]
prompt = f"Problem:\n{question}\n\nPlease solve the problem above step by step and provide the final answer.\n\nSolution:\n"
messages = [{"role": "user", "content": prompt}]
try:
# Call API asynchronously
completion = await client.chat.completions.create(
model=model_name,
messages=messages,
temperature=0.0,
max_tokens=2048,
extra_headers={
"HTTP-Referer": "https://github.com/PutnamGAP",
"X-Title": "PutnamGAP Eval",
}
)
generated_answer = completion.choices[0].message.content
except Exception as e:
generated_answer = f"<API ERROR: {str(e)}>"
# Construct result entry
result_entry = {
"file_index": item["file_index"],
"problem_type": item["problem_type"],
"variant": item["variant"],
"question": question,
"solution": item["solution"],
"generated_solution": generated_answer,
"model": model_name
}
return result_entry
async def run_async_inference(args, dataset):
if AsyncOpenAI is None:
print("Error: 'openai' library not found. Please install it via: pip install openai")
return
if not args.api_key:
print("Error: API key not provided. Use --api_key or set OPENROUTER_API_KEY env var.")
return
print(f"Initializing AsyncOpenAI client with base_url={args.base_url}")
client = AsyncOpenAI(
base_url=args.base_url,
api_key=args.api_key,
)
concurrency = args.concurrency
print(f"Running with concurrency: {concurrency}")
sem = asyncio.Semaphore(concurrency)
tasks = []
for item in dataset:
task = process_item(sem, client, args.model_name, item)
tasks.append(task)
print(f"Starting {len(tasks)} tasks using model: {args.model_name}")
with open(args.output_file, "w", encoding="utf-8") as f_out:
for future in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Async Inference"):
result = await future
f_out.write(json.dumps(result, ensure_ascii=False) + "\n")
f_out.flush()
print(f"Done. Results saved to {args.output_file}")
def main():
parser = argparse.ArgumentParser(description="Run inference on PutnamGAP dataset via OpenRouter (Async)")
parser.add_argument("--data_dir", type=str, default="PutnamGAP", help="Path to PutnamGAP JSON files")
parser.add_argument("--model_name", type=str, required=True, help="OpenRouter model name")
parser.add_argument("--output_file", type=str, default="putnam_gap_openrouter_results.jsonl", help="Output file path")
parser.add_argument("--limit", type=int, default=None, help="Limit number of problems to run (for testing)")
parser.add_argument("--concurrency", type=int, default=10, help="Number of concurrent requests")
parser.add_argument("--api_key", type=str, default=os.getenv("OPENROUTER_API_KEY"), help="OpenRouter API Key")
parser.add_argument("--base_url", type=str, default="https://openrouter.ai/api/v1", help="API Base URL")
parser.add_argument("--dry_run", action="store_true", help="Only load data and print info")
parser.add_argument("--variants", type=str, default=None, help=f"Comma-separated list of variants to include. Choices: {','.join(SUPPORTED_VARIANTS)}")
args = parser.parse_args()
# Parse variants argument
selected_variants = None
if args.variants:
selected_variants = [v.strip() for v in args.variants.split(",")]
print(f"Filtering for variants: {selected_variants}")
print(f"Scanning data from {args.data_dir}...")
dataset = list(load_dataset(args.data_dir, selected_variants=selected_variants))
print(f"Found {len(dataset)} problem variants.")
if args.dry_run:
if dataset:
print("\n--- Example 1 ---")
print(f"Index: {dataset[0]['file_index']}")
print(f"Variant: {dataset[0]['variant']}")
print(f"Question: {dataset[0]['question'][:200]}...")
return
if args.limit:
dataset = dataset[:args.limit]
print(f"Limiting to first {args.limit} examples.")
asyncio.run(run_async_inference(args, dataset))
if __name__ == "__main__":
main()
|