From f1c2cc22d46a6976df3555391e667c7e61592fad Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Wed, 4 Feb 2026 18:59:35 -0600 Subject: Initial commit: RL floating-point noise project --- .gitignore | 5 + README.md | 196 ++++++++++ analyze_results.py | 741 +++++++++++++++++++++++++++++++++++ config.py | 328 ++++++++++++++++ configs/deepspeed_zero2.json | 31 ++ configs/deepspeed_zero3.json | 38 ++ configs/eval_tasks_config.json | 99 +++++ eval_policy.py | 621 ++++++++++++++++++++++++++++++ requirements.txt | 39 ++ run_experiments.py | 601 +++++++++++++++++++++++++++++ scripts/prepare_data.py | 258 +++++++++++++ scripts/run_evaluation.sh | 58 +++ scripts/run_full_experiment.sh | 106 +++++ scripts/run_training.sh | 50 +++ scripts/setup_env.sh | 79 ++++ scripts/slurm_train.sh | 145 +++++++ scripts/submit_all_jobs.sh | 66 ++++ scripts/submit_single_job.sh | 32 ++ scripts/test_quick.sh | 78 ++++ train_rlvr.py | 849 +++++++++++++++++++++++++++++++++++++++++ utils_bf16_sparsity.py | 459 ++++++++++++++++++++++ utils_kl.py | 419 ++++++++++++++++++++ utils_math_eval.py | 367 ++++++++++++++++++ 23 files changed, 5665 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 analyze_results.py create mode 100644 config.py create mode 100644 configs/deepspeed_zero2.json create mode 100644 configs/deepspeed_zero3.json create mode 100644 configs/eval_tasks_config.json create mode 100644 eval_policy.py create mode 100644 requirements.txt create mode 100644 run_experiments.py create mode 100755 scripts/prepare_data.py create mode 100755 scripts/run_evaluation.sh create mode 100755 scripts/run_full_experiment.sh create mode 100755 scripts/run_training.sh create mode 100755 scripts/setup_env.sh create mode 100755 scripts/slurm_train.sh create mode 100755 scripts/submit_all_jobs.sh create mode 100755 scripts/submit_single_job.sh create mode 100644 scripts/test_quick.sh create mode 100644 train_rlvr.py create mode 100644 utils_bf16_sparsity.py create mode 100644 utils_kl.py create mode 100644 utils_math_eval.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a76904f --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +__pycache__/ +*.pyc +data/ +results/ +.claude/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..f1695f6 --- /dev/null +++ b/README.md @@ -0,0 +1,196 @@ +# RLVR Floating-Point Precision Experiment + +This repository implements experiments to study the effects of floating-point precision (FP32 vs bf16) on RLVR (Reinforcement Learning with Verifiable Rewards) training. + +## Overview + +The experiment aims to verify three key hypotheses based on the RLVR Three-Gate Theory: + +1. **On-task performance is insensitive to floating-point noise**: Training task performance should be similar between FP32 and bf16 precision modes. + +2. **Off-task performance is sensitive to floating-point noise**: Out-of-distribution task performance should show higher variance in bf16 mode due to numerical noise accumulation. + +3. **KL divergence patterns differ by task type**: On-task KL should be constrained by DAPO's implicit leash (Gate I), while off-task KL may drift more in bf16 mode. + +## Project Structure + +``` +rl-floating-noise/ +├── config.py # Configuration definitions +├── train_rlvr.py # Training script with DAPO algorithm +├── eval_policy.py # Evaluation script (J_k, KL_k) +├── utils_math_eval.py # Math answer verification utilities +├── utils_kl.py # KL divergence computation utilities +├── utils_bf16_sparsity.py # bf16-aware update sparsity analysis +├── run_experiments.py # Experiment orchestration script +├── analyze_results.py # Results analysis and hypothesis testing +├── requirements.txt # Python dependencies +├── configs/ +│ └── eval_tasks_config.json # Evaluation task configurations +├── scripts/ +│ ├── prepare_data.py # Dataset preparation script +│ ├── run_training.sh # Single training job script +│ ├── run_evaluation.sh # Single evaluation job script +│ └── run_full_experiment.sh # Full experiment pipeline +├── data/ # Training and evaluation datasets +└── results/ # Experiment outputs + ├── train_logs/ # Training checkpoints and logs + ├── eval_metrics/ # Evaluation results + └── analysis/ # Analysis outputs and plots +``` + +## Installation + +```bash +# Create conda environment +conda create -n rlvr-fp python=3.10 -y +conda activate rlvr-fp + +# Install dependencies +pip install -r requirements.txt + +# Install VeRL (optional, for full DAPO implementation) +pip install git+https://github.com/volcengine/verl.git +``` + +## Quick Start + +### 1. Prepare Data + +Generate sample datasets for development: + +```bash +python scripts/prepare_data.py --output_dir ./data +``` + +For production experiments, download the actual datasets: +- DM (DAPO-Math-17k + MATH) +- AIME24, AIME25, AMC23, MATH-500 +- GSM8K, MMLU-STEM, HumanEval + +### 2. Run Single Training Job + +```bash +# Train with bf16 precision +python train_rlvr.py \ + --precision_mode bf16 \ + --seed 1 \ + --output_dir results/train_logs/bf16_seed1 \ + --train_dataset_path data/dm_train.json + +# Train with fp32 precision +python train_rlvr.py \ + --precision_mode fp32 \ + --seed 1 \ + --output_dir results/train_logs/fp32_seed1 \ + --train_dataset_path data/dm_train.json +``` + +### 3. Run Evaluation + +```bash +python eval_policy.py \ + --base_ckpt Qwen/Qwen2.5-Math-7B \ + --ft_ckpt results/train_logs/bf16_seed1/final_model \ + --eval_tasks_config configs/eval_tasks_config.json \ + --output_path results/eval_metrics/bf16_seed1.json \ + --eval_base +``` + +### 4. Run Full Experiment + +```bash +# Run complete experiment pipeline +bash scripts/run_full_experiment.sh + +# Or use Python orchestrator +python run_experiments.py --mode full --seeds 1 2 3 4 5 +``` + +### 5. Analyze Results + +```bash +python analyze_results.py \ + --results_dir results/eval_metrics \ + --output_dir results/analysis \ + --on_task dm_val \ + --off_task aime24 aime25 amc23 math500 mmlu_stem humaneval +``` + +## Precision Configurations + +### P-high (FP32) +- Master weights stored in FP32 +- Deterministic algorithms enabled +- Dropout disabled +- Minimal numerical noise + +### P-bf16 (Default RLVR) +- Master weights stored in bf16 +- Non-deterministic algorithms +- Dropout enabled +- Higher numerical noise (Gate III effects) + +## Key Metrics + +### Performance (J_k) +- Pass@1 accuracy for verifiable tasks +- Computed via Monte Carlo sampling + +### Performance Delta (ΔJ_k) +``` +ΔJ_k = J_k(θ_T) - J_k(θ_0) +``` + +### KL Divergence +``` +KL_k ≈ E_x E_y~π_θ [log π_θ(y|x) - log π_0(y|x)] +``` + +### bf16 Update Sparsity +``` +sparsity = 1 - |{i: |w_i - w'_i| > η·max(|w_i|, |w'_i|)}| / n +``` + +## Expected Results + +Based on RLVR theory predictions: + +| Metric | On-task | Off-task | +|--------|---------|----------| +| E[ΔJ] difference | Small (~0) | Variable | +| Var[ΔJ] (bf16 vs fp32) | Similar | bf16 >> fp32 | +| KL divergence | Constrained | Higher variance | +| bf16 sparsity | 36-92% | - | + +## Configuration + +### Training Hyperparameters (RLVR defaults) +- Model: Qwen2.5-Math-7B +- Algorithm: DAPO (clip-only, β=0) +- Batch size: 256 +- Learning rate: 1e-6 +- Training steps: 300 +- Rollouts per prompt: 16 + +### Evaluation Settings +- Temperature: 0.7 +- Top-p: 0.8 +- Max generation length: 2048-4096 (task-dependent) + +## Citation + +If you use this code, please cite the RLVR paper: + +```bibtex +@article{rlvr2024, + title={Reinforcement Learning with Verifiable Rewards}, + author={...}, + year={2024} +} +``` + +## License + +MIT License + diff --git a/analyze_results.py b/analyze_results.py new file mode 100644 index 0000000..265c005 --- /dev/null +++ b/analyze_results.py @@ -0,0 +1,741 @@ +#!/usr/bin/env python3 +# analyze_results.py +""" +Results Analysis for RLVR Floating-Point Precision Experiments. + +This script analyzes the experimental results to verify the hypotheses: +1. On-task performance is insensitive to floating-point noise +2. Off-task performance and KL divergence are sensitive to precision + +Computes: +- Mean and variance of ΔJ_k for each precision mode +- KL divergence patterns for on-task vs off-task +- Statistical tests for significance +- Visualizations + +Usage: + python analyze_results.py \ + --results_dir results/eval_metrics \ + --output_dir results/analysis +""" + +import argparse +import json +import os +import glob +from typing import Dict, Any, List, Tuple, Optional +from collections import defaultdict +import logging + +import numpy as np +from scipy import stats + +# Optional: for visualizations +try: + import matplotlib.pyplot as plt + import matplotlib + matplotlib.use('Agg') # Non-interactive backend + HAS_MATPLOTLIB = True +except ImportError: + HAS_MATPLOTLIB = False + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Data Loading +# ============================================================================ + +def load_eval_results(results_dir: str) -> Dict[str, Dict[str, Any]]: + """ + Load all evaluation results from a directory. + + Returns: + Dictionary mapping "{precision_mode}_seed{seed}" to results + """ + results = {} + + pattern = os.path.join(results_dir, "*.json") + for filepath in glob.glob(pattern): + filename = os.path.basename(filepath) + + # Skip sparsity files + if "sparsity" in filename: + continue + + # Parse filename: {precision_mode}_seed{seed}.json + name = filename.replace(".json", "") + + try: + with open(filepath, "r") as f: + data = json.load(f) + results[name] = data + logger.info(f"Loaded {filepath}") + except Exception as e: + logger.warning(f"Failed to load {filepath}: {e}") + + return results + + +def load_sparsity_results(results_dir: str) -> Dict[str, Dict[str, Any]]: + """Load bf16 sparsity analysis results.""" + results = {} + + pattern = os.path.join(results_dir, "*_sparsity.json") + for filepath in glob.glob(pattern): + filename = os.path.basename(filepath) + name = filename.replace("_sparsity.json", "") + + try: + with open(filepath, "r") as f: + data = json.load(f) + results[name] = data + logger.info(f"Loaded sparsity: {filepath}") + except Exception as e: + logger.warning(f"Failed to load {filepath}: {e}") + + return results + + +def parse_run_name(name: str) -> Tuple[str, int]: + """Parse run name to get precision mode and seed.""" + # Format: {precision_mode}_seed{seed} + parts = name.split("_seed") + if len(parts) == 2: + precision_mode = parts[0] + seed = int(parts[1]) + return precision_mode, seed + + raise ValueError(f"Cannot parse run name: {name}") + + +# ============================================================================ +# Metrics Aggregation +# ============================================================================ + +def aggregate_by_precision( + results: Dict[str, Dict[str, Any]] +) -> Dict[str, Dict[str, List[float]]]: + """ + Aggregate results by precision mode. + + Returns: + {precision_mode: {task_name: [scores from different seeds]}} + """ + aggregated: Dict[str, Dict[str, List[float]]] = defaultdict( + lambda: defaultdict(list) + ) + + for run_name, run_results in results.items(): + try: + precision_mode, seed = parse_run_name(run_name) + except ValueError: + continue + + tasks = run_results.get("tasks", {}) + for task_name, task_data in tasks.items(): + # Score + if "ft_avg_score" in task_data: + aggregated[precision_mode][f"{task_name}_score"].append( + task_data["ft_avg_score"] + ) + + # Delta J + if "delta_j" in task_data: + aggregated[precision_mode][f"{task_name}_delta_j"].append( + task_data["delta_j"] + ) + + # KL + if "avg_kl" in task_data: + aggregated[precision_mode][f"{task_name}_kl"].append( + task_data["avg_kl"] + ) + + return dict(aggregated) + + +def compute_statistics(values: List[float]) -> Dict[str, float]: + """Compute statistics for a list of values.""" + if not values: + return { + "mean": 0.0, + "std": 0.0, + "var": 0.0, + "min": 0.0, + "max": 0.0, + "n": 0, + } + + arr = np.array(values) + return { + "mean": float(np.mean(arr)), + "std": float(np.std(arr)), + "var": float(np.var(arr)), + "min": float(np.min(arr)), + "max": float(np.max(arr)), + "n": len(arr), + } + + +# ============================================================================ +# Hypothesis Testing +# ============================================================================ + +def test_variance_ratio( + values1: List[float], + values2: List[float], + alpha: float = 0.05 +) -> Dict[str, Any]: + """ + Test if variance of values2 is significantly greater than values1. + + Uses F-test (Levene's test is more robust but F-test is simpler). + """ + if len(values1) < 2 or len(values2) < 2: + return { + "test": "variance_ratio", + "valid": False, + "reason": "Not enough samples", + } + + var1 = np.var(values1, ddof=1) + var2 = np.var(values2, ddof=1) + + # F statistic + if var1 > 0: + f_stat = var2 / var1 + else: + f_stat = float("inf") + + df1 = len(values2) - 1 + df2 = len(values1) - 1 + + # p-value (one-tailed: var2 > var1) + p_value = 1 - stats.f.cdf(f_stat, df1, df2) + + return { + "test": "variance_ratio", + "valid": True, + "var1": var1, + "var2": var2, + "f_statistic": f_stat, + "p_value": p_value, + "significant": p_value < alpha, + "alpha": alpha, + } + + +def test_mean_difference( + values1: List[float], + values2: List[float], + alpha: float = 0.05 +) -> Dict[str, Any]: + """ + Test if means are significantly different (two-tailed t-test). + """ + if len(values1) < 2 or len(values2) < 2: + return { + "test": "mean_difference", + "valid": False, + "reason": "Not enough samples", + } + + t_stat, p_value = stats.ttest_ind(values1, values2) + + return { + "test": "mean_difference", + "valid": True, + "mean1": float(np.mean(values1)), + "mean2": float(np.mean(values2)), + "t_statistic": float(t_stat), + "p_value": float(p_value), + "significant": p_value < alpha, + "alpha": alpha, + } + + +# ============================================================================ +# Hypothesis Verification +# ============================================================================ + +def verify_hypothesis_1( + aggregated: Dict[str, Dict[str, List[float]]], + on_task_names: List[str] +) -> Dict[str, Any]: + """ + Verify Hypothesis 1: On-task performance is insensitive to precision. + + Expected: + - E[ΔJ_0^{fp32}] ≈ E[ΔJ_0^{bf16}] > 0 + - Var[ΔJ_0^{fp32}] and Var[ΔJ_0^{bf16}] are both small + """ + results = { + "hypothesis": "On-task performance insensitive to precision", + "tasks": {}, + } + + fp32_data = aggregated.get("fp32", {}) + bf16_data = aggregated.get("bf16", {}) + + for task_name in on_task_names: + key = f"{task_name}_delta_j" + + fp32_values = fp32_data.get(key, []) + bf16_values = bf16_data.get(key, []) + + task_result = { + "fp32": compute_statistics(fp32_values), + "bf16": compute_statistics(bf16_values), + } + + # Test mean difference (should NOT be significant) + mean_test = test_mean_difference(fp32_values, bf16_values) + task_result["mean_test"] = mean_test + + # Test variance ratio (should NOT show bf16 >> fp32) + var_test = test_variance_ratio(fp32_values, bf16_values) + task_result["variance_test"] = var_test + + # Verify expected pattern + fp32_mean = task_result["fp32"]["mean"] + bf16_mean = task_result["bf16"]["mean"] + fp32_var = task_result["fp32"]["var"] + bf16_var = task_result["bf16"]["var"] + + task_result["verification"] = { + "means_similar": abs(fp32_mean - bf16_mean) < 0.05, # Within 5% + "both_positive": fp32_mean > 0 and bf16_mean > 0, + "variances_small": fp32_var < 0.01 and bf16_var < 0.01, + "hypothesis_supported": ( + abs(fp32_mean - bf16_mean) < 0.05 and + fp32_mean > 0 and bf16_mean > 0 + ), + } + + results["tasks"][task_name] = task_result + + # Overall verdict + all_supported = all( + t["verification"]["hypothesis_supported"] + for t in results["tasks"].values() + ) + results["overall_supported"] = all_supported + + return results + + +def verify_hypothesis_2( + aggregated: Dict[str, Dict[str, List[float]]], + off_task_names: List[str], + on_task_names: List[str] +) -> Dict[str, Any]: + """ + Verify Hypothesis 2: Off-task performance is sensitive to precision. + + Expected: + - Var[ΔJ_k^{bf16}] >> Var[ΔJ_k^{fp32}] >> Var[ΔJ_0] + """ + results = { + "hypothesis": "Off-task performance sensitive to precision", + "tasks": {}, + } + + fp32_data = aggregated.get("fp32", {}) + bf16_data = aggregated.get("bf16", {}) + + # Get on-task variance for comparison + on_task_variances = [] + for task_name in on_task_names: + key = f"{task_name}_delta_j" + for precision_data in [fp32_data, bf16_data]: + values = precision_data.get(key, []) + if values: + on_task_variances.append(np.var(values)) + + on_task_avg_var = np.mean(on_task_variances) if on_task_variances else 0.0 + + for task_name in off_task_names: + key = f"{task_name}_delta_j" + + fp32_values = fp32_data.get(key, []) + bf16_values = bf16_data.get(key, []) + + task_result = { + "fp32": compute_statistics(fp32_values), + "bf16": compute_statistics(bf16_values), + "on_task_avg_var": on_task_avg_var, + } + + # Test variance ratio (bf16 should be >> fp32) + var_test = test_variance_ratio(fp32_values, bf16_values) + task_result["variance_test"] = var_test + + # Verify expected pattern + fp32_var = task_result["fp32"]["var"] + bf16_var = task_result["bf16"]["var"] + + task_result["verification"] = { + "bf16_var_gt_fp32": bf16_var > fp32_var, + "bf16_var_gt_fp32_by_5x": bf16_var > 5 * fp32_var if fp32_var > 0 else False, + "fp32_var_gt_ontask": fp32_var > on_task_avg_var, + "variance_ratio": bf16_var / fp32_var if fp32_var > 0 else float("inf"), + "hypothesis_supported": bf16_var > fp32_var, + } + + results["tasks"][task_name] = task_result + + # Count how many tasks show expected pattern + supported_count = sum( + 1 for t in results["tasks"].values() + if t["verification"]["hypothesis_supported"] + ) + results["num_tasks_supported"] = supported_count + results["num_tasks_total"] = len(off_task_names) + results["overall_supported"] = supported_count > len(off_task_names) // 2 + + return results + + +def verify_hypothesis_3( + aggregated: Dict[str, Dict[str, List[float]]], + on_task_names: List[str], + off_task_names: List[str] +) -> Dict[str, Any]: + """ + Verify Hypothesis 3: KL divergence patterns. + + Expected: + - On-task KL is similar between fp32 and bf16 (DAPO implicit leash) + - Off-task KL has higher variance in bf16 + """ + results = { + "hypothesis": "KL divergence patterns differ by task type", + "on_task": {}, + "off_task": {}, + } + + fp32_data = aggregated.get("fp32", {}) + bf16_data = aggregated.get("bf16", {}) + + # On-task KL analysis + for task_name in on_task_names: + key = f"{task_name}_kl" + + fp32_values = fp32_data.get(key, []) + bf16_values = bf16_data.get(key, []) + + task_result = { + "fp32": compute_statistics(fp32_values), + "bf16": compute_statistics(bf16_values), + } + + mean_test = test_mean_difference(fp32_values, bf16_values) + task_result["mean_test"] = mean_test + + # Verify: KL should be similar (implicit leash working) + task_result["kl_similar"] = not mean_test.get("significant", True) + + results["on_task"][task_name] = task_result + + # Off-task KL analysis + for task_name in off_task_names: + key = f"{task_name}_kl" + + fp32_values = fp32_data.get(key, []) + bf16_values = bf16_data.get(key, []) + + task_result = { + "fp32": compute_statistics(fp32_values), + "bf16": compute_statistics(bf16_values), + } + + var_test = test_variance_ratio(fp32_values, bf16_values) + task_result["variance_test"] = var_test + + # Verify: bf16 should have higher variance + task_result["bf16_higher_variance"] = var_test.get("significant", False) + + results["off_task"][task_name] = task_result + + # Overall assessment + on_task_similar = all( + t.get("kl_similar", False) for t in results["on_task"].values() + ) + off_task_variance_higher = sum( + 1 for t in results["off_task"].values() + if t.get("bf16_higher_variance", False) + ) + + results["summary"] = { + "on_task_kl_similar": on_task_similar, + "off_task_higher_variance_count": off_task_variance_higher, + "off_task_total": len(off_task_names), + } + + return results + + +# ============================================================================ +# Visualization +# ============================================================================ + +def plot_delta_j_comparison( + aggregated: Dict[str, Dict[str, List[float]]], + task_names: List[str], + output_path: str +) -> None: + """Plot ΔJ comparison between precision modes.""" + if not HAS_MATPLOTLIB: + logger.warning("matplotlib not available, skipping plot") + return + + fig, ax = plt.subplots(figsize=(12, 6)) + + x = np.arange(len(task_names)) + width = 0.35 + + fp32_data = aggregated.get("fp32", {}) + bf16_data = aggregated.get("bf16", {}) + + fp32_means = [] + fp32_stds = [] + bf16_means = [] + bf16_stds = [] + + for task_name in task_names: + key = f"{task_name}_delta_j" + + fp32_values = fp32_data.get(key, [0]) + bf16_values = bf16_data.get(key, [0]) + + fp32_means.append(np.mean(fp32_values)) + fp32_stds.append(np.std(fp32_values)) + bf16_means.append(np.mean(bf16_values)) + bf16_stds.append(np.std(bf16_values)) + + ax.bar(x - width/2, fp32_means, width, yerr=fp32_stds, + label='FP32', color='steelblue', capsize=5) + ax.bar(x + width/2, bf16_means, width, yerr=bf16_stds, + label='bf16', color='coral', capsize=5) + + ax.set_ylabel('ΔJ (Performance Delta)') + ax.set_xlabel('Task') + ax.set_title('Performance Delta by Precision Mode') + ax.set_xticks(x) + ax.set_xticklabels(task_names, rotation=45, ha='right') + ax.legend() + ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) + + plt.tight_layout() + plt.savefig(output_path, dpi=150) + plt.close() + + logger.info(f"Saved plot to {output_path}") + + +def plot_variance_comparison( + aggregated: Dict[str, Dict[str, List[float]]], + task_names: List[str], + output_path: str +) -> None: + """Plot variance comparison between precision modes.""" + if not HAS_MATPLOTLIB: + logger.warning("matplotlib not available, skipping plot") + return + + fig, ax = plt.subplots(figsize=(12, 6)) + + x = np.arange(len(task_names)) + width = 0.35 + + fp32_data = aggregated.get("fp32", {}) + bf16_data = aggregated.get("bf16", {}) + + fp32_vars = [] + bf16_vars = [] + + for task_name in task_names: + key = f"{task_name}_delta_j" + + fp32_values = fp32_data.get(key, [0]) + bf16_values = bf16_data.get(key, [0]) + + fp32_vars.append(np.var(fp32_values)) + bf16_vars.append(np.var(bf16_values)) + + ax.bar(x - width/2, fp32_vars, width, label='FP32', color='steelblue') + ax.bar(x + width/2, bf16_vars, width, label='bf16', color='coral') + + ax.set_ylabel('Variance of ΔJ') + ax.set_xlabel('Task') + ax.set_title('Variance of Performance Delta by Precision Mode') + ax.set_xticks(x) + ax.set_xticklabels(task_names, rotation=45, ha='right') + ax.legend() + ax.set_yscale('log') + + plt.tight_layout() + plt.savefig(output_path, dpi=150) + plt.close() + + logger.info(f"Saved plot to {output_path}") + + +# ============================================================================ +# Main Analysis +# ============================================================================ + +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Analyze RLVR floating-point precision experiment results" + ) + parser.add_argument( + "--results_dir", + type=str, + required=True, + help="Directory containing evaluation results" + ) + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Directory to save analysis outputs" + ) + parser.add_argument( + "--on_task", + type=str, + nargs="+", + default=["dm_val"], + help="On-task (training distribution) task names" + ) + parser.add_argument( + "--off_task", + type=str, + nargs="+", + default=["aime24", "aime25", "amc23", "math500", "mmlu_stem", "humaneval"], + help="Off-task task names" + ) + return parser.parse_args() + + +def main() -> None: + """Main analysis function.""" + args = parse_args() + + # Create output directory + os.makedirs(args.output_dir, exist_ok=True) + + # Load results + logger.info(f"Loading results from {args.results_dir}") + eval_results = load_eval_results(args.results_dir) + sparsity_results = load_sparsity_results(args.results_dir) + + if not eval_results: + logger.error("No evaluation results found!") + return + + # Aggregate by precision mode + aggregated = aggregate_by_precision(eval_results) + + # Get all task names from results + all_tasks = set() + for run_results in eval_results.values(): + all_tasks.update(run_results.get("tasks", {}).keys()) + + # Filter task names to those present in results + on_task_names = [t for t in args.on_task if t in all_tasks] + off_task_names = [t for t in args.off_task if t in all_tasks] + + logger.info(f"On-task: {on_task_names}") + logger.info(f"Off-task: {off_task_names}") + + # Verify hypotheses + analysis = {} + + logger.info("\n" + "="*60) + logger.info("HYPOTHESIS 1: On-task insensitivity") + logger.info("="*60) + h1_result = verify_hypothesis_1(aggregated, on_task_names) + analysis["hypothesis_1"] = h1_result + logger.info(f"Supported: {h1_result['overall_supported']}") + + logger.info("\n" + "="*60) + logger.info("HYPOTHESIS 2: Off-task sensitivity") + logger.info("="*60) + h2_result = verify_hypothesis_2(aggregated, off_task_names, on_task_names) + analysis["hypothesis_2"] = h2_result + logger.info(f"Supported: {h2_result['overall_supported']} " + f"({h2_result['num_tasks_supported']}/{h2_result['num_tasks_total']} tasks)") + + logger.info("\n" + "="*60) + logger.info("HYPOTHESIS 3: KL divergence patterns") + logger.info("="*60) + h3_result = verify_hypothesis_3(aggregated, on_task_names, off_task_names) + analysis["hypothesis_3"] = h3_result + logger.info(f"On-task KL similar: {h3_result['summary']['on_task_kl_similar']}") + + # Add sparsity analysis + if sparsity_results: + sparsity_summary = {} + for name, data in sparsity_results.items(): + sparsity_info = data.get("sparsity", {}) + sparsity_summary[name] = { + "sparsity_percent": sparsity_info.get("sparsity_percent", 0), + "num_changed": sparsity_info.get("num_changed", 0), + } + analysis["bf16_sparsity"] = sparsity_summary + + # Save full analysis + analysis_path = os.path.join(args.output_dir, "full_analysis.json") + with open(analysis_path, "w") as f: + json.dump(analysis, f, indent=2, default=str) + logger.info(f"Saved analysis to {analysis_path}") + + # Generate plots + all_task_names = on_task_names + off_task_names + + plot_delta_j_comparison( + aggregated, + all_task_names, + os.path.join(args.output_dir, "delta_j_comparison.png") + ) + + plot_variance_comparison( + aggregated, + all_task_names, + os.path.join(args.output_dir, "variance_comparison.png") + ) + + # Print summary + print("\n" + "="*80) + print("ANALYSIS SUMMARY") + print("="*80) + + print("\nHypothesis 1 (On-task insensitivity):") + print(f" Supported: {h1_result['overall_supported']}") + + print("\nHypothesis 2 (Off-task sensitivity):") + print(f" Supported: {h2_result['overall_supported']}") + print(f" Tasks showing expected pattern: {h2_result['num_tasks_supported']}/{h2_result['num_tasks_total']}") + + print("\nHypothesis 3 (KL patterns):") + print(f" On-task KL similar across precision: {h3_result['summary']['on_task_kl_similar']}") + print(f" Off-task with higher bf16 variance: {h3_result['summary']['off_task_higher_variance_count']}/{h3_result['summary']['off_task_total']}") + + if sparsity_results: + print("\nbf16 Sparsity:") + for name, data in sorted(sparsity_summary.items()): + print(f" {name}: {data['sparsity_percent']:.1f}% sparse") + + print("="*80) + + +if __name__ == "__main__": + main() + diff --git a/config.py b/config.py new file mode 100644 index 0000000..898776c --- /dev/null +++ b/config.py @@ -0,0 +1,328 @@ +# config.py +""" +Configuration definitions for RLVR floating-point precision experiments. + +This module defines configurations for: +- Training parameters (DAPO algorithm, hyperparameters) +- Precision settings (FP32 vs bf16) +- Evaluation task specifications +""" + +from dataclasses import dataclass, field +from typing import List, Dict, Any, Optional +import os + + +@dataclass +class TrainingConfig: + """Configuration for RLVR training with DAPO algorithm.""" + + # Model specification + model_name: str = "Qwen/Qwen2.5-Math-7B" + + # Precision mode: "fp32" for high precision, "bf16" for default RLVR + precision_mode: str = "bf16" + + # Batch configuration (sized for single H200 GPU with gradient checkpointing) + global_batch_size: int = 32 # Reduced from 256 for single GPU + micro_batch_size: int = 4 # Reduced further for fp32 memory safety + grad_accumulation_steps: int = 8 # Increased to maintain effective batch size + + # Rollout configuration + num_rollouts_per_prompt: int = 4 # Reduced from 16 for speed + max_seq_len: int = 2048 # Reduced from 8192 (GSM8K answers are short) + + # Training steps and checkpointing + # Note: With sequential generation, each step takes ~18 min on H200 + # 150 steps ≈ 45 hours, fits in 2-day limit with buffer + num_steps: int = 150 + checkpoint_steps: List[int] = field(default_factory=lambda: [0, 50, 100, 150]) + + # Optimizer configuration (AdamW) + learning_rate: float = 1e-6 + beta1: float = 0.9 + beta2: float = 0.999 + weight_decay: float = 0.01 + + # RL algorithm + rl_algorithm: str = "dapo" + clip_ratio: float = 0.2 # DAPO clip parameter + kl_coef: float = 0.0 # DAPO uses clip-only, no explicit KL penalty + + # Reproducibility + # IMPORTANT: Keep these constant across precision modes to isolate precision effects + # We maximize determinism so the ONLY variance source is floating-point precision + seed: int = 1 + use_dropout: bool = False # Disabled to reduce stochasticity + use_deterministic_algorithms: bool = True # Enabled for reproducibility + + # Paths + output_dir: str = "./results/train_logs" + train_dataset_path: str = "./data/dm_train.json" + + # GPU configuration (single GPU for this implementation) + num_gpus: int = 1 # Current implementation is single-GPU + + def __post_init__(self): + """ + Note: We intentionally keep dropout and determinism settings CONSTANT + across precision modes to isolate the effect of floating-point precision. + + Previously this coupled dropout=False with fp32 and dropout=True with bf16, + which confounded precision effects with stochasticity effects. + + To study pure precision effects: + - Both modes use the SAME dropout setting (default: False) + - Both modes use the SAME determinism setting (default: False for speed) + + The only difference between fp32 and bf16 should be param_dtype. + """ + # Don't modify settings based on precision_mode - keep them independent + pass + + +@dataclass +class PrecisionConfig: + """Configuration for floating-point precision settings.""" + + # Parameter storage dtype + param_dtype: str = "bfloat16" # "float32" or "bfloat16" + + # Automatic mixed precision + use_amp: bool = True + amp_dtype: str = "bfloat16" # "float16" or "bfloat16" + + # Gradient and optimizer state always in FP32 + grad_dtype: str = "float32" + optimizer_dtype: str = "float32" + + # Deterministic algorithms + deterministic: bool = False + + # CUDNN settings + cudnn_benchmark: bool = True + cudnn_deterministic: bool = False + + +@dataclass +class EvalTaskConfig: + """Configuration for a single evaluation task.""" + + # Task identification + name: str = "" + task_type: str = "math" # "math", "code", "qa", "general" + + # Dataset + dataset_path: str = "" + num_samples: int = -1 # -1 means use all samples + + # Whether task has verifiable answers (math problems) + is_verifiable: bool = True + + # Metric type for non-verifiable tasks + metric_type: str = "accuracy" # "accuracy", "bleu", "rouge", "score" + + # Generation parameters + max_gen_len: int = 2048 + temperature: float = 0.7 + top_p: float = 0.8 + num_samples_per_prompt: int = 1 + + +@dataclass +class ExperimentConfig: + """Master configuration for the entire experiment.""" + + # Experiment identification + experiment_name: str = "fp_precision_rlvr" + + # Seeds for multiple runs + seeds: List[int] = field(default_factory=lambda: [1, 2, 3, 4, 5]) + + # Precision modes to compare + precision_modes: List[str] = field(default_factory=lambda: ["fp32", "bf16"]) + + # Base model checkpoint (shared starting point) + base_model_path: str = "Qwen/Qwen2.5-Math-7B" + + # Output directories + base_output_dir: str = "./results" + train_logs_dir: str = "./results/train_logs" + checkpoints_dir: str = "./results/checkpoints" + eval_metrics_dir: str = "./results/eval_metrics" + + # Evaluation configuration + eval_tasks_config_path: str = "./configs/eval_tasks_config.json" + + # bf16 sparsity analysis + bf16_sparsity_eta: float = 1e-3 + + +def make_precision_config(precision_mode: str) -> PrecisionConfig: + """ + Create precision configuration based on mode. + + IMPORTANT: Only the precision-related settings differ between modes. + All other settings (determinism, cudnn) are kept CONSTANT to isolate + the effect of floating-point precision on training outcomes. + + Args: + precision_mode: "fp32" for high precision, "bf16" for default RLVR + + Returns: + PrecisionConfig with appropriate settings + """ + # Common settings for both modes (to avoid confounds) + # Maximize determinism so precision is the ONLY source of variance + common_settings = { + "grad_dtype": "float32", + "optimizer_dtype": "float32", + "deterministic": True, # Enable deterministic algorithms + "cudnn_benchmark": False, # Disable for reproducibility + "cudnn_deterministic": True, # Enable for reproducibility + } + + if precision_mode == "fp32": + return PrecisionConfig( + param_dtype="float32", + use_amp=False, # No AMP needed for fp32 + amp_dtype="float32", + **common_settings + ) + elif precision_mode == "bf16": + return PrecisionConfig( + param_dtype="bfloat16", + use_amp=True, + amp_dtype="bfloat16", + **common_settings + ) + else: + raise ValueError(f"Unknown precision_mode: {precision_mode}") + + +def make_training_config( + precision_mode: str, + seed: int, + output_dir: str, + train_dataset_path: str, + model_name: str = "Qwen/Qwen2.5-Math-7B" +) -> TrainingConfig: + """ + Create training configuration for a specific run. + + Args: + precision_mode: "fp32" or "bf16" + seed: Random seed for this run + output_dir: Directory to save outputs + train_dataset_path: Path to training data + model_name: HuggingFace model identifier + + Returns: + TrainingConfig with all parameters set + """ + config = TrainingConfig( + model_name=model_name, + precision_mode=precision_mode, + seed=seed, + output_dir=output_dir, + train_dataset_path=train_dataset_path + ) + return config + + +def get_run_output_dir( + base_dir: str, + precision_mode: str, + seed: int +) -> str: + """Get output directory for a specific run.""" + return os.path.join(base_dir, f"{precision_mode}_seed{seed}") + + +def get_checkpoint_path( + output_dir: str, + step: Optional[int] = None +) -> str: + """Get checkpoint path for a specific step (None = final).""" + if step is None: + return os.path.join(output_dir, "final_model") + return os.path.join(output_dir, f"checkpoint_step{step}") + + +# Default evaluation tasks for the experiment +DEFAULT_EVAL_TASKS = [ + # On-task: Training distribution + EvalTaskConfig( + name="dm_val", + task_type="math", + dataset_path="./data/dm_val.json", + is_verifiable=True, + metric_type="accuracy", + max_gen_len=2048, + temperature=0.7, + top_p=0.8 + ), + # In-domain OOD: Math benchmarks + EvalTaskConfig( + name="aime24", + task_type="math", + dataset_path="./data/aime24.json", + is_verifiable=True, + metric_type="accuracy", + max_gen_len=4096, + temperature=0.7, + top_p=0.8 + ), + EvalTaskConfig( + name="aime25", + task_type="math", + dataset_path="./data/aime25.json", + is_verifiable=True, + metric_type="accuracy", + max_gen_len=4096, + temperature=0.7, + top_p=0.8 + ), + EvalTaskConfig( + name="amc23", + task_type="math", + dataset_path="./data/amc23.json", + is_verifiable=True, + metric_type="accuracy", + max_gen_len=2048, + temperature=0.7, + top_p=0.8 + ), + EvalTaskConfig( + name="math500", + task_type="math", + dataset_path="./data/math500.json", + is_verifiable=True, + metric_type="accuracy", + max_gen_len=2048, + temperature=0.7, + top_p=0.8 + ), + # Off-domain: General tasks + EvalTaskConfig( + name="mmlu_stem", + task_type="qa", + dataset_path="./data/mmlu_stem.json", + is_verifiable=True, + metric_type="accuracy", + max_gen_len=512, + temperature=0.3, + top_p=0.9 + ), + EvalTaskConfig( + name="humaneval", + task_type="code", + dataset_path="./data/humaneval.json", + is_verifiable=True, + metric_type="accuracy", + max_gen_len=1024, + temperature=0.2, + top_p=0.95 + ), +] + diff --git a/configs/deepspeed_zero2.json b/configs/deepspeed_zero2.json new file mode 100644 index 0000000..bb7f7aa --- /dev/null +++ b/configs/deepspeed_zero2.json @@ -0,0 +1,31 @@ +{ + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "gradient_accumulation_steps": "auto", + + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "none" + }, + "contiguous_gradients": true, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 5e8, + "allgather_bucket_size": 5e8 + }, + + "bf16": { + "enabled": false + }, + + "fp16": { + "enabled": false + }, + + "gradient_clipping": 1.0, + + "zero_allow_untested_optimizer": true, + + "wall_clock_breakdown": false +} diff --git a/configs/deepspeed_zero3.json b/configs/deepspeed_zero3.json new file mode 100644 index 0000000..6e68c8f --- /dev/null +++ b/configs/deepspeed_zero3.json @@ -0,0 +1,38 @@ +{ + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "gradient_accumulation_steps": "auto", + + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "none" + }, + "offload_param": { + "device": "none" + }, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": true + }, + + "bf16": { + "enabled": false + }, + + "fp16": { + "enabled": false + }, + + "gradient_clipping": 1.0, + + "zero_allow_untested_optimizer": true, + + "wall_clock_breakdown": false +} diff --git a/configs/eval_tasks_config.json b/configs/eval_tasks_config.json new file mode 100644 index 0000000..e0dda43 --- /dev/null +++ b/configs/eval_tasks_config.json @@ -0,0 +1,99 @@ +[ + { + "name": "dm_val", + "task_type": "math", + "dataset_path": "./data/dm_val.json", + "is_verifiable": true, + "metric_type": "accuracy", + "num_samples": -1, + "max_gen_len": 2048, + "temperature": 0.7, + "top_p": 0.8, + "num_samples_per_prompt": 1 + }, + { + "name": "aime24", + "task_type": "math", + "dataset_path": "./data/aime24.json", + "is_verifiable": true, + "metric_type": "accuracy", + "num_samples": -1, + "max_gen_len": 4096, + "temperature": 0.7, + "top_p": 0.8, + "num_samples_per_prompt": 1 + }, + { + "name": "aime25", + "task_type": "math", + "dataset_path": "./data/aime25.json", + "is_verifiable": true, + "metric_type": "accuracy", + "num_samples": -1, + "max_gen_len": 4096, + "temperature": 0.7, + "top_p": 0.8, + "num_samples_per_prompt": 1 + }, + { + "name": "amc23", + "task_type": "math", + "dataset_path": "./data/amc23.json", + "is_verifiable": true, + "metric_type": "accuracy", + "num_samples": -1, + "max_gen_len": 2048, + "temperature": 0.7, + "top_p": 0.8, + "num_samples_per_prompt": 1 + }, + { + "name": "math500", + "task_type": "math", + "dataset_path": "./data/math500.json", + "is_verifiable": true, + "metric_type": "accuracy", + "num_samples": 500, + "max_gen_len": 2048, + "temperature": 0.7, + "top_p": 0.8, + "num_samples_per_prompt": 1 + }, + { + "name": "gsm8k", + "task_type": "math", + "dataset_path": "./data/gsm8k.json", + "is_verifiable": true, + "metric_type": "accuracy", + "num_samples": 500, + "max_gen_len": 1024, + "temperature": 0.7, + "top_p": 0.8, + "num_samples_per_prompt": 1 + }, + { + "name": "mmlu_stem", + "task_type": "qa", + "dataset_path": "./data/mmlu_stem.json", + "is_verifiable": true, + "metric_type": "accuracy", + "num_samples": 500, + "max_gen_len": 512, + "temperature": 0.3, + "top_p": 0.9, + "num_samples_per_prompt": 1 + }, + { + "name": "humaneval", + "task_type": "code", + "dataset_path": "./data/humaneval.json", + "is_verifiable": true, + "metric_type": "accuracy", + "num_samples": 164, + "max_gen_len": 1024, + "temperature": 0.2, + "top_p": 0.95, + "num_samples_per_prompt": 1 + } +] + diff --git a/eval_policy.py b/eval_policy.py new file mode 100644 index 0000000..cc30209 --- /dev/null +++ b/eval_policy.py @@ -0,0 +1,621 @@ +#!/usr/bin/env python3 +# eval_policy.py +""" +Policy Evaluation Script for RLVR Experiments. + +This script evaluates trained models on multiple tasks, computing: +- J_k: Task performance (pass@1 accuracy for verifiable tasks) +- KL_k: KL divergence from base model + +Usage: + python eval_policy.py \ + --base_ckpt Qwen/Qwen2.5-Math-7B \ + --ft_ckpt results/train_logs/fp32_seed1/final_model \ + --eval_tasks_config configs/eval_tasks_config.json \ + --output_path results/eval_metrics/fp32_seed1.json +""" + +import argparse +import json +import os +import logging +from typing import Dict, Any, List, Optional, Tuple +from dataclasses import dataclass, asdict + +import numpy as np +import torch +from torch.cuda.amp import autocast +from transformers import AutoModelForCausalLM, AutoTokenizer +from tqdm import tqdm + +from config import EvalTaskConfig + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Data Loading +# ============================================================================ + +def load_eval_tasks(eval_config_path: str) -> List[EvalTaskConfig]: + """Load evaluation task configurations from JSON file.""" + with open(eval_config_path, "r", encoding="utf-8") as f: + data = json.load(f) + + tasks: List[EvalTaskConfig] = [] + for task_item in data: + task = EvalTaskConfig( + name=task_item.get("name", ""), + task_type=task_item.get("task_type", "math"), + dataset_path=task_item.get("dataset_path", ""), + is_verifiable=task_item.get("is_verifiable", True), + metric_type=task_item.get("metric_type", "accuracy"), + num_samples=task_item.get("num_samples", -1), + max_gen_len=task_item.get("max_gen_len", 2048), + temperature=task_item.get("temperature", 0.7), + top_p=task_item.get("top_p", 0.8), + num_samples_per_prompt=task_item.get("num_samples_per_prompt", 1) + ) + tasks.append(task) + + logger.info(f"Loaded {len(tasks)} evaluation tasks from {eval_config_path}") + return tasks + + +def load_dataset(dataset_path: str, num_samples: int = -1) -> List[Dict[str, Any]]: + """Load evaluation dataset from JSON file.""" + with open(dataset_path, "r", encoding="utf-8") as f: + data = json.load(f) + + if num_samples > 0 and num_samples < len(data): + data = data[:num_samples] + + logger.info(f"Loaded {len(data)} examples from {dataset_path}") + return data + + +# ============================================================================ +# Answer Verification +# ============================================================================ + +def extract_boxed_answer(text: str) -> Optional[str]: + """Extract answer from \\boxed{} format.""" + import re + + # Find all \boxed{...} patterns + pattern = r"\\boxed\{([^}]*)\}" + matches = re.findall(pattern, text) + + if matches: + return matches[-1].strip() # Return last match + + return None + + +def extract_final_answer(text: str) -> Optional[str]: + """Extract final answer using various heuristics.""" + # Try boxed format first + boxed = extract_boxed_answer(text) + if boxed: + return boxed + + # Common answer patterns + patterns = [ + r"[Tt]he (?:final )?answer is[:\s]+(.+?)(?:\.|$)", + r"[Tt]herefore[,:\s]+(?:the answer is[:\s]+)?(.+?)(?:\.|$)", + r"[Ss]o[,:\s]+(?:the answer is[:\s]+)?(.+?)(?:\.|$)", + r"[Hh]ence[,:\s]+(.+?)(?:\.|$)", + r"=\s*(.+?)$", + ] + + import re + for pattern in patterns: + match = re.search(pattern, text, re.MULTILINE) + if match: + return match.group(1).strip() + + return None + + +def normalize_answer(answer: str) -> str: + """Normalize answer for comparison.""" + if answer is None: + return "" + + # Convert to lowercase, remove whitespace + normalized = answer.lower().strip() + + # Remove common formatting + normalized = normalized.replace(",", "") + normalized = normalized.replace("$", "") + normalized = normalized.replace("%", "") + + # Try to extract numeric value + import re + numeric_match = re.search(r"-?\d+\.?\d*", normalized) + if numeric_match: + return numeric_match.group() + + return normalized + + +def verify_math_answer( + response: str, + ground_truth: str +) -> bool: + """ + Verify if the response contains the correct answer. + + This is a simplified verifier. For production use, replace with + Eval-Chemy or a more sophisticated verification system. + """ + # Extract answers + predicted = extract_final_answer(response) + + if predicted is None: + return False + + # Normalize for comparison + pred_normalized = normalize_answer(predicted) + gt_normalized = normalize_answer(ground_truth) + + # Direct comparison + if pred_normalized == gt_normalized: + return True + + # Try numeric comparison + try: + pred_num = float(pred_normalized) + gt_num = float(gt_normalized) + if abs(pred_num - gt_num) < 1e-6: + return True + except ValueError: + pass + + return False + + +# ============================================================================ +# KL Divergence Computation +# ============================================================================ + +def compute_sequence_kl( + finetuned_model: torch.nn.Module, + base_model: torch.nn.Module, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + response_start_idx: int, + device: torch.device +) -> Tuple[float, int]: + """ + Compute KL divergence for a single sequence. + + KL(π_ft || π_base) ≈ Σ_t [log π_ft(y_t|x,y_{ TaskResult: + """ + Evaluate a single task. + + Computes: + - avg_score: Mean accuracy (for verifiable tasks) + - avg_kl: Mean KL divergence from base model + """ + dataset = load_dataset(task_config.dataset_path, task_config.num_samples) + + scores: List[float] = [] + kl_values: List[float] = [] + response_lengths: List[int] = [] + + finetuned_model.eval() + base_model.eval() + + amp_dtype = torch.bfloat16 if use_amp else torch.float32 + + for example in tqdm(dataset, desc=f"Evaluating {task_config.name}"): + prompt = example.get("prompt", example.get("question", "")) + ground_truth = example.get("answer", example.get("solution", None)) + + # Tokenize prompt + inputs = finetuned_tokenizer( + prompt, + return_tensors="pt", + truncation=True, + max_length=4096 + ) + input_ids = inputs["input_ids"].to(device) + attention_mask = inputs["attention_mask"].to(device) + prompt_len = input_ids.shape[1] + + # Generate response + with torch.no_grad(): + with autocast(enabled=use_amp, dtype=amp_dtype): + generated_ids = finetuned_model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_new_tokens=task_config.max_gen_len, + do_sample=True, + temperature=task_config.temperature, + top_p=task_config.top_p, + pad_token_id=finetuned_tokenizer.eos_token_id + ) + + # Decode response + response_ids = generated_ids[:, prompt_len:] + response_text = finetuned_tokenizer.batch_decode( + response_ids, + skip_special_tokens=True + )[0] + + response_lengths.append(len(response_ids[0])) + + # Compute score (accuracy for verifiable tasks) + if task_config.is_verifiable and ground_truth is not None: + is_correct = verify_math_answer(response_text, str(ground_truth)) + score = 1.0 if is_correct else 0.0 + else: + # For non-verifiable tasks, use placeholder + score = 0.0 + + scores.append(score) + + # Compute KL divergence + full_ids = generated_ids + full_attention = torch.ones_like(full_ids, device=device) + + kl_sum, num_tokens = compute_sequence_kl( + finetuned_model=finetuned_model, + base_model=base_model, + input_ids=full_ids, + attention_mask=full_attention, + response_start_idx=prompt_len, + device=device + ) + + if num_tokens > 0: + avg_kl_per_token = kl_sum / num_tokens + else: + avg_kl_per_token = 0.0 + + kl_values.append(kl_sum) # Total KL for sequence + + # Compute statistics + result = TaskResult( + task_name=task_config.name, + task_type=task_config.task_type, + num_examples=len(dataset), + avg_score=float(np.mean(scores)) if scores else 0.0, + std_score=float(np.std(scores)) if scores else 0.0, + avg_kl=float(np.mean(kl_values)) if kl_values else 0.0, + std_kl=float(np.std(kl_values)) if kl_values else 0.0, + avg_response_length=float(np.mean(response_lengths)) if response_lengths else 0.0, + scores=scores, + kl_values=kl_values + ) + + logger.info( + f"Task {task_config.name}: " + f"Score={result.avg_score:.4f} (±{result.std_score:.4f}), " + f"KL={result.avg_kl:.4f} (±{result.std_kl:.4f})" + ) + + return result + + +def evaluate_base_model( + base_model: torch.nn.Module, + base_tokenizer, + task_config: EvalTaskConfig, + device: torch.device, + use_amp: bool = True +) -> Dict[str, float]: + """Evaluate the base model (for computing ΔJ).""" + dataset = load_dataset(task_config.dataset_path, task_config.num_samples) + + scores: List[float] = [] + base_model.eval() + + amp_dtype = torch.bfloat16 if use_amp else torch.float32 + + for example in tqdm(dataset, desc=f"Evaluating base on {task_config.name}"): + prompt = example.get("prompt", example.get("question", "")) + ground_truth = example.get("answer", example.get("solution", None)) + + inputs = base_tokenizer( + prompt, + return_tensors="pt", + truncation=True, + max_length=4096 + ) + input_ids = inputs["input_ids"].to(device) + attention_mask = inputs["attention_mask"].to(device) + + with torch.no_grad(): + with autocast(enabled=use_amp, dtype=amp_dtype): + generated_ids = base_model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_new_tokens=task_config.max_gen_len, + do_sample=True, + temperature=task_config.temperature, + top_p=task_config.top_p, + pad_token_id=base_tokenizer.eos_token_id + ) + + response_ids = generated_ids[:, input_ids.shape[1]:] + response_text = base_tokenizer.batch_decode( + response_ids, + skip_special_tokens=True + )[0] + + if task_config.is_verifiable and ground_truth is not None: + is_correct = verify_math_answer(response_text, str(ground_truth)) + score = 1.0 if is_correct else 0.0 + else: + score = 0.0 + + scores.append(score) + + return { + "avg_score": float(np.mean(scores)) if scores else 0.0, + "std_score": float(np.std(scores)) if scores else 0.0, + "num_examples": len(scores) + } + + +# ============================================================================ +# Main Evaluation Pipeline +# ============================================================================ + +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Evaluate RLVR trained models on multiple tasks" + ) + parser.add_argument( + "--base_ckpt", + type=str, + required=True, + help="Path to base model checkpoint" + ) + parser.add_argument( + "--ft_ckpt", + type=str, + required=True, + help="Path to finetuned model checkpoint" + ) + parser.add_argument( + "--eval_tasks_config", + type=str, + required=True, + help="Path to evaluation tasks configuration JSON" + ) + parser.add_argument( + "--output_path", + type=str, + required=True, + help="Path to save evaluation results" + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device to use for evaluation" + ) + parser.add_argument( + "--eval_base", + action="store_true", + help="Also evaluate base model (for computing delta J)" + ) + parser.add_argument( + "--use_amp", + action="store_true", + default=True, + help="Use automatic mixed precision" + ) + return parser.parse_args() + + +def main() -> None: + """Main evaluation function.""" + args = parse_args() + + device = torch.device(args.device if torch.cuda.is_available() else "cpu") + logger.info(f"Using device: {device}") + + # Load tokenizers + logger.info(f"Loading base tokenizer from {args.base_ckpt}") + base_tokenizer = AutoTokenizer.from_pretrained( + args.base_ckpt, + use_fast=True, + trust_remote_code=True + ) + if base_tokenizer.pad_token is None: + base_tokenizer.pad_token = base_tokenizer.eos_token + + logger.info(f"Loading finetuned tokenizer from {args.ft_ckpt}") + ft_tokenizer = AutoTokenizer.from_pretrained( + args.ft_ckpt, + use_fast=True, + trust_remote_code=True + ) + if ft_tokenizer.pad_token is None: + ft_tokenizer.pad_token = ft_tokenizer.eos_token + + # Load models + logger.info(f"Loading base model from {args.base_ckpt}") + base_model = AutoModelForCausalLM.from_pretrained( + args.base_ckpt, + torch_dtype=torch.bfloat16, + device_map=None, + trust_remote_code=True + ).to(device) + base_model.eval() + + logger.info(f"Loading finetuned model from {args.ft_ckpt}") + ft_model = AutoModelForCausalLM.from_pretrained( + args.ft_ckpt, + torch_dtype=torch.bfloat16, + device_map=None, + trust_remote_code=True + ).to(device) + ft_model.eval() + + # Load evaluation tasks + eval_tasks = load_eval_tasks(args.eval_tasks_config) + + # Evaluate on all tasks + all_results: Dict[str, Any] = { + "base_ckpt": args.base_ckpt, + "ft_ckpt": args.ft_ckpt, + "tasks": {} + } + + for task in eval_tasks: + logger.info(f"\n{'='*60}") + logger.info(f"Evaluating task: {task.name}") + logger.info(f"{'='*60}") + + # Evaluate finetuned model + result = evaluate_task( + base_model=base_model, + base_tokenizer=base_tokenizer, + finetuned_model=ft_model, + finetuned_tokenizer=ft_tokenizer, + task_config=task, + device=device, + use_amp=args.use_amp + ) + + task_results = { + "ft_avg_score": result.avg_score, + "ft_std_score": result.std_score, + "avg_kl": result.avg_kl, + "std_kl": result.std_kl, + "avg_response_length": result.avg_response_length, + "num_examples": result.num_examples, + } + + # Optionally evaluate base model + if args.eval_base: + base_result = evaluate_base_model( + base_model=base_model, + base_tokenizer=base_tokenizer, + task_config=task, + device=device, + use_amp=args.use_amp + ) + task_results["base_avg_score"] = base_result["avg_score"] + task_results["base_std_score"] = base_result["std_score"] + task_results["delta_j"] = result.avg_score - base_result["avg_score"] + + all_results["tasks"][task.name] = task_results + + # Save results + os.makedirs(os.path.dirname(args.output_path), exist_ok=True) + with open(args.output_path, "w", encoding="utf-8") as f: + json.dump(all_results, f, indent=2) + + logger.info(f"\nResults saved to {args.output_path}") + + # Print summary + print("\n" + "="*80) + print("EVALUATION SUMMARY") + print("="*80) + for task_name, task_result in all_results["tasks"].items(): + print(f"\n{task_name}:") + print(f" Score: {task_result['ft_avg_score']:.4f} (±{task_result['ft_std_score']:.4f})") + print(f" KL: {task_result['avg_kl']:.4f} (±{task_result['std_kl']:.4f})") + if "delta_j" in task_result: + print(f" ΔJ: {task_result['delta_j']:+.4f}") + print("="*80) + + +if __name__ == "__main__": + main() + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..c108ec1 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,39 @@ +# RLVR Floating-Point Precision Experiment Dependencies +# Core ML frameworks +torch>=2.0.0 +transformers>=4.36.0 +accelerate>=0.25.0 + +# RL framework (VeRL) +# Install from source: pip install git+https://github.com/volcengine/verl.git +# verl + +# Inference +vllm>=0.2.0 + +# Numerical computation +numpy>=1.24.0 +scipy>=1.10.0 + +# Visualization +matplotlib>=3.7.0 + +# Progress tracking +tqdm>=4.65.0 + +# Data handling +datasets>=2.14.0 + +# Utilities +pyyaml>=6.0 +jsonlines>=3.1.0 + +# Distributed training (optional, usually comes with torch) +# torch-distributed + +# Flash attention (optional, for faster inference) +# flash-attn>=2.3.0 + +# Evaluation utilities +# eval-chemy # Math verifier (install from RLVR repo) + diff --git a/run_experiments.py b/run_experiments.py new file mode 100644 index 0000000..0cbcd67 --- /dev/null +++ b/run_experiments.py @@ -0,0 +1,601 @@ +#!/usr/bin/env python3 +# run_experiments.py +""" +Experiment Runner for RLVR Floating-Point Precision Study. + +This script orchestrates the full experimental pipeline: +1. Training models with FP32 and bf16 precision across multiple seeds +2. Evaluating trained models on on-task and off-task benchmarks +3. Computing KL divergence and bf16 sparsity metrics + +Usage: + # Run full experiment + python run_experiments.py --mode full + + # Run training only + python run_experiments.py --mode train --precision_mode bf16 --seed 1 + + # Run evaluation only + python run_experiments.py --mode eval + + # Run analysis only + python run_experiments.py --mode analyze +""" + +import argparse +import json +import os +import subprocess +import sys +import logging +from typing import Dict, Any, List, Optional +from dataclasses import asdict +from concurrent.futures import ProcessPoolExecutor +import time + +from config import ( + ExperimentConfig, + make_training_config, + make_precision_config, + get_run_output_dir, + get_checkpoint_path, +) + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Training Functions +# ============================================================================ + +def run_single_training( + precision_mode: str, + seed: int, + config: ExperimentConfig, + train_dataset_path: str, + dry_run: bool = False +) -> Dict[str, Any]: + """ + Run a single training job. + + Args: + precision_mode: "fp32" or "bf16" + seed: Random seed + config: Experiment configuration + train_dataset_path: Path to training data + dry_run: If True, only print command without running + + Returns: + Dictionary with job status + """ + output_dir = get_run_output_dir(config.train_logs_dir, precision_mode, seed) + + cmd = [ + sys.executable, + "train_rlvr.py", + "--precision_mode", precision_mode, + "--seed", str(seed), + "--output_dir", output_dir, + "--train_dataset_path", train_dataset_path, + "--model_name", config.base_model_path, + ] + + logger.info(f"Running training: {precision_mode} seed={seed}") + logger.info(f"Command: {' '.join(cmd)}") + + if dry_run: + return { + "status": "dry_run", + "precision_mode": precision_mode, + "seed": seed, + "output_dir": output_dir, + "command": " ".join(cmd), + } + + # Create output directory + os.makedirs(output_dir, exist_ok=True) + + # Run training + start_time = time.time() + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=True + ) + duration = time.time() - start_time + + return { + "status": "success", + "precision_mode": precision_mode, + "seed": seed, + "output_dir": output_dir, + "duration_seconds": duration, + "stdout": result.stdout[-1000:], # Last 1000 chars + } + except subprocess.CalledProcessError as e: + return { + "status": "failed", + "precision_mode": precision_mode, + "seed": seed, + "output_dir": output_dir, + "error": str(e), + "stderr": e.stderr[-1000:] if e.stderr else None, + } + + +def run_all_training( + config: ExperimentConfig, + train_dataset_path: str, + dry_run: bool = False, + parallel: bool = False +) -> List[Dict[str, Any]]: + """ + Run training for all precision modes and seeds. + """ + jobs = [] + for precision_mode in config.precision_modes: + for seed in config.seeds: + jobs.append((precision_mode, seed)) + + results = [] + + if parallel and not dry_run: + # Run in parallel (one job per GPU assumed) + with ProcessPoolExecutor(max_workers=len(jobs)) as executor: + futures = [ + executor.submit( + run_single_training, + pm, s, config, train_dataset_path, dry_run + ) + for pm, s in jobs + ] + results = [f.result() for f in futures] + else: + # Run sequentially + for precision_mode, seed in jobs: + result = run_single_training( + precision_mode, seed, config, train_dataset_path, dry_run + ) + results.append(result) + + return results + + +# ============================================================================ +# Evaluation Functions +# ============================================================================ + +def run_single_evaluation( + precision_mode: str, + seed: int, + config: ExperimentConfig, + eval_base: bool = True, + dry_run: bool = False +) -> Dict[str, Any]: + """ + Run evaluation for a single trained model. + """ + run_dir = get_run_output_dir(config.train_logs_dir, precision_mode, seed) + ft_ckpt = get_checkpoint_path(run_dir) + output_path = os.path.join( + config.eval_metrics_dir, + f"{precision_mode}_seed{seed}.json" + ) + + cmd = [ + sys.executable, + "eval_policy.py", + "--base_ckpt", config.base_model_path, + "--ft_ckpt", ft_ckpt, + "--eval_tasks_config", config.eval_tasks_config_path, + "--output_path", output_path, + ] + + if eval_base: + cmd.append("--eval_base") + + logger.info(f"Running evaluation: {precision_mode} seed={seed}") + logger.info(f"Command: {' '.join(cmd)}") + + if dry_run: + return { + "status": "dry_run", + "precision_mode": precision_mode, + "seed": seed, + "output_path": output_path, + "command": " ".join(cmd), + } + + # Create output directory + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + # Check if checkpoint exists + if not os.path.exists(ft_ckpt): + return { + "status": "skipped", + "precision_mode": precision_mode, + "seed": seed, + "reason": f"Checkpoint not found: {ft_ckpt}", + } + + # Run evaluation + start_time = time.time() + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=True + ) + duration = time.time() - start_time + + return { + "status": "success", + "precision_mode": precision_mode, + "seed": seed, + "output_path": output_path, + "duration_seconds": duration, + } + except subprocess.CalledProcessError as e: + return { + "status": "failed", + "precision_mode": precision_mode, + "seed": seed, + "error": str(e), + "stderr": e.stderr[-1000:] if e.stderr else None, + } + + +def run_all_evaluations( + config: ExperimentConfig, + eval_base: bool = True, + dry_run: bool = False +) -> List[Dict[str, Any]]: + """ + Run evaluation for all trained models. + """ + results = [] + + for precision_mode in config.precision_modes: + for seed in config.seeds: + result = run_single_evaluation( + precision_mode, seed, config, eval_base, dry_run + ) + results.append(result) + + return results + + +# ============================================================================ +# bf16 Sparsity Analysis +# ============================================================================ + +def run_sparsity_analysis( + config: ExperimentConfig, + dry_run: bool = False +) -> List[Dict[str, Any]]: + """ + Compute bf16 sparsity for all bf16 runs. + """ + import torch + from transformers import AutoModelForCausalLM + from utils_bf16_sparsity import compute_bf16_sparsity, analyze_update_magnitudes + + results = [] + + # Only analyze bf16 runs + if "bf16" not in config.precision_modes: + logger.info("No bf16 runs to analyze for sparsity") + return results + + if dry_run: + for seed in config.seeds: + results.append({ + "status": "dry_run", + "precision_mode": "bf16", + "seed": seed, + }) + return results + + # Load base model once + logger.info(f"Loading base model: {config.base_model_path}") + base_model = AutoModelForCausalLM.from_pretrained( + config.base_model_path, + torch_dtype=torch.float32, + device_map="cpu" + ) + + for seed in config.seeds: + run_dir = get_run_output_dir(config.train_logs_dir, "bf16", seed) + ft_ckpt = get_checkpoint_path(run_dir) + + if not os.path.exists(ft_ckpt): + results.append({ + "status": "skipped", + "precision_mode": "bf16", + "seed": seed, + "reason": f"Checkpoint not found: {ft_ckpt}", + }) + continue + + logger.info(f"Computing sparsity for bf16 seed={seed}") + + # Load finetuned model + ft_model = AutoModelForCausalLM.from_pretrained( + ft_ckpt, + torch_dtype=torch.float32, + device_map="cpu" + ) + + # Compute sparsity + sparsity_result = compute_bf16_sparsity( + base_model=base_model, + finetuned_model=ft_model, + eta=config.bf16_sparsity_eta, + include_layer_stats=True + ) + + # Analyze update magnitudes + magnitude_result = analyze_update_magnitudes( + base_model=base_model, + finetuned_model=ft_model + ) + + # Save results + output_path = os.path.join( + config.eval_metrics_dir, + f"bf16_seed{seed}_sparsity.json" + ) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + full_result = { + "precision_mode": "bf16", + "seed": seed, + "sparsity": sparsity_result, + "magnitudes": magnitude_result, + } + + with open(output_path, "w") as f: + # Convert layer_stats to serializable format + serializable = { + k: v for k, v in full_result.items() + } + if "layer_stats" in serializable.get("sparsity", {}): + serializable["sparsity"]["layer_stats"] = { + k: {kk: vv for kk, vv in v.items() if kk != "shape"} + for k, v in serializable["sparsity"]["layer_stats"].items() + } + json.dump(serializable, f, indent=2, default=str) + + results.append({ + "status": "success", + "precision_mode": "bf16", + "seed": seed, + "sparsity_percent": sparsity_result["sparsity_percent"], + "output_path": output_path, + }) + + # Free memory + del ft_model + + return results + + +# ============================================================================ +# Main Entry Point +# ============================================================================ + +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Run RLVR floating-point precision experiments" + ) + + parser.add_argument( + "--mode", + type=str, + default="full", + choices=["full", "train", "eval", "analyze", "sparsity"], + help="Execution mode" + ) + + # For single job mode + parser.add_argument( + "--precision_mode", + type=str, + choices=["fp32", "bf16"], + help="Precision mode (for train mode)" + ) + parser.add_argument( + "--seed", + type=int, + help="Random seed (for train mode)" + ) + + # Paths + parser.add_argument( + "--train_dataset_path", + type=str, + default="./data/dm_train.json", + help="Path to training dataset" + ) + parser.add_argument( + "--base_output_dir", + type=str, + default="./results", + help="Base output directory" + ) + parser.add_argument( + "--eval_tasks_config", + type=str, + default="./configs/eval_tasks_config.json", + help="Path to evaluation tasks config" + ) + parser.add_argument( + "--base_model", + type=str, + default="Qwen/Qwen2.5-Math-7B", + help="Base model path or HuggingFace ID" + ) + + # Execution options + parser.add_argument( + "--dry_run", + action="store_true", + help="Print commands without executing" + ) + parser.add_argument( + "--parallel", + action="store_true", + help="Run training jobs in parallel" + ) + parser.add_argument( + "--seeds", + type=int, + nargs="+", + default=[1, 2, 3, 4, 5], + help="Seeds to use" + ) + + return parser.parse_args() + + +def main() -> None: + """Main entry point.""" + args = parse_args() + + # Create experiment configuration + config = ExperimentConfig( + experiment_name="fp_precision_rlvr", + seeds=args.seeds, + precision_modes=["fp32", "bf16"], + base_model_path=args.base_model, + base_output_dir=args.base_output_dir, + train_logs_dir=os.path.join(args.base_output_dir, "train_logs"), + checkpoints_dir=os.path.join(args.base_output_dir, "checkpoints"), + eval_metrics_dir=os.path.join(args.base_output_dir, "eval_metrics"), + eval_tasks_config_path=args.eval_tasks_config, + ) + + # Create directories + os.makedirs(config.train_logs_dir, exist_ok=True) + os.makedirs(config.checkpoints_dir, exist_ok=True) + os.makedirs(config.eval_metrics_dir, exist_ok=True) + + # Save experiment config + config_path = os.path.join(args.base_output_dir, "experiment_config.json") + with open(config_path, "w") as f: + json.dump(asdict(config), f, indent=2) + logger.info(f"Saved experiment config to {config_path}") + + # Execute based on mode + if args.mode == "train": + if args.precision_mode and args.seed: + # Single training job + result = run_single_training( + args.precision_mode, + args.seed, + config, + args.train_dataset_path, + args.dry_run + ) + print(json.dumps(result, indent=2)) + else: + # All training jobs + results = run_all_training( + config, + args.train_dataset_path, + args.dry_run, + args.parallel + ) + print(json.dumps(results, indent=2)) + + elif args.mode == "eval": + results = run_all_evaluations(config, eval_base=True, dry_run=args.dry_run) + print(json.dumps(results, indent=2)) + + elif args.mode == "sparsity": + results = run_sparsity_analysis(config, dry_run=args.dry_run) + print(json.dumps(results, indent=2)) + + elif args.mode == "analyze": + # Run analysis script + analyze_cmd = [ + sys.executable, + "analyze_results.py", + "--results_dir", config.eval_metrics_dir, + "--output_dir", os.path.join(args.base_output_dir, "analysis"), + ] + if args.dry_run: + print(f"Would run: {' '.join(analyze_cmd)}") + else: + subprocess.run(analyze_cmd, check=True) + + elif args.mode == "full": + logger.info("="*60) + logger.info("RUNNING FULL EXPERIMENT PIPELINE") + logger.info("="*60) + + # Step 1: Training + logger.info("\n" + "="*60) + logger.info("STEP 1: Training") + logger.info("="*60) + train_results = run_all_training( + config, + args.train_dataset_path, + args.dry_run, + args.parallel + ) + + # Step 2: Evaluation + logger.info("\n" + "="*60) + logger.info("STEP 2: Evaluation") + logger.info("="*60) + eval_results = run_all_evaluations(config, dry_run=args.dry_run) + + # Step 3: Sparsity Analysis + logger.info("\n" + "="*60) + logger.info("STEP 3: Sparsity Analysis") + logger.info("="*60) + sparsity_results = run_sparsity_analysis(config, dry_run=args.dry_run) + + # Step 4: Results Analysis + logger.info("\n" + "="*60) + logger.info("STEP 4: Results Analysis") + logger.info("="*60) + if not args.dry_run: + analyze_cmd = [ + sys.executable, + "analyze_results.py", + "--results_dir", config.eval_metrics_dir, + "--output_dir", os.path.join(args.base_output_dir, "analysis"), + ] + subprocess.run(analyze_cmd, check=True) + + # Summary + logger.info("\n" + "="*60) + logger.info("EXPERIMENT COMPLETE") + logger.info("="*60) + + all_results = { + "training": train_results, + "evaluation": eval_results, + "sparsity": sparsity_results, + } + + summary_path = os.path.join(args.base_output_dir, "experiment_summary.json") + with open(summary_path, "w") as f: + json.dump(all_results, f, indent=2) + logger.info(f"Saved summary to {summary_path}") + + +if __name__ == "__main__": + main() + diff --git a/scripts/prepare_data.py b/scripts/prepare_data.py new file mode 100755 index 0000000..5ef3c29 --- /dev/null +++ b/scripts/prepare_data.py @@ -0,0 +1,258 @@ +#!/usr/bin/env python3 +""" +Prepare REAL datasets for RLVR floating-point precision experiments. + +Downloads from HuggingFace: +- Training: GSM8K train (7473 samples) +- Evaluation: GSM8K test, MATH-500, AIME, AMC, MMLU-STEM, HumanEval + +Usage: + python scripts/prepare_data.py +""" + +import json +import os +import random +from pathlib import Path +from datasets import load_dataset +from tqdm import tqdm + +DATA_DIR = Path("data") +DATA_DIR.mkdir(exist_ok=True) + + +def save_json(data: list, path: Path): + """Save data as JSON file.""" + with open(path, "w") as f: + json.dump(data, f, indent=2) + print(f" Saved {len(data)} samples to {path}") + + +def prepare_gsm8k_train(): + """Prepare GSM8K training data.""" + print("\n=== Downloading GSM8K Train ===") + ds = load_dataset("openai/gsm8k", "main", split="train") + + data = [] + for i, sample in enumerate(tqdm(ds, desc="Processing")): + # Extract answer from "#### N" format + answer = sample["answer"].split("####")[-1].strip() + data.append({ + "id": f"gsm8k_train_{i}", + "prompt": sample["question"], + "answer": answer, + "solution": sample["answer"], + "source": "gsm8k_train" + }) + + save_json(data, DATA_DIR / "dm_train.json") + return data + + +def prepare_gsm8k_test(): + """Prepare GSM8K test data for evaluation.""" + print("\n=== Downloading GSM8K Test ===") + ds = load_dataset("openai/gsm8k", "main", split="test") + + data = [] + for i, sample in enumerate(tqdm(ds, desc="Processing")): + answer = sample["answer"].split("####")[-1].strip() + data.append({ + "id": f"gsm8k_test_{i}", + "prompt": sample["question"], + "answer": answer, + "solution": sample["answer"], + "source": "gsm8k" + }) + + save_json(data, DATA_DIR / "gsm8k.json") + + # Also create dm_val as a subset (first 500 for on-task eval) + save_json(data[:500], DATA_DIR / "dm_val.json") + return data + + +def prepare_math500(): + """Prepare MATH-500 dataset.""" + print("\n=== Downloading MATH-500 ===") + ds = load_dataset("HuggingFaceH4/MATH-500", split="test") + + data = [] + for i, sample in enumerate(tqdm(ds, desc="Processing")): + data.append({ + "id": f"math500_{i}", + "prompt": sample["problem"], + "answer": sample["answer"], + "solution": sample["solution"], + "subject": sample.get("subject", ""), + "level": sample.get("level", ""), + "source": "math500" + }) + + save_json(data, DATA_DIR / "math500.json") + return data + + +def prepare_aime(): + """Prepare AIME dataset from AI-MO.""" + print("\n=== Downloading AIME ===") + ds = load_dataset("AI-MO/aimo-validation-aime", split="train") + + data = [] + for i, sample in enumerate(tqdm(ds, desc="Processing")): + data.append({ + "id": f"aime_{i}", + "prompt": sample["problem"], + "answer": str(sample["answer"]), + "solution": sample.get("solution", ""), + "url": sample.get("url", ""), + "source": "aime" + }) + + # Split into aime24 and aime25 + # Real AIME has 15 problems per contest, 2 contests per year = 30/year + save_json(data[:30], DATA_DIR / "aime24.json") + save_json(data[30:60], DATA_DIR / "aime25.json") + save_json(data, DATA_DIR / "aime_all.json") + return data + + +def prepare_amc(): + """Prepare AMC dataset from AI-MO.""" + print("\n=== Downloading AMC ===") + ds = load_dataset("AI-MO/aimo-validation-amc", split="train") + + data = [] + for i, sample in enumerate(tqdm(ds, desc="Processing")): + data.append({ + "id": f"amc_{i}", + "prompt": sample["problem"], + "answer": str(sample["answer"]), + "solution": sample.get("solution", ""), + "source": "amc" + }) + + save_json(data, DATA_DIR / "amc23.json") + return data + + +def prepare_mmlu_stem(): + """Prepare MMLU-STEM subset.""" + print("\n=== Downloading MMLU-STEM ===") + + stem_subjects = [ + "abstract_algebra", "astronomy", "college_biology", "college_chemistry", + "college_computer_science", "college_mathematics", "college_physics", + "computer_security", "conceptual_physics", "electrical_engineering", + "elementary_mathematics", "high_school_biology", "high_school_chemistry", + "high_school_computer_science", "high_school_mathematics", "high_school_physics", + "high_school_statistics", "machine_learning" + ] + + data = [] + for subject in tqdm(stem_subjects, desc="Loading subjects"): + try: + ds = load_dataset("cais/mmlu", subject, split="test") + for i, sample in enumerate(ds): + choices = sample["choices"] + correct_idx = sample["answer"] + # Format as multiple choice + prompt = f"{sample['question']}\n" + for j, choice in enumerate(choices): + prompt += f"({chr(65+j)}) {choice}\n" + + data.append({ + "id": f"mmlu_{subject}_{i}", + "prompt": prompt, + "answer": chr(65 + correct_idx), + "subject": subject, + "source": "mmlu_stem" + }) + except Exception as e: + print(f" Warning: Skipping {subject}: {e}") + + # Take a random subset of 500 + random.seed(42) + if len(data) > 500: + data = random.sample(data, 500) + + save_json(data, DATA_DIR / "mmlu_stem.json") + return data + + +def prepare_humaneval(): + """Prepare HumanEval code dataset.""" + print("\n=== Downloading HumanEval ===") + ds = load_dataset("openai/openai_humaneval", split="test") + + data = [] + for i, sample in enumerate(tqdm(ds, desc="Processing")): + data.append({ + "id": f"humaneval_{i}", + "prompt": sample["prompt"], + "answer": sample["canonical_solution"], + "entry_point": sample["entry_point"], + "test": sample["test"], + "source": "humaneval" + }) + + save_json(data, DATA_DIR / "humaneval.json") + return data + + +def verify_data(): + """Verify downloaded data quality.""" + print("\n" + "=" * 60) + print("Verifying Data Quality") + print("=" * 60) + + for f in sorted(DATA_DIR.glob("*.json")): + with open(f) as fp: + data = json.load(fp) + + # Check for unique prompts + prompts = [d["prompt"] for d in data] + unique = len(set(prompts)) + + status = "OK" if unique == len(prompts) else f"WARN: {len(prompts)-unique} duplicates" + print(f" {f.name}: {len(data)} samples, {unique} unique [{status}]") + + # Show first example + if data: + print(f" Example: {data[0]['prompt'][:60]}...") + + +def main(): + print("=" * 60) + print("RLVR Real Data Preparation") + print("=" * 60) + + # Backup old data + backup_dir = DATA_DIR / "backup_synthetic" + if not backup_dir.exists() and any(DATA_DIR.glob("*.json")): + backup_dir.mkdir(exist_ok=True) + for f in DATA_DIR.glob("*.json"): + f.rename(backup_dir / f.name) + print(f"Backed up synthetic data to {backup_dir}") + + # Training data + prepare_gsm8k_train() + + # Evaluation data + prepare_gsm8k_test() + prepare_math500() + prepare_aime() + prepare_amc() + prepare_mmlu_stem() + prepare_humaneval() + + # Verify + verify_data() + + print("\n" + "=" * 60) + print("Data preparation complete!") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/scripts/run_evaluation.sh b/scripts/run_evaluation.sh new file mode 100755 index 0000000..b39c230 --- /dev/null +++ b/scripts/run_evaluation.sh @@ -0,0 +1,58 @@ +#!/bin/bash +# run_evaluation.sh +# Script to run evaluation on trained models + +set -e +set -o pipefail # Properly capture exit codes through pipes + +# Configuration +export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-"0"} + +# HuggingFace cache - use shared HDD storage to avoid quota issues +export HF_HOME="/work/hdd/bfqt/yurenh2/huggingface_cache" +export HF_HUB_CACHE="/work/hdd/bfqt/yurenh2/huggingface_cache/hub" +mkdir -p "$HF_HOME" "$HF_HUB_CACHE" + +# Default values +PRECISION_MODE=${1:-"bf16"} +SEED=${2:-1} +BASE_MODEL=${BASE_MODEL:-"Qwen/Qwen2.5-Math-7B"} +TRAIN_LOGS_DIR=${TRAIN_LOGS_DIR:-"./results/train_logs"} +EVAL_METRICS_DIR=${EVAL_METRICS_DIR:-"./results/eval_metrics"} +EVAL_CONFIG=${EVAL_CONFIG:-"./configs/eval_tasks_config.json"} + +# Paths +FT_CKPT="${TRAIN_LOGS_DIR}/${PRECISION_MODE}_seed${SEED}/final_model" +OUTPUT_PATH="${EVAL_METRICS_DIR}/${PRECISION_MODE}_seed${SEED}.json" + +# Create output directory +mkdir -p "$EVAL_METRICS_DIR" + +echo "==============================================" +echo "Model Evaluation" +echo "==============================================" +echo "Precision Mode: $PRECISION_MODE" +echo "Seed: $SEED" +echo "Base Model: $BASE_MODEL" +echo "Finetuned Model: $FT_CKPT" +echo "Output: $OUTPUT_PATH" +echo "==============================================" + +# Check if checkpoint exists +if [ ! -d "$FT_CKPT" ]; then + echo "Error: Checkpoint not found at $FT_CKPT" + exit 1 +fi + +# Run evaluation +python eval_policy.py \ + --base_ckpt "$BASE_MODEL" \ + --ft_ckpt "$FT_CKPT" \ + --eval_tasks_config "$EVAL_CONFIG" \ + --output_path "$OUTPUT_PATH" \ + --eval_base \ + --use_amp \ + 2>&1 | tee "${EVAL_METRICS_DIR}/${PRECISION_MODE}_seed${SEED}_eval.log" + +echo "Evaluation complete. Results saved to: $OUTPUT_PATH" + diff --git a/scripts/run_full_experiment.sh b/scripts/run_full_experiment.sh new file mode 100755 index 0000000..43e9dd5 --- /dev/null +++ b/scripts/run_full_experiment.sh @@ -0,0 +1,106 @@ +#!/bin/bash +# run_full_experiment.sh +# Master script to run the complete RLVR floating-point precision experiment + +set -e + +# Configuration +SEEDS=(1 2 3 4 5) +PRECISION_MODES=("fp32" "bf16") +TRAIN_DATA=${TRAIN_DATA:-"./data/dm_train.json"} +OUTPUT_BASE=${OUTPUT_BASE:-"./results"} + +echo "==============================================" +echo "RLVR Floating-Point Precision Experiment" +echo "==============================================" +echo "Seeds: ${SEEDS[*]}" +echo "Precision Modes: ${PRECISION_MODES[*]}" +echo "Output: $OUTPUT_BASE" +echo "==============================================" + +# Create directories +mkdir -p "$OUTPUT_BASE/train_logs" +mkdir -p "$OUTPUT_BASE/eval_metrics" +mkdir -p "$OUTPUT_BASE/analysis" + +# Phase 1: Training +echo "" +echo "==============================================" +echo "PHASE 1: TRAINING" +echo "==============================================" + +for precision in "${PRECISION_MODES[@]}"; do + for seed in "${SEEDS[@]}"; do + echo "Training: precision=$precision, seed=$seed" + + OUTPUT_DIR="$OUTPUT_BASE/train_logs/${precision}_seed${seed}" + + # Skip if already completed + if [ -d "$OUTPUT_DIR/final_model" ]; then + echo " -> Skipping (already completed)" + continue + fi + + # Run training + bash scripts/run_training.sh "$precision" "$seed" + done +done + +# Phase 2: Evaluation +echo "" +echo "==============================================" +echo "PHASE 2: EVALUATION" +echo "==============================================" + +for precision in "${PRECISION_MODES[@]}"; do + for seed in "${SEEDS[@]}"; do + echo "Evaluating: precision=$precision, seed=$seed" + + OUTPUT_PATH="$OUTPUT_BASE/eval_metrics/${precision}_seed${seed}.json" + + # Skip if already completed + if [ -f "$OUTPUT_PATH" ]; then + echo " -> Skipping (already completed)" + continue + fi + + # Run evaluation + bash scripts/run_evaluation.sh "$precision" "$seed" + done +done + +# Phase 3: bf16 Sparsity Analysis +echo "" +echo "==============================================" +echo "PHASE 3: BF16 SPARSITY ANALYSIS" +echo "==============================================" + +python run_experiments.py --mode sparsity \ + --base_output_dir "$OUTPUT_BASE" \ + --seeds "${SEEDS[@]}" + +# Phase 4: Results Analysis +echo "" +echo "==============================================" +echo "PHASE 4: RESULTS ANALYSIS" +echo "==============================================" + +python analyze_results.py \ + --results_dir "$OUTPUT_BASE/eval_metrics" \ + --output_dir "$OUTPUT_BASE/analysis" \ + --on_task dm_val \ + --off_task aime24 aime25 amc23 math500 mmlu_stem humaneval + +echo "" +echo "==============================================" +echo "EXPERIMENT COMPLETE" +echo "==============================================" +echo "Results saved to: $OUTPUT_BASE" +echo "" +echo "Key output files:" +echo " - Training logs: $OUTPUT_BASE/train_logs/" +echo " - Evaluation metrics: $OUTPUT_BASE/eval_metrics/" +echo " - Analysis: $OUTPUT_BASE/analysis/full_analysis.json" +echo " - Plots: $OUTPUT_BASE/analysis/*.png" +echo "==============================================" + diff --git a/scripts/run_training.sh b/scripts/run_training.sh new file mode 100755 index 0000000..38b2fc8 --- /dev/null +++ b/scripts/run_training.sh @@ -0,0 +1,50 @@ +#!/bin/bash +# run_training.sh +# Script to run RLVR training experiments with different precision modes + +set -e +set -o pipefail # Properly capture exit codes through pipes + +# Configuration +export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-"0,1"} + +# HuggingFace cache - use shared HDD storage to avoid quota issues +export HF_HOME="/work/hdd/bfqt/yurenh2/huggingface_cache" +export HF_HUB_CACHE="/work/hdd/bfqt/yurenh2/huggingface_cache/hub" +mkdir -p "$HF_HOME" "$HF_HUB_CACHE" + +# Default values +PRECISION_MODE=${1:-"bf16"} +SEED=${2:-1} +TRAIN_DATA=${TRAIN_DATA:-"./data/dm_train.json"} +OUTPUT_BASE=${OUTPUT_BASE:-"./results/train_logs"} +MODEL_NAME=${MODEL_NAME:-"Qwen/Qwen2.5-Math-7B"} +NUM_STEPS=${NUM_STEPS:-300} + +# Create output directory +OUTPUT_DIR="${OUTPUT_BASE}/${PRECISION_MODE}_seed${SEED}" +mkdir -p "$OUTPUT_DIR" + +echo "==============================================" +echo "RLVR Training" +echo "==============================================" +echo "Precision Mode: $PRECISION_MODE" +echo "Seed: $SEED" +echo "Model: $MODEL_NAME" +echo "Training Data: $TRAIN_DATA" +echo "Output: $OUTPUT_DIR" +echo "Num Steps: $NUM_STEPS" +echo "==============================================" + +# Run training +python train_rlvr.py \ + --precision_mode "$PRECISION_MODE" \ + --seed "$SEED" \ + --output_dir "$OUTPUT_DIR" \ + --train_dataset_path "$TRAIN_DATA" \ + --model_name "$MODEL_NAME" \ + --num_steps "$NUM_STEPS" \ + 2>&1 | tee "${OUTPUT_DIR}/training.log" + +echo "Training complete. Output saved to: $OUTPUT_DIR" + diff --git a/scripts/setup_env.sh b/scripts/setup_env.sh new file mode 100755 index 0000000..743f99a --- /dev/null +++ b/scripts/setup_env.sh @@ -0,0 +1,79 @@ +#!/bin/bash +# setup_env.sh +# One-time setup script for the RLVR floating-point precision experiment +# Run this BEFORE submitting any jobs + +set -e + +CONDA_ENV="rlvr-fp" +PROJECT_DIR="/projects/bfqt/users/yurenh2/ml-projects/rl-floating-noise" + +echo "============================================" +echo "RLVR Environment Setup" +echo "============================================" + +# Setup HuggingFace cache directories +echo "" +echo "Setting up HuggingFace cache..." +HF_CACHE_DIR="/work/hdd/bfqt/yurenh2/huggingface_cache" +mkdir -p "$HF_CACHE_DIR/hub" "$HF_CACHE_DIR/transformers" +echo " Cache directory: $HF_CACHE_DIR" + +# Add to shell profile if not already present +PROFILE_FILE="$HOME/.bashrc" +if ! grep -q "HF_HOME.*huggingface_cache" "$PROFILE_FILE" 2>/dev/null; then + echo "" + echo "Adding HuggingFace cache settings to $PROFILE_FILE..." + cat >> "$PROFILE_FILE" << 'EOF' + +# HuggingFace cache - shared across all projects (added by RLVR setup) +export HF_HOME="/work/hdd/bfqt/yurenh2/huggingface_cache" +export HF_HUB_CACHE="/work/hdd/bfqt/yurenh2/huggingface_cache/hub" +export TRANSFORMERS_CACHE="/work/hdd/bfqt/yurenh2/huggingface_cache/transformers" +EOF + echo " Added to $PROFILE_FILE" +else + echo " HuggingFace settings already in $PROFILE_FILE" +fi + +# Source to apply changes +source "$PROFILE_FILE" + +# Check if conda environment exists +echo "" +echo "Checking conda environment..." +source ~/.bashrc + +if conda env list | grep -q "^${CONDA_ENV} "; then + echo " Environment '$CONDA_ENV' already exists" + echo " To recreate, run: conda env remove -n $CONDA_ENV && $0" +else + echo " Creating conda environment: $CONDA_ENV" + conda create -n "$CONDA_ENV" python=3.10 -y + + echo "" + echo "Installing dependencies..." + conda activate "$CONDA_ENV" + cd "$PROJECT_DIR" + pip install -r requirements.txt + + echo "" + echo "Verifying installation..." + python -c "import torch; print(f'PyTorch: {torch.__version__}')" + python -c "import transformers; print(f'Transformers: {transformers.__version__}')" +fi + +echo "" +echo "============================================" +echo "Setup complete!" +echo "============================================" +echo "" +echo "To activate the environment:" +echo " conda activate $CONDA_ENV" +echo "" +echo "To run experiments:" +echo " ./scripts/submit_all_jobs.sh" +echo "" +echo "HuggingFace cache location: $HF_CACHE_DIR" +echo " (1TB quota, shared across all projects)" +echo "============================================" diff --git a/scripts/slurm_train.sh b/scripts/slurm_train.sh new file mode 100755 index 0000000..36bd5b1 --- /dev/null +++ b/scripts/slurm_train.sh @@ -0,0 +1,145 @@ +#!/bin/bash +#SBATCH --job-name=rlvr_fp_exp +#SBATCH --account=bfqt-delta-gpu +#SBATCH --partition=gpuH200x8 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=16 +#SBATCH --gres=gpu:h200:4 +#SBATCH --mem=200G +#SBATCH --time=2-00:00:00 +#SBATCH --output=results/slurm_logs/%x_%j.out +#SBATCH --error=results/slurm_logs/%x_%j.err +#SBATCH --mail-type=BEGIN,END,FAIL +#SBATCH --mail-user=$USER@example.com + +# Exit on error and propagate exit codes through pipes +set -o pipefail + +# ============================================ +# RLVR Floating-Point Precision Experiment +# H200x8 SLURM Job Script +# ============================================ + +# Configuration - modify these as needed +PRECISION_MODE=${PRECISION_MODE:-"bf16"} +SEED=${SEED:-1} +NUM_STEPS=${NUM_STEPS:-150} # ~45 hours on H200 with sequential generation + +# Paths +PROJECT_DIR="/projects/bfqt/users/yurenh2/ml-projects/rl-floating-noise" +CONDA_ENV="rlvr-fp" # Change to your conda env name +MODEL_NAME="Qwen/Qwen2.5-Math-7B" +TRAIN_DATA="${PROJECT_DIR}/data/dm_train.json" + +# ============================================ +# HuggingFace cache configuration +# Use shared HDD storage to avoid home directory quota issues +# This cache is shared across all projects +# ============================================ +export HF_HOME="/work/hdd/bfqt/yurenh2/huggingface_cache" +export HF_HUB_CACHE="/work/hdd/bfqt/yurenh2/huggingface_cache/hub" +mkdir -p "$HF_HOME" "$HF_HUB_CACHE" + +# Print job info +echo "============================================" +echo "SLURM Job ID: $SLURM_JOB_ID" +echo "Running on: $(hostname)" +echo "Start time: $(date)" +echo "============================================" +echo "Precision Mode: $PRECISION_MODE" +echo "Seed: $SEED" +echo "Num Steps: $NUM_STEPS" +echo "GPUs: $CUDA_VISIBLE_DEVICES" +echo "HF Cache: $HF_HOME" +echo "============================================" + +# Setup environment +cd "$PROJECT_DIR" +mkdir -p results/slurm_logs + +# Activate conda environment +source ~/.bashrc + +# Check if conda environment exists +if conda env list | grep -q "^${CONDA_ENV} "; then + echo "Activating existing conda environment: $CONDA_ENV" + conda activate "$CONDA_ENV" +else + echo "ERROR: Conda environment '$CONDA_ENV' does not exist!" + echo "Please create it first by running:" + echo " conda create -n $CONDA_ENV python=3.10 -y" + echo " conda activate $CONDA_ENV" + echo " pip install -r requirements.txt" + exit 1 +fi + +# Verify activation succeeded +if [[ "$CONDA_DEFAULT_ENV" != "$CONDA_ENV" ]]; then + echo "ERROR: Failed to activate conda environment '$CONDA_ENV'" + exit 1 +fi +echo "Conda environment activated: $CONDA_DEFAULT_ENV" + +# Check GPU availability +nvidia-smi +echo "CUDA devices: $(python -c 'import torch; print(torch.cuda.device_count())')" + +# Output directory (use /work for large checkpoints - /projects is limited) +OUTPUT_DIR="/work/hdd/bfqt/yurenh2/rlvr_results/${PRECISION_MODE}_seed${SEED}" +mkdir -p "$OUTPUT_DIR" + +# DeepSpeed config (ZeRO-3 for full sharding of model/optimizer/gradients) +DEEPSPEED_CONFIG="${PROJECT_DIR}/configs/deepspeed_zero3.json" + +# Number of GPUs for DeepSpeed (all GPUs, ref model on same GPU as training per rank) +NUM_GPUS=$(python -c 'import torch; print(torch.cuda.device_count())') +echo "Using $NUM_GPUS GPUs for DeepSpeed training (ref model on each rank's GPU)" + +# Use random port to avoid conflicts with other jobs +MASTER_PORT=$((29500 + RANDOM % 1000)) +echo "Using master port: $MASTER_PORT" + +# Run training with DeepSpeed +echo "" +echo "Starting training with DeepSpeed ZeRO-3..." +echo "============================================" + +deepspeed --num_gpus=$NUM_GPUS --master_port=$MASTER_PORT train_rlvr.py \ + --precision_mode "$PRECISION_MODE" \ + --seed "$SEED" \ + --output_dir "$OUTPUT_DIR" \ + --train_dataset_path "$TRAIN_DATA" \ + --model_name "$MODEL_NAME" \ + --num_steps "$NUM_STEPS" \ + --deepspeed "$DEEPSPEED_CONFIG" \ + 2>&1 | tee "${OUTPUT_DIR}/training_slurm.log" + +# IMPORTANT: With set -o pipefail, $? now captures python's exit code, not tee's +TRAIN_EXIT_CODE=$? + +echo "" +echo "============================================" +echo "Training completed with exit code: $TRAIN_EXIT_CODE" +echo "End time: $(date)" +echo "============================================" + +# If training succeeded, run evaluation +if [ $TRAIN_EXIT_CODE -eq 0 ]; then + echo "" + echo "Starting evaluation..." + echo "============================================" + + python eval_policy.py \ + --base_ckpt "$MODEL_NAME" \ + --ft_ckpt "${OUTPUT_DIR}/final_model" \ + --eval_tasks_config configs/eval_tasks_config.json \ + --output_path "results/eval_metrics/${PRECISION_MODE}_seed${SEED}.json" \ + --eval_base \ + --use_amp \ + 2>&1 | tee "${OUTPUT_DIR}/eval_slurm.log" +fi + +echo "" +echo "Job completed at: $(date)" + diff --git a/scripts/submit_all_jobs.sh b/scripts/submit_all_jobs.sh new file mode 100755 index 0000000..86c0f5d --- /dev/null +++ b/scripts/submit_all_jobs.sh @@ -0,0 +1,66 @@ +#!/bin/bash +# submit_all_jobs.sh +# Submit all experiment jobs to SLURM queue +# Jobs will run automatically when resources become available + +set -e + +PROJECT_DIR="/projects/bfqt/users/yurenh2/ml-projects/rl-floating-noise" +cd "$PROJECT_DIR" + +# Create log directory +mkdir -p results/slurm_logs + +# Configuration +SEEDS=(1 2 3 4 5) +PRECISION_MODES=("fp32" "bf16") + +echo "============================================" +echo "Submitting RLVR Experiment Jobs" +echo "============================================" +echo "Seeds: ${SEEDS[*]}" +echo "Precision Modes: ${PRECISION_MODES[*]}" +echo "Total jobs: $((${#SEEDS[@]} * ${#PRECISION_MODES[@]}))" +echo "============================================" + +# Track submitted job IDs +declare -a JOB_IDS + +for precision in "${PRECISION_MODES[@]}"; do + for seed in "${SEEDS[@]}"; do + JOB_NAME="rlvr_${precision}_s${seed}" + + echo "Submitting: $JOB_NAME" + + # Submit job with environment variables + JOB_ID=$(sbatch \ + --job-name="$JOB_NAME" \ + --export=ALL,PRECISION_MODE="$precision",SEED="$seed" \ + scripts/slurm_train.sh | awk '{print $4}') + + JOB_IDS+=("$JOB_ID") + echo " -> Job ID: $JOB_ID" + done +done + +echo "" +echo "============================================" +echo "All jobs submitted!" +echo "Job IDs: ${JOB_IDS[*]}" +echo "============================================" +echo "" +echo "Monitor with:" +echo " squeue -u $USER" +echo " squeue -j $(IFS=,; echo "${JOB_IDS[*]}")" +echo "" +echo "View logs:" +echo " tail -f results/slurm_logs/rlvr_*.out" +echo "" +echo "Cancel all:" +echo " scancel ${JOB_IDS[*]}" +echo "============================================" + +# Save job IDs for reference +echo "${JOB_IDS[*]}" > results/slurm_logs/submitted_jobs.txt +echo "Job IDs saved to: results/slurm_logs/submitted_jobs.txt" + diff --git a/scripts/submit_single_job.sh b/scripts/submit_single_job.sh new file mode 100755 index 0000000..7fe7492 --- /dev/null +++ b/scripts/submit_single_job.sh @@ -0,0 +1,32 @@ +#!/bin/bash +# submit_single_job.sh +# Submit a single training job +# Usage: ./submit_single_job.sh +# Example: ./submit_single_job.sh bf16 1 + +PRECISION_MODE=${1:-"bf16"} +SEED=${2:-1} + +PROJECT_DIR="/projects/bfqt/users/yurenh2/ml-projects/rl-floating-noise" +cd "$PROJECT_DIR" + +mkdir -p results/slurm_logs + +JOB_NAME="rlvr_${PRECISION_MODE}_s${SEED}" + +echo "Submitting job: $JOB_NAME" +echo " Precision: $PRECISION_MODE" +echo " Seed: $SEED" + +JOB_ID=$(sbatch \ + --job-name="$JOB_NAME" \ + --export=ALL,PRECISION_MODE="$PRECISION_MODE",SEED="$SEED" \ + scripts/slurm_train.sh | awk '{print $4}') + +echo "" +echo "Submitted! Job ID: $JOB_ID" +echo "" +echo "Monitor with: squeue -j $JOB_ID" +echo "View output: tail -f results/slurm_logs/${JOB_NAME}_${JOB_ID}.out" +echo "Cancel: scancel $JOB_ID" + diff --git a/scripts/test_quick.sh b/scripts/test_quick.sh new file mode 100644 index 0000000..f66e73b --- /dev/null +++ b/scripts/test_quick.sh @@ -0,0 +1,78 @@ +#!/bin/bash +#SBATCH --job-name=rlvr_test +#SBATCH --account=bfqt-delta-gpu +#SBATCH --partition=gpuH200x8 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=16 +#SBATCH --gres=gpu:h200:4 +#SBATCH --mem=200G +#SBATCH --time=04:00:00 +#SBATCH --output=results/slurm_logs/test_%j.out +#SBATCH --error=results/slurm_logs/test_%j.err + +set -o pipefail + +PROJECT_DIR="/projects/bfqt/users/yurenh2/ml-projects/rl-floating-noise" +cd "$PROJECT_DIR" + +source ~/.bashrc +conda activate rlvr-fp + +export HF_HOME="/work/hdd/bfqt/yurenh2/huggingface_cache" +export HF_HUB_CACHE="/work/hdd/bfqt/yurenh2/huggingface_cache/hub" + +echo "============================================" +echo "Quick test on $(hostname)" +echo "SLURM Job ID: $SLURM_JOB_ID" +nvidia-smi +echo "============================================" + +NUM_GPUS=$(python -c 'import torch; print(torch.cuda.device_count())') +echo "Using $NUM_GPUS GPUs for DeepSpeed" + +# Use random port to avoid conflicts +MASTER_PORT=$((29500 + RANDOM % 1000)) +echo "Using master port: $MASTER_PORT" + +# Test fp32 with just 3 steps +echo "Testing fp32..." +deepspeed --num_gpus=$NUM_GPUS --master_port=$MASTER_PORT train_rlvr.py \ + --precision_mode fp32 \ + --seed 1 \ + --output_dir /work/hdd/bfqt/yurenh2/rlvr_results/test_fp32 \ + --train_dataset_path data/dm_train.json \ + --model_name Qwen/Qwen2.5-Math-7B \ + --num_steps 3 \ + --deepspeed configs/deepspeed_zero3.json + +FP32_EXIT=$? +echo "fp32 test exit code: $FP32_EXIT" + +if [ $FP32_EXIT -eq 0 ]; then + echo "fp32 test PASSED" + + # Also test bf16 + echo "Testing bf16..." + deepspeed --num_gpus=$NUM_GPUS --master_port=$MASTER_PORT train_rlvr.py \ + --precision_mode bf16 \ + --seed 1 \ + --output_dir /work/hdd/bfqt/yurenh2/rlvr_results/test_bf16 \ + --train_dataset_path data/dm_train.json \ + --model_name Qwen/Qwen2.5-Math-7B \ + --num_steps 3 \ + --deepspeed configs/deepspeed_zero3.json + + BF16_EXIT=$? + echo "bf16 test exit code: $BF16_EXIT" + + if [ $BF16_EXIT -eq 0 ]; then + echo "============================================" + echo "ALL TESTS PASSED!" + echo "============================================" + else + echo "bf16 test FAILED" + fi +else + echo "fp32 test FAILED" +fi diff --git a/train_rlvr.py b/train_rlvr.py new file mode 100644 index 0000000..1076df8 --- /dev/null +++ b/train_rlvr.py @@ -0,0 +1,849 @@ +#!/usr/bin/env python3 +# train_rlvr.py +""" +RLVR Training Script with DAPO Algorithm. + +This script implements the training loop for reinforcement learning with +verifiable rewards (RLVR) using the DAPO algorithm. It supports two precision +configurations: + +- P-high (fp32): High precision master weights for low numerical noise +- P-bf16: Default RLVR configuration with bf16 master weights + +The script integrates with VeRL framework for distributed RL training. + +Usage: + python train_rlvr.py \ + --precision_mode fp32 \ + --seed 1 \ + --output_dir results/train_logs/fp32_seed1 \ + --train_dataset_path data/dm_train.json +""" + +import argparse +import json +import os +import random +import logging +from typing import Dict, Any, List, Optional, Tuple +from dataclasses import asdict + +import numpy as np +import torch +import torch.distributed as dist +from torch.cuda.amp import autocast, GradScaler +from transformers import AutoModelForCausalLM, AutoTokenizer +import deepspeed + +from config import ( + make_training_config, + make_precision_config, + TrainingConfig, + PrecisionConfig, +) + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Seed and Determinism Utilities +# ============================================================================ + +def set_seed(seed: int) -> None: + """Set random seeds for reproducibility.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # For hash-based operations + os.environ["PYTHONHASHSEED"] = str(seed) + logger.info(f"Set random seed to {seed}") + + +def configure_torch_deterministic(deterministic: bool) -> None: + """Configure PyTorch deterministic algorithms.""" + if deterministic: + torch.use_deterministic_algorithms(True, warn_only=True) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + logger.info("Enabled deterministic algorithms") + else: + torch.use_deterministic_algorithms(False) + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + logger.info("Using non-deterministic algorithms (default)") + + +# ============================================================================ +# Model Utilities +# ============================================================================ + +def get_torch_dtype(dtype_str: str) -> torch.dtype: + """Convert string dtype to torch.dtype.""" + dtype_map = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + } + if dtype_str not in dtype_map: + raise ValueError(f"Unknown dtype: {dtype_str}") + return dtype_map[dtype_str] + + +def cast_model_param_dtype( + model: torch.nn.Module, + param_dtype: str +) -> torch.nn.Module: + """Cast model parameters to specified dtype.""" + dtype = get_torch_dtype(param_dtype) + model.to(dtype=dtype) + logger.info(f"Cast model parameters to {param_dtype}") + return model + + +def disable_dropout(model: torch.nn.Module) -> None: + """Disable all dropout layers in the model.""" + for module in model.modules(): + if isinstance(module, torch.nn.Dropout): + module.p = 0.0 + logger.info("Disabled all dropout layers") + + +def count_parameters(model: torch.nn.Module) -> int: + """Count trainable parameters.""" + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +# ============================================================================ +# Data Loading +# ============================================================================ + +def load_training_data(dataset_path: str) -> List[Dict[str, Any]]: + """Load training dataset from JSON file.""" + with open(dataset_path, "r", encoding="utf-8") as f: + data = json.load(f) + logger.info(f"Loaded {len(data)} training examples from {dataset_path}") + return data + + +def sample_batch( + data: List[Dict[str, Any]], + batch_size: int, + rng: np.random.Generator +) -> List[Dict[str, Any]]: + """Sample a batch of prompts from the dataset.""" + indices = rng.choice(len(data), size=min(batch_size, len(data)), replace=False) + return [data[i] for i in indices] + + +# ============================================================================ +# Reward Function (Math Verifier) +# ============================================================================ + +def compute_math_reward( + prompt: str, + response: str, + ground_truth: Optional[str] = None +) -> float: + """ + Compute reward for a math problem response. + + Uses a simple rule-based verifier. In production, this should be replaced + with Eval-Chemy or similar math verification system. + + Args: + prompt: The math problem prompt + response: The model's generated response + ground_truth: The expected answer (if available) + + Returns: + +1.0 if correct, -1.0 if incorrect + """ + # TODO: Replace with actual math verifier (Eval-Chemy) + # This is a placeholder implementation + + if ground_truth is None: + # Cannot verify without ground truth + return 0.0 + + # Extract final answer from response (simple heuristic) + response_lower = response.lower().strip() + gt_lower = ground_truth.lower().strip() + + # Check for common answer formats + answer_markers = ["the answer is", "therefore", "=", "\\boxed{"] + + for marker in answer_markers: + if marker in response_lower: + idx = response_lower.rfind(marker) + potential_answer = response_lower[idx:].strip() + if gt_lower in potential_answer: + return 1.0 + + # Direct containment check as fallback + if gt_lower in response_lower: + return 1.0 + + return -1.0 + + +# ============================================================================ +# DAPO Algorithm Implementation +# ============================================================================ + +class DAPOTrainer: + """ + DAPO (Direct Alignment from Preferences Optimization) Trainer. + + Implements clip-only DAPO with implicit KL constraint through ratio clipping. + This is a simplified implementation - for production use VeRL's DapoTrainer. + """ + + def __init__( + self, + model_engine, # DeepSpeed engine or raw model + ref_model: torch.nn.Module, + tokenizer, + train_config: TrainingConfig, + precision_config: PrecisionConfig, + device: torch.device, + ref_device: Optional[torch.device] = None, + use_deepspeed: bool = False + ) -> None: + self.use_deepspeed = use_deepspeed + if use_deepspeed: + self.model_engine = model_engine + self.model = model_engine.module + else: + self.model = model_engine + self.model_engine = None + self.ref_model = ref_model + self.tokenizer = tokenizer + self.train_config = train_config + self.precision_config = precision_config + self.device = device + self.ref_device = ref_device if ref_device is not None else device + + # Freeze reference model + for param in self.ref_model.parameters(): + param.requires_grad = False + self.ref_model.eval() + + # Setup optimizer (only if not using DeepSpeed) + if not use_deepspeed: + self.optimizer = torch.optim.AdamW( + self.model.parameters(), + lr=train_config.learning_rate, + betas=(train_config.beta1, train_config.beta2), + weight_decay=train_config.weight_decay + ) + else: + self.optimizer = None # DeepSpeed manages optimizer + + # Setup AMP scaler if using fp16 (not needed with DeepSpeed) + self.scaler = None + if not use_deepspeed and precision_config.use_amp and precision_config.amp_dtype == "float16": + self.scaler = GradScaler() + + # Training state + self.global_step = 0 + self.rng = np.random.default_rng(train_config.seed) + + # Metrics tracking + self.metrics_history: List[Dict[str, float]] = [] + + def compute_log_probs( + self, + model: torch.nn.Module, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + labels: torch.Tensor + ) -> torch.Tensor: + """Compute token-level log probabilities.""" + with autocast( + enabled=self.precision_config.use_amp, + dtype=get_torch_dtype(self.precision_config.amp_dtype) + ): + outputs = model( + input_ids=input_ids, + attention_mask=attention_mask, + use_cache=False + ) + + logits = outputs.logits + log_probs = torch.log_softmax(logits, dim=-1) + + # Gather log probs for actual tokens + # Shift for autoregressive: predict next token + shift_log_probs = log_probs[:, :-1, :] + shift_labels = labels[:, 1:] + + token_log_probs = torch.gather( + shift_log_probs, + dim=-1, + index=shift_labels.unsqueeze(-1) + ).squeeze(-1) + + return token_log_probs + + def compute_dapo_loss( + self, + policy_log_probs: torch.Tensor, + ref_log_probs: torch.Tensor, + rewards: torch.Tensor, + response_mask: torch.Tensor + ) -> Tuple[torch.Tensor, Dict[str, float]]: + """ + Compute DAPO clip-only loss. + + DAPO uses ratio clipping without explicit KL penalty (beta=0). + The clipping provides implicit KL constraint (Gate I). + """ + # Compute log ratios + log_ratios = policy_log_probs - ref_log_probs + + # Sum log ratios over response tokens + masked_log_ratios = log_ratios * response_mask + sequence_log_ratios = masked_log_ratios.sum(dim=-1) + + # Compute importance sampling ratios + ratios = torch.exp(sequence_log_ratios) + + # DAPO objective with clipping + clip_ratio = self.train_config.clip_ratio + + # Advantage estimation (simplified: just use rewards) + advantages = rewards + + # Clipped surrogate objective + unclipped = ratios * advantages + clipped = torch.clamp(ratios, 1 - clip_ratio, 1 + clip_ratio) * advantages + + # Take minimum for pessimistic update + loss = -torch.min(unclipped, clipped).mean() + + # Compute metrics + with torch.no_grad(): + approx_kl = (log_ratios * response_mask).sum() / response_mask.sum() + clip_fraction = ((ratios - 1).abs() > clip_ratio).float().mean() + + metrics = { + "loss": loss.item(), + "approx_kl": approx_kl.item(), + "clip_fraction": clip_fraction.item(), + "mean_ratio": ratios.mean().item(), + "mean_reward": rewards.mean().item(), + } + + return loss, metrics + + def generate_rollouts( + self, + prompts: List[str], + num_samples: int + ) -> List[Dict[str, Any]]: + """Generate rollouts for a batch of prompts.""" + rollouts = [] + + self.model.eval() + with torch.no_grad(): + for prompt in prompts: + inputs = self.tokenizer( + prompt, + return_tensors="pt", + truncation=True, + max_length=self.train_config.max_seq_len // 2 + ).to(self.device) + + for _ in range(num_samples): + with autocast( + enabled=self.precision_config.use_amp, + dtype=get_torch_dtype(self.precision_config.amp_dtype) + ): + outputs = self.model.generate( + **inputs, + max_new_tokens=self.train_config.max_seq_len // 2, + do_sample=True, + temperature=0.7, + top_p=0.8, + pad_token_id=self.tokenizer.eos_token_id + ) + + response_ids = outputs[0, inputs["input_ids"].shape[1]:] + response_text = self.tokenizer.decode( + response_ids, + skip_special_tokens=True + ) + + rollouts.append({ + "prompt": prompt, + "response": response_text, + "input_ids": inputs["input_ids"][0], + "response_ids": response_ids, + "full_ids": outputs[0] + }) + + self.model.train() + return rollouts + + def train_step( + self, + batch: List[Dict[str, Any]] + ) -> Dict[str, float]: + """Execute one training step on a batch.""" + self.model.train() + + # Generate rollouts + prompts = [ex["prompt"] for ex in batch] + ground_truths = [ex.get("answer", None) for ex in batch] + + rollouts = self.generate_rollouts( + prompts, + self.train_config.num_rollouts_per_prompt + ) + + # Compute rewards + rewards = [] + for i, rollout in enumerate(rollouts): + prompt_idx = i // self.train_config.num_rollouts_per_prompt + gt = ground_truths[prompt_idx] if prompt_idx < len(ground_truths) else None + reward = compute_math_reward(rollout["prompt"], rollout["response"], gt) + rewards.append(reward) + + rewards_tensor = torch.tensor(rewards, device=self.device, dtype=torch.float32) + + # Skip if all rewards are the same (no learning signal) + if rewards_tensor.std() < 1e-6: + return {"skipped": 1.0} + + # Normalize rewards per prompt (advantage estimation) + rewards_per_prompt = rewards_tensor.view(-1, self.train_config.num_rollouts_per_prompt) + normalized_rewards = (rewards_per_prompt - rewards_per_prompt.mean(dim=1, keepdim=True)) + normalized_rewards = normalized_rewards / (rewards_per_prompt.std(dim=1, keepdim=True) + 1e-8) + normalized_rewards = normalized_rewards.view(-1) + + # Prepare for training (DeepSpeed handles zero_grad internally) + if not self.use_deepspeed: + self.optimizer.zero_grad() + + total_loss = 0.0 + all_metrics: Dict[str, List[float]] = {} + + # Process rollouts in micro-batches + num_rollouts = len(rollouts) + micro_batch_size = self.train_config.micro_batch_size + + for mb_start in range(0, num_rollouts, micro_batch_size): + mb_end = min(mb_start + micro_batch_size, num_rollouts) + mb_rollouts = rollouts[mb_start:mb_end] + mb_rewards = normalized_rewards[mb_start:mb_end] + + # Prepare batch tensors + max_len = max(len(r["full_ids"]) for r in mb_rollouts) + batch_input_ids = torch.zeros(len(mb_rollouts), max_len, dtype=torch.long, device=self.device) + batch_attention_mask = torch.zeros(len(mb_rollouts), max_len, dtype=torch.long, device=self.device) + batch_response_mask = torch.zeros(len(mb_rollouts), max_len - 1, dtype=torch.float32, device=self.device) + + for i, rollout in enumerate(mb_rollouts): + seq_len = len(rollout["full_ids"]) + batch_input_ids[i, :seq_len] = rollout["full_ids"] + batch_attention_mask[i, :seq_len] = 1 + prompt_len = len(rollout["input_ids"]) + batch_response_mask[i, prompt_len-1:seq_len-1] = 1 + + # Compute log probs for policy and reference + policy_log_probs = self.compute_log_probs( + self.model, + batch_input_ids, + batch_attention_mask, + batch_input_ids + ) + + with torch.no_grad(): + # Move tensors to ref_device if different from training device + if self.ref_device != self.device: + ref_input_ids = batch_input_ids.to(self.ref_device) + ref_attention_mask = batch_attention_mask.to(self.ref_device) + else: + ref_input_ids = batch_input_ids + ref_attention_mask = batch_attention_mask + + ref_log_probs = self.compute_log_probs( + self.ref_model, + ref_input_ids, + ref_attention_mask, + ref_input_ids + ) + + # Move ref_log_probs back to training device + if self.ref_device != self.device: + ref_log_probs = ref_log_probs.to(self.device) + + # Compute DAPO loss + loss, metrics = self.compute_dapo_loss( + policy_log_probs, + ref_log_probs, + mb_rewards, + batch_response_mask + ) + + # Scale loss for gradient accumulation (DeepSpeed handles this internally) + if self.use_deepspeed: + scaled_loss = loss + else: + scaled_loss = loss / self.train_config.grad_accumulation_steps + + # Backward pass + if self.use_deepspeed: + self.model_engine.backward(scaled_loss) + elif self.scaler is not None: + self.scaler.scale(scaled_loss).backward() + else: + scaled_loss.backward() + + total_loss += loss.item() + + for k, v in metrics.items(): + if k not in all_metrics: + all_metrics[k] = [] + all_metrics[k].append(v) + + # Optimizer step + if self.use_deepspeed: + self.model_engine.step() + elif self.scaler is not None: + self.scaler.step(self.optimizer) + self.scaler.update() + else: + self.optimizer.step() + + self.global_step += 1 + + # Aggregate metrics + step_metrics = {k: np.mean(v) for k, v in all_metrics.items()} + step_metrics["total_loss"] = total_loss + step_metrics["step"] = self.global_step + + self.metrics_history.append(step_metrics) + + return step_metrics + + def train( + self, + train_data: List[Dict[str, Any]], + save_checkpoints: bool = True + ) -> None: + """Run the full training loop.""" + logger.info(f"Starting training for {self.train_config.num_steps} steps") + + checkpoint_steps = set(self.train_config.checkpoint_steps) + + for step in range(self.train_config.num_steps): + # Sample batch + batch = sample_batch( + train_data, + self.train_config.global_batch_size // self.train_config.num_rollouts_per_prompt, + self.rng + ) + + # Training step + metrics = self.train_step(batch) + + # Logging + if step % 10 == 0: + logger.info( + f"Step {step}/{self.train_config.num_steps} | " + f"Loss: {metrics.get('total_loss', 0):.4f} | " + f"KL: {metrics.get('approx_kl', 0):.4f} | " + f"Reward: {metrics.get('mean_reward', 0):.4f}" + ) + + # Checkpointing + if save_checkpoints and (step + 1) in checkpoint_steps: + self.save_checkpoint(step + 1) + + def save_checkpoint(self, step: int) -> None: + """Save model checkpoint.""" + ckpt_dir = os.path.join( + self.train_config.output_dir, + f"checkpoint_step{step}" + ) + os.makedirs(ckpt_dir, exist_ok=True) + + if self.use_deepspeed: + # Use DeepSpeed's save_checkpoint for proper ZeRO handling + self.model_engine.save_checkpoint(ckpt_dir, tag=f"step{step}") + else: + self.model.save_pretrained(ckpt_dir) + # Save training state + state = { + "step": step, + "optimizer_state_dict": self.optimizer.state_dict(), + "metrics_history": self.metrics_history, + } + torch.save(state, os.path.join(ckpt_dir, "training_state.pt")) + + self.tokenizer.save_pretrained(ckpt_dir) + logger.info(f"Saved checkpoint at step {step} to {ckpt_dir}") + + def save_final(self) -> str: + """Save final model.""" + final_dir = os.path.join(self.train_config.output_dir, "final_model") + os.makedirs(final_dir, exist_ok=True) + + if self.use_deepspeed: + # Use DeepSpeed's save_checkpoint for proper ZeRO-3 weight gathering + self.model_engine.save_checkpoint(final_dir, tag="final") + else: + self.model.save_pretrained(final_dir) + + self.tokenizer.save_pretrained(final_dir) + + # Save metrics + with open(os.path.join(final_dir, "metrics_history.json"), "w") as f: + json.dump(self.metrics_history, f, indent=2) + + logger.info(f"Saved final model to {final_dir}") + return final_dir + + +# ============================================================================ +# Main Entry Point +# ============================================================================ + +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="RLVR Training with DAPO Algorithm" + ) + parser.add_argument( + "--precision_mode", + type=str, + required=True, + choices=["fp32", "bf16"], + help="Precision mode: fp32 (high precision) or bf16 (default RLVR)" + ) + parser.add_argument( + "--seed", + type=int, + required=True, + help="Random seed for reproducibility" + ) + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Directory to save outputs" + ) + parser.add_argument( + "--train_dataset_path", + type=str, + required=True, + help="Path to training dataset JSON" + ) + parser.add_argument( + "--model_name", + type=str, + default="Qwen/Qwen2.5-Math-7B", + help="HuggingFace model identifier" + ) + parser.add_argument( + "--num_steps", + type=int, + default=300, + help="Number of training steps" + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device to use for training" + ) + parser.add_argument( + "--deepspeed", + type=str, + default=None, + help="Path to DeepSpeed config file" + ) + parser.add_argument( + "--local_rank", + type=int, + default=-1, + help="Local rank for distributed training (set by DeepSpeed)" + ) + return parser.parse_args() + + +def main() -> None: + """Main training function.""" + args = parse_args() + + # Create configurations + train_config = make_training_config( + precision_mode=args.precision_mode, + seed=args.seed, + output_dir=args.output_dir, + train_dataset_path=args.train_dataset_path, + model_name=args.model_name + ) + train_config.num_steps = args.num_steps + + precision_config = make_precision_config(args.precision_mode) + + # Setup output directory + os.makedirs(train_config.output_dir, exist_ok=True) + + # Save configurations + with open(os.path.join(train_config.output_dir, "train_config.json"), "w") as f: + json.dump(asdict(train_config), f, indent=2) + with open(os.path.join(train_config.output_dir, "precision_config.json"), "w") as f: + json.dump(asdict(precision_config), f, indent=2) + + # Set seeds and determinism + set_seed(train_config.seed) + configure_torch_deterministic(precision_config.deterministic) + + # Check if using DeepSpeed + use_deepspeed = args.deepspeed is not None + + # Setup devices + num_gpus = torch.cuda.device_count() + + if use_deepspeed: + # Initialize DeepSpeed distributed + deepspeed.init_distributed() + local_rank = args.local_rank if args.local_rank >= 0 else int(os.environ.get("LOCAL_RANK", 0)) + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + # Put reference model on same GPU as training (each rank has its own copy) + # ZeRO-2 shards optimizer states, so there's room for the bf16 ref model (~14GB) + ref_device = device + logger.info(f"DeepSpeed: rank {local_rank}, training on {device}, ref model on {ref_device}") + else: + device = torch.device(args.device if torch.cuda.is_available() else "cpu") + if num_gpus >= 2: + ref_device = torch.device("cuda:1") + logger.info(f"Using device: {device} for training, {ref_device} for reference model") + else: + ref_device = device + logger.info(f"Using device: {device} for both training and reference model") + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained( + train_config.model_name, + use_fast=True, + trust_remote_code=True + ) + tokenizer.padding_side = "left" + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # Load model + logger.info(f"Loading model: {train_config.model_name}") + model = AutoModelForCausalLM.from_pretrained( + train_config.model_name, + torch_dtype=torch.float32, # Load in FP32 first + device_map=None, + trust_remote_code=True + ) + + # Cast to target precision + model = cast_model_param_dtype(model, precision_config.param_dtype) + + # Enable gradient checkpointing to save memory (trades compute for memory) + model.gradient_checkpointing_enable() + logger.info("Enabled gradient checkpointing to reduce memory usage") + + # Disable dropout if needed + if not train_config.use_dropout: + disable_dropout(model) + + logger.info(f"Model loaded with {count_parameters(model):,} trainable parameters") + + # Initialize DeepSpeed or move model to device + if use_deepspeed: + # Create optimizer for DeepSpeed + optimizer = torch.optim.AdamW( + model.parameters(), + lr=train_config.learning_rate, + betas=(train_config.beta1, train_config.beta2), + weight_decay=train_config.weight_decay + ) + + # Load DeepSpeed config + with open(args.deepspeed, 'r') as f: + ds_config = json.load(f) + + # Compute batch sizes compatible with DeepSpeed + # DeepSpeed requires: train_batch_size = micro_batch * grad_acc * world_size + world_size = int(os.environ.get("WORLD_SIZE", num_gpus - 1)) # -1 for ref model GPU + micro_batch = train_config.micro_batch_size + + # Compute grad_acc to get closest to desired global batch size + desired_global = train_config.global_batch_size + grad_acc = max(1, desired_global // (micro_batch * world_size)) + actual_global = micro_batch * grad_acc * world_size + + ds_config["train_batch_size"] = actual_global + ds_config["train_micro_batch_size_per_gpu"] = micro_batch + ds_config["gradient_accumulation_steps"] = grad_acc + + logger.info(f"DeepSpeed batch config: global={actual_global}, micro={micro_batch}, grad_acc={grad_acc}, world_size={world_size}") + + # Initialize DeepSpeed engine + model_engine, optimizer, _, _ = deepspeed.initialize( + model=model, + optimizer=optimizer, + config=ds_config + ) + logger.info(f"DeepSpeed ZeRO-2 initialized with {world_size} GPUs") + else: + model = model.to(device) + model_engine = model + + # Load reference model (frozen copy) + # Always use bf16 for reference model - it's frozen and only used for KL computation. + # This is a controlled variable: same precision for both fp32 and bf16 training modes. + # The experiment tests training precision effects, not reference model precision. + logger.info("Loading reference model (bf16 for both modes - controlled variable)") + ref_model = AutoModelForCausalLM.from_pretrained( + train_config.model_name, + torch_dtype=torch.bfloat16, # Always bf16 to save memory & control variable + device_map=None, + trust_remote_code=True + ) + ref_model = ref_model.to(ref_device) + ref_model.eval() + + # Load training data + train_data = load_training_data(train_config.train_dataset_path) + + # Initialize trainer + trainer = DAPOTrainer( + model_engine=model_engine, + ref_model=ref_model, + tokenizer=tokenizer, + train_config=train_config, + precision_config=precision_config, + device=device, + ref_device=ref_device, + use_deepspeed=use_deepspeed + ) + + # Run training + trainer.train(train_data, save_checkpoints=True) + + # Save final model + final_path = trainer.save_final() + logger.info(f"Training complete. Final model saved to: {final_path}") + + +if __name__ == "__main__": + main() + diff --git a/utils_bf16_sparsity.py b/utils_bf16_sparsity.py new file mode 100644 index 0000000..2e0729d --- /dev/null +++ b/utils_bf16_sparsity.py @@ -0,0 +1,459 @@ +# utils_bf16_sparsity.py +""" +bf16-Aware Update Sparsity Utilities. + +This module implements the bf16-aware update sparsity metric from the RLVR paper, +which measures how many parameter updates are "visible" after bf16 quantization. + +Key concepts: +- Due to bf16's limited precision (7 mantissa bits), small updates may be "swallowed" +- The bf16 ULP (Unit in Last Place) creates a minimum relative update threshold +- Updates smaller than ~0.2-0.4% may not be reflected in bf16 representation + +Reference: +- RLVR paper Definition 2.1 & 2.2 +- bf16 ULP analysis showing relative update threshold of 2^{-8} to 2^{-7} +""" + +import torch +import numpy as np +from typing import Dict, Any, Tuple, List, Optional +import logging +from tqdm import tqdm + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# bf16 Equality Check +# ============================================================================ + +def bf16_approximately_equal( + w: torch.Tensor, + w_hat: torch.Tensor, + eta: float = 1e-3 +) -> torch.Tensor: + """ + Check if two tensors are approximately equal under bf16 precision. + + From RLVR Definition 2.1: + Two values w and w_hat are considered bf16-equal if: + |w_hat - w| <= eta * max(|w|, |w_hat|) + + When eta < 2^{-9}, this is equivalent to bit-wise bf16 equality. + + Args: + w: Original weights tensor + w_hat: Updated weights tensor + eta: Relative tolerance (default 1e-3 as in RLVR) + + Returns: + Boolean mask where True indicates bf16-equal + """ + max_abs = torch.maximum(w.abs(), w_hat.abs()) + diff = (w_hat - w).abs() + + # Handle zero weights (avoid division by zero in relative comparison) + # For zeros, use absolute comparison + zero_mask = max_abs < 1e-10 + + # Relative comparison + relative_equal = diff <= eta * max_abs + + # For zeros, check if both are effectively zero + both_zero = w.abs() < 1e-10 + both_zero = both_zero & (w_hat.abs() < 1e-10) + + # Combine: either relatively equal, or both effectively zero + equal_mask = relative_equal | (zero_mask & both_zero) + + return equal_mask + + +def bf16_bitwise_equal( + w: torch.Tensor, + w_hat: torch.Tensor +) -> torch.Tensor: + """ + Check if two tensors are bitwise equal in bf16 representation. + + This is the strictest equality check - values must have identical + bf16 bit patterns. + + Args: + w: Original weights tensor + w_hat: Updated weights tensor + + Returns: + Boolean mask where True indicates bitwise bf16 equality + """ + # Convert to bf16 and compare + w_bf16 = w.to(torch.bfloat16) + w_hat_bf16 = w_hat.to(torch.bfloat16) + + # Bitwise comparison via view as int16 + w_bits = w_bf16.view(torch.int16) + w_hat_bits = w_hat_bf16.view(torch.int16) + + return w_bits == w_hat_bits + + +# ============================================================================ +# Update Count and Sparsity +# ============================================================================ + +def compute_bf16_update_count( + w: torch.Tensor, + w_hat: torch.Tensor, + eta: float = 1e-3 +) -> Tuple[int, int, int]: + """ + Compute bf16-aware update count. + + From RLVR Definition 2.2: + |θ_1 - θ_0|_{0,bf16,η} = #{i: w_hat_i not≈_{bf16,η} w_i} + + Args: + w: Original weights tensor + w_hat: Updated weights tensor + eta: Relative tolerance + + Returns: + Tuple of (num_changed, num_unchanged, total) + """ + equal_mask = bf16_approximately_equal(w, w_hat, eta=eta) + + total = int(equal_mask.numel()) + num_unchanged = int(equal_mask.sum().item()) + num_changed = total - num_unchanged + + return num_changed, num_unchanged, total + + +def compute_bf16_sparsity( + base_model: torch.nn.Module, + finetuned_model: torch.nn.Module, + eta: float = 1e-3, + include_layer_stats: bool = False +) -> Dict[str, Any]: + """ + Compute bf16-aware update sparsity between two models. + + Sparsity = 1 - |θ_1 - θ_0|_{0,bf16,η} / n + + Values closer to 1 mean more sparse (fewer visible updates). + Values closer to 0 mean more dense (more visible updates). + + RLVR Table 1 reports sparsity in range 36%-92% for their experiments. + + Args: + base_model: Original model (θ_0) + finetuned_model: Updated model (θ_1) + eta: Relative tolerance + include_layer_stats: If True, include per-layer statistics + + Returns: + Dictionary with sparsity metrics + """ + base_params = dict(base_model.named_parameters()) + ft_params = dict(finetuned_model.named_parameters()) + + total_elements = 0 + changed_elements = 0 + + layer_stats: Dict[str, Dict[str, Any]] = {} + + for name, base_param in base_params.items(): + if name not in ft_params: + logger.warning(f"Parameter {name} not found in finetuned model") + continue + + ft_param = ft_params[name] + + if base_param.shape != ft_param.shape: + logger.warning( + f"Shape mismatch for {name}: " + f"{base_param.shape} vs {ft_param.shape}" + ) + continue + + # Move to CPU for computation + w = base_param.detach().cpu().float().flatten() + w_hat = ft_param.detach().cpu().float().flatten() + + # Compute update count + num_changed, num_unchanged, total = compute_bf16_update_count( + w, w_hat, eta=eta + ) + + total_elements += total + changed_elements += num_changed + + if include_layer_stats: + layer_sparsity = 1.0 - num_changed / total if total > 0 else 1.0 + layer_stats[name] = { + "num_changed": num_changed, + "num_unchanged": num_unchanged, + "total": total, + "sparsity": layer_sparsity, + "shape": list(base_param.shape) + } + + # Compute overall sparsity + overall_sparsity = 1.0 - changed_elements / total_elements if total_elements > 0 else 1.0 + + result = { + "sparsity": overall_sparsity, + "sparsity_percent": overall_sparsity * 100, + "num_changed": changed_elements, + "num_unchanged": total_elements - changed_elements, + "total_parameters": total_elements, + "eta": eta, + "update_fraction": changed_elements / total_elements if total_elements > 0 else 0.0, + } + + if include_layer_stats: + result["layer_stats"] = layer_stats + + return result + + +# ============================================================================ +# Update Magnitude Analysis +# ============================================================================ + +def analyze_update_magnitudes( + base_model: torch.nn.Module, + finetuned_model: torch.nn.Module +) -> Dict[str, Any]: + """ + Analyze the distribution of update magnitudes. + + This helps understand which updates are below the bf16 ULP threshold. + + Returns statistics about: + - Absolute update magnitudes + - Relative update magnitudes + - Distribution relative to bf16 ULP + """ + base_params = dict(base_model.named_parameters()) + ft_params = dict(finetuned_model.named_parameters()) + + all_relative_updates: List[float] = [] + all_absolute_updates: List[float] = [] + + for name, base_param in base_params.items(): + if name not in ft_params: + continue + + ft_param = ft_params[name] + if base_param.shape != ft_param.shape: + continue + + w = base_param.detach().cpu().float().flatten() + w_hat = ft_param.detach().cpu().float().flatten() + + # Absolute updates + abs_updates = (w_hat - w).abs() + + # Relative updates (avoid division by zero) + max_abs = torch.maximum(w.abs(), w_hat.abs()) + valid_mask = max_abs > 1e-10 + + rel_updates = torch.zeros_like(abs_updates) + rel_updates[valid_mask] = abs_updates[valid_mask] / max_abs[valid_mask] + + # Sample for statistics (avoid memory issues) + sample_size = min(10000, len(abs_updates)) + indices = np.random.choice(len(abs_updates), sample_size, replace=False) + + all_absolute_updates.extend(abs_updates[indices].tolist()) + all_relative_updates.extend(rel_updates[indices].tolist()) + + abs_array = np.array(all_absolute_updates) + rel_array = np.array(all_relative_updates) + + # bf16 ULP threshold (approximately 2^{-8} to 2^{-7}, or 0.2% to 0.4%) + bf16_ulp_low = 2 ** -8 # ~0.39% + bf16_ulp_high = 2 ** -7 # ~0.78% + + # Fraction of updates below ULP threshold + below_low = (rel_array < bf16_ulp_low).mean() + below_high = (rel_array < bf16_ulp_high).mean() + + result = { + "absolute_updates": { + "mean": float(np.mean(abs_array)), + "std": float(np.std(abs_array)), + "median": float(np.median(abs_array)), + "min": float(np.min(abs_array)), + "max": float(np.max(abs_array)), + "percentiles": { + "p25": float(np.percentile(abs_array, 25)), + "p50": float(np.percentile(abs_array, 50)), + "p75": float(np.percentile(abs_array, 75)), + "p90": float(np.percentile(abs_array, 90)), + "p99": float(np.percentile(abs_array, 99)), + } + }, + "relative_updates": { + "mean": float(np.mean(rel_array)), + "std": float(np.std(rel_array)), + "median": float(np.median(rel_array)), + "min": float(np.min(rel_array)), + "max": float(np.max(rel_array)), + "percentiles": { + "p25": float(np.percentile(rel_array, 25)), + "p50": float(np.percentile(rel_array, 50)), + "p75": float(np.percentile(rel_array, 75)), + "p90": float(np.percentile(rel_array, 90)), + "p99": float(np.percentile(rel_array, 99)), + } + }, + "bf16_ulp_analysis": { + "ulp_low_threshold": bf16_ulp_low, + "ulp_high_threshold": bf16_ulp_high, + "fraction_below_low": float(below_low), + "fraction_below_high": float(below_high), + "estimated_swallowed_fraction": float(below_low), + } + } + + return result + + +# ============================================================================ +# Sparsity Trajectory +# ============================================================================ + +def compute_sparsity_trajectory( + base_model_path: str, + checkpoint_paths: List[str], + eta: float = 1e-3 +) -> List[Dict[str, Any]]: + """ + Compute bf16 sparsity for a sequence of checkpoints. + + Useful for understanding how sparsity evolves during training. + + Args: + base_model_path: Path to base model + checkpoint_paths: List of checkpoint paths (in training order) + eta: Relative tolerance + + Returns: + List of sparsity results for each checkpoint + """ + from transformers import AutoModelForCausalLM + + # Load base model + logger.info(f"Loading base model from {base_model_path}") + base_model = AutoModelForCausalLM.from_pretrained( + base_model_path, + torch_dtype=torch.float32, + device_map="cpu" + ) + + trajectory = [] + + for ckpt_path in tqdm(checkpoint_paths, desc="Computing sparsity"): + # Load checkpoint + ckpt_model = AutoModelForCausalLM.from_pretrained( + ckpt_path, + torch_dtype=torch.float32, + device_map="cpu" + ) + + # Compute sparsity + sparsity_result = compute_bf16_sparsity( + base_model=base_model, + finetuned_model=ckpt_model, + eta=eta, + include_layer_stats=False + ) + + trajectory.append({ + "checkpoint": ckpt_path, + "sparsity": sparsity_result["sparsity"], + "sparsity_percent": sparsity_result["sparsity_percent"], + "num_changed": sparsity_result["num_changed"], + }) + + # Free memory + del ckpt_model + + return trajectory + + +# ============================================================================ +# Layer-wise Sparsity Analysis +# ============================================================================ + +def analyze_layer_sparsity_patterns( + base_model: torch.nn.Module, + finetuned_model: torch.nn.Module, + eta: float = 1e-3 +) -> Dict[str, Any]: + """ + Analyze sparsity patterns across different layer types. + + Groups layers by type (attention, MLP, embeddings, etc.) and + reports aggregate sparsity statistics. + """ + sparsity_result = compute_bf16_sparsity( + base_model=base_model, + finetuned_model=finetuned_model, + eta=eta, + include_layer_stats=True + ) + + layer_stats = sparsity_result.get("layer_stats", {}) + + # Group by layer type + groups: Dict[str, List[Dict[str, Any]]] = { + "attention": [], + "mlp": [], + "embedding": [], + "norm": [], + "other": [] + } + + for name, stats in layer_stats.items(): + name_lower = name.lower() + + if any(k in name_lower for k in ["attn", "attention", "self_attn"]): + groups["attention"].append(stats) + elif any(k in name_lower for k in ["mlp", "fc", "dense", "linear"]): + groups["mlp"].append(stats) + elif any(k in name_lower for k in ["embed", "wte", "wpe"]): + groups["embedding"].append(stats) + elif any(k in name_lower for k in ["norm", "ln", "layer_norm"]): + groups["norm"].append(stats) + else: + groups["other"].append(stats) + + # Compute aggregate statistics per group + group_analysis = {} + for group_name, layer_list in groups.items(): + if not layer_list: + continue + + sparsities = [l["sparsity"] for l in layer_list] + total_params = sum(l["total"] for l in layer_list) + total_changed = sum(l["num_changed"] for l in layer_list) + + group_analysis[group_name] = { + "num_layers": len(layer_list), + "total_params": total_params, + "mean_sparsity": float(np.mean(sparsities)), + "std_sparsity": float(np.std(sparsities)), + "min_sparsity": float(np.min(sparsities)), + "max_sparsity": float(np.max(sparsities)), + "aggregate_sparsity": 1.0 - total_changed / total_params if total_params > 0 else 1.0, + } + + return { + "overall_sparsity": sparsity_result["sparsity"], + "group_analysis": group_analysis, + } + diff --git a/utils_kl.py b/utils_kl.py new file mode 100644 index 0000000..2be50a0 --- /dev/null +++ b/utils_kl.py @@ -0,0 +1,419 @@ +# utils_kl.py +""" +KL Divergence Utilities for RLVR Experiments. + +This module provides utilities for computing KL divergence between +policy distributions, including: +- Token-level KL computation +- Sequence-level KL aggregation +- Dataset-level KL estimation +""" + +import torch +import torch.nn.functional as F +from typing import Dict, Any, List, Tuple, Optional +import numpy as np +import logging +from tqdm import tqdm + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Token-Level KL Computation +# ============================================================================ + +def compute_token_log_probs( + model: torch.nn.Module, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + labels: Optional[torch.Tensor] = None +) -> torch.Tensor: + """ + Compute token-level log probabilities. + + Args: + model: Language model + input_ids: Input token IDs [batch, seq_len] + attention_mask: Attention mask [batch, seq_len] + labels: Token labels for which to compute log probs (default: input_ids) + + Returns: + Token log probabilities [batch, seq_len - 1] + """ + if labels is None: + labels = input_ids + + with torch.no_grad(): + outputs = model( + input_ids=input_ids, + attention_mask=attention_mask, + use_cache=False + ) + + logits = outputs.logits # [batch, seq_len, vocab] + + # Shift for autoregressive: predict token t from tokens 0..t-1 + shift_logits = logits[:, :-1, :] # [batch, seq_len-1, vocab] + shift_labels = labels[:, 1:] # [batch, seq_len-1] + + # Compute log probabilities + log_probs = F.log_softmax(shift_logits, dim=-1) + + # Gather log probs for actual tokens + token_log_probs = torch.gather( + log_probs, + dim=-1, + index=shift_labels.unsqueeze(-1) + ).squeeze(-1) # [batch, seq_len-1] + + return token_log_probs + + +def compute_kl_per_token( + policy_log_probs: torch.Tensor, + ref_log_probs: torch.Tensor +) -> torch.Tensor: + """ + Compute per-token KL divergence. + + KL(π || π_ref) at token t = log π(y_t) - log π_ref(y_t) + + Note: This is the forward KL from policy to reference. + """ + return policy_log_probs - ref_log_probs + + +def compute_reverse_kl_per_token( + policy_logits: torch.Tensor, + ref_logits: torch.Tensor, + temperature: float = 1.0 +) -> torch.Tensor: + """ + Compute per-token reverse KL divergence using full distributions. + + KL(π || π_ref) = Σ_y π(y) [log π(y) - log π_ref(y)] + + This is more expensive but gives the true KL. + """ + policy_probs = F.softmax(policy_logits / temperature, dim=-1) + policy_log_probs = F.log_softmax(policy_logits / temperature, dim=-1) + ref_log_probs = F.log_softmax(ref_logits / temperature, dim=-1) + + # KL = Σ p(x) log(p(x)/q(x)) = Σ p(x) [log p(x) - log q(x)] + kl = (policy_probs * (policy_log_probs - ref_log_probs)).sum(dim=-1) + + return kl + + +# ============================================================================ +# Sequence-Level KL +# ============================================================================ + +def compute_sequence_kl( + policy_model: torch.nn.Module, + ref_model: torch.nn.Module, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + response_start_idx: int = 0, + normalize_by_length: bool = False +) -> Dict[str, float]: + """ + Compute KL divergence for a single sequence. + + Args: + policy_model: Finetuned policy model + ref_model: Reference model + input_ids: Full sequence (prompt + response) [1, seq_len] + attention_mask: Attention mask [1, seq_len] + response_start_idx: Index where response starts + normalize_by_length: If True, return average KL per token + + Returns: + Dictionary with KL metrics + """ + # Get log probs from both models + policy_log_probs = compute_token_log_probs( + policy_model, input_ids, attention_mask + ) + ref_log_probs = compute_token_log_probs( + ref_model, input_ids, attention_mask + ) + + # Compute per-token KL + kl_per_token = compute_kl_per_token(policy_log_probs, ref_log_probs) + + # Create mask for response tokens only + seq_len = kl_per_token.shape[1] + response_mask = torch.zeros(1, seq_len, device=input_ids.device) + if response_start_idx > 0: + response_mask[:, response_start_idx-1:] = 1.0 + else: + response_mask[:, :] = 1.0 + + # Apply attention mask + valid_mask = attention_mask[:, 1:].float() * response_mask + + # Compute statistics + masked_kl = kl_per_token * valid_mask + num_tokens = valid_mask.sum().item() + total_kl = masked_kl.sum().item() + + result = { + "total_kl": total_kl, + "num_tokens": int(num_tokens), + "mean_kl": total_kl / num_tokens if num_tokens > 0 else 0.0, + "max_kl": (kl_per_token * valid_mask).max().item() if num_tokens > 0 else 0.0, + "min_kl": (kl_per_token * valid_mask).min().item() if num_tokens > 0 else 0.0, + } + + if normalize_by_length: + result["kl"] = result["mean_kl"] + else: + result["kl"] = result["total_kl"] + + return result + + +# ============================================================================ +# Dataset-Level KL Estimation +# ============================================================================ + +def estimate_dataset_kl( + policy_model: torch.nn.Module, + ref_model: torch.nn.Module, + tokenizer, + prompts: List[str], + responses: List[str], + device: torch.device, + max_seq_len: int = 4096, + normalize_by_length: bool = False, + show_progress: bool = True +) -> Dict[str, Any]: + """ + Estimate KL divergence over a dataset. + + Args: + policy_model: Finetuned policy model + ref_model: Reference model + tokenizer: Tokenizer for both models + prompts: List of prompts + responses: List of corresponding responses + device: Device to use + max_seq_len: Maximum sequence length + normalize_by_length: If True, use mean KL per token + show_progress: Show progress bar + + Returns: + Dictionary with dataset-level KL statistics + """ + assert len(prompts) == len(responses), \ + "Number of prompts must match responses" + + policy_model.eval() + ref_model.eval() + + all_kl_values: List[float] = [] + all_num_tokens: List[int] = [] + + iterator = zip(prompts, responses) + if show_progress: + iterator = tqdm( + list(iterator), + desc="Computing KL" + ) + + for prompt, response in iterator: + # Tokenize prompt + prompt_tokens = tokenizer( + prompt, + return_tensors="pt", + truncation=True, + max_length=max_seq_len // 2 + ) + prompt_len = prompt_tokens["input_ids"].shape[1] + + # Tokenize full sequence + full_text = prompt + response + full_tokens = tokenizer( + full_text, + return_tensors="pt", + truncation=True, + max_length=max_seq_len + ) + + input_ids = full_tokens["input_ids"].to(device) + attention_mask = full_tokens["attention_mask"].to(device) + + # Compute sequence KL + with torch.no_grad(): + kl_result = compute_sequence_kl( + policy_model=policy_model, + ref_model=ref_model, + input_ids=input_ids, + attention_mask=attention_mask, + response_start_idx=prompt_len, + normalize_by_length=normalize_by_length + ) + + all_kl_values.append(kl_result["kl"]) + all_num_tokens.append(kl_result["num_tokens"]) + + # Aggregate statistics + kl_array = np.array(all_kl_values) + + result = { + "mean_kl": float(np.mean(kl_array)), + "std_kl": float(np.std(kl_array)), + "median_kl": float(np.median(kl_array)), + "min_kl": float(np.min(kl_array)), + "max_kl": float(np.max(kl_array)), + "total_samples": len(prompts), + "total_tokens": sum(all_num_tokens), + "kl_values": all_kl_values, + } + + return result + + +# ============================================================================ +# On-Task vs Off-Task KL Analysis +# ============================================================================ + +def analyze_kl_by_task( + kl_results: Dict[str, Dict[str, Any]], + on_task_names: List[str], + off_task_names: List[str] +) -> Dict[str, Any]: + """ + Analyze KL divergence patterns for on-task vs off-task. + + Args: + kl_results: Dictionary mapping task names to KL results + on_task_names: List of on-task (training distribution) names + off_task_names: List of off-task names + + Returns: + Analysis of KL patterns + """ + on_task_kl = [] + off_task_kl = [] + + for name in on_task_names: + if name in kl_results: + on_task_kl.append(kl_results[name]["mean_kl"]) + + for name in off_task_names: + if name in kl_results: + off_task_kl.append(kl_results[name]["mean_kl"]) + + analysis = { + "on_task": { + "mean": float(np.mean(on_task_kl)) if on_task_kl else 0.0, + "std": float(np.std(on_task_kl)) if on_task_kl else 0.0, + "values": on_task_kl, + }, + "off_task": { + "mean": float(np.mean(off_task_kl)) if off_task_kl else 0.0, + "std": float(np.std(off_task_kl)) if off_task_kl else 0.0, + "values": off_task_kl, + }, + } + + # Compute ratio + if analysis["on_task"]["mean"] > 0: + analysis["off_to_on_ratio"] = ( + analysis["off_task"]["mean"] / analysis["on_task"]["mean"] + ) + else: + analysis["off_to_on_ratio"] = float("inf") + + return analysis + + +# ============================================================================ +# KL Contribution Analysis +# ============================================================================ + +def analyze_kl_contribution_by_layer( + model: torch.nn.Module, + input_ids: torch.Tensor, + attention_mask: torch.Tensor +) -> Dict[str, float]: + """ + Analyze which layers contribute most to the final prediction. + + This is a simplified analysis - for full KL attribution, + you would need layer-wise probing. + """ + # This is a placeholder for more sophisticated analysis + # Full implementation would require modifying the model + # to output intermediate representations + + return { + "note": "Layer-wise KL attribution not implemented", + } + + +def compute_kl_trajectory( + checkpoints: List[str], + ref_model_path: str, + tokenizer_path: str, + prompts: List[str], + responses: List[str], + device: torch.device +) -> List[Dict[str, Any]]: + """ + Compute KL divergence trajectory over training checkpoints. + + Useful for understanding how KL evolves during training. + """ + from transformers import AutoModelForCausalLM, AutoTokenizer + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # Load reference model + ref_model = AutoModelForCausalLM.from_pretrained( + ref_model_path, + torch_dtype=torch.bfloat16, + device_map=None + ).to(device) + ref_model.eval() + + trajectory = [] + + for ckpt_path in tqdm(checkpoints, desc="Computing KL trajectory"): + # Load checkpoint + policy_model = AutoModelForCausalLM.from_pretrained( + ckpt_path, + torch_dtype=torch.bfloat16, + device_map=None + ).to(device) + policy_model.eval() + + # Estimate KL + kl_result = estimate_dataset_kl( + policy_model=policy_model, + ref_model=ref_model, + tokenizer=tokenizer, + prompts=prompts, + responses=responses, + device=device, + show_progress=False + ) + + trajectory.append({ + "checkpoint": ckpt_path, + "mean_kl": kl_result["mean_kl"], + "std_kl": kl_result["std_kl"], + }) + + # Free memory + del policy_model + torch.cuda.empty_cache() + + return trajectory + diff --git a/utils_math_eval.py b/utils_math_eval.py new file mode 100644 index 0000000..d4a1db2 --- /dev/null +++ b/utils_math_eval.py @@ -0,0 +1,367 @@ +# utils_math_eval.py +""" +Math Evaluation Utilities for RLVR Experiments. + +This module provides utilities for: +- Extracting answers from model responses +- Verifying mathematical answers +- Computing accuracy metrics +""" + +import re +from typing import Optional, List, Dict, Any, Tuple +import logging + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Answer Extraction +# ============================================================================ + +def extract_boxed_content(text: str) -> List[str]: + """ + Extract all content from \\boxed{} patterns. + + Handles nested braces correctly. + """ + results = [] + i = 0 + + while i < len(text): + # Find \boxed{ + idx = text.find("\\boxed{", i) + if idx == -1: + break + + # Find matching closing brace + start = idx + 7 # len("\\boxed{") + depth = 1 + j = start + + while j < len(text) and depth > 0: + if text[j] == "{": + depth += 1 + elif text[j] == "}": + depth -= 1 + j += 1 + + if depth == 0: + content = text[start:j-1] + results.append(content.strip()) + + i = j + + return results + + +def extract_answer_from_boxed(text: str) -> Optional[str]: + """Extract the last boxed answer from text.""" + boxed_contents = extract_boxed_content(text) + if boxed_contents: + return boxed_contents[-1] + return None + + +def extract_answer_from_patterns(text: str) -> Optional[str]: + """ + Extract answer using common natural language patterns. + """ + # Patterns in order of priority + patterns = [ + # Explicit answer statements + (r"[Tt]he\s+(?:final\s+)?answer\s+is\s*[:\s]*(.+?)(?:\.|,|$)", 1), + (r"[Aa]nswer\s*[:\s]+(.+?)(?:\.|,|$)", 1), + + # Conclusion patterns + (r"[Tt]herefore\s*,?\s*(?:the\s+answer\s+is\s*)?(.+?)(?:\.|$)", 1), + (r"[Hh]ence\s*,?\s*(?:the\s+answer\s+is\s*)?(.+?)(?:\.|$)", 1), + (r"[Ss]o\s*,?\s*(?:the\s+answer\s+is\s*)?(.+?)(?:\.|$)", 1), + (r"[Tt]hus\s*,?\s*(?:the\s+answer\s+is\s*)?(.+?)(?:\.|$)", 1), + + # Equation result + (r"=\s*(\S+)\s*$", 1), + ] + + for pattern, group in patterns: + match = re.search(pattern, text, re.MULTILINE | re.IGNORECASE) + if match: + answer = match.group(group).strip() + # Clean up trailing punctuation + answer = re.sub(r"[.,;:!?]+$", "", answer).strip() + if answer: + return answer + + return None + + +def extract_final_answer(text: str) -> Optional[str]: + """ + Extract the final answer from a model response. + + Priority: + 1. \\boxed{} format + 2. Natural language patterns + """ + # Try boxed format first + boxed = extract_answer_from_boxed(text) + if boxed: + return boxed + + # Try natural language patterns + pattern_answer = extract_answer_from_patterns(text) + if pattern_answer: + return pattern_answer + + return None + + +# ============================================================================ +# Answer Normalization +# ============================================================================ + +def normalize_numeric_answer(answer: str) -> Optional[float]: + """ + Normalize a numeric answer for comparison. + + Handles: + - Integers and decimals + - Fractions (simple forms like a/b) + - Scientific notation + - Percentages + """ + if not answer: + return None + + # Clean up the string + cleaned = answer.strip().lower() + cleaned = cleaned.replace(" ", "") + cleaned = cleaned.replace(",", "") + + # Handle percentages + if cleaned.endswith("%"): + cleaned = cleaned[:-1] + try: + return float(cleaned) / 100 + except ValueError: + pass + + # Handle fractions (a/b) + if "/" in cleaned: + parts = cleaned.split("/") + if len(parts) == 2: + try: + num = float(parts[0]) + denom = float(parts[1]) + if denom != 0: + return num / denom + except ValueError: + pass + + # Handle scientific notation and regular numbers + try: + return float(cleaned) + except ValueError: + pass + + return None + + +def normalize_text_answer(answer: str) -> str: + """ + Normalize a text answer for comparison. + + - Lowercase + - Remove extra whitespace + - Remove common formatting + """ + if not answer: + return "" + + normalized = answer.strip().lower() + + # Remove LaTeX formatting + normalized = re.sub(r"\\[a-zA-Z]+", "", normalized) + normalized = re.sub(r"[{}$]", "", normalized) + + # Normalize whitespace + normalized = " ".join(normalized.split()) + + # Remove common punctuation + normalized = re.sub(r"[.,;:!?]+$", "", normalized).strip() + + return normalized + + +# ============================================================================ +# Answer Comparison +# ============================================================================ + +def compare_numeric_answers( + predicted: str, + ground_truth: str, + tolerance: float = 1e-6 +) -> bool: + """ + Compare two answers numerically. + + Returns True if both can be parsed as numbers and are within tolerance. + """ + pred_num = normalize_numeric_answer(predicted) + gt_num = normalize_numeric_answer(ground_truth) + + if pred_num is None or gt_num is None: + return False + + # Absolute tolerance for small numbers + if abs(gt_num) < 1e-6: + return abs(pred_num - gt_num) < tolerance + + # Relative tolerance for larger numbers + rel_diff = abs(pred_num - gt_num) / abs(gt_num) + return rel_diff < tolerance + + +def compare_text_answers( + predicted: str, + ground_truth: str +) -> bool: + """Compare two text answers after normalization.""" + pred_norm = normalize_text_answer(predicted) + gt_norm = normalize_text_answer(ground_truth) + + return pred_norm == gt_norm + + +def verify_answer( + response: str, + ground_truth: str, + task_type: str = "math" +) -> Tuple[bool, Optional[str]]: + """ + Verify if the response contains the correct answer. + + Args: + response: Model's full response + ground_truth: Expected answer + task_type: Type of task ("math", "qa", "code") + + Returns: + Tuple of (is_correct, extracted_answer) + """ + # Extract predicted answer + predicted = extract_final_answer(response) + + if predicted is None: + return False, None + + # Try numeric comparison first + if compare_numeric_answers(predicted, ground_truth): + return True, predicted + + # Try text comparison + if compare_text_answers(predicted, ground_truth): + return True, predicted + + # Check if ground truth is contained in predicted + gt_norm = normalize_text_answer(ground_truth) + pred_norm = normalize_text_answer(predicted) + + if gt_norm and gt_norm in pred_norm: + return True, predicted + + return False, predicted + + +# ============================================================================ +# Batch Evaluation +# ============================================================================ + +def evaluate_batch( + responses: List[str], + ground_truths: List[str], + task_type: str = "math" +) -> Dict[str, Any]: + """ + Evaluate a batch of responses. + + Args: + responses: List of model responses + ground_truths: List of expected answers + task_type: Type of task + + Returns: + Dictionary with evaluation metrics + """ + assert len(responses) == len(ground_truths), \ + "Number of responses must match ground truths" + + correct = 0 + total = len(responses) + results = [] + + for response, gt in zip(responses, ground_truths): + is_correct, extracted = verify_answer(response, gt, task_type) + correct += int(is_correct) + results.append({ + "is_correct": is_correct, + "extracted_answer": extracted, + "ground_truth": gt + }) + + accuracy = correct / total if total > 0 else 0.0 + + return { + "accuracy": accuracy, + "correct": correct, + "total": total, + "results": results + } + + +# ============================================================================ +# Answer Format Detection +# ============================================================================ + +def detect_answer_format(text: str) -> str: + """ + Detect the format of an answer. + + Returns one of: "boxed", "numeric", "fraction", "text", "unknown" + """ + if "\\boxed{" in text: + return "boxed" + + # Check for fraction + if re.match(r"^-?\d+/\d+$", text.strip()): + return "fraction" + + # Check for numeric + try: + float(text.strip().replace(",", "")) + return "numeric" + except ValueError: + pass + + if text.strip(): + return "text" + + return "unknown" + + +def format_answer_for_display(answer: str, detected_format: str) -> str: + """Format answer for display based on detected format.""" + if detected_format == "fraction": + num = normalize_numeric_answer(answer) + if num is not None: + return f"{answer} ≈ {num:.6f}" + + if detected_format == "numeric": + try: + num = float(answer.replace(",", "")) + return f"{num:g}" + except ValueError: + pass + + return answer + -- cgit v1.2.3