diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-04 18:59:35 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-04 18:59:35 -0600 |
| commit | f1c2cc22d46a6976df3555391e667c7e61592fad (patch) | |
| tree | 0b37b52c8ff91042a742d3b3ec54542cb6d6e2f6 /scripts/prepare_data.py | |
Diffstat (limited to 'scripts/prepare_data.py')
| -rwxr-xr-x | scripts/prepare_data.py | 258 |
1 files changed, 258 insertions, 0 deletions
diff --git a/scripts/prepare_data.py b/scripts/prepare_data.py new file mode 100755 index 0000000..5ef3c29 --- /dev/null +++ b/scripts/prepare_data.py @@ -0,0 +1,258 @@ +#!/usr/bin/env python3 +""" +Prepare REAL datasets for RLVR floating-point precision experiments. + +Downloads from HuggingFace: +- Training: GSM8K train (7473 samples) +- Evaluation: GSM8K test, MATH-500, AIME, AMC, MMLU-STEM, HumanEval + +Usage: + python scripts/prepare_data.py +""" + +import json +import os +import random +from pathlib import Path +from datasets import load_dataset +from tqdm import tqdm + +DATA_DIR = Path("data") +DATA_DIR.mkdir(exist_ok=True) + + +def save_json(data: list, path: Path): + """Save data as JSON file.""" + with open(path, "w") as f: + json.dump(data, f, indent=2) + print(f" Saved {len(data)} samples to {path}") + + +def prepare_gsm8k_train(): + """Prepare GSM8K training data.""" + print("\n=== Downloading GSM8K Train ===") + ds = load_dataset("openai/gsm8k", "main", split="train") + + data = [] + for i, sample in enumerate(tqdm(ds, desc="Processing")): + # Extract answer from "#### N" format + answer = sample["answer"].split("####")[-1].strip() + data.append({ + "id": f"gsm8k_train_{i}", + "prompt": sample["question"], + "answer": answer, + "solution": sample["answer"], + "source": "gsm8k_train" + }) + + save_json(data, DATA_DIR / "dm_train.json") + return data + + +def prepare_gsm8k_test(): + """Prepare GSM8K test data for evaluation.""" + print("\n=== Downloading GSM8K Test ===") + ds = load_dataset("openai/gsm8k", "main", split="test") + + data = [] + for i, sample in enumerate(tqdm(ds, desc="Processing")): + answer = sample["answer"].split("####")[-1].strip() + data.append({ + "id": f"gsm8k_test_{i}", + "prompt": sample["question"], + "answer": answer, + "solution": sample["answer"], + "source": "gsm8k" + }) + + save_json(data, DATA_DIR / "gsm8k.json") + + # Also create dm_val as a subset (first 500 for on-task eval) + save_json(data[:500], DATA_DIR / "dm_val.json") + return data + + +def prepare_math500(): + """Prepare MATH-500 dataset.""" + print("\n=== Downloading MATH-500 ===") + ds = load_dataset("HuggingFaceH4/MATH-500", split="test") + + data = [] + for i, sample in enumerate(tqdm(ds, desc="Processing")): + data.append({ + "id": f"math500_{i}", + "prompt": sample["problem"], + "answer": sample["answer"], + "solution": sample["solution"], + "subject": sample.get("subject", ""), + "level": sample.get("level", ""), + "source": "math500" + }) + + save_json(data, DATA_DIR / "math500.json") + return data + + +def prepare_aime(): + """Prepare AIME dataset from AI-MO.""" + print("\n=== Downloading AIME ===") + ds = load_dataset("AI-MO/aimo-validation-aime", split="train") + + data = [] + for i, sample in enumerate(tqdm(ds, desc="Processing")): + data.append({ + "id": f"aime_{i}", + "prompt": sample["problem"], + "answer": str(sample["answer"]), + "solution": sample.get("solution", ""), + "url": sample.get("url", ""), + "source": "aime" + }) + + # Split into aime24 and aime25 + # Real AIME has 15 problems per contest, 2 contests per year = 30/year + save_json(data[:30], DATA_DIR / "aime24.json") + save_json(data[30:60], DATA_DIR / "aime25.json") + save_json(data, DATA_DIR / "aime_all.json") + return data + + +def prepare_amc(): + """Prepare AMC dataset from AI-MO.""" + print("\n=== Downloading AMC ===") + ds = load_dataset("AI-MO/aimo-validation-amc", split="train") + + data = [] + for i, sample in enumerate(tqdm(ds, desc="Processing")): + data.append({ + "id": f"amc_{i}", + "prompt": sample["problem"], + "answer": str(sample["answer"]), + "solution": sample.get("solution", ""), + "source": "amc" + }) + + save_json(data, DATA_DIR / "amc23.json") + return data + + +def prepare_mmlu_stem(): + """Prepare MMLU-STEM subset.""" + print("\n=== Downloading MMLU-STEM ===") + + stem_subjects = [ + "abstract_algebra", "astronomy", "college_biology", "college_chemistry", + "college_computer_science", "college_mathematics", "college_physics", + "computer_security", "conceptual_physics", "electrical_engineering", + "elementary_mathematics", "high_school_biology", "high_school_chemistry", + "high_school_computer_science", "high_school_mathematics", "high_school_physics", + "high_school_statistics", "machine_learning" + ] + + data = [] + for subject in tqdm(stem_subjects, desc="Loading subjects"): + try: + ds = load_dataset("cais/mmlu", subject, split="test") + for i, sample in enumerate(ds): + choices = sample["choices"] + correct_idx = sample["answer"] + # Format as multiple choice + prompt = f"{sample['question']}\n" + for j, choice in enumerate(choices): + prompt += f"({chr(65+j)}) {choice}\n" + + data.append({ + "id": f"mmlu_{subject}_{i}", + "prompt": prompt, + "answer": chr(65 + correct_idx), + "subject": subject, + "source": "mmlu_stem" + }) + except Exception as e: + print(f" Warning: Skipping {subject}: {e}") + + # Take a random subset of 500 + random.seed(42) + if len(data) > 500: + data = random.sample(data, 500) + + save_json(data, DATA_DIR / "mmlu_stem.json") + return data + + +def prepare_humaneval(): + """Prepare HumanEval code dataset.""" + print("\n=== Downloading HumanEval ===") + ds = load_dataset("openai/openai_humaneval", split="test") + + data = [] + for i, sample in enumerate(tqdm(ds, desc="Processing")): + data.append({ + "id": f"humaneval_{i}", + "prompt": sample["prompt"], + "answer": sample["canonical_solution"], + "entry_point": sample["entry_point"], + "test": sample["test"], + "source": "humaneval" + }) + + save_json(data, DATA_DIR / "humaneval.json") + return data + + +def verify_data(): + """Verify downloaded data quality.""" + print("\n" + "=" * 60) + print("Verifying Data Quality") + print("=" * 60) + + for f in sorted(DATA_DIR.glob("*.json")): + with open(f) as fp: + data = json.load(fp) + + # Check for unique prompts + prompts = [d["prompt"] for d in data] + unique = len(set(prompts)) + + status = "OK" if unique == len(prompts) else f"WARN: {len(prompts)-unique} duplicates" + print(f" {f.name}: {len(data)} samples, {unique} unique [{status}]") + + # Show first example + if data: + print(f" Example: {data[0]['prompt'][:60]}...") + + +def main(): + print("=" * 60) + print("RLVR Real Data Preparation") + print("=" * 60) + + # Backup old data + backup_dir = DATA_DIR / "backup_synthetic" + if not backup_dir.exists() and any(DATA_DIR.glob("*.json")): + backup_dir.mkdir(exist_ok=True) + for f in DATA_DIR.glob("*.json"): + f.rename(backup_dir / f.name) + print(f"Backed up synthetic data to {backup_dir}") + + # Training data + prepare_gsm8k_train() + + # Evaluation data + prepare_gsm8k_test() + prepare_math500() + prepare_aime() + prepare_amc() + prepare_mmlu_stem() + prepare_humaneval() + + # Verify + verify_data() + + print("\n" + "=" * 60) + print("Data preparation complete!") + print("=" * 60) + + +if __name__ == "__main__": + main() |
