summaryrefslogtreecommitdiff
path: root/putnamsup
diff options
context:
space:
mode:
Diffstat (limited to 'putnamsup')
-rw-r--r--putnamsup/evaluate_putnam_gap.py74
-rw-r--r--putnamsup/putnam_utils.py95
-rw-r--r--putnamsup/putnamgap_viewer.py277
-rw-r--r--putnamsup/requirements.txt7
-rw-r--r--putnamsup/run_putnam_gap.py167
-rw-r--r--putnamsup/run_putnam_gap_openrouter.py124
6 files changed, 744 insertions, 0 deletions
diff --git a/putnamsup/evaluate_putnam_gap.py b/putnamsup/evaluate_putnam_gap.py
new file mode 100644
index 0000000..5c9f35e
--- /dev/null
+++ b/putnamsup/evaluate_putnam_gap.py
@@ -0,0 +1,74 @@
+import json
+import argparse
+import re
+
+def normalize_answer(text):
+ """Simple normalization for comparison."""
+ if text is None: return ""
+ text = text.strip().lower()
+ # Remove latex formatting for simple check
+ text = re.sub(r'\\[\(\)\[\]]', ' ', text)
+ return text
+
+def simple_evaluate(ground_truth, generated):
+ """
+ A very naive evaluator.
+ Returns True if the generated answer seems to contain the ground truth
+ (if ground truth is short) or based on some heuristics.
+ """
+ gt_norm = normalize_answer(ground_truth)
+ gen_norm = normalize_answer(generated)
+
+ # If ground truth is very short (likely a number or variable), check if it's in the generated text
+ if len(gt_norm) < 20:
+ return gt_norm in gen_norm
+
+ # For longer proofs, this metric is useless.
+ return False
+
+def main():
+ parser = argparse.ArgumentParser(description="Evaluate PutnamGAP results")
+ parser.add_argument("--results_file", type=str, required=True, help="Path to JSONL results file")
+ args = parser.parse_args()
+
+ total = 0
+ correct_heuristic = 0
+ by_type = {}
+
+ print(f"Evaluating {args.results_file}...")
+
+ with open(args.results_file, "r", encoding="utf-8") as f:
+ for line in f:
+ line = line.strip()
+ if not line: continue
+
+ data = json.loads(line)
+ prob_type = data.get("problem_type", "unknown")
+
+ total += 1
+ if prob_type not in by_type:
+ by_type[prob_type] = {"count": 0, "heuristic_match": 0}
+
+ by_type[prob_type]["count"] += 1
+
+ # This is a placeholder evaluation.
+ # Real evaluation for proofs needs an LLM judge.
+ is_match = simple_evaluate(data["solution"], data["generated_solution"])
+
+ if is_match:
+ correct_heuristic += 1
+ by_type[prob_type]["heuristic_match"] += 1
+
+ print(f"Total processed: {total}")
+ print("-" * 40)
+ print("Breakdown by Problem Type:")
+ for p_type, stats in by_type.items():
+ acc = (stats["heuristic_match"] / stats["count"]) * 100 if stats["count"] > 0 else 0
+ print(f" {p_type}: {stats['count']} items, {stats['heuristic_match']} heuristic matches ({acc:.2f}%)")
+ print("-" * 40)
+ print("Note: The heuristic match is very basic (checks if short ground truth is substring of generated output).")
+ print("For 'proof' problems, this metric is not reliable. Use an LLM-based judge for accurate evaluation.")
+
+if __name__ == "__main__":
+ main()
+
diff --git a/putnamsup/putnam_utils.py b/putnamsup/putnam_utils.py
new file mode 100644
index 0000000..7761c49
--- /dev/null
+++ b/putnamsup/putnam_utils.py
@@ -0,0 +1,95 @@
+import os
+import json
+from typing import Dict, Any, Generator, Tuple, Optional, List
+
+# Supported variants as seen in putnamgap_viewer.py
+SUPPORTED_VARIANTS = [
+ "original",
+ "descriptive_long",
+ "descriptive_long_confusing",
+ "descriptive_long_misleading",
+ "garbled_string",
+ "kernel_variant",
+]
+
+def get_original_qa(d: Dict[str, Any]) -> Tuple[Optional[str], Optional[str]]:
+ """Extract original question and solution."""
+ question = d.get("question")
+ solution = d.get("solution", d.get("answer"))
+ return question, solution
+
+def get_variant_qa(d: Dict[str, Any], variant_key: str) -> Tuple[Optional[str], Optional[str]]:
+ """Extract variant question and solution."""
+ variants = d.get("variants")
+ if not isinstance(variants, dict):
+ return None, None
+ var = variants.get(variant_key)
+ if not isinstance(var, dict):
+ return None, None
+ question = var.get("question")
+ solution = var.get("solution", var.get("answer"))
+ return question, solution
+
+def load_dataset(data_dir: str, selected_variants: Optional[List[str]] = None) -> Generator[Dict[str, Any], None, None]:
+ """
+ Iterates over all JSON files in data_dir and yields problem instances.
+ Each instance is a dict with keys: file_index, type, variant, question, solution.
+
+ Args:
+ data_dir: Path to the dataset directory.
+ selected_variants: List of variants to include. If None, include all.
+ Supported values are in SUPPORTED_VARIANTS.
+ """
+ if not os.path.isdir(data_dir):
+ raise ValueError(f"Directory not found: {data_dir}")
+
+ # Validate selected_variants
+ if selected_variants:
+ for v in selected_variants:
+ if v not in SUPPORTED_VARIANTS:
+ print(f"Warning: Variant '{v}' not recognized. Supported: {SUPPORTED_VARIANTS}")
+
+ # If no filter provided, use all supported
+ target_variants = selected_variants if selected_variants else SUPPORTED_VARIANTS
+
+ files = [f for f in os.listdir(data_dir) if f.lower().endswith(".json")]
+ files.sort()
+
+ for f in files:
+ filepath = os.path.join(data_dir, f)
+ try:
+ with open(filepath, "r", encoding="utf-8") as fp:
+ data = json.load(fp)
+ except Exception as e:
+ print(f"Error loading {filepath}: {e}")
+ continue
+
+ file_index = data.get("index", f) # Use filename as index if 'index' key missing
+ prob_type = data.get("problem_type", "unknown")
+
+ # 1. Original
+ if "original" in target_variants:
+ q, a = get_original_qa(data)
+ if q and a:
+ yield {
+ "file_index": file_index,
+ "problem_type": prob_type,
+ "variant": "original",
+ "question": q,
+ "solution": a
+ }
+
+ # 2. Variants
+ for var_key in SUPPORTED_VARIANTS:
+ if var_key == "original": continue
+ if var_key not in target_variants: continue
+
+ q, a = get_variant_qa(data, var_key)
+ if q and a:
+ yield {
+ "file_index": file_index,
+ "problem_type": prob_type,
+ "variant": var_key,
+ "question": q,
+ "solution": a
+ }
diff --git a/putnamsup/putnamgap_viewer.py b/putnamsup/putnamgap_viewer.py
new file mode 100644
index 0000000..d3678f1
--- /dev/null
+++ b/putnamsup/putnamgap_viewer.py
@@ -0,0 +1,277 @@
+#!/usr/bin/env python3
+"""
+Streamlit-based PutnamGAP dataset viewer.
+
+Features:
+- Scans preprocess/PutnamGAP for JSON files and allows prev/next navigation
+- Select specific file from a dropdown
+- Choose which variant to display: original or one of:
+ descriptive_long, descriptive_long_confusing, descriptive_long_misleading, garbled_string, kernel_variant
+- Toggle to show Question, Solution (a.k.a. Answer), or Both
+- TeX rendering via Markdown by default, with optional HTML+MathJax fallback
+"""
+import json
+import os
+from typing import Any, Dict, List, Optional, Tuple
+
+import streamlit as st
+from streamlit.components.v1 import html as st_html
+
+
+DATA_DIR = os.path.join(os.path.dirname(__file__), "PutnamGAP")
+SUPPORTED_VARIANTS = [
+ "original",
+ "descriptive_long",
+ "descriptive_long_confusing",
+ "descriptive_long_misleading",
+ "garbled_string",
+ "kernel_variant",
+]
+
+
+def discover_json_files(data_dir: str) -> List[str]:
+ if not os.path.isdir(data_dir):
+ return []
+ files = [
+ os.path.join(data_dir, f)
+ for f in os.listdir(data_dir)
+ if f.lower().endswith(".json")
+ ]
+ files.sort()
+ return files
+
+
+def load_json(filepath: str) -> Dict[str, Any]:
+ with open(filepath, "r", encoding="utf-8") as f:
+ return json.load(f)
+
+
+def get_original_qa(d: Dict[str, Any]) -> Tuple[Optional[str], Optional[str]]:
+ # Prefer "question"/"solution"; gracefully fall back to "answer" if present
+ question: Optional[str] = d.get("question")
+ solution: Optional[str] = d.get("solution", d.get("answer"))
+ return question, solution
+
+
+def get_variant_qa(
+ d: Dict[str, Any], variant_key: str
+) -> Tuple[Optional[str], Optional[str]]:
+ variants = d.get("variants")
+ if not isinstance(variants, dict):
+ return None, None
+ var = variants.get(variant_key)
+ if not isinstance(var, dict):
+ return None, None
+ question: Optional[str] = var.get("question")
+ solution: Optional[str] = var.get("solution", var.get("answer"))
+ return question, solution
+
+
+def render_markdown_with_math(text: str) -> None:
+ # Streamlit markdown supports MathJax ($...$, $$...$$)
+ st.markdown(text, unsafe_allow_html=True)
+
+
+def render_with_mathjax_html(blocks: List[Tuple[str, str]]) -> None:
+ """
+ Render content with MathJax v3 inside a single HTML component.
+ blocks: list of (heading, content) tuples
+ """
+ # Build a small HTML page with MathJax v3; render all blocks together.
+ content_sections = []
+ for heading, content in blocks:
+ section_html = f"""
+ <section style="margin-bottom: 1.25rem;">
+ <h3 style="margin: 0 0 .5rem 0; font-family: ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial;">
+ {heading}
+ </h3>
+ <div class="mj-content">{content}</div>
+ </section>
+ """
+ content_sections.append(section_html)
+
+ page = f"""
+<!DOCTYPE html>
+<html>
+ <head>
+ <meta charset="utf-8">
+ <meta name="viewport" content="width=device-width, initial-scale=1">
+ <script>
+ window.MathJax = {{
+ tex: {{
+ inlineMath: [['$', '$'], ['\\\\(', '\\\\)']],
+ displayMath: [['$$', '$$'], ['\\\\[', '\\\\]']]
+ }},
+ svg: {{ fontCache: 'global' }}
+ }};
+ </script>
+ <script src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-svg.js" async></script>
+ <style>
+ html, body {{
+ background: #0f0f10;
+ color: #f5f6f7;
+ }}
+ body {{
+ padding: 0.5rem 0.25rem;
+ color: #f5f6f7;
+ background: #0f0f10;
+ }}
+ .mj-content {{
+ line-height: 1.6;
+ white-space: pre-wrap;
+ word-wrap: break-word;
+ font-family: ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial;
+ font-size: 1rem;
+ color: #f5f6f7;
+ background: #0f0f10;
+ padding: 0.25rem 0.25rem;
+ border-radius: 4px;
+ }}
+ code, pre {{
+ font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;
+ color: #e6e6e6;
+ }}
+ svg, .MathJax, .mjx-svg, .mjx-mrow {{
+ color: #f5f6f7;
+ }}
+ </style>
+ </head>
+ <body>
+ {''.join(content_sections)}
+ </body>
+</html>
+"""
+ st_html(page, height=600, scrolling=True)
+
+
+def main() -> None:
+ st.set_page_config(page_title="PutnamGAP Viewer", layout="wide")
+ st.title("PutnamGAP 数据可视化与校对")
+ st.caption("浏览原题与不同变体;支持 TeX 渲染与文件前后切换。")
+
+ files = discover_json_files(DATA_DIR)
+ if not files:
+ st.error(f"未在目录中发现 JSON 文件:{DATA_DIR}")
+ st.stop()
+
+ # Sidebar controls
+ with st.sidebar:
+ st.subheader("文件与显示设置")
+
+ # Single source of truth for navigation: file_index
+ file_labels = [os.path.basename(p) for p in files]
+ if "file_index" not in st.session_state:
+ st.session_state.file_index = 0
+
+ selected_label = st.selectbox(
+ "选择题目文件",
+ options=file_labels,
+ index=st.session_state.file_index,
+ )
+ # Sync file_index if user chose a different label
+ current_index = file_labels.index(selected_label)
+ if current_index != st.session_state.file_index:
+ st.session_state.file_index = current_index
+
+ # Variant selection
+ variant_human_labels = {
+ "original": "原题 original",
+ "descriptive_long": "descriptive_long",
+ "descriptive_long_confusing": "descriptive_long_confusing",
+ "descriptive_long_misleading": "descriptive_long_misleading",
+ "garbled_string": "garbled_string",
+ "kernel_variant": "kernel_variant",
+ }
+ variant_choice_label = st.radio(
+ "选择显示内容",
+ options=[variant_human_labels[k] for k in SUPPORTED_VARIANTS],
+ index=0,
+ )
+ # Reverse map to internal key
+ selected_variant = {
+ v: k for k, v in variant_human_labels.items()
+ }[variant_choice_label]
+
+ # Display options
+ show_mode = st.radio(
+ "显示部分",
+ options=["Question", "Solution", "Both"],
+ index=0,
+ horizontal=True,
+ )
+ render_mode = st.radio(
+ "渲染方式",
+ options=["Markdown (默认)", "HTML + MathJax"],
+ index=1,
+ )
+ show_meta = st.checkbox("显示原始 JSON 片段", value=False)
+
+ # Prev/Next navigation buttons in header row
+ left, mid, right = st.columns([1, 6, 1])
+ with left:
+ if st.button("⬅️ 上一题", use_container_width=True):
+ new_index = (st.session_state.file_index - 1) % len(files)
+ st.session_state.file_index = new_index
+ st.rerun()
+ with right:
+ if st.button("下一题 ➡️", use_container_width=True):
+ new_index = (st.session_state.file_index + 1) % len(files)
+ st.session_state.file_index = new_index
+ st.rerun()
+
+ current_path = files[st.session_state.file_index]
+ data = load_json(current_path)
+
+ st.write(f"当前文件:`{os.path.basename(current_path)}` ({st.session_state.file_index + 1}/{len(files)})")
+ st.divider()
+
+ # Resolve question/solution for chosen variant
+ if selected_variant == "original":
+ q_text, s_text = get_original_qa(data)
+ else:
+ q_text, s_text = get_variant_qa(data, selected_variant)
+
+ # Assemble content blocks to render
+ blocks: List[Tuple[str, str]] = []
+ if show_mode in ("Question", "Both"):
+ if q_text:
+ blocks.append(("Question", q_text))
+ else:
+ st.warning("该选择下未找到 Question。")
+ if show_mode in ("Solution", "Both"):
+ if s_text:
+ blocks.append(("Solution", s_text))
+ else:
+ st.warning("该选择下未找到 Solution/Answer。")
+
+ if len(blocks) > 0:
+ if render_mode.startswith("Markdown"):
+ for heading, content in blocks:
+ st.subheader(heading)
+ render_markdown_with_math(content)
+ st.markdown("---")
+ else:
+ render_with_mathjax_html(blocks)
+ else:
+ st.info("无可显示内容。")
+
+ if show_meta:
+ with st.expander("原始 JSON(截断显示)", expanded=False):
+ # Show a trimmed version to avoid overwhelming the UI
+ preview: Dict[str, Any] = {}
+ for k in ("index", "type", "tag", "difficulty", "problem_type"):
+ if k in data:
+ preview[k] = data[k]
+ preview["keys"] = list(data.keys())
+ st.json(preview)
+
+ st.caption("完整 JSON 路径:")
+ st.code(current_path)
+
+ st.caption("提示:可以使用侧边栏选择具体文件与变体,也可通过顶部按钮快速前后切换 JSON。")
+
+
+if __name__ == "__main__":
+ main()
+
+
diff --git a/putnamsup/requirements.txt b/putnamsup/requirements.txt
new file mode 100644
index 0000000..981cde1
--- /dev/null
+++ b/putnamsup/requirements.txt
@@ -0,0 +1,7 @@
+torch>=2.0.0
+transformers>=4.37.0
+accelerate>=0.26.0
+tqdm>=4.66.0
+openai>=1.0.0
+streamlit>=1.30.0
+
diff --git a/putnamsup/run_putnam_gap.py b/putnamsup/run_putnam_gap.py
new file mode 100644
index 0000000..73f0ef6
--- /dev/null
+++ b/putnamsup/run_putnam_gap.py
@@ -0,0 +1,167 @@
+import os
+import argparse
+import torch
+import time
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from tqdm import tqdm
+from putnam_utils import load_dataset, SUPPORTED_VARIANTS
+import json
+
+def run_inference_batch(model, tokenizer, questions: list, device: str) -> list:
+ """
+ Runs generation for a batch of questions.
+ """
+ prompts = [f"Problem:\n{q}\n\nPlease solve the problem above step by step and provide the final answer.\n\nSolution:\n" for q in questions]
+
+ # Determine target device for inputs
+ if device == "auto":
+ target_device = model.device
+ else:
+ target_device = device
+
+ input_texts = []
+ if tokenizer.chat_template:
+ for p in prompts:
+ messages = [{"role": "user", "content": p}]
+ try:
+ formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+ input_texts.append(formatted)
+ except Exception:
+ input_texts.append(p)
+ else:
+ input_texts = prompts
+
+ # Tokenize with padding
+ inputs = tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True).to(target_device)
+
+ with torch.no_grad():
+ output_ids = model.generate(
+ **inputs,
+ max_new_tokens=1024,
+ do_sample=False,
+ pad_token_id=tokenizer.pad_token_id
+ )
+
+ # Decode only new tokens
+ # output_ids contains input_ids + new_tokens. We need to slice.
+ # However, input lengths might vary due to padding.
+ # batch_decode usually decodes everything.
+ # A common trick is to decode everything and then strip the prompt, but prompts are different.
+ # Better: tokenizer.batch_decode(output_ids[:, inputs.input_ids.shape[1]:]) works if left-padded and consistent length?
+ # No, with left padding, the new tokens are at the end.
+
+ generated_texts = tokenizer.batch_decode(output_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
+ return [t.strip() for t in generated_texts]
+
+def main():
+ parser = argparse.ArgumentParser(description="Run inference on PutnamGAP dataset")
+ parser.add_argument("--data_dir", type=str, default="PutnamGAP", help="Path to PutnamGAP JSON files")
+ parser.add_argument("--model_name_or_path", type=str, required=True, help="Hugging Face model name or path")
+ parser.add_argument("--output_file", type=str, default="putnam_gap_results.jsonl", help="Output file path")
+ parser.add_argument("--limit", type=int, default=None, help="Limit total number of problems to run")
+ parser.add_argument("--limit_per_variant", type=int, default=None, help="Limit number of problems per variant")
+ parser.add_argument("--batch_size", type=int, default=1, help="Batch size for inference")
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to run on (use 'auto' for multi-GPU)")
+ parser.add_argument("--dry_run", action="store_true", help="Only load data and print first few examples, do not load model")
+ parser.add_argument("--variants", type=str, default=None, help=f"Comma-separated list of variants to include. Choices: {','.join(SUPPORTED_VARIANTS)}")
+ args = parser.parse_args()
+
+ # Parse variants argument
+ selected_variants = None
+
+ # Diagnostic check for CUDA availability
+ if torch.cuda.device_count() > 0 and not torch.cuda.is_available():
+ print("\n" + "!"*60)
+ print(f"WARNING: PyTorch detects {torch.cuda.device_count()} CUDA devices but cannot use them.")
+ print(f"torch.cuda.is_available() == False")
+ print(f"Current PyTorch version: {torch.__version__}")
+ print(f"Your driver probably supports an older CUDA version than this PyTorch build.")
+ print("!"*60 + "\n")
+
+ if args.variants:
+ selected_variants = [v.strip() for v in args.variants.split(",")]
+ print(f"Filtering for variants: {selected_variants}")
+
+ print(f"Scanning data from {args.data_dir}...")
+ dataset = list(load_dataset(args.data_dir, selected_variants=selected_variants))
+ print(f"Found {len(dataset)} problem variants.")
+
+ if args.limit_per_variant:
+ from collections import defaultdict
+ counts = defaultdict(int)
+ filtered_dataset = []
+ for item in dataset:
+ v = item['variant']
+ if counts[v] < args.limit_per_variant:
+ filtered_dataset.append(item)
+ counts[v] += 1
+ dataset = filtered_dataset
+ print(f"Filtered to {len(dataset)} examples (max {args.limit_per_variant} per variant).")
+
+ if args.dry_run:
+ if dataset:
+ print("\n--- Example 1 ---")
+ print(f"Index: {dataset[0]['file_index']}")
+ print(f"Variant: {dataset[0]['variant']}")
+ print(f"Question: {dataset[0]['question'][:200]}...")
+ print(f"Solution: {dataset[0]['solution'][:200]}...")
+ return
+
+ print(f"Loading model: {args.model_name_or_path} on {args.device}")
+
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True, padding_side='left')
+ if tokenizer.pad_token_id is None:
+ if tokenizer.eos_token_id is not None:
+ tokenizer.pad_token_id = tokenizer.eos_token_id
+ else:
+ tokenizer.pad_token_id = 0
+
+ # Determine dtype
+ torch_dtype = torch.float16
+ if args.device == "cpu":
+ torch_dtype = torch.float32
+
+ model = AutoModelForCausalLM.from_pretrained(
+ args.model_name_or_path,
+ device_map=args.device,
+ trust_remote_code=True,
+ torch_dtype=torch_dtype
+ )
+ except Exception as e:
+ print(f"Failed to load model: {e}")
+ return
+
+ if args.limit:
+ dataset = dataset[:args.limit]
+ print(f"Limiting to first {args.limit} examples.")
+
+ with open(args.output_file, "w", encoding="utf-8") as f_out:
+ batch_size = args.batch_size
+ for i in tqdm(range(0, len(dataset), batch_size), desc="Running Inference"):
+ batch = dataset[i : i + batch_size]
+ questions = [item["question"] for item in batch]
+
+ try:
+ generated_answers = run_inference_batch(model, tokenizer, questions, args.device)
+ except Exception as e:
+ print(f"Error generating for batch starting at index {i}: {e}")
+ generated_answers = [f"<ERROR: {str(e)}>" for _ in batch]
+
+ for item, ans in zip(batch, generated_answers):
+ result_entry = {
+ "file_index": item["file_index"],
+ "problem_type": item["problem_type"],
+ "variant": item["variant"],
+ "question": item["question"],
+ "solution": item["solution"],
+ "generated_solution": ans
+ }
+
+ f_out.write(json.dumps(result_entry, ensure_ascii=False) + "\n")
+ f_out.flush()
+
+ print(f"Done. Results saved to {args.output_file}")
+
+if __name__ == "__main__":
+ main()
diff --git a/putnamsup/run_putnam_gap_openrouter.py b/putnamsup/run_putnam_gap_openrouter.py
new file mode 100644
index 0000000..8a23141
--- /dev/null
+++ b/putnamsup/run_putnam_gap_openrouter.py
@@ -0,0 +1,124 @@
+import os
+import json
+import argparse
+import asyncio
+import time
+from tqdm.asyncio import tqdm
+from putnam_utils import load_dataset, SUPPORTED_VARIANTS
+
+try:
+ from openai import AsyncOpenAI
+except ImportError:
+ AsyncOpenAI = None
+
+async def process_item(sem, client, model_name, item):
+ """
+ Process a single item with semaphore for concurrency control.
+ """
+ async with sem:
+ question = item["question"]
+ prompt = f"Problem:\n{question}\n\nPlease solve the problem above step by step and provide the final answer.\n\nSolution:\n"
+ messages = [{"role": "user", "content": prompt}]
+
+ try:
+ # Call API asynchronously
+ completion = await client.chat.completions.create(
+ model=model_name,
+ messages=messages,
+ temperature=0.0,
+ max_tokens=2048,
+ extra_headers={
+ "HTTP-Referer": "https://github.com/PutnamGAP",
+ "X-Title": "PutnamGAP Eval",
+ }
+ )
+ generated_answer = completion.choices[0].message.content
+ except Exception as e:
+ generated_answer = f"<API ERROR: {str(e)}>"
+
+ # Construct result entry
+ result_entry = {
+ "file_index": item["file_index"],
+ "problem_type": item["problem_type"],
+ "variant": item["variant"],
+ "question": question,
+ "solution": item["solution"],
+ "generated_solution": generated_answer,
+ "model": model_name
+ }
+ return result_entry
+
+async def run_async_inference(args, dataset):
+ if AsyncOpenAI is None:
+ print("Error: 'openai' library not found. Please install it via: pip install openai")
+ return
+
+ if not args.api_key:
+ print("Error: API key not provided. Use --api_key or set OPENROUTER_API_KEY env var.")
+ return
+
+ print(f"Initializing AsyncOpenAI client with base_url={args.base_url}")
+ client = AsyncOpenAI(
+ base_url=args.base_url,
+ api_key=args.api_key,
+ )
+
+ concurrency = args.concurrency
+ print(f"Running with concurrency: {concurrency}")
+ sem = asyncio.Semaphore(concurrency)
+
+ tasks = []
+ for item in dataset:
+ task = process_item(sem, client, args.model_name, item)
+ tasks.append(task)
+
+ print(f"Starting {len(tasks)} tasks using model: {args.model_name}")
+
+ with open(args.output_file, "w", encoding="utf-8") as f_out:
+ for future in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Async Inference"):
+ result = await future
+ f_out.write(json.dumps(result, ensure_ascii=False) + "\n")
+ f_out.flush()
+
+ print(f"Done. Results saved to {args.output_file}")
+
+def main():
+ parser = argparse.ArgumentParser(description="Run inference on PutnamGAP dataset via OpenRouter (Async)")
+ parser.add_argument("--data_dir", type=str, default="PutnamGAP", help="Path to PutnamGAP JSON files")
+ parser.add_argument("--model_name", type=str, required=True, help="OpenRouter model name")
+ parser.add_argument("--output_file", type=str, default="putnam_gap_openrouter_results.jsonl", help="Output file path")
+ parser.add_argument("--limit", type=int, default=None, help="Limit number of problems to run (for testing)")
+ parser.add_argument("--concurrency", type=int, default=10, help="Number of concurrent requests")
+ parser.add_argument("--api_key", type=str, default=os.getenv("OPENROUTER_API_KEY"), help="OpenRouter API Key")
+ parser.add_argument("--base_url", type=str, default="https://openrouter.ai/api/v1", help="API Base URL")
+ parser.add_argument("--dry_run", action="store_true", help="Only load data and print info")
+ parser.add_argument("--variants", type=str, default=None, help=f"Comma-separated list of variants to include. Choices: {','.join(SUPPORTED_VARIANTS)}")
+
+ args = parser.parse_args()
+
+ # Parse variants argument
+ selected_variants = None
+ if args.variants:
+ selected_variants = [v.strip() for v in args.variants.split(",")]
+ print(f"Filtering for variants: {selected_variants}")
+
+ print(f"Scanning data from {args.data_dir}...")
+ dataset = list(load_dataset(args.data_dir, selected_variants=selected_variants))
+ print(f"Found {len(dataset)} problem variants.")
+
+ if args.dry_run:
+ if dataset:
+ print("\n--- Example 1 ---")
+ print(f"Index: {dataset[0]['file_index']}")
+ print(f"Variant: {dataset[0]['variant']}")
+ print(f"Question: {dataset[0]['question'][:200]}...")
+ return
+
+ if args.limit:
+ dataset = dataset[:args.limit]
+ print(f"Limiting to first {args.limit} examples.")
+
+ asyncio.run(run_async_inference(args, dataset))
+
+if __name__ == "__main__":
+ main()