diff options
Diffstat (limited to 'putnamsup')
| -rw-r--r-- | putnamsup/evaluate_putnam_gap.py | 74 | ||||
| -rw-r--r-- | putnamsup/putnam_utils.py | 95 | ||||
| -rw-r--r-- | putnamsup/putnamgap_viewer.py | 277 | ||||
| -rw-r--r-- | putnamsup/requirements.txt | 7 | ||||
| -rw-r--r-- | putnamsup/run_putnam_gap.py | 167 | ||||
| -rw-r--r-- | putnamsup/run_putnam_gap_openrouter.py | 124 |
6 files changed, 744 insertions, 0 deletions
diff --git a/putnamsup/evaluate_putnam_gap.py b/putnamsup/evaluate_putnam_gap.py new file mode 100644 index 0000000..5c9f35e --- /dev/null +++ b/putnamsup/evaluate_putnam_gap.py @@ -0,0 +1,74 @@ +import json +import argparse +import re + +def normalize_answer(text): + """Simple normalization for comparison.""" + if text is None: return "" + text = text.strip().lower() + # Remove latex formatting for simple check + text = re.sub(r'\\[\(\)\[\]]', ' ', text) + return text + +def simple_evaluate(ground_truth, generated): + """ + A very naive evaluator. + Returns True if the generated answer seems to contain the ground truth + (if ground truth is short) or based on some heuristics. + """ + gt_norm = normalize_answer(ground_truth) + gen_norm = normalize_answer(generated) + + # If ground truth is very short (likely a number or variable), check if it's in the generated text + if len(gt_norm) < 20: + return gt_norm in gen_norm + + # For longer proofs, this metric is useless. + return False + +def main(): + parser = argparse.ArgumentParser(description="Evaluate PutnamGAP results") + parser.add_argument("--results_file", type=str, required=True, help="Path to JSONL results file") + args = parser.parse_args() + + total = 0 + correct_heuristic = 0 + by_type = {} + + print(f"Evaluating {args.results_file}...") + + with open(args.results_file, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: continue + + data = json.loads(line) + prob_type = data.get("problem_type", "unknown") + + total += 1 + if prob_type not in by_type: + by_type[prob_type] = {"count": 0, "heuristic_match": 0} + + by_type[prob_type]["count"] += 1 + + # This is a placeholder evaluation. + # Real evaluation for proofs needs an LLM judge. + is_match = simple_evaluate(data["solution"], data["generated_solution"]) + + if is_match: + correct_heuristic += 1 + by_type[prob_type]["heuristic_match"] += 1 + + print(f"Total processed: {total}") + print("-" * 40) + print("Breakdown by Problem Type:") + for p_type, stats in by_type.items(): + acc = (stats["heuristic_match"] / stats["count"]) * 100 if stats["count"] > 0 else 0 + print(f" {p_type}: {stats['count']} items, {stats['heuristic_match']} heuristic matches ({acc:.2f}%)") + print("-" * 40) + print("Note: The heuristic match is very basic (checks if short ground truth is substring of generated output).") + print("For 'proof' problems, this metric is not reliable. Use an LLM-based judge for accurate evaluation.") + +if __name__ == "__main__": + main() + diff --git a/putnamsup/putnam_utils.py b/putnamsup/putnam_utils.py new file mode 100644 index 0000000..7761c49 --- /dev/null +++ b/putnamsup/putnam_utils.py @@ -0,0 +1,95 @@ +import os +import json +from typing import Dict, Any, Generator, Tuple, Optional, List + +# Supported variants as seen in putnamgap_viewer.py +SUPPORTED_VARIANTS = [ + "original", + "descriptive_long", + "descriptive_long_confusing", + "descriptive_long_misleading", + "garbled_string", + "kernel_variant", +] + +def get_original_qa(d: Dict[str, Any]) -> Tuple[Optional[str], Optional[str]]: + """Extract original question and solution.""" + question = d.get("question") + solution = d.get("solution", d.get("answer")) + return question, solution + +def get_variant_qa(d: Dict[str, Any], variant_key: str) -> Tuple[Optional[str], Optional[str]]: + """Extract variant question and solution.""" + variants = d.get("variants") + if not isinstance(variants, dict): + return None, None + var = variants.get(variant_key) + if not isinstance(var, dict): + return None, None + question = var.get("question") + solution = var.get("solution", var.get("answer")) + return question, solution + +def load_dataset(data_dir: str, selected_variants: Optional[List[str]] = None) -> Generator[Dict[str, Any], None, None]: + """ + Iterates over all JSON files in data_dir and yields problem instances. + Each instance is a dict with keys: file_index, type, variant, question, solution. + + Args: + data_dir: Path to the dataset directory. + selected_variants: List of variants to include. If None, include all. + Supported values are in SUPPORTED_VARIANTS. + """ + if not os.path.isdir(data_dir): + raise ValueError(f"Directory not found: {data_dir}") + + # Validate selected_variants + if selected_variants: + for v in selected_variants: + if v not in SUPPORTED_VARIANTS: + print(f"Warning: Variant '{v}' not recognized. Supported: {SUPPORTED_VARIANTS}") + + # If no filter provided, use all supported + target_variants = selected_variants if selected_variants else SUPPORTED_VARIANTS + + files = [f for f in os.listdir(data_dir) if f.lower().endswith(".json")] + files.sort() + + for f in files: + filepath = os.path.join(data_dir, f) + try: + with open(filepath, "r", encoding="utf-8") as fp: + data = json.load(fp) + except Exception as e: + print(f"Error loading {filepath}: {e}") + continue + + file_index = data.get("index", f) # Use filename as index if 'index' key missing + prob_type = data.get("problem_type", "unknown") + + # 1. Original + if "original" in target_variants: + q, a = get_original_qa(data) + if q and a: + yield { + "file_index": file_index, + "problem_type": prob_type, + "variant": "original", + "question": q, + "solution": a + } + + # 2. Variants + for var_key in SUPPORTED_VARIANTS: + if var_key == "original": continue + if var_key not in target_variants: continue + + q, a = get_variant_qa(data, var_key) + if q and a: + yield { + "file_index": file_index, + "problem_type": prob_type, + "variant": var_key, + "question": q, + "solution": a + } diff --git a/putnamsup/putnamgap_viewer.py b/putnamsup/putnamgap_viewer.py new file mode 100644 index 0000000..d3678f1 --- /dev/null +++ b/putnamsup/putnamgap_viewer.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python3 +""" +Streamlit-based PutnamGAP dataset viewer. + +Features: +- Scans preprocess/PutnamGAP for JSON files and allows prev/next navigation +- Select specific file from a dropdown +- Choose which variant to display: original or one of: + descriptive_long, descriptive_long_confusing, descriptive_long_misleading, garbled_string, kernel_variant +- Toggle to show Question, Solution (a.k.a. Answer), or Both +- TeX rendering via Markdown by default, with optional HTML+MathJax fallback +""" +import json +import os +from typing import Any, Dict, List, Optional, Tuple + +import streamlit as st +from streamlit.components.v1 import html as st_html + + +DATA_DIR = os.path.join(os.path.dirname(__file__), "PutnamGAP") +SUPPORTED_VARIANTS = [ + "original", + "descriptive_long", + "descriptive_long_confusing", + "descriptive_long_misleading", + "garbled_string", + "kernel_variant", +] + + +def discover_json_files(data_dir: str) -> List[str]: + if not os.path.isdir(data_dir): + return [] + files = [ + os.path.join(data_dir, f) + for f in os.listdir(data_dir) + if f.lower().endswith(".json") + ] + files.sort() + return files + + +def load_json(filepath: str) -> Dict[str, Any]: + with open(filepath, "r", encoding="utf-8") as f: + return json.load(f) + + +def get_original_qa(d: Dict[str, Any]) -> Tuple[Optional[str], Optional[str]]: + # Prefer "question"/"solution"; gracefully fall back to "answer" if present + question: Optional[str] = d.get("question") + solution: Optional[str] = d.get("solution", d.get("answer")) + return question, solution + + +def get_variant_qa( + d: Dict[str, Any], variant_key: str +) -> Tuple[Optional[str], Optional[str]]: + variants = d.get("variants") + if not isinstance(variants, dict): + return None, None + var = variants.get(variant_key) + if not isinstance(var, dict): + return None, None + question: Optional[str] = var.get("question") + solution: Optional[str] = var.get("solution", var.get("answer")) + return question, solution + + +def render_markdown_with_math(text: str) -> None: + # Streamlit markdown supports MathJax ($...$, $$...$$) + st.markdown(text, unsafe_allow_html=True) + + +def render_with_mathjax_html(blocks: List[Tuple[str, str]]) -> None: + """ + Render content with MathJax v3 inside a single HTML component. + blocks: list of (heading, content) tuples + """ + # Build a small HTML page with MathJax v3; render all blocks together. + content_sections = [] + for heading, content in blocks: + section_html = f""" + <section style="margin-bottom: 1.25rem;"> + <h3 style="margin: 0 0 .5rem 0; font-family: ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial;"> + {heading} + </h3> + <div class="mj-content">{content}</div> + </section> + """ + content_sections.append(section_html) + + page = f""" +<!DOCTYPE html> +<html> + <head> + <meta charset="utf-8"> + <meta name="viewport" content="width=device-width, initial-scale=1"> + <script> + window.MathJax = {{ + tex: {{ + inlineMath: [['$', '$'], ['\\\\(', '\\\\)']], + displayMath: [['$$', '$$'], ['\\\\[', '\\\\]']] + }}, + svg: {{ fontCache: 'global' }} + }}; + </script> + <script src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-svg.js" async></script> + <style> + html, body {{ + background: #0f0f10; + color: #f5f6f7; + }} + body {{ + padding: 0.5rem 0.25rem; + color: #f5f6f7; + background: #0f0f10; + }} + .mj-content {{ + line-height: 1.6; + white-space: pre-wrap; + word-wrap: break-word; + font-family: ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial; + font-size: 1rem; + color: #f5f6f7; + background: #0f0f10; + padding: 0.25rem 0.25rem; + border-radius: 4px; + }} + code, pre {{ + font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace; + color: #e6e6e6; + }} + svg, .MathJax, .mjx-svg, .mjx-mrow {{ + color: #f5f6f7; + }} + </style> + </head> + <body> + {''.join(content_sections)} + </body> +</html> +""" + st_html(page, height=600, scrolling=True) + + +def main() -> None: + st.set_page_config(page_title="PutnamGAP Viewer", layout="wide") + st.title("PutnamGAP 数据可视化与校对") + st.caption("浏览原题与不同变体;支持 TeX 渲染与文件前后切换。") + + files = discover_json_files(DATA_DIR) + if not files: + st.error(f"未在目录中发现 JSON 文件:{DATA_DIR}") + st.stop() + + # Sidebar controls + with st.sidebar: + st.subheader("文件与显示设置") + + # Single source of truth for navigation: file_index + file_labels = [os.path.basename(p) for p in files] + if "file_index" not in st.session_state: + st.session_state.file_index = 0 + + selected_label = st.selectbox( + "选择题目文件", + options=file_labels, + index=st.session_state.file_index, + ) + # Sync file_index if user chose a different label + current_index = file_labels.index(selected_label) + if current_index != st.session_state.file_index: + st.session_state.file_index = current_index + + # Variant selection + variant_human_labels = { + "original": "原题 original", + "descriptive_long": "descriptive_long", + "descriptive_long_confusing": "descriptive_long_confusing", + "descriptive_long_misleading": "descriptive_long_misleading", + "garbled_string": "garbled_string", + "kernel_variant": "kernel_variant", + } + variant_choice_label = st.radio( + "选择显示内容", + options=[variant_human_labels[k] for k in SUPPORTED_VARIANTS], + index=0, + ) + # Reverse map to internal key + selected_variant = { + v: k for k, v in variant_human_labels.items() + }[variant_choice_label] + + # Display options + show_mode = st.radio( + "显示部分", + options=["Question", "Solution", "Both"], + index=0, + horizontal=True, + ) + render_mode = st.radio( + "渲染方式", + options=["Markdown (默认)", "HTML + MathJax"], + index=1, + ) + show_meta = st.checkbox("显示原始 JSON 片段", value=False) + + # Prev/Next navigation buttons in header row + left, mid, right = st.columns([1, 6, 1]) + with left: + if st.button("⬅️ 上一题", use_container_width=True): + new_index = (st.session_state.file_index - 1) % len(files) + st.session_state.file_index = new_index + st.rerun() + with right: + if st.button("下一题 ➡️", use_container_width=True): + new_index = (st.session_state.file_index + 1) % len(files) + st.session_state.file_index = new_index + st.rerun() + + current_path = files[st.session_state.file_index] + data = load_json(current_path) + + st.write(f"当前文件:`{os.path.basename(current_path)}` ({st.session_state.file_index + 1}/{len(files)})") + st.divider() + + # Resolve question/solution for chosen variant + if selected_variant == "original": + q_text, s_text = get_original_qa(data) + else: + q_text, s_text = get_variant_qa(data, selected_variant) + + # Assemble content blocks to render + blocks: List[Tuple[str, str]] = [] + if show_mode in ("Question", "Both"): + if q_text: + blocks.append(("Question", q_text)) + else: + st.warning("该选择下未找到 Question。") + if show_mode in ("Solution", "Both"): + if s_text: + blocks.append(("Solution", s_text)) + else: + st.warning("该选择下未找到 Solution/Answer。") + + if len(blocks) > 0: + if render_mode.startswith("Markdown"): + for heading, content in blocks: + st.subheader(heading) + render_markdown_with_math(content) + st.markdown("---") + else: + render_with_mathjax_html(blocks) + else: + st.info("无可显示内容。") + + if show_meta: + with st.expander("原始 JSON(截断显示)", expanded=False): + # Show a trimmed version to avoid overwhelming the UI + preview: Dict[str, Any] = {} + for k in ("index", "type", "tag", "difficulty", "problem_type"): + if k in data: + preview[k] = data[k] + preview["keys"] = list(data.keys()) + st.json(preview) + + st.caption("完整 JSON 路径:") + st.code(current_path) + + st.caption("提示:可以使用侧边栏选择具体文件与变体,也可通过顶部按钮快速前后切换 JSON。") + + +if __name__ == "__main__": + main() + + diff --git a/putnamsup/requirements.txt b/putnamsup/requirements.txt new file mode 100644 index 0000000..981cde1 --- /dev/null +++ b/putnamsup/requirements.txt @@ -0,0 +1,7 @@ +torch>=2.0.0 +transformers>=4.37.0 +accelerate>=0.26.0 +tqdm>=4.66.0 +openai>=1.0.0 +streamlit>=1.30.0 + diff --git a/putnamsup/run_putnam_gap.py b/putnamsup/run_putnam_gap.py new file mode 100644 index 0000000..73f0ef6 --- /dev/null +++ b/putnamsup/run_putnam_gap.py @@ -0,0 +1,167 @@ +import os +import argparse +import torch +import time +from transformers import AutoModelForCausalLM, AutoTokenizer +from tqdm import tqdm +from putnam_utils import load_dataset, SUPPORTED_VARIANTS +import json + +def run_inference_batch(model, tokenizer, questions: list, device: str) -> list: + """ + Runs generation for a batch of questions. + """ + prompts = [f"Problem:\n{q}\n\nPlease solve the problem above step by step and provide the final answer.\n\nSolution:\n" for q in questions] + + # Determine target device for inputs + if device == "auto": + target_device = model.device + else: + target_device = device + + input_texts = [] + if tokenizer.chat_template: + for p in prompts: + messages = [{"role": "user", "content": p}] + try: + formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + input_texts.append(formatted) + except Exception: + input_texts.append(p) + else: + input_texts = prompts + + # Tokenize with padding + inputs = tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True).to(target_device) + + with torch.no_grad(): + output_ids = model.generate( + **inputs, + max_new_tokens=1024, + do_sample=False, + pad_token_id=tokenizer.pad_token_id + ) + + # Decode only new tokens + # output_ids contains input_ids + new_tokens. We need to slice. + # However, input lengths might vary due to padding. + # batch_decode usually decodes everything. + # A common trick is to decode everything and then strip the prompt, but prompts are different. + # Better: tokenizer.batch_decode(output_ids[:, inputs.input_ids.shape[1]:]) works if left-padded and consistent length? + # No, with left padding, the new tokens are at the end. + + generated_texts = tokenizer.batch_decode(output_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True) + return [t.strip() for t in generated_texts] + +def main(): + parser = argparse.ArgumentParser(description="Run inference on PutnamGAP dataset") + parser.add_argument("--data_dir", type=str, default="PutnamGAP", help="Path to PutnamGAP JSON files") + parser.add_argument("--model_name_or_path", type=str, required=True, help="Hugging Face model name or path") + parser.add_argument("--output_file", type=str, default="putnam_gap_results.jsonl", help="Output file path") + parser.add_argument("--limit", type=int, default=None, help="Limit total number of problems to run") + parser.add_argument("--limit_per_variant", type=int, default=None, help="Limit number of problems per variant") + parser.add_argument("--batch_size", type=int, default=1, help="Batch size for inference") + parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to run on (use 'auto' for multi-GPU)") + parser.add_argument("--dry_run", action="store_true", help="Only load data and print first few examples, do not load model") + parser.add_argument("--variants", type=str, default=None, help=f"Comma-separated list of variants to include. Choices: {','.join(SUPPORTED_VARIANTS)}") + args = parser.parse_args() + + # Parse variants argument + selected_variants = None + + # Diagnostic check for CUDA availability + if torch.cuda.device_count() > 0 and not torch.cuda.is_available(): + print("\n" + "!"*60) + print(f"WARNING: PyTorch detects {torch.cuda.device_count()} CUDA devices but cannot use them.") + print(f"torch.cuda.is_available() == False") + print(f"Current PyTorch version: {torch.__version__}") + print(f"Your driver probably supports an older CUDA version than this PyTorch build.") + print("!"*60 + "\n") + + if args.variants: + selected_variants = [v.strip() for v in args.variants.split(",")] + print(f"Filtering for variants: {selected_variants}") + + print(f"Scanning data from {args.data_dir}...") + dataset = list(load_dataset(args.data_dir, selected_variants=selected_variants)) + print(f"Found {len(dataset)} problem variants.") + + if args.limit_per_variant: + from collections import defaultdict + counts = defaultdict(int) + filtered_dataset = [] + for item in dataset: + v = item['variant'] + if counts[v] < args.limit_per_variant: + filtered_dataset.append(item) + counts[v] += 1 + dataset = filtered_dataset + print(f"Filtered to {len(dataset)} examples (max {args.limit_per_variant} per variant).") + + if args.dry_run: + if dataset: + print("\n--- Example 1 ---") + print(f"Index: {dataset[0]['file_index']}") + print(f"Variant: {dataset[0]['variant']}") + print(f"Question: {dataset[0]['question'][:200]}...") + print(f"Solution: {dataset[0]['solution'][:200]}...") + return + + print(f"Loading model: {args.model_name_or_path} on {args.device}") + + try: + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True, padding_side='left') + if tokenizer.pad_token_id is None: + if tokenizer.eos_token_id is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id + else: + tokenizer.pad_token_id = 0 + + # Determine dtype + torch_dtype = torch.float16 + if args.device == "cpu": + torch_dtype = torch.float32 + + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + device_map=args.device, + trust_remote_code=True, + torch_dtype=torch_dtype + ) + except Exception as e: + print(f"Failed to load model: {e}") + return + + if args.limit: + dataset = dataset[:args.limit] + print(f"Limiting to first {args.limit} examples.") + + with open(args.output_file, "w", encoding="utf-8") as f_out: + batch_size = args.batch_size + for i in tqdm(range(0, len(dataset), batch_size), desc="Running Inference"): + batch = dataset[i : i + batch_size] + questions = [item["question"] for item in batch] + + try: + generated_answers = run_inference_batch(model, tokenizer, questions, args.device) + except Exception as e: + print(f"Error generating for batch starting at index {i}: {e}") + generated_answers = [f"<ERROR: {str(e)}>" for _ in batch] + + for item, ans in zip(batch, generated_answers): + result_entry = { + "file_index": item["file_index"], + "problem_type": item["problem_type"], + "variant": item["variant"], + "question": item["question"], + "solution": item["solution"], + "generated_solution": ans + } + + f_out.write(json.dumps(result_entry, ensure_ascii=False) + "\n") + f_out.flush() + + print(f"Done. Results saved to {args.output_file}") + +if __name__ == "__main__": + main() diff --git a/putnamsup/run_putnam_gap_openrouter.py b/putnamsup/run_putnam_gap_openrouter.py new file mode 100644 index 0000000..8a23141 --- /dev/null +++ b/putnamsup/run_putnam_gap_openrouter.py @@ -0,0 +1,124 @@ +import os +import json +import argparse +import asyncio +import time +from tqdm.asyncio import tqdm +from putnam_utils import load_dataset, SUPPORTED_VARIANTS + +try: + from openai import AsyncOpenAI +except ImportError: + AsyncOpenAI = None + +async def process_item(sem, client, model_name, item): + """ + Process a single item with semaphore for concurrency control. + """ + async with sem: + question = item["question"] + prompt = f"Problem:\n{question}\n\nPlease solve the problem above step by step and provide the final answer.\n\nSolution:\n" + messages = [{"role": "user", "content": prompt}] + + try: + # Call API asynchronously + completion = await client.chat.completions.create( + model=model_name, + messages=messages, + temperature=0.0, + max_tokens=2048, + extra_headers={ + "HTTP-Referer": "https://github.com/PutnamGAP", + "X-Title": "PutnamGAP Eval", + } + ) + generated_answer = completion.choices[0].message.content + except Exception as e: + generated_answer = f"<API ERROR: {str(e)}>" + + # Construct result entry + result_entry = { + "file_index": item["file_index"], + "problem_type": item["problem_type"], + "variant": item["variant"], + "question": question, + "solution": item["solution"], + "generated_solution": generated_answer, + "model": model_name + } + return result_entry + +async def run_async_inference(args, dataset): + if AsyncOpenAI is None: + print("Error: 'openai' library not found. Please install it via: pip install openai") + return + + if not args.api_key: + print("Error: API key not provided. Use --api_key or set OPENROUTER_API_KEY env var.") + return + + print(f"Initializing AsyncOpenAI client with base_url={args.base_url}") + client = AsyncOpenAI( + base_url=args.base_url, + api_key=args.api_key, + ) + + concurrency = args.concurrency + print(f"Running with concurrency: {concurrency}") + sem = asyncio.Semaphore(concurrency) + + tasks = [] + for item in dataset: + task = process_item(sem, client, args.model_name, item) + tasks.append(task) + + print(f"Starting {len(tasks)} tasks using model: {args.model_name}") + + with open(args.output_file, "w", encoding="utf-8") as f_out: + for future in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Async Inference"): + result = await future + f_out.write(json.dumps(result, ensure_ascii=False) + "\n") + f_out.flush() + + print(f"Done. Results saved to {args.output_file}") + +def main(): + parser = argparse.ArgumentParser(description="Run inference on PutnamGAP dataset via OpenRouter (Async)") + parser.add_argument("--data_dir", type=str, default="PutnamGAP", help="Path to PutnamGAP JSON files") + parser.add_argument("--model_name", type=str, required=True, help="OpenRouter model name") + parser.add_argument("--output_file", type=str, default="putnam_gap_openrouter_results.jsonl", help="Output file path") + parser.add_argument("--limit", type=int, default=None, help="Limit number of problems to run (for testing)") + parser.add_argument("--concurrency", type=int, default=10, help="Number of concurrent requests") + parser.add_argument("--api_key", type=str, default=os.getenv("OPENROUTER_API_KEY"), help="OpenRouter API Key") + parser.add_argument("--base_url", type=str, default="https://openrouter.ai/api/v1", help="API Base URL") + parser.add_argument("--dry_run", action="store_true", help="Only load data and print info") + parser.add_argument("--variants", type=str, default=None, help=f"Comma-separated list of variants to include. Choices: {','.join(SUPPORTED_VARIANTS)}") + + args = parser.parse_args() + + # Parse variants argument + selected_variants = None + if args.variants: + selected_variants = [v.strip() for v in args.variants.split(",")] + print(f"Filtering for variants: {selected_variants}") + + print(f"Scanning data from {args.data_dir}...") + dataset = list(load_dataset(args.data_dir, selected_variants=selected_variants)) + print(f"Found {len(dataset)} problem variants.") + + if args.dry_run: + if dataset: + print("\n--- Example 1 ---") + print(f"Index: {dataset[0]['file_index']}") + print(f"Variant: {dataset[0]['variant']}") + print(f"Question: {dataset[0]['question'][:200]}...") + return + + if args.limit: + dataset = dataset[:args.limit] + print(f"Limiting to first {args.limit} examples.") + + asyncio.run(run_async_inference(args, dataset)) + +if __name__ == "__main__": + main() |
