summaryrefslogtreecommitdiff
path: root/scripts/prepare_data.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/prepare_data.py')
-rwxr-xr-xscripts/prepare_data.py258
1 files changed, 258 insertions, 0 deletions
diff --git a/scripts/prepare_data.py b/scripts/prepare_data.py
new file mode 100755
index 0000000..5ef3c29
--- /dev/null
+++ b/scripts/prepare_data.py
@@ -0,0 +1,258 @@
+#!/usr/bin/env python3
+"""
+Prepare REAL datasets for RLVR floating-point precision experiments.
+
+Downloads from HuggingFace:
+- Training: GSM8K train (7473 samples)
+- Evaluation: GSM8K test, MATH-500, AIME, AMC, MMLU-STEM, HumanEval
+
+Usage:
+ python scripts/prepare_data.py
+"""
+
+import json
+import os
+import random
+from pathlib import Path
+from datasets import load_dataset
+from tqdm import tqdm
+
+DATA_DIR = Path("data")
+DATA_DIR.mkdir(exist_ok=True)
+
+
+def save_json(data: list, path: Path):
+ """Save data as JSON file."""
+ with open(path, "w") as f:
+ json.dump(data, f, indent=2)
+ print(f" Saved {len(data)} samples to {path}")
+
+
+def prepare_gsm8k_train():
+ """Prepare GSM8K training data."""
+ print("\n=== Downloading GSM8K Train ===")
+ ds = load_dataset("openai/gsm8k", "main", split="train")
+
+ data = []
+ for i, sample in enumerate(tqdm(ds, desc="Processing")):
+ # Extract answer from "#### N" format
+ answer = sample["answer"].split("####")[-1].strip()
+ data.append({
+ "id": f"gsm8k_train_{i}",
+ "prompt": sample["question"],
+ "answer": answer,
+ "solution": sample["answer"],
+ "source": "gsm8k_train"
+ })
+
+ save_json(data, DATA_DIR / "dm_train.json")
+ return data
+
+
+def prepare_gsm8k_test():
+ """Prepare GSM8K test data for evaluation."""
+ print("\n=== Downloading GSM8K Test ===")
+ ds = load_dataset("openai/gsm8k", "main", split="test")
+
+ data = []
+ for i, sample in enumerate(tqdm(ds, desc="Processing")):
+ answer = sample["answer"].split("####")[-1].strip()
+ data.append({
+ "id": f"gsm8k_test_{i}",
+ "prompt": sample["question"],
+ "answer": answer,
+ "solution": sample["answer"],
+ "source": "gsm8k"
+ })
+
+ save_json(data, DATA_DIR / "gsm8k.json")
+
+ # Also create dm_val as a subset (first 500 for on-task eval)
+ save_json(data[:500], DATA_DIR / "dm_val.json")
+ return data
+
+
+def prepare_math500():
+ """Prepare MATH-500 dataset."""
+ print("\n=== Downloading MATH-500 ===")
+ ds = load_dataset("HuggingFaceH4/MATH-500", split="test")
+
+ data = []
+ for i, sample in enumerate(tqdm(ds, desc="Processing")):
+ data.append({
+ "id": f"math500_{i}",
+ "prompt": sample["problem"],
+ "answer": sample["answer"],
+ "solution": sample["solution"],
+ "subject": sample.get("subject", ""),
+ "level": sample.get("level", ""),
+ "source": "math500"
+ })
+
+ save_json(data, DATA_DIR / "math500.json")
+ return data
+
+
+def prepare_aime():
+ """Prepare AIME dataset from AI-MO."""
+ print("\n=== Downloading AIME ===")
+ ds = load_dataset("AI-MO/aimo-validation-aime", split="train")
+
+ data = []
+ for i, sample in enumerate(tqdm(ds, desc="Processing")):
+ data.append({
+ "id": f"aime_{i}",
+ "prompt": sample["problem"],
+ "answer": str(sample["answer"]),
+ "solution": sample.get("solution", ""),
+ "url": sample.get("url", ""),
+ "source": "aime"
+ })
+
+ # Split into aime24 and aime25
+ # Real AIME has 15 problems per contest, 2 contests per year = 30/year
+ save_json(data[:30], DATA_DIR / "aime24.json")
+ save_json(data[30:60], DATA_DIR / "aime25.json")
+ save_json(data, DATA_DIR / "aime_all.json")
+ return data
+
+
+def prepare_amc():
+ """Prepare AMC dataset from AI-MO."""
+ print("\n=== Downloading AMC ===")
+ ds = load_dataset("AI-MO/aimo-validation-amc", split="train")
+
+ data = []
+ for i, sample in enumerate(tqdm(ds, desc="Processing")):
+ data.append({
+ "id": f"amc_{i}",
+ "prompt": sample["problem"],
+ "answer": str(sample["answer"]),
+ "solution": sample.get("solution", ""),
+ "source": "amc"
+ })
+
+ save_json(data, DATA_DIR / "amc23.json")
+ return data
+
+
+def prepare_mmlu_stem():
+ """Prepare MMLU-STEM subset."""
+ print("\n=== Downloading MMLU-STEM ===")
+
+ stem_subjects = [
+ "abstract_algebra", "astronomy", "college_biology", "college_chemistry",
+ "college_computer_science", "college_mathematics", "college_physics",
+ "computer_security", "conceptual_physics", "electrical_engineering",
+ "elementary_mathematics", "high_school_biology", "high_school_chemistry",
+ "high_school_computer_science", "high_school_mathematics", "high_school_physics",
+ "high_school_statistics", "machine_learning"
+ ]
+
+ data = []
+ for subject in tqdm(stem_subjects, desc="Loading subjects"):
+ try:
+ ds = load_dataset("cais/mmlu", subject, split="test")
+ for i, sample in enumerate(ds):
+ choices = sample["choices"]
+ correct_idx = sample["answer"]
+ # Format as multiple choice
+ prompt = f"{sample['question']}\n"
+ for j, choice in enumerate(choices):
+ prompt += f"({chr(65+j)}) {choice}\n"
+
+ data.append({
+ "id": f"mmlu_{subject}_{i}",
+ "prompt": prompt,
+ "answer": chr(65 + correct_idx),
+ "subject": subject,
+ "source": "mmlu_stem"
+ })
+ except Exception as e:
+ print(f" Warning: Skipping {subject}: {e}")
+
+ # Take a random subset of 500
+ random.seed(42)
+ if len(data) > 500:
+ data = random.sample(data, 500)
+
+ save_json(data, DATA_DIR / "mmlu_stem.json")
+ return data
+
+
+def prepare_humaneval():
+ """Prepare HumanEval code dataset."""
+ print("\n=== Downloading HumanEval ===")
+ ds = load_dataset("openai/openai_humaneval", split="test")
+
+ data = []
+ for i, sample in enumerate(tqdm(ds, desc="Processing")):
+ data.append({
+ "id": f"humaneval_{i}",
+ "prompt": sample["prompt"],
+ "answer": sample["canonical_solution"],
+ "entry_point": sample["entry_point"],
+ "test": sample["test"],
+ "source": "humaneval"
+ })
+
+ save_json(data, DATA_DIR / "humaneval.json")
+ return data
+
+
+def verify_data():
+ """Verify downloaded data quality."""
+ print("\n" + "=" * 60)
+ print("Verifying Data Quality")
+ print("=" * 60)
+
+ for f in sorted(DATA_DIR.glob("*.json")):
+ with open(f) as fp:
+ data = json.load(fp)
+
+ # Check for unique prompts
+ prompts = [d["prompt"] for d in data]
+ unique = len(set(prompts))
+
+ status = "OK" if unique == len(prompts) else f"WARN: {len(prompts)-unique} duplicates"
+ print(f" {f.name}: {len(data)} samples, {unique} unique [{status}]")
+
+ # Show first example
+ if data:
+ print(f" Example: {data[0]['prompt'][:60]}...")
+
+
+def main():
+ print("=" * 60)
+ print("RLVR Real Data Preparation")
+ print("=" * 60)
+
+ # Backup old data
+ backup_dir = DATA_DIR / "backup_synthetic"
+ if not backup_dir.exists() and any(DATA_DIR.glob("*.json")):
+ backup_dir.mkdir(exist_ok=True)
+ for f in DATA_DIR.glob("*.json"):
+ f.rename(backup_dir / f.name)
+ print(f"Backed up synthetic data to {backup_dir}")
+
+ # Training data
+ prepare_gsm8k_train()
+
+ # Evaluation data
+ prepare_gsm8k_test()
+ prepare_math500()
+ prepare_aime()
+ prepare_amc()
+ prepare_mmlu_stem()
+ prepare_humaneval()
+
+ # Verify
+ verify_data()
+
+ print("\n" + "=" * 60)
+ print("Data preparation complete!")
+ print("=" * 60)
+
+
+if __name__ == "__main__":
+ main()