summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore5
-rw-r--r--README.md196
-rw-r--r--analyze_results.py741
-rw-r--r--config.py328
-rw-r--r--configs/deepspeed_zero2.json31
-rw-r--r--configs/deepspeed_zero3.json38
-rw-r--r--configs/eval_tasks_config.json99
-rw-r--r--eval_policy.py621
-rw-r--r--requirements.txt39
-rw-r--r--run_experiments.py601
-rwxr-xr-xscripts/prepare_data.py258
-rwxr-xr-xscripts/run_evaluation.sh58
-rwxr-xr-xscripts/run_full_experiment.sh106
-rwxr-xr-xscripts/run_training.sh50
-rwxr-xr-xscripts/setup_env.sh79
-rwxr-xr-xscripts/slurm_train.sh145
-rwxr-xr-xscripts/submit_all_jobs.sh66
-rwxr-xr-xscripts/submit_single_job.sh32
-rw-r--r--scripts/test_quick.sh78
-rw-r--r--train_rlvr.py849
-rw-r--r--utils_bf16_sparsity.py459
-rw-r--r--utils_kl.py419
-rw-r--r--utils_math_eval.py367
23 files changed, 5665 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..a76904f
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,5 @@
+__pycache__/
+*.pyc
+data/
+results/
+.claude/
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..f1695f6
--- /dev/null
+++ b/README.md
@@ -0,0 +1,196 @@
+# RLVR Floating-Point Precision Experiment
+
+This repository implements experiments to study the effects of floating-point precision (FP32 vs bf16) on RLVR (Reinforcement Learning with Verifiable Rewards) training.
+
+## Overview
+
+The experiment aims to verify three key hypotheses based on the RLVR Three-Gate Theory:
+
+1. **On-task performance is insensitive to floating-point noise**: Training task performance should be similar between FP32 and bf16 precision modes.
+
+2. **Off-task performance is sensitive to floating-point noise**: Out-of-distribution task performance should show higher variance in bf16 mode due to numerical noise accumulation.
+
+3. **KL divergence patterns differ by task type**: On-task KL should be constrained by DAPO's implicit leash (Gate I), while off-task KL may drift more in bf16 mode.
+
+## Project Structure
+
+```
+rl-floating-noise/
+├── config.py # Configuration definitions
+├── train_rlvr.py # Training script with DAPO algorithm
+├── eval_policy.py # Evaluation script (J_k, KL_k)
+├── utils_math_eval.py # Math answer verification utilities
+├── utils_kl.py # KL divergence computation utilities
+├── utils_bf16_sparsity.py # bf16-aware update sparsity analysis
+├── run_experiments.py # Experiment orchestration script
+├── analyze_results.py # Results analysis and hypothesis testing
+├── requirements.txt # Python dependencies
+├── configs/
+│ └── eval_tasks_config.json # Evaluation task configurations
+├── scripts/
+│ ├── prepare_data.py # Dataset preparation script
+│ ├── run_training.sh # Single training job script
+│ ├── run_evaluation.sh # Single evaluation job script
+│ └── run_full_experiment.sh # Full experiment pipeline
+├── data/ # Training and evaluation datasets
+└── results/ # Experiment outputs
+ ├── train_logs/ # Training checkpoints and logs
+ ├── eval_metrics/ # Evaluation results
+ └── analysis/ # Analysis outputs and plots
+```
+
+## Installation
+
+```bash
+# Create conda environment
+conda create -n rlvr-fp python=3.10 -y
+conda activate rlvr-fp
+
+# Install dependencies
+pip install -r requirements.txt
+
+# Install VeRL (optional, for full DAPO implementation)
+pip install git+https://github.com/volcengine/verl.git
+```
+
+## Quick Start
+
+### 1. Prepare Data
+
+Generate sample datasets for development:
+
+```bash
+python scripts/prepare_data.py --output_dir ./data
+```
+
+For production experiments, download the actual datasets:
+- DM (DAPO-Math-17k + MATH)
+- AIME24, AIME25, AMC23, MATH-500
+- GSM8K, MMLU-STEM, HumanEval
+
+### 2. Run Single Training Job
+
+```bash
+# Train with bf16 precision
+python train_rlvr.py \
+ --precision_mode bf16 \
+ --seed 1 \
+ --output_dir results/train_logs/bf16_seed1 \
+ --train_dataset_path data/dm_train.json
+
+# Train with fp32 precision
+python train_rlvr.py \
+ --precision_mode fp32 \
+ --seed 1 \
+ --output_dir results/train_logs/fp32_seed1 \
+ --train_dataset_path data/dm_train.json
+```
+
+### 3. Run Evaluation
+
+```bash
+python eval_policy.py \
+ --base_ckpt Qwen/Qwen2.5-Math-7B \
+ --ft_ckpt results/train_logs/bf16_seed1/final_model \
+ --eval_tasks_config configs/eval_tasks_config.json \
+ --output_path results/eval_metrics/bf16_seed1.json \
+ --eval_base
+```
+
+### 4. Run Full Experiment
+
+```bash
+# Run complete experiment pipeline
+bash scripts/run_full_experiment.sh
+
+# Or use Python orchestrator
+python run_experiments.py --mode full --seeds 1 2 3 4 5
+```
+
+### 5. Analyze Results
+
+```bash
+python analyze_results.py \
+ --results_dir results/eval_metrics \
+ --output_dir results/analysis \
+ --on_task dm_val \
+ --off_task aime24 aime25 amc23 math500 mmlu_stem humaneval
+```
+
+## Precision Configurations
+
+### P-high (FP32)
+- Master weights stored in FP32
+- Deterministic algorithms enabled
+- Dropout disabled
+- Minimal numerical noise
+
+### P-bf16 (Default RLVR)
+- Master weights stored in bf16
+- Non-deterministic algorithms
+- Dropout enabled
+- Higher numerical noise (Gate III effects)
+
+## Key Metrics
+
+### Performance (J_k)
+- Pass@1 accuracy for verifiable tasks
+- Computed via Monte Carlo sampling
+
+### Performance Delta (ΔJ_k)
+```
+ΔJ_k = J_k(θ_T) - J_k(θ_0)
+```
+
+### KL Divergence
+```
+KL_k ≈ E_x E_y~π_θ [log π_θ(y|x) - log π_0(y|x)]
+```
+
+### bf16 Update Sparsity
+```
+sparsity = 1 - |{i: |w_i - w'_i| > η·max(|w_i|, |w'_i|)}| / n
+```
+
+## Expected Results
+
+Based on RLVR theory predictions:
+
+| Metric | On-task | Off-task |
+|--------|---------|----------|
+| E[ΔJ] difference | Small (~0) | Variable |
+| Var[ΔJ] (bf16 vs fp32) | Similar | bf16 >> fp32 |
+| KL divergence | Constrained | Higher variance |
+| bf16 sparsity | 36-92% | - |
+
+## Configuration
+
+### Training Hyperparameters (RLVR defaults)
+- Model: Qwen2.5-Math-7B
+- Algorithm: DAPO (clip-only, β=0)
+- Batch size: 256
+- Learning rate: 1e-6
+- Training steps: 300
+- Rollouts per prompt: 16
+
+### Evaluation Settings
+- Temperature: 0.7
+- Top-p: 0.8
+- Max generation length: 2048-4096 (task-dependent)
+
+## Citation
+
+If you use this code, please cite the RLVR paper:
+
+```bibtex
+@article{rlvr2024,
+ title={Reinforcement Learning with Verifiable Rewards},
+ author={...},
+ year={2024}
+}
+```
+
+## License
+
+MIT License
+
diff --git a/analyze_results.py b/analyze_results.py
new file mode 100644
index 0000000..265c005
--- /dev/null
+++ b/analyze_results.py
@@ -0,0 +1,741 @@
+#!/usr/bin/env python3
+# analyze_results.py
+"""
+Results Analysis for RLVR Floating-Point Precision Experiments.
+
+This script analyzes the experimental results to verify the hypotheses:
+1. On-task performance is insensitive to floating-point noise
+2. Off-task performance and KL divergence are sensitive to precision
+
+Computes:
+- Mean and variance of ΔJ_k for each precision mode
+- KL divergence patterns for on-task vs off-task
+- Statistical tests for significance
+- Visualizations
+
+Usage:
+ python analyze_results.py \
+ --results_dir results/eval_metrics \
+ --output_dir results/analysis
+"""
+
+import argparse
+import json
+import os
+import glob
+from typing import Dict, Any, List, Tuple, Optional
+from collections import defaultdict
+import logging
+
+import numpy as np
+from scipy import stats
+
+# Optional: for visualizations
+try:
+ import matplotlib.pyplot as plt
+ import matplotlib
+ matplotlib.use('Agg') # Non-interactive backend
+ HAS_MATPLOTLIB = True
+except ImportError:
+ HAS_MATPLOTLIB = False
+
+logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+)
+logger = logging.getLogger(__name__)
+
+
+# ============================================================================
+# Data Loading
+# ============================================================================
+
+def load_eval_results(results_dir: str) -> Dict[str, Dict[str, Any]]:
+ """
+ Load all evaluation results from a directory.
+
+ Returns:
+ Dictionary mapping "{precision_mode}_seed{seed}" to results
+ """
+ results = {}
+
+ pattern = os.path.join(results_dir, "*.json")
+ for filepath in glob.glob(pattern):
+ filename = os.path.basename(filepath)
+
+ # Skip sparsity files
+ if "sparsity" in filename:
+ continue
+
+ # Parse filename: {precision_mode}_seed{seed}.json
+ name = filename.replace(".json", "")
+
+ try:
+ with open(filepath, "r") as f:
+ data = json.load(f)
+ results[name] = data
+ logger.info(f"Loaded {filepath}")
+ except Exception as e:
+ logger.warning(f"Failed to load {filepath}: {e}")
+
+ return results
+
+
+def load_sparsity_results(results_dir: str) -> Dict[str, Dict[str, Any]]:
+ """Load bf16 sparsity analysis results."""
+ results = {}
+
+ pattern = os.path.join(results_dir, "*_sparsity.json")
+ for filepath in glob.glob(pattern):
+ filename = os.path.basename(filepath)
+ name = filename.replace("_sparsity.json", "")
+
+ try:
+ with open(filepath, "r") as f:
+ data = json.load(f)
+ results[name] = data
+ logger.info(f"Loaded sparsity: {filepath}")
+ except Exception as e:
+ logger.warning(f"Failed to load {filepath}: {e}")
+
+ return results
+
+
+def parse_run_name(name: str) -> Tuple[str, int]:
+ """Parse run name to get precision mode and seed."""
+ # Format: {precision_mode}_seed{seed}
+ parts = name.split("_seed")
+ if len(parts) == 2:
+ precision_mode = parts[0]
+ seed = int(parts[1])
+ return precision_mode, seed
+
+ raise ValueError(f"Cannot parse run name: {name}")
+
+
+# ============================================================================
+# Metrics Aggregation
+# ============================================================================
+
+def aggregate_by_precision(
+ results: Dict[str, Dict[str, Any]]
+) -> Dict[str, Dict[str, List[float]]]:
+ """
+ Aggregate results by precision mode.
+
+ Returns:
+ {precision_mode: {task_name: [scores from different seeds]}}
+ """
+ aggregated: Dict[str, Dict[str, List[float]]] = defaultdict(
+ lambda: defaultdict(list)
+ )
+
+ for run_name, run_results in results.items():
+ try:
+ precision_mode, seed = parse_run_name(run_name)
+ except ValueError:
+ continue
+
+ tasks = run_results.get("tasks", {})
+ for task_name, task_data in tasks.items():
+ # Score
+ if "ft_avg_score" in task_data:
+ aggregated[precision_mode][f"{task_name}_score"].append(
+ task_data["ft_avg_score"]
+ )
+
+ # Delta J
+ if "delta_j" in task_data:
+ aggregated[precision_mode][f"{task_name}_delta_j"].append(
+ task_data["delta_j"]
+ )
+
+ # KL
+ if "avg_kl" in task_data:
+ aggregated[precision_mode][f"{task_name}_kl"].append(
+ task_data["avg_kl"]
+ )
+
+ return dict(aggregated)
+
+
+def compute_statistics(values: List[float]) -> Dict[str, float]:
+ """Compute statistics for a list of values."""
+ if not values:
+ return {
+ "mean": 0.0,
+ "std": 0.0,
+ "var": 0.0,
+ "min": 0.0,
+ "max": 0.0,
+ "n": 0,
+ }
+
+ arr = np.array(values)
+ return {
+ "mean": float(np.mean(arr)),
+ "std": float(np.std(arr)),
+ "var": float(np.var(arr)),
+ "min": float(np.min(arr)),
+ "max": float(np.max(arr)),
+ "n": len(arr),
+ }
+
+
+# ============================================================================
+# Hypothesis Testing
+# ============================================================================
+
+def test_variance_ratio(
+ values1: List[float],
+ values2: List[float],
+ alpha: float = 0.05
+) -> Dict[str, Any]:
+ """
+ Test if variance of values2 is significantly greater than values1.
+
+ Uses F-test (Levene's test is more robust but F-test is simpler).
+ """
+ if len(values1) < 2 or len(values2) < 2:
+ return {
+ "test": "variance_ratio",
+ "valid": False,
+ "reason": "Not enough samples",
+ }
+
+ var1 = np.var(values1, ddof=1)
+ var2 = np.var(values2, ddof=1)
+
+ # F statistic
+ if var1 > 0:
+ f_stat = var2 / var1
+ else:
+ f_stat = float("inf")
+
+ df1 = len(values2) - 1
+ df2 = len(values1) - 1
+
+ # p-value (one-tailed: var2 > var1)
+ p_value = 1 - stats.f.cdf(f_stat, df1, df2)
+
+ return {
+ "test": "variance_ratio",
+ "valid": True,
+ "var1": var1,
+ "var2": var2,
+ "f_statistic": f_stat,
+ "p_value": p_value,
+ "significant": p_value < alpha,
+ "alpha": alpha,
+ }
+
+
+def test_mean_difference(
+ values1: List[float],
+ values2: List[float],
+ alpha: float = 0.05
+) -> Dict[str, Any]:
+ """
+ Test if means are significantly different (two-tailed t-test).
+ """
+ if len(values1) < 2 or len(values2) < 2:
+ return {
+ "test": "mean_difference",
+ "valid": False,
+ "reason": "Not enough samples",
+ }
+
+ t_stat, p_value = stats.ttest_ind(values1, values2)
+
+ return {
+ "test": "mean_difference",
+ "valid": True,
+ "mean1": float(np.mean(values1)),
+ "mean2": float(np.mean(values2)),
+ "t_statistic": float(t_stat),
+ "p_value": float(p_value),
+ "significant": p_value < alpha,
+ "alpha": alpha,
+ }
+
+
+# ============================================================================
+# Hypothesis Verification
+# ============================================================================
+
+def verify_hypothesis_1(
+ aggregated: Dict[str, Dict[str, List[float]]],
+ on_task_names: List[str]
+) -> Dict[str, Any]:
+ """
+ Verify Hypothesis 1: On-task performance is insensitive to precision.
+
+ Expected:
+ - E[ΔJ_0^{fp32}] ≈ E[ΔJ_0^{bf16}] > 0
+ - Var[ΔJ_0^{fp32}] and Var[ΔJ_0^{bf16}] are both small
+ """
+ results = {
+ "hypothesis": "On-task performance insensitive to precision",
+ "tasks": {},
+ }
+
+ fp32_data = aggregated.get("fp32", {})
+ bf16_data = aggregated.get("bf16", {})
+
+ for task_name in on_task_names:
+ key = f"{task_name}_delta_j"
+
+ fp32_values = fp32_data.get(key, [])
+ bf16_values = bf16_data.get(key, [])
+
+ task_result = {
+ "fp32": compute_statistics(fp32_values),
+ "bf16": compute_statistics(bf16_values),
+ }
+
+ # Test mean difference (should NOT be significant)
+ mean_test = test_mean_difference(fp32_values, bf16_values)
+ task_result["mean_test"] = mean_test
+
+ # Test variance ratio (should NOT show bf16 >> fp32)
+ var_test = test_variance_ratio(fp32_values, bf16_values)
+ task_result["variance_test"] = var_test
+
+ # Verify expected pattern
+ fp32_mean = task_result["fp32"]["mean"]
+ bf16_mean = task_result["bf16"]["mean"]
+ fp32_var = task_result["fp32"]["var"]
+ bf16_var = task_result["bf16"]["var"]
+
+ task_result["verification"] = {
+ "means_similar": abs(fp32_mean - bf16_mean) < 0.05, # Within 5%
+ "both_positive": fp32_mean > 0 and bf16_mean > 0,
+ "variances_small": fp32_var < 0.01 and bf16_var < 0.01,
+ "hypothesis_supported": (
+ abs(fp32_mean - bf16_mean) < 0.05 and
+ fp32_mean > 0 and bf16_mean > 0
+ ),
+ }
+
+ results["tasks"][task_name] = task_result
+
+ # Overall verdict
+ all_supported = all(
+ t["verification"]["hypothesis_supported"]
+ for t in results["tasks"].values()
+ )
+ results["overall_supported"] = all_supported
+
+ return results
+
+
+def verify_hypothesis_2(
+ aggregated: Dict[str, Dict[str, List[float]]],
+ off_task_names: List[str],
+ on_task_names: List[str]
+) -> Dict[str, Any]:
+ """
+ Verify Hypothesis 2: Off-task performance is sensitive to precision.
+
+ Expected:
+ - Var[ΔJ_k^{bf16}] >> Var[ΔJ_k^{fp32}] >> Var[ΔJ_0]
+ """
+ results = {
+ "hypothesis": "Off-task performance sensitive to precision",
+ "tasks": {},
+ }
+
+ fp32_data = aggregated.get("fp32", {})
+ bf16_data = aggregated.get("bf16", {})
+
+ # Get on-task variance for comparison
+ on_task_variances = []
+ for task_name in on_task_names:
+ key = f"{task_name}_delta_j"
+ for precision_data in [fp32_data, bf16_data]:
+ values = precision_data.get(key, [])
+ if values:
+ on_task_variances.append(np.var(values))
+
+ on_task_avg_var = np.mean(on_task_variances) if on_task_variances else 0.0
+
+ for task_name in off_task_names:
+ key = f"{task_name}_delta_j"
+
+ fp32_values = fp32_data.get(key, [])
+ bf16_values = bf16_data.get(key, [])
+
+ task_result = {
+ "fp32": compute_statistics(fp32_values),
+ "bf16": compute_statistics(bf16_values),
+ "on_task_avg_var": on_task_avg_var,
+ }
+
+ # Test variance ratio (bf16 should be >> fp32)
+ var_test = test_variance_ratio(fp32_values, bf16_values)
+ task_result["variance_test"] = var_test
+
+ # Verify expected pattern
+ fp32_var = task_result["fp32"]["var"]
+ bf16_var = task_result["bf16"]["var"]
+
+ task_result["verification"] = {
+ "bf16_var_gt_fp32": bf16_var > fp32_var,
+ "bf16_var_gt_fp32_by_5x": bf16_var > 5 * fp32_var if fp32_var > 0 else False,
+ "fp32_var_gt_ontask": fp32_var > on_task_avg_var,
+ "variance_ratio": bf16_var / fp32_var if fp32_var > 0 else float("inf"),
+ "hypothesis_supported": bf16_var > fp32_var,
+ }
+
+ results["tasks"][task_name] = task_result
+
+ # Count how many tasks show expected pattern
+ supported_count = sum(
+ 1 for t in results["tasks"].values()
+ if t["verification"]["hypothesis_supported"]
+ )
+ results["num_tasks_supported"] = supported_count
+ results["num_tasks_total"] = len(off_task_names)
+ results["overall_supported"] = supported_count > len(off_task_names) // 2
+
+ return results
+
+
+def verify_hypothesis_3(
+ aggregated: Dict[str, Dict[str, List[float]]],
+ on_task_names: List[str],
+ off_task_names: List[str]
+) -> Dict[str, Any]:
+ """
+ Verify Hypothesis 3: KL divergence patterns.
+
+ Expected:
+ - On-task KL is similar between fp32 and bf16 (DAPO implicit leash)
+ - Off-task KL has higher variance in bf16
+ """
+ results = {
+ "hypothesis": "KL divergence patterns differ by task type",
+ "on_task": {},
+ "off_task": {},
+ }
+
+ fp32_data = aggregated.get("fp32", {})
+ bf16_data = aggregated.get("bf16", {})
+
+ # On-task KL analysis
+ for task_name in on_task_names:
+ key = f"{task_name}_kl"
+
+ fp32_values = fp32_data.get(key, [])
+ bf16_values = bf16_data.get(key, [])
+
+ task_result = {
+ "fp32": compute_statistics(fp32_values),
+ "bf16": compute_statistics(bf16_values),
+ }
+
+ mean_test = test_mean_difference(fp32_values, bf16_values)
+ task_result["mean_test"] = mean_test
+
+ # Verify: KL should be similar (implicit leash working)
+ task_result["kl_similar"] = not mean_test.get("significant", True)
+
+ results["on_task"][task_name] = task_result
+
+ # Off-task KL analysis
+ for task_name in off_task_names:
+ key = f"{task_name}_kl"
+
+ fp32_values = fp32_data.get(key, [])
+ bf16_values = bf16_data.get(key, [])
+
+ task_result = {
+ "fp32": compute_statistics(fp32_values),
+ "bf16": compute_statistics(bf16_values),
+ }
+
+ var_test = test_variance_ratio(fp32_values, bf16_values)
+ task_result["variance_test"] = var_test
+
+ # Verify: bf16 should have higher variance
+ task_result["bf16_higher_variance"] = var_test.get("significant", False)
+
+ results["off_task"][task_name] = task_result
+
+ # Overall assessment
+ on_task_similar = all(
+ t.get("kl_similar", False) for t in results["on_task"].values()
+ )
+ off_task_variance_higher = sum(
+ 1 for t in results["off_task"].values()
+ if t.get("bf16_higher_variance", False)
+ )
+
+ results["summary"] = {
+ "on_task_kl_similar": on_task_similar,
+ "off_task_higher_variance_count": off_task_variance_higher,
+ "off_task_total": len(off_task_names),
+ }
+
+ return results
+
+
+# ============================================================================
+# Visualization
+# ============================================================================
+
+def plot_delta_j_comparison(
+ aggregated: Dict[str, Dict[str, List[float]]],
+ task_names: List[str],
+ output_path: str
+) -> None:
+ """Plot ΔJ comparison between precision modes."""
+ if not HAS_MATPLOTLIB:
+ logger.warning("matplotlib not available, skipping plot")
+ return
+
+ fig, ax = plt.subplots(figsize=(12, 6))
+
+ x = np.arange(len(task_names))
+ width = 0.35
+
+ fp32_data = aggregated.get("fp32", {})
+ bf16_data = aggregated.get("bf16", {})
+
+ fp32_means = []
+ fp32_stds = []
+ bf16_means = []
+ bf16_stds = []
+
+ for task_name in task_names:
+ key = f"{task_name}_delta_j"
+
+ fp32_values = fp32_data.get(key, [0])
+ bf16_values = bf16_data.get(key, [0])
+
+ fp32_means.append(np.mean(fp32_values))
+ fp32_stds.append(np.std(fp32_values))
+ bf16_means.append(np.mean(bf16_values))
+ bf16_stds.append(np.std(bf16_values))
+
+ ax.bar(x - width/2, fp32_means, width, yerr=fp32_stds,
+ label='FP32', color='steelblue', capsize=5)
+ ax.bar(x + width/2, bf16_means, width, yerr=bf16_stds,
+ label='bf16', color='coral', capsize=5)
+
+ ax.set_ylabel('ΔJ (Performance Delta)')
+ ax.set_xlabel('Task')
+ ax.set_title('Performance Delta by Precision Mode')
+ ax.set_xticks(x)
+ ax.set_xticklabels(task_names, rotation=45, ha='right')
+ ax.legend()
+ ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+
+ plt.tight_layout()
+ plt.savefig(output_path, dpi=150)
+ plt.close()
+
+ logger.info(f"Saved plot to {output_path}")
+
+
+def plot_variance_comparison(
+ aggregated: Dict[str, Dict[str, List[float]]],
+ task_names: List[str],
+ output_path: str
+) -> None:
+ """Plot variance comparison between precision modes."""
+ if not HAS_MATPLOTLIB:
+ logger.warning("matplotlib not available, skipping plot")
+ return
+
+ fig, ax = plt.subplots(figsize=(12, 6))
+
+ x = np.arange(len(task_names))
+ width = 0.35
+
+ fp32_data = aggregated.get("fp32", {})
+ bf16_data = aggregated.get("bf16", {})
+
+ fp32_vars = []
+ bf16_vars = []
+
+ for task_name in task_names:
+ key = f"{task_name}_delta_j"
+
+ fp32_values = fp32_data.get(key, [0])
+ bf16_values = bf16_data.get(key, [0])
+
+ fp32_vars.append(np.var(fp32_values))
+ bf16_vars.append(np.var(bf16_values))
+
+ ax.bar(x - width/2, fp32_vars, width, label='FP32', color='steelblue')
+ ax.bar(x + width/2, bf16_vars, width, label='bf16', color='coral')
+
+ ax.set_ylabel('Variance of ΔJ')
+ ax.set_xlabel('Task')
+ ax.set_title('Variance of Performance Delta by Precision Mode')
+ ax.set_xticks(x)
+ ax.set_xticklabels(task_names, rotation=45, ha='right')
+ ax.legend()
+ ax.set_yscale('log')
+
+ plt.tight_layout()
+ plt.savefig(output_path, dpi=150)
+ plt.close()
+
+ logger.info(f"Saved plot to {output_path}")
+
+
+# ============================================================================
+# Main Analysis
+# ============================================================================
+
+def parse_args() -> argparse.Namespace:
+ """Parse command line arguments."""
+ parser = argparse.ArgumentParser(
+ description="Analyze RLVR floating-point precision experiment results"
+ )
+ parser.add_argument(
+ "--results_dir",
+ type=str,
+ required=True,
+ help="Directory containing evaluation results"
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ required=True,
+ help="Directory to save analysis outputs"
+ )
+ parser.add_argument(
+ "--on_task",
+ type=str,
+ nargs="+",
+ default=["dm_val"],
+ help="On-task (training distribution) task names"
+ )
+ parser.add_argument(
+ "--off_task",
+ type=str,
+ nargs="+",
+ default=["aime24", "aime25", "amc23", "math500", "mmlu_stem", "humaneval"],
+ help="Off-task task names"
+ )
+ return parser.parse_args()
+
+
+def main() -> None:
+ """Main analysis function."""
+ args = parse_args()
+
+ # Create output directory
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # Load results
+ logger.info(f"Loading results from {args.results_dir}")
+ eval_results = load_eval_results(args.results_dir)
+ sparsity_results = load_sparsity_results(args.results_dir)
+
+ if not eval_results:
+ logger.error("No evaluation results found!")
+ return
+
+ # Aggregate by precision mode
+ aggregated = aggregate_by_precision(eval_results)
+
+ # Get all task names from results
+ all_tasks = set()
+ for run_results in eval_results.values():
+ all_tasks.update(run_results.get("tasks", {}).keys())
+
+ # Filter task names to those present in results
+ on_task_names = [t for t in args.on_task if t in all_tasks]
+ off_task_names = [t for t in args.off_task if t in all_tasks]
+
+ logger.info(f"On-task: {on_task_names}")
+ logger.info(f"Off-task: {off_task_names}")
+
+ # Verify hypotheses
+ analysis = {}
+
+ logger.info("\n" + "="*60)
+ logger.info("HYPOTHESIS 1: On-task insensitivity")
+ logger.info("="*60)
+ h1_result = verify_hypothesis_1(aggregated, on_task_names)
+ analysis["hypothesis_1"] = h1_result
+ logger.info(f"Supported: {h1_result['overall_supported']}")
+
+ logger.info("\n" + "="*60)
+ logger.info("HYPOTHESIS 2: Off-task sensitivity")
+ logger.info("="*60)
+ h2_result = verify_hypothesis_2(aggregated, off_task_names, on_task_names)
+ analysis["hypothesis_2"] = h2_result
+ logger.info(f"Supported: {h2_result['overall_supported']} "
+ f"({h2_result['num_tasks_supported']}/{h2_result['num_tasks_total']} tasks)")
+
+ logger.info("\n" + "="*60)
+ logger.info("HYPOTHESIS 3: KL divergence patterns")
+ logger.info("="*60)
+ h3_result = verify_hypothesis_3(aggregated, on_task_names, off_task_names)
+ analysis["hypothesis_3"] = h3_result
+ logger.info(f"On-task KL similar: {h3_result['summary']['on_task_kl_similar']}")
+
+ # Add sparsity analysis
+ if sparsity_results:
+ sparsity_summary = {}
+ for name, data in sparsity_results.items():
+ sparsity_info = data.get("sparsity", {})
+ sparsity_summary[name] = {
+ "sparsity_percent": sparsity_info.get("sparsity_percent", 0),
+ "num_changed": sparsity_info.get("num_changed", 0),
+ }
+ analysis["bf16_sparsity"] = sparsity_summary
+
+ # Save full analysis
+ analysis_path = os.path.join(args.output_dir, "full_analysis.json")
+ with open(analysis_path, "w") as f:
+ json.dump(analysis, f, indent=2, default=str)
+ logger.info(f"Saved analysis to {analysis_path}")
+
+ # Generate plots
+ all_task_names = on_task_names + off_task_names
+
+ plot_delta_j_comparison(
+ aggregated,
+ all_task_names,
+ os.path.join(args.output_dir, "delta_j_comparison.png")
+ )
+
+ plot_variance_comparison(
+ aggregated,
+ all_task_names,
+ os.path.join(args.output_dir, "variance_comparison.png")
+ )
+
+ # Print summary
+ print("\n" + "="*80)
+ print("ANALYSIS SUMMARY")
+ print("="*80)
+
+ print("\nHypothesis 1 (On-task insensitivity):")
+ print(f" Supported: {h1_result['overall_supported']}")
+
+ print("\nHypothesis 2 (Off-task sensitivity):")
+ print(f" Supported: {h2_result['overall_supported']}")
+ print(f" Tasks showing expected pattern: {h2_result['num_tasks_supported']}/{h2_result['num_tasks_total']}")
+
+ print("\nHypothesis 3 (KL patterns):")
+ print(f" On-task KL similar across precision: {h3_result['summary']['on_task_kl_similar']}")
+ print(f" Off-task with higher bf16 variance: {h3_result['summary']['off_task_higher_variance_count']}/{h3_result['summary']['off_task_total']}")
+
+ if sparsity_results:
+ print("\nbf16 Sparsity:")
+ for name, data in sorted(sparsity_summary.items()):
+ print(f" {name}: {data['sparsity_percent']:.1f}% sparse")
+
+ print("="*80)
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/config.py b/config.py
new file mode 100644
index 0000000..898776c
--- /dev/null
+++ b/config.py
@@ -0,0 +1,328 @@
+# config.py
+"""
+Configuration definitions for RLVR floating-point precision experiments.
+
+This module defines configurations for:
+- Training parameters (DAPO algorithm, hyperparameters)
+- Precision settings (FP32 vs bf16)
+- Evaluation task specifications
+"""
+
+from dataclasses import dataclass, field
+from typing import List, Dict, Any, Optional
+import os
+
+
+@dataclass
+class TrainingConfig:
+ """Configuration for RLVR training with DAPO algorithm."""
+
+ # Model specification
+ model_name: str = "Qwen/Qwen2.5-Math-7B"
+
+ # Precision mode: "fp32" for high precision, "bf16" for default RLVR
+ precision_mode: str = "bf16"
+
+ # Batch configuration (sized for single H200 GPU with gradient checkpointing)
+ global_batch_size: int = 32 # Reduced from 256 for single GPU
+ micro_batch_size: int = 4 # Reduced further for fp32 memory safety
+ grad_accumulation_steps: int = 8 # Increased to maintain effective batch size
+
+ # Rollout configuration
+ num_rollouts_per_prompt: int = 4 # Reduced from 16 for speed
+ max_seq_len: int = 2048 # Reduced from 8192 (GSM8K answers are short)
+
+ # Training steps and checkpointing
+ # Note: With sequential generation, each step takes ~18 min on H200
+ # 150 steps ≈ 45 hours, fits in 2-day limit with buffer
+ num_steps: int = 150
+ checkpoint_steps: List[int] = field(default_factory=lambda: [0, 50, 100, 150])
+
+ # Optimizer configuration (AdamW)
+ learning_rate: float = 1e-6
+ beta1: float = 0.9
+ beta2: float = 0.999
+ weight_decay: float = 0.01
+
+ # RL algorithm
+ rl_algorithm: str = "dapo"
+ clip_ratio: float = 0.2 # DAPO clip parameter
+ kl_coef: float = 0.0 # DAPO uses clip-only, no explicit KL penalty
+
+ # Reproducibility
+ # IMPORTANT: Keep these constant across precision modes to isolate precision effects
+ # We maximize determinism so the ONLY variance source is floating-point precision
+ seed: int = 1
+ use_dropout: bool = False # Disabled to reduce stochasticity
+ use_deterministic_algorithms: bool = True # Enabled for reproducibility
+
+ # Paths
+ output_dir: str = "./results/train_logs"
+ train_dataset_path: str = "./data/dm_train.json"
+
+ # GPU configuration (single GPU for this implementation)
+ num_gpus: int = 1 # Current implementation is single-GPU
+
+ def __post_init__(self):
+ """
+ Note: We intentionally keep dropout and determinism settings CONSTANT
+ across precision modes to isolate the effect of floating-point precision.
+
+ Previously this coupled dropout=False with fp32 and dropout=True with bf16,
+ which confounded precision effects with stochasticity effects.
+
+ To study pure precision effects:
+ - Both modes use the SAME dropout setting (default: False)
+ - Both modes use the SAME determinism setting (default: False for speed)
+
+ The only difference between fp32 and bf16 should be param_dtype.
+ """
+ # Don't modify settings based on precision_mode - keep them independent
+ pass
+
+
+@dataclass
+class PrecisionConfig:
+ """Configuration for floating-point precision settings."""
+
+ # Parameter storage dtype
+ param_dtype: str = "bfloat16" # "float32" or "bfloat16"
+
+ # Automatic mixed precision
+ use_amp: bool = True
+ amp_dtype: str = "bfloat16" # "float16" or "bfloat16"
+
+ # Gradient and optimizer state always in FP32
+ grad_dtype: str = "float32"
+ optimizer_dtype: str = "float32"
+
+ # Deterministic algorithms
+ deterministic: bool = False
+
+ # CUDNN settings
+ cudnn_benchmark: bool = True
+ cudnn_deterministic: bool = False
+
+
+@dataclass
+class EvalTaskConfig:
+ """Configuration for a single evaluation task."""
+
+ # Task identification
+ name: str = ""
+ task_type: str = "math" # "math", "code", "qa", "general"
+
+ # Dataset
+ dataset_path: str = ""
+ num_samples: int = -1 # -1 means use all samples
+
+ # Whether task has verifiable answers (math problems)
+ is_verifiable: bool = True
+
+ # Metric type for non-verifiable tasks
+ metric_type: str = "accuracy" # "accuracy", "bleu", "rouge", "score"
+
+ # Generation parameters
+ max_gen_len: int = 2048
+ temperature: float = 0.7
+ top_p: float = 0.8
+ num_samples_per_prompt: int = 1
+
+
+@dataclass
+class ExperimentConfig:
+ """Master configuration for the entire experiment."""
+
+ # Experiment identification
+ experiment_name: str = "fp_precision_rlvr"
+
+ # Seeds for multiple runs
+ seeds: List[int] = field(default_factory=lambda: [1, 2, 3, 4, 5])
+
+ # Precision modes to compare
+ precision_modes: List[str] = field(default_factory=lambda: ["fp32", "bf16"])
+
+ # Base model checkpoint (shared starting point)
+ base_model_path: str = "Qwen/Qwen2.5-Math-7B"
+
+ # Output directories
+ base_output_dir: str = "./results"
+ train_logs_dir: str = "./results/train_logs"
+ checkpoints_dir: str = "./results/checkpoints"
+ eval_metrics_dir: str = "./results/eval_metrics"
+
+ # Evaluation configuration
+ eval_tasks_config_path: str = "./configs/eval_tasks_config.json"
+
+ # bf16 sparsity analysis
+ bf16_sparsity_eta: float = 1e-3
+
+
+def make_precision_config(precision_mode: str) -> PrecisionConfig:
+ """
+ Create precision configuration based on mode.
+
+ IMPORTANT: Only the precision-related settings differ between modes.
+ All other settings (determinism, cudnn) are kept CONSTANT to isolate
+ the effect of floating-point precision on training outcomes.
+
+ Args:
+ precision_mode: "fp32" for high precision, "bf16" for default RLVR
+
+ Returns:
+ PrecisionConfig with appropriate settings
+ """
+ # Common settings for both modes (to avoid confounds)
+ # Maximize determinism so precision is the ONLY source of variance
+ common_settings = {
+ "grad_dtype": "float32",
+ "optimizer_dtype": "float32",
+ "deterministic": True, # Enable deterministic algorithms
+ "cudnn_benchmark": False, # Disable for reproducibility
+ "cudnn_deterministic": True, # Enable for reproducibility
+ }
+
+ if precision_mode == "fp32":
+ return PrecisionConfig(
+ param_dtype="float32",
+ use_amp=False, # No AMP needed for fp32
+ amp_dtype="float32",
+ **common_settings
+ )
+ elif precision_mode == "bf16":
+ return PrecisionConfig(
+ param_dtype="bfloat16",
+ use_amp=True,
+ amp_dtype="bfloat16",
+ **common_settings
+ )
+ else:
+ raise ValueError(f"Unknown precision_mode: {precision_mode}")
+
+
+def make_training_config(
+ precision_mode: str,
+ seed: int,
+ output_dir: str,
+ train_dataset_path: str,
+ model_name: str = "Qwen/Qwen2.5-Math-7B"
+) -> TrainingConfig:
+ """
+ Create training configuration for a specific run.
+
+ Args:
+ precision_mode: "fp32" or "bf16"
+ seed: Random seed for this run
+ output_dir: Directory to save outputs
+ train_dataset_path: Path to training data
+ model_name: HuggingFace model identifier
+
+ Returns:
+ TrainingConfig with all parameters set
+ """
+ config = TrainingConfig(
+ model_name=model_name,
+ precision_mode=precision_mode,
+ seed=seed,
+ output_dir=output_dir,
+ train_dataset_path=train_dataset_path
+ )
+ return config
+
+
+def get_run_output_dir(
+ base_dir: str,
+ precision_mode: str,
+ seed: int
+) -> str:
+ """Get output directory for a specific run."""
+ return os.path.join(base_dir, f"{precision_mode}_seed{seed}")
+
+
+def get_checkpoint_path(
+ output_dir: str,
+ step: Optional[int] = None
+) -> str:
+ """Get checkpoint path for a specific step (None = final)."""
+ if step is None:
+ return os.path.join(output_dir, "final_model")
+ return os.path.join(output_dir, f"checkpoint_step{step}")
+
+
+# Default evaluation tasks for the experiment
+DEFAULT_EVAL_TASKS = [
+ # On-task: Training distribution
+ EvalTaskConfig(
+ name="dm_val",
+ task_type="math",
+ dataset_path="./data/dm_val.json",
+ is_verifiable=True,
+ metric_type="accuracy",
+ max_gen_len=2048,
+ temperature=0.7,
+ top_p=0.8
+ ),
+ # In-domain OOD: Math benchmarks
+ EvalTaskConfig(
+ name="aime24",
+ task_type="math",
+ dataset_path="./data/aime24.json",
+ is_verifiable=True,
+ metric_type="accuracy",
+ max_gen_len=4096,
+ temperature=0.7,
+ top_p=0.8
+ ),
+ EvalTaskConfig(
+ name="aime25",
+ task_type="math",
+ dataset_path="./data/aime25.json",
+ is_verifiable=True,
+ metric_type="accuracy",
+ max_gen_len=4096,
+ temperature=0.7,
+ top_p=0.8
+ ),
+ EvalTaskConfig(
+ name="amc23",
+ task_type="math",
+ dataset_path="./data/amc23.json",
+ is_verifiable=True,
+ metric_type="accuracy",
+ max_gen_len=2048,
+ temperature=0.7,
+ top_p=0.8
+ ),
+ EvalTaskConfig(
+ name="math500",
+ task_type="math",
+ dataset_path="./data/math500.json",
+ is_verifiable=True,
+ metric_type="accuracy",
+ max_gen_len=2048,
+ temperature=0.7,
+ top_p=0.8
+ ),
+ # Off-domain: General tasks
+ EvalTaskConfig(
+ name="mmlu_stem",
+ task_type="qa",
+ dataset_path="./data/mmlu_stem.json",
+ is_verifiable=True,
+ metric_type="accuracy",
+ max_gen_len=512,
+ temperature=0.3,
+ top_p=0.9
+ ),
+ EvalTaskConfig(
+ name="humaneval",
+ task_type="code",
+ dataset_path="./data/humaneval.json",
+ is_verifiable=True,
+ metric_type="accuracy",
+ max_gen_len=1024,
+ temperature=0.2,
+ top_p=0.95
+ ),
+]
+
diff --git a/configs/deepspeed_zero2.json b/configs/deepspeed_zero2.json
new file mode 100644
index 0000000..bb7f7aa
--- /dev/null
+++ b/configs/deepspeed_zero2.json
@@ -0,0 +1,31 @@
+{
+ "train_batch_size": "auto",
+ "train_micro_batch_size_per_gpu": "auto",
+ "gradient_accumulation_steps": "auto",
+
+ "zero_optimization": {
+ "stage": 2,
+ "offload_optimizer": {
+ "device": "none"
+ },
+ "contiguous_gradients": true,
+ "overlap_comm": true,
+ "reduce_scatter": true,
+ "reduce_bucket_size": 5e8,
+ "allgather_bucket_size": 5e8
+ },
+
+ "bf16": {
+ "enabled": false
+ },
+
+ "fp16": {
+ "enabled": false
+ },
+
+ "gradient_clipping": 1.0,
+
+ "zero_allow_untested_optimizer": true,
+
+ "wall_clock_breakdown": false
+}
diff --git a/configs/deepspeed_zero3.json b/configs/deepspeed_zero3.json
new file mode 100644
index 0000000..6e68c8f
--- /dev/null
+++ b/configs/deepspeed_zero3.json
@@ -0,0 +1,38 @@
+{
+ "train_batch_size": "auto",
+ "train_micro_batch_size_per_gpu": "auto",
+ "gradient_accumulation_steps": "auto",
+
+ "zero_optimization": {
+ "stage": 3,
+ "offload_optimizer": {
+ "device": "none"
+ },
+ "offload_param": {
+ "device": "none"
+ },
+ "overlap_comm": true,
+ "contiguous_gradients": true,
+ "sub_group_size": 1e9,
+ "reduce_bucket_size": "auto",
+ "stage3_prefetch_bucket_size": "auto",
+ "stage3_param_persistence_threshold": "auto",
+ "stage3_max_live_parameters": 1e9,
+ "stage3_max_reuse_distance": 1e9,
+ "stage3_gather_16bit_weights_on_model_save": true
+ },
+
+ "bf16": {
+ "enabled": false
+ },
+
+ "fp16": {
+ "enabled": false
+ },
+
+ "gradient_clipping": 1.0,
+
+ "zero_allow_untested_optimizer": true,
+
+ "wall_clock_breakdown": false
+}
diff --git a/configs/eval_tasks_config.json b/configs/eval_tasks_config.json
new file mode 100644
index 0000000..e0dda43
--- /dev/null
+++ b/configs/eval_tasks_config.json
@@ -0,0 +1,99 @@
+[
+ {
+ "name": "dm_val",
+ "task_type": "math",
+ "dataset_path": "./data/dm_val.json",
+ "is_verifiable": true,
+ "metric_type": "accuracy",
+ "num_samples": -1,
+ "max_gen_len": 2048,
+ "temperature": 0.7,
+ "top_p": 0.8,
+ "num_samples_per_prompt": 1
+ },
+ {
+ "name": "aime24",
+ "task_type": "math",
+ "dataset_path": "./data/aime24.json",
+ "is_verifiable": true,
+ "metric_type": "accuracy",
+ "num_samples": -1,
+ "max_gen_len": 4096,
+ "temperature": 0.7,
+ "top_p": 0.8,
+ "num_samples_per_prompt": 1
+ },
+ {
+ "name": "aime25",
+ "task_type": "math",
+ "dataset_path": "./data/aime25.json",
+ "is_verifiable": true,
+ "metric_type": "accuracy",
+ "num_samples": -1,
+ "max_gen_len": 4096,
+ "temperature": 0.7,
+ "top_p": 0.8,
+ "num_samples_per_prompt": 1
+ },
+ {
+ "name": "amc23",
+ "task_type": "math",
+ "dataset_path": "./data/amc23.json",
+ "is_verifiable": true,
+ "metric_type": "accuracy",
+ "num_samples": -1,
+ "max_gen_len": 2048,
+ "temperature": 0.7,
+ "top_p": 0.8,
+ "num_samples_per_prompt": 1
+ },
+ {
+ "name": "math500",
+ "task_type": "math",
+ "dataset_path": "./data/math500.json",
+ "is_verifiable": true,
+ "metric_type": "accuracy",
+ "num_samples": 500,
+ "max_gen_len": 2048,
+ "temperature": 0.7,
+ "top_p": 0.8,
+ "num_samples_per_prompt": 1
+ },
+ {
+ "name": "gsm8k",
+ "task_type": "math",
+ "dataset_path": "./data/gsm8k.json",
+ "is_verifiable": true,
+ "metric_type": "accuracy",
+ "num_samples": 500,
+ "max_gen_len": 1024,
+ "temperature": 0.7,
+ "top_p": 0.8,
+ "num_samples_per_prompt": 1
+ },
+ {
+ "name": "mmlu_stem",
+ "task_type": "qa",
+ "dataset_path": "./data/mmlu_stem.json",
+ "is_verifiable": true,
+ "metric_type": "accuracy",
+ "num_samples": 500,
+ "max_gen_len": 512,
+ "temperature": 0.3,
+ "top_p": 0.9,
+ "num_samples_per_prompt": 1
+ },
+ {
+ "name": "humaneval",
+ "task_type": "code",
+ "dataset_path": "./data/humaneval.json",
+ "is_verifiable": true,
+ "metric_type": "accuracy",
+ "num_samples": 164,
+ "max_gen_len": 1024,
+ "temperature": 0.2,
+ "top_p": 0.95,
+ "num_samples_per_prompt": 1
+ }
+]
+
diff --git a/eval_policy.py b/eval_policy.py
new file mode 100644
index 0000000..cc30209
--- /dev/null
+++ b/eval_policy.py
@@ -0,0 +1,621 @@
+#!/usr/bin/env python3
+# eval_policy.py
+"""
+Policy Evaluation Script for RLVR Experiments.
+
+This script evaluates trained models on multiple tasks, computing:
+- J_k: Task performance (pass@1 accuracy for verifiable tasks)
+- KL_k: KL divergence from base model
+
+Usage:
+ python eval_policy.py \
+ --base_ckpt Qwen/Qwen2.5-Math-7B \
+ --ft_ckpt results/train_logs/fp32_seed1/final_model \
+ --eval_tasks_config configs/eval_tasks_config.json \
+ --output_path results/eval_metrics/fp32_seed1.json
+"""
+
+import argparse
+import json
+import os
+import logging
+from typing import Dict, Any, List, Optional, Tuple
+from dataclasses import dataclass, asdict
+
+import numpy as np
+import torch
+from torch.cuda.amp import autocast
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from tqdm import tqdm
+
+from config import EvalTaskConfig
+
+# Configure logging
+logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+)
+logger = logging.getLogger(__name__)
+
+
+# ============================================================================
+# Data Loading
+# ============================================================================
+
+def load_eval_tasks(eval_config_path: str) -> List[EvalTaskConfig]:
+ """Load evaluation task configurations from JSON file."""
+ with open(eval_config_path, "r", encoding="utf-8") as f:
+ data = json.load(f)
+
+ tasks: List[EvalTaskConfig] = []
+ for task_item in data:
+ task = EvalTaskConfig(
+ name=task_item.get("name", ""),
+ task_type=task_item.get("task_type", "math"),
+ dataset_path=task_item.get("dataset_path", ""),
+ is_verifiable=task_item.get("is_verifiable", True),
+ metric_type=task_item.get("metric_type", "accuracy"),
+ num_samples=task_item.get("num_samples", -1),
+ max_gen_len=task_item.get("max_gen_len", 2048),
+ temperature=task_item.get("temperature", 0.7),
+ top_p=task_item.get("top_p", 0.8),
+ num_samples_per_prompt=task_item.get("num_samples_per_prompt", 1)
+ )
+ tasks.append(task)
+
+ logger.info(f"Loaded {len(tasks)} evaluation tasks from {eval_config_path}")
+ return tasks
+
+
+def load_dataset(dataset_path: str, num_samples: int = -1) -> List[Dict[str, Any]]:
+ """Load evaluation dataset from JSON file."""
+ with open(dataset_path, "r", encoding="utf-8") as f:
+ data = json.load(f)
+
+ if num_samples > 0 and num_samples < len(data):
+ data = data[:num_samples]
+
+ logger.info(f"Loaded {len(data)} examples from {dataset_path}")
+ return data
+
+
+# ============================================================================
+# Answer Verification
+# ============================================================================
+
+def extract_boxed_answer(text: str) -> Optional[str]:
+ """Extract answer from \\boxed{} format."""
+ import re
+
+ # Find all \boxed{...} patterns
+ pattern = r"\\boxed\{([^}]*)\}"
+ matches = re.findall(pattern, text)
+
+ if matches:
+ return matches[-1].strip() # Return last match
+
+ return None
+
+
+def extract_final_answer(text: str) -> Optional[str]:
+ """Extract final answer using various heuristics."""
+ # Try boxed format first
+ boxed = extract_boxed_answer(text)
+ if boxed:
+ return boxed
+
+ # Common answer patterns
+ patterns = [
+ r"[Tt]he (?:final )?answer is[:\s]+(.+?)(?:\.|$)",
+ r"[Tt]herefore[,:\s]+(?:the answer is[:\s]+)?(.+?)(?:\.|$)",
+ r"[Ss]o[,:\s]+(?:the answer is[:\s]+)?(.+?)(?:\.|$)",
+ r"[Hh]ence[,:\s]+(.+?)(?:\.|$)",
+ r"=\s*(.+?)$",
+ ]
+
+ import re
+ for pattern in patterns:
+ match = re.search(pattern, text, re.MULTILINE)
+ if match:
+ return match.group(1).strip()
+
+ return None
+
+
+def normalize_answer(answer: str) -> str:
+ """Normalize answer for comparison."""
+ if answer is None:
+ return ""
+
+ # Convert to lowercase, remove whitespace
+ normalized = answer.lower().strip()
+
+ # Remove common formatting
+ normalized = normalized.replace(",", "")
+ normalized = normalized.replace("$", "")
+ normalized = normalized.replace("%", "")
+
+ # Try to extract numeric value
+ import re
+ numeric_match = re.search(r"-?\d+\.?\d*", normalized)
+ if numeric_match:
+ return numeric_match.group()
+
+ return normalized
+
+
+def verify_math_answer(
+ response: str,
+ ground_truth: str
+) -> bool:
+ """
+ Verify if the response contains the correct answer.
+
+ This is a simplified verifier. For production use, replace with
+ Eval-Chemy or a more sophisticated verification system.
+ """
+ # Extract answers
+ predicted = extract_final_answer(response)
+
+ if predicted is None:
+ return False
+
+ # Normalize for comparison
+ pred_normalized = normalize_answer(predicted)
+ gt_normalized = normalize_answer(ground_truth)
+
+ # Direct comparison
+ if pred_normalized == gt_normalized:
+ return True
+
+ # Try numeric comparison
+ try:
+ pred_num = float(pred_normalized)
+ gt_num = float(gt_normalized)
+ if abs(pred_num - gt_num) < 1e-6:
+ return True
+ except ValueError:
+ pass
+
+ return False
+
+
+# ============================================================================
+# KL Divergence Computation
+# ============================================================================
+
+def compute_sequence_kl(
+ finetuned_model: torch.nn.Module,
+ base_model: torch.nn.Module,
+ input_ids: torch.Tensor,
+ attention_mask: torch.Tensor,
+ response_start_idx: int,
+ device: torch.device
+) -> Tuple[float, int]:
+ """
+ Compute KL divergence for a single sequence.
+
+ KL(π_ft || π_base) ≈ Σ_t [log π_ft(y_t|x,y_{<t}) - log π_base(y_t|x,y_{<t})]
+
+ Returns:
+ Tuple of (kl_sum, num_tokens)
+ """
+ with torch.no_grad():
+ # Get logits from both models
+ ft_outputs = finetuned_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask
+ )
+ base_outputs = base_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask
+ )
+
+ ft_logits = ft_outputs.logits
+ base_logits = base_outputs.logits
+
+ # Compute log probabilities
+ ft_log_probs = torch.log_softmax(ft_logits, dim=-1)
+ base_log_probs = torch.log_softmax(base_logits, dim=-1)
+
+ # Get log probs for actual tokens (shifted for autoregressive)
+ shift_ft_log_probs = ft_log_probs[:, :-1, :]
+ shift_base_log_probs = base_log_probs[:, :-1, :]
+ shift_labels = input_ids[:, 1:]
+
+ ft_token_log_probs = torch.gather(
+ shift_ft_log_probs,
+ dim=-1,
+ index=shift_labels.unsqueeze(-1)
+ ).squeeze(-1)
+
+ base_token_log_probs = torch.gather(
+ shift_base_log_probs,
+ dim=-1,
+ index=shift_labels.unsqueeze(-1)
+ ).squeeze(-1)
+
+ # Compute KL only for response tokens
+ kl_per_token = ft_token_log_probs - base_token_log_probs
+
+ # Mask for response tokens only
+ response_mask = torch.zeros_like(kl_per_token)
+ response_mask[:, response_start_idx-1:] = 1.0
+
+ # Apply attention mask
+ valid_mask = attention_mask[:, 1:].float() * response_mask
+
+ kl_sum = (kl_per_token * valid_mask).sum().item()
+ num_tokens = valid_mask.sum().item()
+
+ return kl_sum, int(num_tokens)
+
+
+# ============================================================================
+# Evaluation Functions
+# ============================================================================
+
+@dataclass
+class TaskResult:
+ """Results for a single evaluation task."""
+ task_name: str
+ task_type: str
+ num_examples: int
+ avg_score: float
+ std_score: float
+ avg_kl: float
+ std_kl: float
+ avg_response_length: float
+ scores: List[float]
+ kl_values: List[float]
+
+
+def evaluate_task(
+ base_model: torch.nn.Module,
+ base_tokenizer,
+ finetuned_model: torch.nn.Module,
+ finetuned_tokenizer,
+ task_config: EvalTaskConfig,
+ device: torch.device,
+ use_amp: bool = True
+) -> TaskResult:
+ """
+ Evaluate a single task.
+
+ Computes:
+ - avg_score: Mean accuracy (for verifiable tasks)
+ - avg_kl: Mean KL divergence from base model
+ """
+ dataset = load_dataset(task_config.dataset_path, task_config.num_samples)
+
+ scores: List[float] = []
+ kl_values: List[float] = []
+ response_lengths: List[int] = []
+
+ finetuned_model.eval()
+ base_model.eval()
+
+ amp_dtype = torch.bfloat16 if use_amp else torch.float32
+
+ for example in tqdm(dataset, desc=f"Evaluating {task_config.name}"):
+ prompt = example.get("prompt", example.get("question", ""))
+ ground_truth = example.get("answer", example.get("solution", None))
+
+ # Tokenize prompt
+ inputs = finetuned_tokenizer(
+ prompt,
+ return_tensors="pt",
+ truncation=True,
+ max_length=4096
+ )
+ input_ids = inputs["input_ids"].to(device)
+ attention_mask = inputs["attention_mask"].to(device)
+ prompt_len = input_ids.shape[1]
+
+ # Generate response
+ with torch.no_grad():
+ with autocast(enabled=use_amp, dtype=amp_dtype):
+ generated_ids = finetuned_model.generate(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ max_new_tokens=task_config.max_gen_len,
+ do_sample=True,
+ temperature=task_config.temperature,
+ top_p=task_config.top_p,
+ pad_token_id=finetuned_tokenizer.eos_token_id
+ )
+
+ # Decode response
+ response_ids = generated_ids[:, prompt_len:]
+ response_text = finetuned_tokenizer.batch_decode(
+ response_ids,
+ skip_special_tokens=True
+ )[0]
+
+ response_lengths.append(len(response_ids[0]))
+
+ # Compute score (accuracy for verifiable tasks)
+ if task_config.is_verifiable and ground_truth is not None:
+ is_correct = verify_math_answer(response_text, str(ground_truth))
+ score = 1.0 if is_correct else 0.0
+ else:
+ # For non-verifiable tasks, use placeholder
+ score = 0.0
+
+ scores.append(score)
+
+ # Compute KL divergence
+ full_ids = generated_ids
+ full_attention = torch.ones_like(full_ids, device=device)
+
+ kl_sum, num_tokens = compute_sequence_kl(
+ finetuned_model=finetuned_model,
+ base_model=base_model,
+ input_ids=full_ids,
+ attention_mask=full_attention,
+ response_start_idx=prompt_len,
+ device=device
+ )
+
+ if num_tokens > 0:
+ avg_kl_per_token = kl_sum / num_tokens
+ else:
+ avg_kl_per_token = 0.0
+
+ kl_values.append(kl_sum) # Total KL for sequence
+
+ # Compute statistics
+ result = TaskResult(
+ task_name=task_config.name,
+ task_type=task_config.task_type,
+ num_examples=len(dataset),
+ avg_score=float(np.mean(scores)) if scores else 0.0,
+ std_score=float(np.std(scores)) if scores else 0.0,
+ avg_kl=float(np.mean(kl_values)) if kl_values else 0.0,
+ std_kl=float(np.std(kl_values)) if kl_values else 0.0,
+ avg_response_length=float(np.mean(response_lengths)) if response_lengths else 0.0,
+ scores=scores,
+ kl_values=kl_values
+ )
+
+ logger.info(
+ f"Task {task_config.name}: "
+ f"Score={result.avg_score:.4f} (±{result.std_score:.4f}), "
+ f"KL={result.avg_kl:.4f} (±{result.std_kl:.4f})"
+ )
+
+ return result
+
+
+def evaluate_base_model(
+ base_model: torch.nn.Module,
+ base_tokenizer,
+ task_config: EvalTaskConfig,
+ device: torch.device,
+ use_amp: bool = True
+) -> Dict[str, float]:
+ """Evaluate the base model (for computing ΔJ)."""
+ dataset = load_dataset(task_config.dataset_path, task_config.num_samples)
+
+ scores: List[float] = []
+ base_model.eval()
+
+ amp_dtype = torch.bfloat16 if use_amp else torch.float32
+
+ for example in tqdm(dataset, desc=f"Evaluating base on {task_config.name}"):
+ prompt = example.get("prompt", example.get("question", ""))
+ ground_truth = example.get("answer", example.get("solution", None))
+
+ inputs = base_tokenizer(
+ prompt,
+ return_tensors="pt",
+ truncation=True,
+ max_length=4096
+ )
+ input_ids = inputs["input_ids"].to(device)
+ attention_mask = inputs["attention_mask"].to(device)
+
+ with torch.no_grad():
+ with autocast(enabled=use_amp, dtype=amp_dtype):
+ generated_ids = base_model.generate(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ max_new_tokens=task_config.max_gen_len,
+ do_sample=True,
+ temperature=task_config.temperature,
+ top_p=task_config.top_p,
+ pad_token_id=base_tokenizer.eos_token_id
+ )
+
+ response_ids = generated_ids[:, input_ids.shape[1]:]
+ response_text = base_tokenizer.batch_decode(
+ response_ids,
+ skip_special_tokens=True
+ )[0]
+
+ if task_config.is_verifiable and ground_truth is not None:
+ is_correct = verify_math_answer(response_text, str(ground_truth))
+ score = 1.0 if is_correct else 0.0
+ else:
+ score = 0.0
+
+ scores.append(score)
+
+ return {
+ "avg_score": float(np.mean(scores)) if scores else 0.0,
+ "std_score": float(np.std(scores)) if scores else 0.0,
+ "num_examples": len(scores)
+ }
+
+
+# ============================================================================
+# Main Evaluation Pipeline
+# ============================================================================
+
+def parse_args() -> argparse.Namespace:
+ """Parse command line arguments."""
+ parser = argparse.ArgumentParser(
+ description="Evaluate RLVR trained models on multiple tasks"
+ )
+ parser.add_argument(
+ "--base_ckpt",
+ type=str,
+ required=True,
+ help="Path to base model checkpoint"
+ )
+ parser.add_argument(
+ "--ft_ckpt",
+ type=str,
+ required=True,
+ help="Path to finetuned model checkpoint"
+ )
+ parser.add_argument(
+ "--eval_tasks_config",
+ type=str,
+ required=True,
+ help="Path to evaluation tasks configuration JSON"
+ )
+ parser.add_argument(
+ "--output_path",
+ type=str,
+ required=True,
+ help="Path to save evaluation results"
+ )
+ parser.add_argument(
+ "--device",
+ type=str,
+ default="cuda",
+ help="Device to use for evaluation"
+ )
+ parser.add_argument(
+ "--eval_base",
+ action="store_true",
+ help="Also evaluate base model (for computing delta J)"
+ )
+ parser.add_argument(
+ "--use_amp",
+ action="store_true",
+ default=True,
+ help="Use automatic mixed precision"
+ )
+ return parser.parse_args()
+
+
+def main() -> None:
+ """Main evaluation function."""
+ args = parse_args()
+
+ device = torch.device(args.device if torch.cuda.is_available() else "cpu")
+ logger.info(f"Using device: {device}")
+
+ # Load tokenizers
+ logger.info(f"Loading base tokenizer from {args.base_ckpt}")
+ base_tokenizer = AutoTokenizer.from_pretrained(
+ args.base_ckpt,
+ use_fast=True,
+ trust_remote_code=True
+ )
+ if base_tokenizer.pad_token is None:
+ base_tokenizer.pad_token = base_tokenizer.eos_token
+
+ logger.info(f"Loading finetuned tokenizer from {args.ft_ckpt}")
+ ft_tokenizer = AutoTokenizer.from_pretrained(
+ args.ft_ckpt,
+ use_fast=True,
+ trust_remote_code=True
+ )
+ if ft_tokenizer.pad_token is None:
+ ft_tokenizer.pad_token = ft_tokenizer.eos_token
+
+ # Load models
+ logger.info(f"Loading base model from {args.base_ckpt}")
+ base_model = AutoModelForCausalLM.from_pretrained(
+ args.base_ckpt,
+ torch_dtype=torch.bfloat16,
+ device_map=None,
+ trust_remote_code=True
+ ).to(device)
+ base_model.eval()
+
+ logger.info(f"Loading finetuned model from {args.ft_ckpt}")
+ ft_model = AutoModelForCausalLM.from_pretrained(
+ args.ft_ckpt,
+ torch_dtype=torch.bfloat16,
+ device_map=None,
+ trust_remote_code=True
+ ).to(device)
+ ft_model.eval()
+
+ # Load evaluation tasks
+ eval_tasks = load_eval_tasks(args.eval_tasks_config)
+
+ # Evaluate on all tasks
+ all_results: Dict[str, Any] = {
+ "base_ckpt": args.base_ckpt,
+ "ft_ckpt": args.ft_ckpt,
+ "tasks": {}
+ }
+
+ for task in eval_tasks:
+ logger.info(f"\n{'='*60}")
+ logger.info(f"Evaluating task: {task.name}")
+ logger.info(f"{'='*60}")
+
+ # Evaluate finetuned model
+ result = evaluate_task(
+ base_model=base_model,
+ base_tokenizer=base_tokenizer,
+ finetuned_model=ft_model,
+ finetuned_tokenizer=ft_tokenizer,
+ task_config=task,
+ device=device,
+ use_amp=args.use_amp
+ )
+
+ task_results = {
+ "ft_avg_score": result.avg_score,
+ "ft_std_score": result.std_score,
+ "avg_kl": result.avg_kl,
+ "std_kl": result.std_kl,
+ "avg_response_length": result.avg_response_length,
+ "num_examples": result.num_examples,
+ }
+
+ # Optionally evaluate base model
+ if args.eval_base:
+ base_result = evaluate_base_model(
+ base_model=base_model,
+ base_tokenizer=base_tokenizer,
+ task_config=task,
+ device=device,
+ use_amp=args.use_amp
+ )
+ task_results["base_avg_score"] = base_result["avg_score"]
+ task_results["base_std_score"] = base_result["std_score"]
+ task_results["delta_j"] = result.avg_score - base_result["avg_score"]
+
+ all_results["tasks"][task.name] = task_results
+
+ # Save results
+ os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
+ with open(args.output_path, "w", encoding="utf-8") as f:
+ json.dump(all_results, f, indent=2)
+
+ logger.info(f"\nResults saved to {args.output_path}")
+
+ # Print summary
+ print("\n" + "="*80)
+ print("EVALUATION SUMMARY")
+ print("="*80)
+ for task_name, task_result in all_results["tasks"].items():
+ print(f"\n{task_name}:")
+ print(f" Score: {task_result['ft_avg_score']:.4f} (±{task_result['ft_std_score']:.4f})")
+ print(f" KL: {task_result['avg_kl']:.4f} (±{task_result['std_kl']:.4f})")
+ if "delta_j" in task_result:
+ print(f" ΔJ: {task_result['delta_j']:+.4f}")
+ print("="*80)
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..c108ec1
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,39 @@
+# RLVR Floating-Point Precision Experiment Dependencies
+# Core ML frameworks
+torch>=2.0.0
+transformers>=4.36.0
+accelerate>=0.25.0
+
+# RL framework (VeRL)
+# Install from source: pip install git+https://github.com/volcengine/verl.git
+# verl
+
+# Inference
+vllm>=0.2.0
+
+# Numerical computation
+numpy>=1.24.0
+scipy>=1.10.0
+
+# Visualization
+matplotlib>=3.7.0
+
+# Progress tracking
+tqdm>=4.65.0
+
+# Data handling
+datasets>=2.14.0
+
+# Utilities
+pyyaml>=6.0
+jsonlines>=3.1.0
+
+# Distributed training (optional, usually comes with torch)
+# torch-distributed
+
+# Flash attention (optional, for faster inference)
+# flash-attn>=2.3.0
+
+# Evaluation utilities
+# eval-chemy # Math verifier (install from RLVR repo)
+
diff --git a/run_experiments.py b/run_experiments.py
new file mode 100644
index 0000000..0cbcd67
--- /dev/null
+++ b/run_experiments.py
@@ -0,0 +1,601 @@
+#!/usr/bin/env python3
+# run_experiments.py
+"""
+Experiment Runner for RLVR Floating-Point Precision Study.
+
+This script orchestrates the full experimental pipeline:
+1. Training models with FP32 and bf16 precision across multiple seeds
+2. Evaluating trained models on on-task and off-task benchmarks
+3. Computing KL divergence and bf16 sparsity metrics
+
+Usage:
+ # Run full experiment
+ python run_experiments.py --mode full
+
+ # Run training only
+ python run_experiments.py --mode train --precision_mode bf16 --seed 1
+
+ # Run evaluation only
+ python run_experiments.py --mode eval
+
+ # Run analysis only
+ python run_experiments.py --mode analyze
+"""
+
+import argparse
+import json
+import os
+import subprocess
+import sys
+import logging
+from typing import Dict, Any, List, Optional
+from dataclasses import asdict
+from concurrent.futures import ProcessPoolExecutor
+import time
+
+from config import (
+ ExperimentConfig,
+ make_training_config,
+ make_precision_config,
+ get_run_output_dir,
+ get_checkpoint_path,
+)
+
+# Configure logging
+logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+)
+logger = logging.getLogger(__name__)
+
+
+# ============================================================================
+# Training Functions
+# ============================================================================
+
+def run_single_training(
+ precision_mode: str,
+ seed: int,
+ config: ExperimentConfig,
+ train_dataset_path: str,
+ dry_run: bool = False
+) -> Dict[str, Any]:
+ """
+ Run a single training job.
+
+ Args:
+ precision_mode: "fp32" or "bf16"
+ seed: Random seed
+ config: Experiment configuration
+ train_dataset_path: Path to training data
+ dry_run: If True, only print command without running
+
+ Returns:
+ Dictionary with job status
+ """
+ output_dir = get_run_output_dir(config.train_logs_dir, precision_mode, seed)
+
+ cmd = [
+ sys.executable,
+ "train_rlvr.py",
+ "--precision_mode", precision_mode,
+ "--seed", str(seed),
+ "--output_dir", output_dir,
+ "--train_dataset_path", train_dataset_path,
+ "--model_name", config.base_model_path,
+ ]
+
+ logger.info(f"Running training: {precision_mode} seed={seed}")
+ logger.info(f"Command: {' '.join(cmd)}")
+
+ if dry_run:
+ return {
+ "status": "dry_run",
+ "precision_mode": precision_mode,
+ "seed": seed,
+ "output_dir": output_dir,
+ "command": " ".join(cmd),
+ }
+
+ # Create output directory
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Run training
+ start_time = time.time()
+ try:
+ result = subprocess.run(
+ cmd,
+ capture_output=True,
+ text=True,
+ check=True
+ )
+ duration = time.time() - start_time
+
+ return {
+ "status": "success",
+ "precision_mode": precision_mode,
+ "seed": seed,
+ "output_dir": output_dir,
+ "duration_seconds": duration,
+ "stdout": result.stdout[-1000:], # Last 1000 chars
+ }
+ except subprocess.CalledProcessError as e:
+ return {
+ "status": "failed",
+ "precision_mode": precision_mode,
+ "seed": seed,
+ "output_dir": output_dir,
+ "error": str(e),
+ "stderr": e.stderr[-1000:] if e.stderr else None,
+ }
+
+
+def run_all_training(
+ config: ExperimentConfig,
+ train_dataset_path: str,
+ dry_run: bool = False,
+ parallel: bool = False
+) -> List[Dict[str, Any]]:
+ """
+ Run training for all precision modes and seeds.
+ """
+ jobs = []
+ for precision_mode in config.precision_modes:
+ for seed in config.seeds:
+ jobs.append((precision_mode, seed))
+
+ results = []
+
+ if parallel and not dry_run:
+ # Run in parallel (one job per GPU assumed)
+ with ProcessPoolExecutor(max_workers=len(jobs)) as executor:
+ futures = [
+ executor.submit(
+ run_single_training,
+ pm, s, config, train_dataset_path, dry_run
+ )
+ for pm, s in jobs
+ ]
+ results = [f.result() for f in futures]
+ else:
+ # Run sequentially
+ for precision_mode, seed in jobs:
+ result = run_single_training(
+ precision_mode, seed, config, train_dataset_path, dry_run
+ )
+ results.append(result)
+
+ return results
+
+
+# ============================================================================
+# Evaluation Functions
+# ============================================================================
+
+def run_single_evaluation(
+ precision_mode: str,
+ seed: int,
+ config: ExperimentConfig,
+ eval_base: bool = True,
+ dry_run: bool = False
+) -> Dict[str, Any]:
+ """
+ Run evaluation for a single trained model.
+ """
+ run_dir = get_run_output_dir(config.train_logs_dir, precision_mode, seed)
+ ft_ckpt = get_checkpoint_path(run_dir)
+ output_path = os.path.join(
+ config.eval_metrics_dir,
+ f"{precision_mode}_seed{seed}.json"
+ )
+
+ cmd = [
+ sys.executable,
+ "eval_policy.py",
+ "--base_ckpt", config.base_model_path,
+ "--ft_ckpt", ft_ckpt,
+ "--eval_tasks_config", config.eval_tasks_config_path,
+ "--output_path", output_path,
+ ]
+
+ if eval_base:
+ cmd.append("--eval_base")
+
+ logger.info(f"Running evaluation: {precision_mode} seed={seed}")
+ logger.info(f"Command: {' '.join(cmd)}")
+
+ if dry_run:
+ return {
+ "status": "dry_run",
+ "precision_mode": precision_mode,
+ "seed": seed,
+ "output_path": output_path,
+ "command": " ".join(cmd),
+ }
+
+ # Create output directory
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+
+ # Check if checkpoint exists
+ if not os.path.exists(ft_ckpt):
+ return {
+ "status": "skipped",
+ "precision_mode": precision_mode,
+ "seed": seed,
+ "reason": f"Checkpoint not found: {ft_ckpt}",
+ }
+
+ # Run evaluation
+ start_time = time.time()
+ try:
+ result = subprocess.run(
+ cmd,
+ capture_output=True,
+ text=True,
+ check=True
+ )
+ duration = time.time() - start_time
+
+ return {
+ "status": "success",
+ "precision_mode": precision_mode,
+ "seed": seed,
+ "output_path": output_path,
+ "duration_seconds": duration,
+ }
+ except subprocess.CalledProcessError as e:
+ return {
+ "status": "failed",
+ "precision_mode": precision_mode,
+ "seed": seed,
+ "error": str(e),
+ "stderr": e.stderr[-1000:] if e.stderr else None,
+ }
+
+
+def run_all_evaluations(
+ config: ExperimentConfig,
+ eval_base: bool = True,
+ dry_run: bool = False
+) -> List[Dict[str, Any]]:
+ """
+ Run evaluation for all trained models.
+ """
+ results = []
+
+ for precision_mode in config.precision_modes:
+ for seed in config.seeds:
+ result = run_single_evaluation(
+ precision_mode, seed, config, eval_base, dry_run
+ )
+ results.append(result)
+
+ return results
+
+
+# ============================================================================
+# bf16 Sparsity Analysis
+# ============================================================================
+
+def run_sparsity_analysis(
+ config: ExperimentConfig,
+ dry_run: bool = False
+) -> List[Dict[str, Any]]:
+ """
+ Compute bf16 sparsity for all bf16 runs.
+ """
+ import torch
+ from transformers import AutoModelForCausalLM
+ from utils_bf16_sparsity import compute_bf16_sparsity, analyze_update_magnitudes
+
+ results = []
+
+ # Only analyze bf16 runs
+ if "bf16" not in config.precision_modes:
+ logger.info("No bf16 runs to analyze for sparsity")
+ return results
+
+ if dry_run:
+ for seed in config.seeds:
+ results.append({
+ "status": "dry_run",
+ "precision_mode": "bf16",
+ "seed": seed,
+ })
+ return results
+
+ # Load base model once
+ logger.info(f"Loading base model: {config.base_model_path}")
+ base_model = AutoModelForCausalLM.from_pretrained(
+ config.base_model_path,
+ torch_dtype=torch.float32,
+ device_map="cpu"
+ )
+
+ for seed in config.seeds:
+ run_dir = get_run_output_dir(config.train_logs_dir, "bf16", seed)
+ ft_ckpt = get_checkpoint_path(run_dir)
+
+ if not os.path.exists(ft_ckpt):
+ results.append({
+ "status": "skipped",
+ "precision_mode": "bf16",
+ "seed": seed,
+ "reason": f"Checkpoint not found: {ft_ckpt}",
+ })
+ continue
+
+ logger.info(f"Computing sparsity for bf16 seed={seed}")
+
+ # Load finetuned model
+ ft_model = AutoModelForCausalLM.from_pretrained(
+ ft_ckpt,
+ torch_dtype=torch.float32,
+ device_map="cpu"
+ )
+
+ # Compute sparsity
+ sparsity_result = compute_bf16_sparsity(
+ base_model=base_model,
+ finetuned_model=ft_model,
+ eta=config.bf16_sparsity_eta,
+ include_layer_stats=True
+ )
+
+ # Analyze update magnitudes
+ magnitude_result = analyze_update_magnitudes(
+ base_model=base_model,
+ finetuned_model=ft_model
+ )
+
+ # Save results
+ output_path = os.path.join(
+ config.eval_metrics_dir,
+ f"bf16_seed{seed}_sparsity.json"
+ )
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+
+ full_result = {
+ "precision_mode": "bf16",
+ "seed": seed,
+ "sparsity": sparsity_result,
+ "magnitudes": magnitude_result,
+ }
+
+ with open(output_path, "w") as f:
+ # Convert layer_stats to serializable format
+ serializable = {
+ k: v for k, v in full_result.items()
+ }
+ if "layer_stats" in serializable.get("sparsity", {}):
+ serializable["sparsity"]["layer_stats"] = {
+ k: {kk: vv for kk, vv in v.items() if kk != "shape"}
+ for k, v in serializable["sparsity"]["layer_stats"].items()
+ }
+ json.dump(serializable, f, indent=2, default=str)
+
+ results.append({
+ "status": "success",
+ "precision_mode": "bf16",
+ "seed": seed,
+ "sparsity_percent": sparsity_result["sparsity_percent"],
+ "output_path": output_path,
+ })
+
+ # Free memory
+ del ft_model
+
+ return results
+
+
+# ============================================================================
+# Main Entry Point
+# ============================================================================
+
+def parse_args() -> argparse.Namespace:
+ """Parse command line arguments."""
+ parser = argparse.ArgumentParser(
+ description="Run RLVR floating-point precision experiments"
+ )
+
+ parser.add_argument(
+ "--mode",
+ type=str,
+ default="full",
+ choices=["full", "train", "eval", "analyze", "sparsity"],
+ help="Execution mode"
+ )
+
+ # For single job mode
+ parser.add_argument(
+ "--precision_mode",
+ type=str,
+ choices=["fp32", "bf16"],
+ help="Precision mode (for train mode)"
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ help="Random seed (for train mode)"
+ )
+
+ # Paths
+ parser.add_argument(
+ "--train_dataset_path",
+ type=str,
+ default="./data/dm_train.json",
+ help="Path to training dataset"
+ )
+ parser.add_argument(
+ "--base_output_dir",
+ type=str,
+ default="./results",
+ help="Base output directory"
+ )
+ parser.add_argument(
+ "--eval_tasks_config",
+ type=str,
+ default="./configs/eval_tasks_config.json",
+ help="Path to evaluation tasks config"
+ )
+ parser.add_argument(
+ "--base_model",
+ type=str,
+ default="Qwen/Qwen2.5-Math-7B",
+ help="Base model path or HuggingFace ID"
+ )
+
+ # Execution options
+ parser.add_argument(
+ "--dry_run",
+ action="store_true",
+ help="Print commands without executing"
+ )
+ parser.add_argument(
+ "--parallel",
+ action="store_true",
+ help="Run training jobs in parallel"
+ )
+ parser.add_argument(
+ "--seeds",
+ type=int,
+ nargs="+",
+ default=[1, 2, 3, 4, 5],
+ help="Seeds to use"
+ )
+
+ return parser.parse_args()
+
+
+def main() -> None:
+ """Main entry point."""
+ args = parse_args()
+
+ # Create experiment configuration
+ config = ExperimentConfig(
+ experiment_name="fp_precision_rlvr",
+ seeds=args.seeds,
+ precision_modes=["fp32", "bf16"],
+ base_model_path=args.base_model,
+ base_output_dir=args.base_output_dir,
+ train_logs_dir=os.path.join(args.base_output_dir, "train_logs"),
+ checkpoints_dir=os.path.join(args.base_output_dir, "checkpoints"),
+ eval_metrics_dir=os.path.join(args.base_output_dir, "eval_metrics"),
+ eval_tasks_config_path=args.eval_tasks_config,
+ )
+
+ # Create directories
+ os.makedirs(config.train_logs_dir, exist_ok=True)
+ os.makedirs(config.checkpoints_dir, exist_ok=True)
+ os.makedirs(config.eval_metrics_dir, exist_ok=True)
+
+ # Save experiment config
+ config_path = os.path.join(args.base_output_dir, "experiment_config.json")
+ with open(config_path, "w") as f:
+ json.dump(asdict(config), f, indent=2)
+ logger.info(f"Saved experiment config to {config_path}")
+
+ # Execute based on mode
+ if args.mode == "train":
+ if args.precision_mode and args.seed:
+ # Single training job
+ result = run_single_training(
+ args.precision_mode,
+ args.seed,
+ config,
+ args.train_dataset_path,
+ args.dry_run
+ )
+ print(json.dumps(result, indent=2))
+ else:
+ # All training jobs
+ results = run_all_training(
+ config,
+ args.train_dataset_path,
+ args.dry_run,
+ args.parallel
+ )
+ print(json.dumps(results, indent=2))
+
+ elif args.mode == "eval":
+ results = run_all_evaluations(config, eval_base=True, dry_run=args.dry_run)
+ print(json.dumps(results, indent=2))
+
+ elif args.mode == "sparsity":
+ results = run_sparsity_analysis(config, dry_run=args.dry_run)
+ print(json.dumps(results, indent=2))
+
+ elif args.mode == "analyze":
+ # Run analysis script
+ analyze_cmd = [
+ sys.executable,
+ "analyze_results.py",
+ "--results_dir", config.eval_metrics_dir,
+ "--output_dir", os.path.join(args.base_output_dir, "analysis"),
+ ]
+ if args.dry_run:
+ print(f"Would run: {' '.join(analyze_cmd)}")
+ else:
+ subprocess.run(analyze_cmd, check=True)
+
+ elif args.mode == "full":
+ logger.info("="*60)
+ logger.info("RUNNING FULL EXPERIMENT PIPELINE")
+ logger.info("="*60)
+
+ # Step 1: Training
+ logger.info("\n" + "="*60)
+ logger.info("STEP 1: Training")
+ logger.info("="*60)
+ train_results = run_all_training(
+ config,
+ args.train_dataset_path,
+ args.dry_run,
+ args.parallel
+ )
+
+ # Step 2: Evaluation
+ logger.info("\n" + "="*60)
+ logger.info("STEP 2: Evaluation")
+ logger.info("="*60)
+ eval_results = run_all_evaluations(config, dry_run=args.dry_run)
+
+ # Step 3: Sparsity Analysis
+ logger.info("\n" + "="*60)
+ logger.info("STEP 3: Sparsity Analysis")
+ logger.info("="*60)
+ sparsity_results = run_sparsity_analysis(config, dry_run=args.dry_run)
+
+ # Step 4: Results Analysis
+ logger.info("\n" + "="*60)
+ logger.info("STEP 4: Results Analysis")
+ logger.info("="*60)
+ if not args.dry_run:
+ analyze_cmd = [
+ sys.executable,
+ "analyze_results.py",
+ "--results_dir", config.eval_metrics_dir,
+ "--output_dir", os.path.join(args.base_output_dir, "analysis"),
+ ]
+ subprocess.run(analyze_cmd, check=True)
+
+ # Summary
+ logger.info("\n" + "="*60)
+ logger.info("EXPERIMENT COMPLETE")
+ logger.info("="*60)
+
+ all_results = {
+ "training": train_results,
+ "evaluation": eval_results,
+ "sparsity": sparsity_results,
+ }
+
+ summary_path = os.path.join(args.base_output_dir, "experiment_summary.json")
+ with open(summary_path, "w") as f:
+ json.dump(all_results, f, indent=2)
+ logger.info(f"Saved summary to {summary_path}")
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/scripts/prepare_data.py b/scripts/prepare_data.py
new file mode 100755
index 0000000..5ef3c29
--- /dev/null
+++ b/scripts/prepare_data.py
@@ -0,0 +1,258 @@
+#!/usr/bin/env python3
+"""
+Prepare REAL datasets for RLVR floating-point precision experiments.
+
+Downloads from HuggingFace:
+- Training: GSM8K train (7473 samples)
+- Evaluation: GSM8K test, MATH-500, AIME, AMC, MMLU-STEM, HumanEval
+
+Usage:
+ python scripts/prepare_data.py
+"""
+
+import json
+import os
+import random
+from pathlib import Path
+from datasets import load_dataset
+from tqdm import tqdm
+
+DATA_DIR = Path("data")
+DATA_DIR.mkdir(exist_ok=True)
+
+
+def save_json(data: list, path: Path):
+ """Save data as JSON file."""
+ with open(path, "w") as f:
+ json.dump(data, f, indent=2)
+ print(f" Saved {len(data)} samples to {path}")
+
+
+def prepare_gsm8k_train():
+ """Prepare GSM8K training data."""
+ print("\n=== Downloading GSM8K Train ===")
+ ds = load_dataset("openai/gsm8k", "main", split="train")
+
+ data = []
+ for i, sample in enumerate(tqdm(ds, desc="Processing")):
+ # Extract answer from "#### N" format
+ answer = sample["answer"].split("####")[-1].strip()
+ data.append({
+ "id": f"gsm8k_train_{i}",
+ "prompt": sample["question"],
+ "answer": answer,
+ "solution": sample["answer"],
+ "source": "gsm8k_train"
+ })
+
+ save_json(data, DATA_DIR / "dm_train.json")
+ return data
+
+
+def prepare_gsm8k_test():
+ """Prepare GSM8K test data for evaluation."""
+ print("\n=== Downloading GSM8K Test ===")
+ ds = load_dataset("openai/gsm8k", "main", split="test")
+
+ data = []
+ for i, sample in enumerate(tqdm(ds, desc="Processing")):
+ answer = sample["answer"].split("####")[-1].strip()
+ data.append({
+ "id": f"gsm8k_test_{i}",
+ "prompt": sample["question"],
+ "answer": answer,
+ "solution": sample["answer"],
+ "source": "gsm8k"
+ })
+
+ save_json(data, DATA_DIR / "gsm8k.json")
+
+ # Also create dm_val as a subset (first 500 for on-task eval)
+ save_json(data[:500], DATA_DIR / "dm_val.json")
+ return data
+
+
+def prepare_math500():
+ """Prepare MATH-500 dataset."""
+ print("\n=== Downloading MATH-500 ===")
+ ds = load_dataset("HuggingFaceH4/MATH-500", split="test")
+
+ data = []
+ for i, sample in enumerate(tqdm(ds, desc="Processing")):
+ data.append({
+ "id": f"math500_{i}",
+ "prompt": sample["problem"],
+ "answer": sample["answer"],
+ "solution": sample["solution"],
+ "subject": sample.get("subject", ""),
+ "level": sample.get("level", ""),
+ "source": "math500"
+ })
+
+ save_json(data, DATA_DIR / "math500.json")
+ return data
+
+
+def prepare_aime():
+ """Prepare AIME dataset from AI-MO."""
+ print("\n=== Downloading AIME ===")
+ ds = load_dataset("AI-MO/aimo-validation-aime", split="train")
+
+ data = []
+ for i, sample in enumerate(tqdm(ds, desc="Processing")):
+ data.append({
+ "id": f"aime_{i}",
+ "prompt": sample["problem"],
+ "answer": str(sample["answer"]),
+ "solution": sample.get("solution", ""),
+ "url": sample.get("url", ""),
+ "source": "aime"
+ })
+
+ # Split into aime24 and aime25
+ # Real AIME has 15 problems per contest, 2 contests per year = 30/year
+ save_json(data[:30], DATA_DIR / "aime24.json")
+ save_json(data[30:60], DATA_DIR / "aime25.json")
+ save_json(data, DATA_DIR / "aime_all.json")
+ return data
+
+
+def prepare_amc():
+ """Prepare AMC dataset from AI-MO."""
+ print("\n=== Downloading AMC ===")
+ ds = load_dataset("AI-MO/aimo-validation-amc", split="train")
+
+ data = []
+ for i, sample in enumerate(tqdm(ds, desc="Processing")):
+ data.append({
+ "id": f"amc_{i}",
+ "prompt": sample["problem"],
+ "answer": str(sample["answer"]),
+ "solution": sample.get("solution", ""),
+ "source": "amc"
+ })
+
+ save_json(data, DATA_DIR / "amc23.json")
+ return data
+
+
+def prepare_mmlu_stem():
+ """Prepare MMLU-STEM subset."""
+ print("\n=== Downloading MMLU-STEM ===")
+
+ stem_subjects = [
+ "abstract_algebra", "astronomy", "college_biology", "college_chemistry",
+ "college_computer_science", "college_mathematics", "college_physics",
+ "computer_security", "conceptual_physics", "electrical_engineering",
+ "elementary_mathematics", "high_school_biology", "high_school_chemistry",
+ "high_school_computer_science", "high_school_mathematics", "high_school_physics",
+ "high_school_statistics", "machine_learning"
+ ]
+
+ data = []
+ for subject in tqdm(stem_subjects, desc="Loading subjects"):
+ try:
+ ds = load_dataset("cais/mmlu", subject, split="test")
+ for i, sample in enumerate(ds):
+ choices = sample["choices"]
+ correct_idx = sample["answer"]
+ # Format as multiple choice
+ prompt = f"{sample['question']}\n"
+ for j, choice in enumerate(choices):
+ prompt += f"({chr(65+j)}) {choice}\n"
+
+ data.append({
+ "id": f"mmlu_{subject}_{i}",
+ "prompt": prompt,
+ "answer": chr(65 + correct_idx),
+ "subject": subject,
+ "source": "mmlu_stem"
+ })
+ except Exception as e:
+ print(f" Warning: Skipping {subject}: {e}")
+
+ # Take a random subset of 500
+ random.seed(42)
+ if len(data) > 500:
+ data = random.sample(data, 500)
+
+ save_json(data, DATA_DIR / "mmlu_stem.json")
+ return data
+
+
+def prepare_humaneval():
+ """Prepare HumanEval code dataset."""
+ print("\n=== Downloading HumanEval ===")
+ ds = load_dataset("openai/openai_humaneval", split="test")
+
+ data = []
+ for i, sample in enumerate(tqdm(ds, desc="Processing")):
+ data.append({
+ "id": f"humaneval_{i}",
+ "prompt": sample["prompt"],
+ "answer": sample["canonical_solution"],
+ "entry_point": sample["entry_point"],
+ "test": sample["test"],
+ "source": "humaneval"
+ })
+
+ save_json(data, DATA_DIR / "humaneval.json")
+ return data
+
+
+def verify_data():
+ """Verify downloaded data quality."""
+ print("\n" + "=" * 60)
+ print("Verifying Data Quality")
+ print("=" * 60)
+
+ for f in sorted(DATA_DIR.glob("*.json")):
+ with open(f) as fp:
+ data = json.load(fp)
+
+ # Check for unique prompts
+ prompts = [d["prompt"] for d in data]
+ unique = len(set(prompts))
+
+ status = "OK" if unique == len(prompts) else f"WARN: {len(prompts)-unique} duplicates"
+ print(f" {f.name}: {len(data)} samples, {unique} unique [{status}]")
+
+ # Show first example
+ if data:
+ print(f" Example: {data[0]['prompt'][:60]}...")
+
+
+def main():
+ print("=" * 60)
+ print("RLVR Real Data Preparation")
+ print("=" * 60)
+
+ # Backup old data
+ backup_dir = DATA_DIR / "backup_synthetic"
+ if not backup_dir.exists() and any(DATA_DIR.glob("*.json")):
+ backup_dir.mkdir(exist_ok=True)
+ for f in DATA_DIR.glob("*.json"):
+ f.rename(backup_dir / f.name)
+ print(f"Backed up synthetic data to {backup_dir}")
+
+ # Training data
+ prepare_gsm8k_train()
+
+ # Evaluation data
+ prepare_gsm8k_test()
+ prepare_math500()
+ prepare_aime()
+ prepare_amc()
+ prepare_mmlu_stem()
+ prepare_humaneval()
+
+ # Verify
+ verify_data()
+
+ print("\n" + "=" * 60)
+ print("Data preparation complete!")
+ print("=" * 60)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/run_evaluation.sh b/scripts/run_evaluation.sh
new file mode 100755
index 0000000..b39c230
--- /dev/null
+++ b/scripts/run_evaluation.sh
@@ -0,0 +1,58 @@
+#!/bin/bash
+# run_evaluation.sh
+# Script to run evaluation on trained models
+
+set -e
+set -o pipefail # Properly capture exit codes through pipes
+
+# Configuration
+export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-"0"}
+
+# HuggingFace cache - use shared HDD storage to avoid quota issues
+export HF_HOME="/work/hdd/bfqt/yurenh2/huggingface_cache"
+export HF_HUB_CACHE="/work/hdd/bfqt/yurenh2/huggingface_cache/hub"
+mkdir -p "$HF_HOME" "$HF_HUB_CACHE"
+
+# Default values
+PRECISION_MODE=${1:-"bf16"}
+SEED=${2:-1}
+BASE_MODEL=${BASE_MODEL:-"Qwen/Qwen2.5-Math-7B"}
+TRAIN_LOGS_DIR=${TRAIN_LOGS_DIR:-"./results/train_logs"}
+EVAL_METRICS_DIR=${EVAL_METRICS_DIR:-"./results/eval_metrics"}
+EVAL_CONFIG=${EVAL_CONFIG:-"./configs/eval_tasks_config.json"}
+
+# Paths
+FT_CKPT="${TRAIN_LOGS_DIR}/${PRECISION_MODE}_seed${SEED}/final_model"
+OUTPUT_PATH="${EVAL_METRICS_DIR}/${PRECISION_MODE}_seed${SEED}.json"
+
+# Create output directory
+mkdir -p "$EVAL_METRICS_DIR"
+
+echo "=============================================="
+echo "Model Evaluation"
+echo "=============================================="
+echo "Precision Mode: $PRECISION_MODE"
+echo "Seed: $SEED"
+echo "Base Model: $BASE_MODEL"
+echo "Finetuned Model: $FT_CKPT"
+echo "Output: $OUTPUT_PATH"
+echo "=============================================="
+
+# Check if checkpoint exists
+if [ ! -d "$FT_CKPT" ]; then
+ echo "Error: Checkpoint not found at $FT_CKPT"
+ exit 1
+fi
+
+# Run evaluation
+python eval_policy.py \
+ --base_ckpt "$BASE_MODEL" \
+ --ft_ckpt "$FT_CKPT" \
+ --eval_tasks_config "$EVAL_CONFIG" \
+ --output_path "$OUTPUT_PATH" \
+ --eval_base \
+ --use_amp \
+ 2>&1 | tee "${EVAL_METRICS_DIR}/${PRECISION_MODE}_seed${SEED}_eval.log"
+
+echo "Evaluation complete. Results saved to: $OUTPUT_PATH"
+
diff --git a/scripts/run_full_experiment.sh b/scripts/run_full_experiment.sh
new file mode 100755
index 0000000..43e9dd5
--- /dev/null
+++ b/scripts/run_full_experiment.sh
@@ -0,0 +1,106 @@
+#!/bin/bash
+# run_full_experiment.sh
+# Master script to run the complete RLVR floating-point precision experiment
+
+set -e
+
+# Configuration
+SEEDS=(1 2 3 4 5)
+PRECISION_MODES=("fp32" "bf16")
+TRAIN_DATA=${TRAIN_DATA:-"./data/dm_train.json"}
+OUTPUT_BASE=${OUTPUT_BASE:-"./results"}
+
+echo "=============================================="
+echo "RLVR Floating-Point Precision Experiment"
+echo "=============================================="
+echo "Seeds: ${SEEDS[*]}"
+echo "Precision Modes: ${PRECISION_MODES[*]}"
+echo "Output: $OUTPUT_BASE"
+echo "=============================================="
+
+# Create directories
+mkdir -p "$OUTPUT_BASE/train_logs"
+mkdir -p "$OUTPUT_BASE/eval_metrics"
+mkdir -p "$OUTPUT_BASE/analysis"
+
+# Phase 1: Training
+echo ""
+echo "=============================================="
+echo "PHASE 1: TRAINING"
+echo "=============================================="
+
+for precision in "${PRECISION_MODES[@]}"; do
+ for seed in "${SEEDS[@]}"; do
+ echo "Training: precision=$precision, seed=$seed"
+
+ OUTPUT_DIR="$OUTPUT_BASE/train_logs/${precision}_seed${seed}"
+
+ # Skip if already completed
+ if [ -d "$OUTPUT_DIR/final_model" ]; then
+ echo " -> Skipping (already completed)"
+ continue
+ fi
+
+ # Run training
+ bash scripts/run_training.sh "$precision" "$seed"
+ done
+done
+
+# Phase 2: Evaluation
+echo ""
+echo "=============================================="
+echo "PHASE 2: EVALUATION"
+echo "=============================================="
+
+for precision in "${PRECISION_MODES[@]}"; do
+ for seed in "${SEEDS[@]}"; do
+ echo "Evaluating: precision=$precision, seed=$seed"
+
+ OUTPUT_PATH="$OUTPUT_BASE/eval_metrics/${precision}_seed${seed}.json"
+
+ # Skip if already completed
+ if [ -f "$OUTPUT_PATH" ]; then
+ echo " -> Skipping (already completed)"
+ continue
+ fi
+
+ # Run evaluation
+ bash scripts/run_evaluation.sh "$precision" "$seed"
+ done
+done
+
+# Phase 3: bf16 Sparsity Analysis
+echo ""
+echo "=============================================="
+echo "PHASE 3: BF16 SPARSITY ANALYSIS"
+echo "=============================================="
+
+python run_experiments.py --mode sparsity \
+ --base_output_dir "$OUTPUT_BASE" \
+ --seeds "${SEEDS[@]}"
+
+# Phase 4: Results Analysis
+echo ""
+echo "=============================================="
+echo "PHASE 4: RESULTS ANALYSIS"
+echo "=============================================="
+
+python analyze_results.py \
+ --results_dir "$OUTPUT_BASE/eval_metrics" \
+ --output_dir "$OUTPUT_BASE/analysis" \
+ --on_task dm_val \
+ --off_task aime24 aime25 amc23 math500 mmlu_stem humaneval
+
+echo ""
+echo "=============================================="
+echo "EXPERIMENT COMPLETE"
+echo "=============================================="
+echo "Results saved to: $OUTPUT_BASE"
+echo ""
+echo "Key output files:"
+echo " - Training logs: $OUTPUT_BASE/train_logs/"
+echo " - Evaluation metrics: $OUTPUT_BASE/eval_metrics/"
+echo " - Analysis: $OUTPUT_BASE/analysis/full_analysis.json"
+echo " - Plots: $OUTPUT_BASE/analysis/*.png"
+echo "=============================================="
+
diff --git a/scripts/run_training.sh b/scripts/run_training.sh
new file mode 100755
index 0000000..38b2fc8
--- /dev/null
+++ b/scripts/run_training.sh
@@ -0,0 +1,50 @@
+#!/bin/bash
+# run_training.sh
+# Script to run RLVR training experiments with different precision modes
+
+set -e
+set -o pipefail # Properly capture exit codes through pipes
+
+# Configuration
+export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-"0,1"}
+
+# HuggingFace cache - use shared HDD storage to avoid quota issues
+export HF_HOME="/work/hdd/bfqt/yurenh2/huggingface_cache"
+export HF_HUB_CACHE="/work/hdd/bfqt/yurenh2/huggingface_cache/hub"
+mkdir -p "$HF_HOME" "$HF_HUB_CACHE"
+
+# Default values
+PRECISION_MODE=${1:-"bf16"}
+SEED=${2:-1}
+TRAIN_DATA=${TRAIN_DATA:-"./data/dm_train.json"}
+OUTPUT_BASE=${OUTPUT_BASE:-"./results/train_logs"}
+MODEL_NAME=${MODEL_NAME:-"Qwen/Qwen2.5-Math-7B"}
+NUM_STEPS=${NUM_STEPS:-300}
+
+# Create output directory
+OUTPUT_DIR="${OUTPUT_BASE}/${PRECISION_MODE}_seed${SEED}"
+mkdir -p "$OUTPUT_DIR"
+
+echo "=============================================="
+echo "RLVR Training"
+echo "=============================================="
+echo "Precision Mode: $PRECISION_MODE"
+echo "Seed: $SEED"
+echo "Model: $MODEL_NAME"
+echo "Training Data: $TRAIN_DATA"
+echo "Output: $OUTPUT_DIR"
+echo "Num Steps: $NUM_STEPS"
+echo "=============================================="
+
+# Run training
+python train_rlvr.py \
+ --precision_mode "$PRECISION_MODE" \
+ --seed "$SEED" \
+ --output_dir "$OUTPUT_DIR" \
+ --train_dataset_path "$TRAIN_DATA" \
+ --model_name "$MODEL_NAME" \
+ --num_steps "$NUM_STEPS" \
+ 2>&1 | tee "${OUTPUT_DIR}/training.log"
+
+echo "Training complete. Output saved to: $OUTPUT_DIR"
+
diff --git a/scripts/setup_env.sh b/scripts/setup_env.sh
new file mode 100755
index 0000000..743f99a
--- /dev/null
+++ b/scripts/setup_env.sh
@@ -0,0 +1,79 @@
+#!/bin/bash
+# setup_env.sh
+# One-time setup script for the RLVR floating-point precision experiment
+# Run this BEFORE submitting any jobs
+
+set -e
+
+CONDA_ENV="rlvr-fp"
+PROJECT_DIR="/projects/bfqt/users/yurenh2/ml-projects/rl-floating-noise"
+
+echo "============================================"
+echo "RLVR Environment Setup"
+echo "============================================"
+
+# Setup HuggingFace cache directories
+echo ""
+echo "Setting up HuggingFace cache..."
+HF_CACHE_DIR="/work/hdd/bfqt/yurenh2/huggingface_cache"
+mkdir -p "$HF_CACHE_DIR/hub" "$HF_CACHE_DIR/transformers"
+echo " Cache directory: $HF_CACHE_DIR"
+
+# Add to shell profile if not already present
+PROFILE_FILE="$HOME/.bashrc"
+if ! grep -q "HF_HOME.*huggingface_cache" "$PROFILE_FILE" 2>/dev/null; then
+ echo ""
+ echo "Adding HuggingFace cache settings to $PROFILE_FILE..."
+ cat >> "$PROFILE_FILE" << 'EOF'
+
+# HuggingFace cache - shared across all projects (added by RLVR setup)
+export HF_HOME="/work/hdd/bfqt/yurenh2/huggingface_cache"
+export HF_HUB_CACHE="/work/hdd/bfqt/yurenh2/huggingface_cache/hub"
+export TRANSFORMERS_CACHE="/work/hdd/bfqt/yurenh2/huggingface_cache/transformers"
+EOF
+ echo " Added to $PROFILE_FILE"
+else
+ echo " HuggingFace settings already in $PROFILE_FILE"
+fi
+
+# Source to apply changes
+source "$PROFILE_FILE"
+
+# Check if conda environment exists
+echo ""
+echo "Checking conda environment..."
+source ~/.bashrc
+
+if conda env list | grep -q "^${CONDA_ENV} "; then
+ echo " Environment '$CONDA_ENV' already exists"
+ echo " To recreate, run: conda env remove -n $CONDA_ENV && $0"
+else
+ echo " Creating conda environment: $CONDA_ENV"
+ conda create -n "$CONDA_ENV" python=3.10 -y
+
+ echo ""
+ echo "Installing dependencies..."
+ conda activate "$CONDA_ENV"
+ cd "$PROJECT_DIR"
+ pip install -r requirements.txt
+
+ echo ""
+ echo "Verifying installation..."
+ python -c "import torch; print(f'PyTorch: {torch.__version__}')"
+ python -c "import transformers; print(f'Transformers: {transformers.__version__}')"
+fi
+
+echo ""
+echo "============================================"
+echo "Setup complete!"
+echo "============================================"
+echo ""
+echo "To activate the environment:"
+echo " conda activate $CONDA_ENV"
+echo ""
+echo "To run experiments:"
+echo " ./scripts/submit_all_jobs.sh"
+echo ""
+echo "HuggingFace cache location: $HF_CACHE_DIR"
+echo " (1TB quota, shared across all projects)"
+echo "============================================"
diff --git a/scripts/slurm_train.sh b/scripts/slurm_train.sh
new file mode 100755
index 0000000..36bd5b1
--- /dev/null
+++ b/scripts/slurm_train.sh
@@ -0,0 +1,145 @@
+#!/bin/bash
+#SBATCH --job-name=rlvr_fp_exp
+#SBATCH --account=bfqt-delta-gpu
+#SBATCH --partition=gpuH200x8
+#SBATCH --nodes=1
+#SBATCH --ntasks-per-node=1
+#SBATCH --cpus-per-task=16
+#SBATCH --gres=gpu:h200:4
+#SBATCH --mem=200G
+#SBATCH --time=2-00:00:00
+#SBATCH --output=results/slurm_logs/%x_%j.out
+#SBATCH --error=results/slurm_logs/%x_%j.err
+#SBATCH --mail-type=BEGIN,END,FAIL
+#SBATCH --mail-user=$USER@example.com
+
+# Exit on error and propagate exit codes through pipes
+set -o pipefail
+
+# ============================================
+# RLVR Floating-Point Precision Experiment
+# H200x8 SLURM Job Script
+# ============================================
+
+# Configuration - modify these as needed
+PRECISION_MODE=${PRECISION_MODE:-"bf16"}
+SEED=${SEED:-1}
+NUM_STEPS=${NUM_STEPS:-150} # ~45 hours on H200 with sequential generation
+
+# Paths
+PROJECT_DIR="/projects/bfqt/users/yurenh2/ml-projects/rl-floating-noise"
+CONDA_ENV="rlvr-fp" # Change to your conda env name
+MODEL_NAME="Qwen/Qwen2.5-Math-7B"
+TRAIN_DATA="${PROJECT_DIR}/data/dm_train.json"
+
+# ============================================
+# HuggingFace cache configuration
+# Use shared HDD storage to avoid home directory quota issues
+# This cache is shared across all projects
+# ============================================
+export HF_HOME="/work/hdd/bfqt/yurenh2/huggingface_cache"
+export HF_HUB_CACHE="/work/hdd/bfqt/yurenh2/huggingface_cache/hub"
+mkdir -p "$HF_HOME" "$HF_HUB_CACHE"
+
+# Print job info
+echo "============================================"
+echo "SLURM Job ID: $SLURM_JOB_ID"
+echo "Running on: $(hostname)"
+echo "Start time: $(date)"
+echo "============================================"
+echo "Precision Mode: $PRECISION_MODE"
+echo "Seed: $SEED"
+echo "Num Steps: $NUM_STEPS"
+echo "GPUs: $CUDA_VISIBLE_DEVICES"
+echo "HF Cache: $HF_HOME"
+echo "============================================"
+
+# Setup environment
+cd "$PROJECT_DIR"
+mkdir -p results/slurm_logs
+
+# Activate conda environment
+source ~/.bashrc
+
+# Check if conda environment exists
+if conda env list | grep -q "^${CONDA_ENV} "; then
+ echo "Activating existing conda environment: $CONDA_ENV"
+ conda activate "$CONDA_ENV"
+else
+ echo "ERROR: Conda environment '$CONDA_ENV' does not exist!"
+ echo "Please create it first by running:"
+ echo " conda create -n $CONDA_ENV python=3.10 -y"
+ echo " conda activate $CONDA_ENV"
+ echo " pip install -r requirements.txt"
+ exit 1
+fi
+
+# Verify activation succeeded
+if [[ "$CONDA_DEFAULT_ENV" != "$CONDA_ENV" ]]; then
+ echo "ERROR: Failed to activate conda environment '$CONDA_ENV'"
+ exit 1
+fi
+echo "Conda environment activated: $CONDA_DEFAULT_ENV"
+
+# Check GPU availability
+nvidia-smi
+echo "CUDA devices: $(python -c 'import torch; print(torch.cuda.device_count())')"
+
+# Output directory (use /work for large checkpoints - /projects is limited)
+OUTPUT_DIR="/work/hdd/bfqt/yurenh2/rlvr_results/${PRECISION_MODE}_seed${SEED}"
+mkdir -p "$OUTPUT_DIR"
+
+# DeepSpeed config (ZeRO-3 for full sharding of model/optimizer/gradients)
+DEEPSPEED_CONFIG="${PROJECT_DIR}/configs/deepspeed_zero3.json"
+
+# Number of GPUs for DeepSpeed (all GPUs, ref model on same GPU as training per rank)
+NUM_GPUS=$(python -c 'import torch; print(torch.cuda.device_count())')
+echo "Using $NUM_GPUS GPUs for DeepSpeed training (ref model on each rank's GPU)"
+
+# Use random port to avoid conflicts with other jobs
+MASTER_PORT=$((29500 + RANDOM % 1000))
+echo "Using master port: $MASTER_PORT"
+
+# Run training with DeepSpeed
+echo ""
+echo "Starting training with DeepSpeed ZeRO-3..."
+echo "============================================"
+
+deepspeed --num_gpus=$NUM_GPUS --master_port=$MASTER_PORT train_rlvr.py \
+ --precision_mode "$PRECISION_MODE" \
+ --seed "$SEED" \
+ --output_dir "$OUTPUT_DIR" \
+ --train_dataset_path "$TRAIN_DATA" \
+ --model_name "$MODEL_NAME" \
+ --num_steps "$NUM_STEPS" \
+ --deepspeed "$DEEPSPEED_CONFIG" \
+ 2>&1 | tee "${OUTPUT_DIR}/training_slurm.log"
+
+# IMPORTANT: With set -o pipefail, $? now captures python's exit code, not tee's
+TRAIN_EXIT_CODE=$?
+
+echo ""
+echo "============================================"
+echo "Training completed with exit code: $TRAIN_EXIT_CODE"
+echo "End time: $(date)"
+echo "============================================"
+
+# If training succeeded, run evaluation
+if [ $TRAIN_EXIT_CODE -eq 0 ]; then
+ echo ""
+ echo "Starting evaluation..."
+ echo "============================================"
+
+ python eval_policy.py \
+ --base_ckpt "$MODEL_NAME" \
+ --ft_ckpt "${OUTPUT_DIR}/final_model" \
+ --eval_tasks_config configs/eval_tasks_config.json \
+ --output_path "results/eval_metrics/${PRECISION_MODE}_seed${SEED}.json" \
+ --eval_base \
+ --use_amp \
+ 2>&1 | tee "${OUTPUT_DIR}/eval_slurm.log"
+fi
+
+echo ""
+echo "Job completed at: $(date)"
+
diff --git a/scripts/submit_all_jobs.sh b/scripts/submit_all_jobs.sh
new file mode 100755
index 0000000..86c0f5d
--- /dev/null
+++ b/scripts/submit_all_jobs.sh
@@ -0,0 +1,66 @@
+#!/bin/bash
+# submit_all_jobs.sh
+# Submit all experiment jobs to SLURM queue
+# Jobs will run automatically when resources become available
+
+set -e
+
+PROJECT_DIR="/projects/bfqt/users/yurenh2/ml-projects/rl-floating-noise"
+cd "$PROJECT_DIR"
+
+# Create log directory
+mkdir -p results/slurm_logs
+
+# Configuration
+SEEDS=(1 2 3 4 5)
+PRECISION_MODES=("fp32" "bf16")
+
+echo "============================================"
+echo "Submitting RLVR Experiment Jobs"
+echo "============================================"
+echo "Seeds: ${SEEDS[*]}"
+echo "Precision Modes: ${PRECISION_MODES[*]}"
+echo "Total jobs: $((${#SEEDS[@]} * ${#PRECISION_MODES[@]}))"
+echo "============================================"
+
+# Track submitted job IDs
+declare -a JOB_IDS
+
+for precision in "${PRECISION_MODES[@]}"; do
+ for seed in "${SEEDS[@]}"; do
+ JOB_NAME="rlvr_${precision}_s${seed}"
+
+ echo "Submitting: $JOB_NAME"
+
+ # Submit job with environment variables
+ JOB_ID=$(sbatch \
+ --job-name="$JOB_NAME" \
+ --export=ALL,PRECISION_MODE="$precision",SEED="$seed" \
+ scripts/slurm_train.sh | awk '{print $4}')
+
+ JOB_IDS+=("$JOB_ID")
+ echo " -> Job ID: $JOB_ID"
+ done
+done
+
+echo ""
+echo "============================================"
+echo "All jobs submitted!"
+echo "Job IDs: ${JOB_IDS[*]}"
+echo "============================================"
+echo ""
+echo "Monitor with:"
+echo " squeue -u $USER"
+echo " squeue -j $(IFS=,; echo "${JOB_IDS[*]}")"
+echo ""
+echo "View logs:"
+echo " tail -f results/slurm_logs/rlvr_*.out"
+echo ""
+echo "Cancel all:"
+echo " scancel ${JOB_IDS[*]}"
+echo "============================================"
+
+# Save job IDs for reference
+echo "${JOB_IDS[*]}" > results/slurm_logs/submitted_jobs.txt
+echo "Job IDs saved to: results/slurm_logs/submitted_jobs.txt"
+
diff --git a/scripts/submit_single_job.sh b/scripts/submit_single_job.sh
new file mode 100755
index 0000000..7fe7492
--- /dev/null
+++ b/scripts/submit_single_job.sh
@@ -0,0 +1,32 @@
+#!/bin/bash
+# submit_single_job.sh
+# Submit a single training job
+# Usage: ./submit_single_job.sh <precision_mode> <seed>
+# Example: ./submit_single_job.sh bf16 1
+
+PRECISION_MODE=${1:-"bf16"}
+SEED=${2:-1}
+
+PROJECT_DIR="/projects/bfqt/users/yurenh2/ml-projects/rl-floating-noise"
+cd "$PROJECT_DIR"
+
+mkdir -p results/slurm_logs
+
+JOB_NAME="rlvr_${PRECISION_MODE}_s${SEED}"
+
+echo "Submitting job: $JOB_NAME"
+echo " Precision: $PRECISION_MODE"
+echo " Seed: $SEED"
+
+JOB_ID=$(sbatch \
+ --job-name="$JOB_NAME" \
+ --export=ALL,PRECISION_MODE="$PRECISION_MODE",SEED="$SEED" \
+ scripts/slurm_train.sh | awk '{print $4}')
+
+echo ""
+echo "Submitted! Job ID: $JOB_ID"
+echo ""
+echo "Monitor with: squeue -j $JOB_ID"
+echo "View output: tail -f results/slurm_logs/${JOB_NAME}_${JOB_ID}.out"
+echo "Cancel: scancel $JOB_ID"
+
diff --git a/scripts/test_quick.sh b/scripts/test_quick.sh
new file mode 100644
index 0000000..f66e73b
--- /dev/null
+++ b/scripts/test_quick.sh
@@ -0,0 +1,78 @@
+#!/bin/bash
+#SBATCH --job-name=rlvr_test
+#SBATCH --account=bfqt-delta-gpu
+#SBATCH --partition=gpuH200x8
+#SBATCH --nodes=1
+#SBATCH --ntasks-per-node=1
+#SBATCH --cpus-per-task=16
+#SBATCH --gres=gpu:h200:4
+#SBATCH --mem=200G
+#SBATCH --time=04:00:00
+#SBATCH --output=results/slurm_logs/test_%j.out
+#SBATCH --error=results/slurm_logs/test_%j.err
+
+set -o pipefail
+
+PROJECT_DIR="/projects/bfqt/users/yurenh2/ml-projects/rl-floating-noise"
+cd "$PROJECT_DIR"
+
+source ~/.bashrc
+conda activate rlvr-fp
+
+export HF_HOME="/work/hdd/bfqt/yurenh2/huggingface_cache"
+export HF_HUB_CACHE="/work/hdd/bfqt/yurenh2/huggingface_cache/hub"
+
+echo "============================================"
+echo "Quick test on $(hostname)"
+echo "SLURM Job ID: $SLURM_JOB_ID"
+nvidia-smi
+echo "============================================"
+
+NUM_GPUS=$(python -c 'import torch; print(torch.cuda.device_count())')
+echo "Using $NUM_GPUS GPUs for DeepSpeed"
+
+# Use random port to avoid conflicts
+MASTER_PORT=$((29500 + RANDOM % 1000))
+echo "Using master port: $MASTER_PORT"
+
+# Test fp32 with just 3 steps
+echo "Testing fp32..."
+deepspeed --num_gpus=$NUM_GPUS --master_port=$MASTER_PORT train_rlvr.py \
+ --precision_mode fp32 \
+ --seed 1 \
+ --output_dir /work/hdd/bfqt/yurenh2/rlvr_results/test_fp32 \
+ --train_dataset_path data/dm_train.json \
+ --model_name Qwen/Qwen2.5-Math-7B \
+ --num_steps 3 \
+ --deepspeed configs/deepspeed_zero3.json
+
+FP32_EXIT=$?
+echo "fp32 test exit code: $FP32_EXIT"
+
+if [ $FP32_EXIT -eq 0 ]; then
+ echo "fp32 test PASSED"
+
+ # Also test bf16
+ echo "Testing bf16..."
+ deepspeed --num_gpus=$NUM_GPUS --master_port=$MASTER_PORT train_rlvr.py \
+ --precision_mode bf16 \
+ --seed 1 \
+ --output_dir /work/hdd/bfqt/yurenh2/rlvr_results/test_bf16 \
+ --train_dataset_path data/dm_train.json \
+ --model_name Qwen/Qwen2.5-Math-7B \
+ --num_steps 3 \
+ --deepspeed configs/deepspeed_zero3.json
+
+ BF16_EXIT=$?
+ echo "bf16 test exit code: $BF16_EXIT"
+
+ if [ $BF16_EXIT -eq 0 ]; then
+ echo "============================================"
+ echo "ALL TESTS PASSED!"
+ echo "============================================"
+ else
+ echo "bf16 test FAILED"
+ fi
+else
+ echo "fp32 test FAILED"
+fi
diff --git a/train_rlvr.py b/train_rlvr.py
new file mode 100644
index 0000000..1076df8
--- /dev/null
+++ b/train_rlvr.py
@@ -0,0 +1,849 @@
+#!/usr/bin/env python3
+# train_rlvr.py
+"""
+RLVR Training Script with DAPO Algorithm.
+
+This script implements the training loop for reinforcement learning with
+verifiable rewards (RLVR) using the DAPO algorithm. It supports two precision
+configurations:
+
+- P-high (fp32): High precision master weights for low numerical noise
+- P-bf16: Default RLVR configuration with bf16 master weights
+
+The script integrates with VeRL framework for distributed RL training.
+
+Usage:
+ python train_rlvr.py \
+ --precision_mode fp32 \
+ --seed 1 \
+ --output_dir results/train_logs/fp32_seed1 \
+ --train_dataset_path data/dm_train.json
+"""
+
+import argparse
+import json
+import os
+import random
+import logging
+from typing import Dict, Any, List, Optional, Tuple
+from dataclasses import asdict
+
+import numpy as np
+import torch
+import torch.distributed as dist
+from torch.cuda.amp import autocast, GradScaler
+from transformers import AutoModelForCausalLM, AutoTokenizer
+import deepspeed
+
+from config import (
+ make_training_config,
+ make_precision_config,
+ TrainingConfig,
+ PrecisionConfig,
+)
+
+# Configure logging
+logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+)
+logger = logging.getLogger(__name__)
+
+
+# ============================================================================
+# Seed and Determinism Utilities
+# ============================================================================
+
+def set_seed(seed: int) -> None:
+ """Set random seeds for reproducibility."""
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ # For hash-based operations
+ os.environ["PYTHONHASHSEED"] = str(seed)
+ logger.info(f"Set random seed to {seed}")
+
+
+def configure_torch_deterministic(deterministic: bool) -> None:
+ """Configure PyTorch deterministic algorithms."""
+ if deterministic:
+ torch.use_deterministic_algorithms(True, warn_only=True)
+ torch.backends.cudnn.benchmark = False
+ torch.backends.cudnn.deterministic = True
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
+ logger.info("Enabled deterministic algorithms")
+ else:
+ torch.use_deterministic_algorithms(False)
+ torch.backends.cudnn.benchmark = True
+ torch.backends.cudnn.deterministic = False
+ logger.info("Using non-deterministic algorithms (default)")
+
+
+# ============================================================================
+# Model Utilities
+# ============================================================================
+
+def get_torch_dtype(dtype_str: str) -> torch.dtype:
+ """Convert string dtype to torch.dtype."""
+ dtype_map = {
+ "float32": torch.float32,
+ "float16": torch.float16,
+ "bfloat16": torch.bfloat16,
+ }
+ if dtype_str not in dtype_map:
+ raise ValueError(f"Unknown dtype: {dtype_str}")
+ return dtype_map[dtype_str]
+
+
+def cast_model_param_dtype(
+ model: torch.nn.Module,
+ param_dtype: str
+) -> torch.nn.Module:
+ """Cast model parameters to specified dtype."""
+ dtype = get_torch_dtype(param_dtype)
+ model.to(dtype=dtype)
+ logger.info(f"Cast model parameters to {param_dtype}")
+ return model
+
+
+def disable_dropout(model: torch.nn.Module) -> None:
+ """Disable all dropout layers in the model."""
+ for module in model.modules():
+ if isinstance(module, torch.nn.Dropout):
+ module.p = 0.0
+ logger.info("Disabled all dropout layers")
+
+
+def count_parameters(model: torch.nn.Module) -> int:
+ """Count trainable parameters."""
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+
+# ============================================================================
+# Data Loading
+# ============================================================================
+
+def load_training_data(dataset_path: str) -> List[Dict[str, Any]]:
+ """Load training dataset from JSON file."""
+ with open(dataset_path, "r", encoding="utf-8") as f:
+ data = json.load(f)
+ logger.info(f"Loaded {len(data)} training examples from {dataset_path}")
+ return data
+
+
+def sample_batch(
+ data: List[Dict[str, Any]],
+ batch_size: int,
+ rng: np.random.Generator
+) -> List[Dict[str, Any]]:
+ """Sample a batch of prompts from the dataset."""
+ indices = rng.choice(len(data), size=min(batch_size, len(data)), replace=False)
+ return [data[i] for i in indices]
+
+
+# ============================================================================
+# Reward Function (Math Verifier)
+# ============================================================================
+
+def compute_math_reward(
+ prompt: str,
+ response: str,
+ ground_truth: Optional[str] = None
+) -> float:
+ """
+ Compute reward for a math problem response.
+
+ Uses a simple rule-based verifier. In production, this should be replaced
+ with Eval-Chemy or similar math verification system.
+
+ Args:
+ prompt: The math problem prompt
+ response: The model's generated response
+ ground_truth: The expected answer (if available)
+
+ Returns:
+ +1.0 if correct, -1.0 if incorrect
+ """
+ # TODO: Replace with actual math verifier (Eval-Chemy)
+ # This is a placeholder implementation
+
+ if ground_truth is None:
+ # Cannot verify without ground truth
+ return 0.0
+
+ # Extract final answer from response (simple heuristic)
+ response_lower = response.lower().strip()
+ gt_lower = ground_truth.lower().strip()
+
+ # Check for common answer formats
+ answer_markers = ["the answer is", "therefore", "=", "\\boxed{"]
+
+ for marker in answer_markers:
+ if marker in response_lower:
+ idx = response_lower.rfind(marker)
+ potential_answer = response_lower[idx:].strip()
+ if gt_lower in potential_answer:
+ return 1.0
+
+ # Direct containment check as fallback
+ if gt_lower in response_lower:
+ return 1.0
+
+ return -1.0
+
+
+# ============================================================================
+# DAPO Algorithm Implementation
+# ============================================================================
+
+class DAPOTrainer:
+ """
+ DAPO (Direct Alignment from Preferences Optimization) Trainer.
+
+ Implements clip-only DAPO with implicit KL constraint through ratio clipping.
+ This is a simplified implementation - for production use VeRL's DapoTrainer.
+ """
+
+ def __init__(
+ self,
+ model_engine, # DeepSpeed engine or raw model
+ ref_model: torch.nn.Module,
+ tokenizer,
+ train_config: TrainingConfig,
+ precision_config: PrecisionConfig,
+ device: torch.device,
+ ref_device: Optional[torch.device] = None,
+ use_deepspeed: bool = False
+ ) -> None:
+ self.use_deepspeed = use_deepspeed
+ if use_deepspeed:
+ self.model_engine = model_engine
+ self.model = model_engine.module
+ else:
+ self.model = model_engine
+ self.model_engine = None
+ self.ref_model = ref_model
+ self.tokenizer = tokenizer
+ self.train_config = train_config
+ self.precision_config = precision_config
+ self.device = device
+ self.ref_device = ref_device if ref_device is not None else device
+
+ # Freeze reference model
+ for param in self.ref_model.parameters():
+ param.requires_grad = False
+ self.ref_model.eval()
+
+ # Setup optimizer (only if not using DeepSpeed)
+ if not use_deepspeed:
+ self.optimizer = torch.optim.AdamW(
+ self.model.parameters(),
+ lr=train_config.learning_rate,
+ betas=(train_config.beta1, train_config.beta2),
+ weight_decay=train_config.weight_decay
+ )
+ else:
+ self.optimizer = None # DeepSpeed manages optimizer
+
+ # Setup AMP scaler if using fp16 (not needed with DeepSpeed)
+ self.scaler = None
+ if not use_deepspeed and precision_config.use_amp and precision_config.amp_dtype == "float16":
+ self.scaler = GradScaler()
+
+ # Training state
+ self.global_step = 0
+ self.rng = np.random.default_rng(train_config.seed)
+
+ # Metrics tracking
+ self.metrics_history: List[Dict[str, float]] = []
+
+ def compute_log_probs(
+ self,
+ model: torch.nn.Module,
+ input_ids: torch.Tensor,
+ attention_mask: torch.Tensor,
+ labels: torch.Tensor
+ ) -> torch.Tensor:
+ """Compute token-level log probabilities."""
+ with autocast(
+ enabled=self.precision_config.use_amp,
+ dtype=get_torch_dtype(self.precision_config.amp_dtype)
+ ):
+ outputs = model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ use_cache=False
+ )
+
+ logits = outputs.logits
+ log_probs = torch.log_softmax(logits, dim=-1)
+
+ # Gather log probs for actual tokens
+ # Shift for autoregressive: predict next token
+ shift_log_probs = log_probs[:, :-1, :]
+ shift_labels = labels[:, 1:]
+
+ token_log_probs = torch.gather(
+ shift_log_probs,
+ dim=-1,
+ index=shift_labels.unsqueeze(-1)
+ ).squeeze(-1)
+
+ return token_log_probs
+
+ def compute_dapo_loss(
+ self,
+ policy_log_probs: torch.Tensor,
+ ref_log_probs: torch.Tensor,
+ rewards: torch.Tensor,
+ response_mask: torch.Tensor
+ ) -> Tuple[torch.Tensor, Dict[str, float]]:
+ """
+ Compute DAPO clip-only loss.
+
+ DAPO uses ratio clipping without explicit KL penalty (beta=0).
+ The clipping provides implicit KL constraint (Gate I).
+ """
+ # Compute log ratios
+ log_ratios = policy_log_probs - ref_log_probs
+
+ # Sum log ratios over response tokens
+ masked_log_ratios = log_ratios * response_mask
+ sequence_log_ratios = masked_log_ratios.sum(dim=-1)
+
+ # Compute importance sampling ratios
+ ratios = torch.exp(sequence_log_ratios)
+
+ # DAPO objective with clipping
+ clip_ratio = self.train_config.clip_ratio
+
+ # Advantage estimation (simplified: just use rewards)
+ advantages = rewards
+
+ # Clipped surrogate objective
+ unclipped = ratios * advantages
+ clipped = torch.clamp(ratios, 1 - clip_ratio, 1 + clip_ratio) * advantages
+
+ # Take minimum for pessimistic update
+ loss = -torch.min(unclipped, clipped).mean()
+
+ # Compute metrics
+ with torch.no_grad():
+ approx_kl = (log_ratios * response_mask).sum() / response_mask.sum()
+ clip_fraction = ((ratios - 1).abs() > clip_ratio).float().mean()
+
+ metrics = {
+ "loss": loss.item(),
+ "approx_kl": approx_kl.item(),
+ "clip_fraction": clip_fraction.item(),
+ "mean_ratio": ratios.mean().item(),
+ "mean_reward": rewards.mean().item(),
+ }
+
+ return loss, metrics
+
+ def generate_rollouts(
+ self,
+ prompts: List[str],
+ num_samples: int
+ ) -> List[Dict[str, Any]]:
+ """Generate rollouts for a batch of prompts."""
+ rollouts = []
+
+ self.model.eval()
+ with torch.no_grad():
+ for prompt in prompts:
+ inputs = self.tokenizer(
+ prompt,
+ return_tensors="pt",
+ truncation=True,
+ max_length=self.train_config.max_seq_len // 2
+ ).to(self.device)
+
+ for _ in range(num_samples):
+ with autocast(
+ enabled=self.precision_config.use_amp,
+ dtype=get_torch_dtype(self.precision_config.amp_dtype)
+ ):
+ outputs = self.model.generate(
+ **inputs,
+ max_new_tokens=self.train_config.max_seq_len // 2,
+ do_sample=True,
+ temperature=0.7,
+ top_p=0.8,
+ pad_token_id=self.tokenizer.eos_token_id
+ )
+
+ response_ids = outputs[0, inputs["input_ids"].shape[1]:]
+ response_text = self.tokenizer.decode(
+ response_ids,
+ skip_special_tokens=True
+ )
+
+ rollouts.append({
+ "prompt": prompt,
+ "response": response_text,
+ "input_ids": inputs["input_ids"][0],
+ "response_ids": response_ids,
+ "full_ids": outputs[0]
+ })
+
+ self.model.train()
+ return rollouts
+
+ def train_step(
+ self,
+ batch: List[Dict[str, Any]]
+ ) -> Dict[str, float]:
+ """Execute one training step on a batch."""
+ self.model.train()
+
+ # Generate rollouts
+ prompts = [ex["prompt"] for ex in batch]
+ ground_truths = [ex.get("answer", None) for ex in batch]
+
+ rollouts = self.generate_rollouts(
+ prompts,
+ self.train_config.num_rollouts_per_prompt
+ )
+
+ # Compute rewards
+ rewards = []
+ for i, rollout in enumerate(rollouts):
+ prompt_idx = i // self.train_config.num_rollouts_per_prompt
+ gt = ground_truths[prompt_idx] if prompt_idx < len(ground_truths) else None
+ reward = compute_math_reward(rollout["prompt"], rollout["response"], gt)
+ rewards.append(reward)
+
+ rewards_tensor = torch.tensor(rewards, device=self.device, dtype=torch.float32)
+
+ # Skip if all rewards are the same (no learning signal)
+ if rewards_tensor.std() < 1e-6:
+ return {"skipped": 1.0}
+
+ # Normalize rewards per prompt (advantage estimation)
+ rewards_per_prompt = rewards_tensor.view(-1, self.train_config.num_rollouts_per_prompt)
+ normalized_rewards = (rewards_per_prompt - rewards_per_prompt.mean(dim=1, keepdim=True))
+ normalized_rewards = normalized_rewards / (rewards_per_prompt.std(dim=1, keepdim=True) + 1e-8)
+ normalized_rewards = normalized_rewards.view(-1)
+
+ # Prepare for training (DeepSpeed handles zero_grad internally)
+ if not self.use_deepspeed:
+ self.optimizer.zero_grad()
+
+ total_loss = 0.0
+ all_metrics: Dict[str, List[float]] = {}
+
+ # Process rollouts in micro-batches
+ num_rollouts = len(rollouts)
+ micro_batch_size = self.train_config.micro_batch_size
+
+ for mb_start in range(0, num_rollouts, micro_batch_size):
+ mb_end = min(mb_start + micro_batch_size, num_rollouts)
+ mb_rollouts = rollouts[mb_start:mb_end]
+ mb_rewards = normalized_rewards[mb_start:mb_end]
+
+ # Prepare batch tensors
+ max_len = max(len(r["full_ids"]) for r in mb_rollouts)
+ batch_input_ids = torch.zeros(len(mb_rollouts), max_len, dtype=torch.long, device=self.device)
+ batch_attention_mask = torch.zeros(len(mb_rollouts), max_len, dtype=torch.long, device=self.device)
+ batch_response_mask = torch.zeros(len(mb_rollouts), max_len - 1, dtype=torch.float32, device=self.device)
+
+ for i, rollout in enumerate(mb_rollouts):
+ seq_len = len(rollout["full_ids"])
+ batch_input_ids[i, :seq_len] = rollout["full_ids"]
+ batch_attention_mask[i, :seq_len] = 1
+ prompt_len = len(rollout["input_ids"])
+ batch_response_mask[i, prompt_len-1:seq_len-1] = 1
+
+ # Compute log probs for policy and reference
+ policy_log_probs = self.compute_log_probs(
+ self.model,
+ batch_input_ids,
+ batch_attention_mask,
+ batch_input_ids
+ )
+
+ with torch.no_grad():
+ # Move tensors to ref_device if different from training device
+ if self.ref_device != self.device:
+ ref_input_ids = batch_input_ids.to(self.ref_device)
+ ref_attention_mask = batch_attention_mask.to(self.ref_device)
+ else:
+ ref_input_ids = batch_input_ids
+ ref_attention_mask = batch_attention_mask
+
+ ref_log_probs = self.compute_log_probs(
+ self.ref_model,
+ ref_input_ids,
+ ref_attention_mask,
+ ref_input_ids
+ )
+
+ # Move ref_log_probs back to training device
+ if self.ref_device != self.device:
+ ref_log_probs = ref_log_probs.to(self.device)
+
+ # Compute DAPO loss
+ loss, metrics = self.compute_dapo_loss(
+ policy_log_probs,
+ ref_log_probs,
+ mb_rewards,
+ batch_response_mask
+ )
+
+ # Scale loss for gradient accumulation (DeepSpeed handles this internally)
+ if self.use_deepspeed:
+ scaled_loss = loss
+ else:
+ scaled_loss = loss / self.train_config.grad_accumulation_steps
+
+ # Backward pass
+ if self.use_deepspeed:
+ self.model_engine.backward(scaled_loss)
+ elif self.scaler is not None:
+ self.scaler.scale(scaled_loss).backward()
+ else:
+ scaled_loss.backward()
+
+ total_loss += loss.item()
+
+ for k, v in metrics.items():
+ if k not in all_metrics:
+ all_metrics[k] = []
+ all_metrics[k].append(v)
+
+ # Optimizer step
+ if self.use_deepspeed:
+ self.model_engine.step()
+ elif self.scaler is not None:
+ self.scaler.step(self.optimizer)
+ self.scaler.update()
+ else:
+ self.optimizer.step()
+
+ self.global_step += 1
+
+ # Aggregate metrics
+ step_metrics = {k: np.mean(v) for k, v in all_metrics.items()}
+ step_metrics["total_loss"] = total_loss
+ step_metrics["step"] = self.global_step
+
+ self.metrics_history.append(step_metrics)
+
+ return step_metrics
+
+ def train(
+ self,
+ train_data: List[Dict[str, Any]],
+ save_checkpoints: bool = True
+ ) -> None:
+ """Run the full training loop."""
+ logger.info(f"Starting training for {self.train_config.num_steps} steps")
+
+ checkpoint_steps = set(self.train_config.checkpoint_steps)
+
+ for step in range(self.train_config.num_steps):
+ # Sample batch
+ batch = sample_batch(
+ train_data,
+ self.train_config.global_batch_size // self.train_config.num_rollouts_per_prompt,
+ self.rng
+ )
+
+ # Training step
+ metrics = self.train_step(batch)
+
+ # Logging
+ if step % 10 == 0:
+ logger.info(
+ f"Step {step}/{self.train_config.num_steps} | "
+ f"Loss: {metrics.get('total_loss', 0):.4f} | "
+ f"KL: {metrics.get('approx_kl', 0):.4f} | "
+ f"Reward: {metrics.get('mean_reward', 0):.4f}"
+ )
+
+ # Checkpointing
+ if save_checkpoints and (step + 1) in checkpoint_steps:
+ self.save_checkpoint(step + 1)
+
+ def save_checkpoint(self, step: int) -> None:
+ """Save model checkpoint."""
+ ckpt_dir = os.path.join(
+ self.train_config.output_dir,
+ f"checkpoint_step{step}"
+ )
+ os.makedirs(ckpt_dir, exist_ok=True)
+
+ if self.use_deepspeed:
+ # Use DeepSpeed's save_checkpoint for proper ZeRO handling
+ self.model_engine.save_checkpoint(ckpt_dir, tag=f"step{step}")
+ else:
+ self.model.save_pretrained(ckpt_dir)
+ # Save training state
+ state = {
+ "step": step,
+ "optimizer_state_dict": self.optimizer.state_dict(),
+ "metrics_history": self.metrics_history,
+ }
+ torch.save(state, os.path.join(ckpt_dir, "training_state.pt"))
+
+ self.tokenizer.save_pretrained(ckpt_dir)
+ logger.info(f"Saved checkpoint at step {step} to {ckpt_dir}")
+
+ def save_final(self) -> str:
+ """Save final model."""
+ final_dir = os.path.join(self.train_config.output_dir, "final_model")
+ os.makedirs(final_dir, exist_ok=True)
+
+ if self.use_deepspeed:
+ # Use DeepSpeed's save_checkpoint for proper ZeRO-3 weight gathering
+ self.model_engine.save_checkpoint(final_dir, tag="final")
+ else:
+ self.model.save_pretrained(final_dir)
+
+ self.tokenizer.save_pretrained(final_dir)
+
+ # Save metrics
+ with open(os.path.join(final_dir, "metrics_history.json"), "w") as f:
+ json.dump(self.metrics_history, f, indent=2)
+
+ logger.info(f"Saved final model to {final_dir}")
+ return final_dir
+
+
+# ============================================================================
+# Main Entry Point
+# ============================================================================
+
+def parse_args() -> argparse.Namespace:
+ """Parse command line arguments."""
+ parser = argparse.ArgumentParser(
+ description="RLVR Training with DAPO Algorithm"
+ )
+ parser.add_argument(
+ "--precision_mode",
+ type=str,
+ required=True,
+ choices=["fp32", "bf16"],
+ help="Precision mode: fp32 (high precision) or bf16 (default RLVR)"
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ required=True,
+ help="Random seed for reproducibility"
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ required=True,
+ help="Directory to save outputs"
+ )
+ parser.add_argument(
+ "--train_dataset_path",
+ type=str,
+ required=True,
+ help="Path to training dataset JSON"
+ )
+ parser.add_argument(
+ "--model_name",
+ type=str,
+ default="Qwen/Qwen2.5-Math-7B",
+ help="HuggingFace model identifier"
+ )
+ parser.add_argument(
+ "--num_steps",
+ type=int,
+ default=300,
+ help="Number of training steps"
+ )
+ parser.add_argument(
+ "--device",
+ type=str,
+ default="cuda",
+ help="Device to use for training"
+ )
+ parser.add_argument(
+ "--deepspeed",
+ type=str,
+ default=None,
+ help="Path to DeepSpeed config file"
+ )
+ parser.add_argument(
+ "--local_rank",
+ type=int,
+ default=-1,
+ help="Local rank for distributed training (set by DeepSpeed)"
+ )
+ return parser.parse_args()
+
+
+def main() -> None:
+ """Main training function."""
+ args = parse_args()
+
+ # Create configurations
+ train_config = make_training_config(
+ precision_mode=args.precision_mode,
+ seed=args.seed,
+ output_dir=args.output_dir,
+ train_dataset_path=args.train_dataset_path,
+ model_name=args.model_name
+ )
+ train_config.num_steps = args.num_steps
+
+ precision_config = make_precision_config(args.precision_mode)
+
+ # Setup output directory
+ os.makedirs(train_config.output_dir, exist_ok=True)
+
+ # Save configurations
+ with open(os.path.join(train_config.output_dir, "train_config.json"), "w") as f:
+ json.dump(asdict(train_config), f, indent=2)
+ with open(os.path.join(train_config.output_dir, "precision_config.json"), "w") as f:
+ json.dump(asdict(precision_config), f, indent=2)
+
+ # Set seeds and determinism
+ set_seed(train_config.seed)
+ configure_torch_deterministic(precision_config.deterministic)
+
+ # Check if using DeepSpeed
+ use_deepspeed = args.deepspeed is not None
+
+ # Setup devices
+ num_gpus = torch.cuda.device_count()
+
+ if use_deepspeed:
+ # Initialize DeepSpeed distributed
+ deepspeed.init_distributed()
+ local_rank = args.local_rank if args.local_rank >= 0 else int(os.environ.get("LOCAL_RANK", 0))
+ device = torch.device(f"cuda:{local_rank}")
+ torch.cuda.set_device(device)
+ # Put reference model on same GPU as training (each rank has its own copy)
+ # ZeRO-2 shards optimizer states, so there's room for the bf16 ref model (~14GB)
+ ref_device = device
+ logger.info(f"DeepSpeed: rank {local_rank}, training on {device}, ref model on {ref_device}")
+ else:
+ device = torch.device(args.device if torch.cuda.is_available() else "cpu")
+ if num_gpus >= 2:
+ ref_device = torch.device("cuda:1")
+ logger.info(f"Using device: {device} for training, {ref_device} for reference model")
+ else:
+ ref_device = device
+ logger.info(f"Using device: {device} for both training and reference model")
+
+ # Load tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(
+ train_config.model_name,
+ use_fast=True,
+ trust_remote_code=True
+ )
+ tokenizer.padding_side = "left"
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token
+
+ # Load model
+ logger.info(f"Loading model: {train_config.model_name}")
+ model = AutoModelForCausalLM.from_pretrained(
+ train_config.model_name,
+ torch_dtype=torch.float32, # Load in FP32 first
+ device_map=None,
+ trust_remote_code=True
+ )
+
+ # Cast to target precision
+ model = cast_model_param_dtype(model, precision_config.param_dtype)
+
+ # Enable gradient checkpointing to save memory (trades compute for memory)
+ model.gradient_checkpointing_enable()
+ logger.info("Enabled gradient checkpointing to reduce memory usage")
+
+ # Disable dropout if needed
+ if not train_config.use_dropout:
+ disable_dropout(model)
+
+ logger.info(f"Model loaded with {count_parameters(model):,} trainable parameters")
+
+ # Initialize DeepSpeed or move model to device
+ if use_deepspeed:
+ # Create optimizer for DeepSpeed
+ optimizer = torch.optim.AdamW(
+ model.parameters(),
+ lr=train_config.learning_rate,
+ betas=(train_config.beta1, train_config.beta2),
+ weight_decay=train_config.weight_decay
+ )
+
+ # Load DeepSpeed config
+ with open(args.deepspeed, 'r') as f:
+ ds_config = json.load(f)
+
+ # Compute batch sizes compatible with DeepSpeed
+ # DeepSpeed requires: train_batch_size = micro_batch * grad_acc * world_size
+ world_size = int(os.environ.get("WORLD_SIZE", num_gpus - 1)) # -1 for ref model GPU
+ micro_batch = train_config.micro_batch_size
+
+ # Compute grad_acc to get closest to desired global batch size
+ desired_global = train_config.global_batch_size
+ grad_acc = max(1, desired_global // (micro_batch * world_size))
+ actual_global = micro_batch * grad_acc * world_size
+
+ ds_config["train_batch_size"] = actual_global
+ ds_config["train_micro_batch_size_per_gpu"] = micro_batch
+ ds_config["gradient_accumulation_steps"] = grad_acc
+
+ logger.info(f"DeepSpeed batch config: global={actual_global}, micro={micro_batch}, grad_acc={grad_acc}, world_size={world_size}")
+
+ # Initialize DeepSpeed engine
+ model_engine, optimizer, _, _ = deepspeed.initialize(
+ model=model,
+ optimizer=optimizer,
+ config=ds_config
+ )
+ logger.info(f"DeepSpeed ZeRO-2 initialized with {world_size} GPUs")
+ else:
+ model = model.to(device)
+ model_engine = model
+
+ # Load reference model (frozen copy)
+ # Always use bf16 for reference model - it's frozen and only used for KL computation.
+ # This is a controlled variable: same precision for both fp32 and bf16 training modes.
+ # The experiment tests training precision effects, not reference model precision.
+ logger.info("Loading reference model (bf16 for both modes - controlled variable)")
+ ref_model = AutoModelForCausalLM.from_pretrained(
+ train_config.model_name,
+ torch_dtype=torch.bfloat16, # Always bf16 to save memory & control variable
+ device_map=None,
+ trust_remote_code=True
+ )
+ ref_model = ref_model.to(ref_device)
+ ref_model.eval()
+
+ # Load training data
+ train_data = load_training_data(train_config.train_dataset_path)
+
+ # Initialize trainer
+ trainer = DAPOTrainer(
+ model_engine=model_engine,
+ ref_model=ref_model,
+ tokenizer=tokenizer,
+ train_config=train_config,
+ precision_config=precision_config,
+ device=device,
+ ref_device=ref_device,
+ use_deepspeed=use_deepspeed
+ )
+
+ # Run training
+ trainer.train(train_data, save_checkpoints=True)
+
+ # Save final model
+ final_path = trainer.save_final()
+ logger.info(f"Training complete. Final model saved to: {final_path}")
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/utils_bf16_sparsity.py b/utils_bf16_sparsity.py
new file mode 100644
index 0000000..2e0729d
--- /dev/null
+++ b/utils_bf16_sparsity.py
@@ -0,0 +1,459 @@
+# utils_bf16_sparsity.py
+"""
+bf16-Aware Update Sparsity Utilities.
+
+This module implements the bf16-aware update sparsity metric from the RLVR paper,
+which measures how many parameter updates are "visible" after bf16 quantization.
+
+Key concepts:
+- Due to bf16's limited precision (7 mantissa bits), small updates may be "swallowed"
+- The bf16 ULP (Unit in Last Place) creates a minimum relative update threshold
+- Updates smaller than ~0.2-0.4% may not be reflected in bf16 representation
+
+Reference:
+- RLVR paper Definition 2.1 & 2.2
+- bf16 ULP analysis showing relative update threshold of 2^{-8} to 2^{-7}
+"""
+
+import torch
+import numpy as np
+from typing import Dict, Any, Tuple, List, Optional
+import logging
+from tqdm import tqdm
+
+logger = logging.getLogger(__name__)
+
+
+# ============================================================================
+# bf16 Equality Check
+# ============================================================================
+
+def bf16_approximately_equal(
+ w: torch.Tensor,
+ w_hat: torch.Tensor,
+ eta: float = 1e-3
+) -> torch.Tensor:
+ """
+ Check if two tensors are approximately equal under bf16 precision.
+
+ From RLVR Definition 2.1:
+ Two values w and w_hat are considered bf16-equal if:
+ |w_hat - w| <= eta * max(|w|, |w_hat|)
+
+ When eta < 2^{-9}, this is equivalent to bit-wise bf16 equality.
+
+ Args:
+ w: Original weights tensor
+ w_hat: Updated weights tensor
+ eta: Relative tolerance (default 1e-3 as in RLVR)
+
+ Returns:
+ Boolean mask where True indicates bf16-equal
+ """
+ max_abs = torch.maximum(w.abs(), w_hat.abs())
+ diff = (w_hat - w).abs()
+
+ # Handle zero weights (avoid division by zero in relative comparison)
+ # For zeros, use absolute comparison
+ zero_mask = max_abs < 1e-10
+
+ # Relative comparison
+ relative_equal = diff <= eta * max_abs
+
+ # For zeros, check if both are effectively zero
+ both_zero = w.abs() < 1e-10
+ both_zero = both_zero & (w_hat.abs() < 1e-10)
+
+ # Combine: either relatively equal, or both effectively zero
+ equal_mask = relative_equal | (zero_mask & both_zero)
+
+ return equal_mask
+
+
+def bf16_bitwise_equal(
+ w: torch.Tensor,
+ w_hat: torch.Tensor
+) -> torch.Tensor:
+ """
+ Check if two tensors are bitwise equal in bf16 representation.
+
+ This is the strictest equality check - values must have identical
+ bf16 bit patterns.
+
+ Args:
+ w: Original weights tensor
+ w_hat: Updated weights tensor
+
+ Returns:
+ Boolean mask where True indicates bitwise bf16 equality
+ """
+ # Convert to bf16 and compare
+ w_bf16 = w.to(torch.bfloat16)
+ w_hat_bf16 = w_hat.to(torch.bfloat16)
+
+ # Bitwise comparison via view as int16
+ w_bits = w_bf16.view(torch.int16)
+ w_hat_bits = w_hat_bf16.view(torch.int16)
+
+ return w_bits == w_hat_bits
+
+
+# ============================================================================
+# Update Count and Sparsity
+# ============================================================================
+
+def compute_bf16_update_count(
+ w: torch.Tensor,
+ w_hat: torch.Tensor,
+ eta: float = 1e-3
+) -> Tuple[int, int, int]:
+ """
+ Compute bf16-aware update count.
+
+ From RLVR Definition 2.2:
+ |θ_1 - θ_0|_{0,bf16,η} = #{i: w_hat_i not≈_{bf16,η} w_i}
+
+ Args:
+ w: Original weights tensor
+ w_hat: Updated weights tensor
+ eta: Relative tolerance
+
+ Returns:
+ Tuple of (num_changed, num_unchanged, total)
+ """
+ equal_mask = bf16_approximately_equal(w, w_hat, eta=eta)
+
+ total = int(equal_mask.numel())
+ num_unchanged = int(equal_mask.sum().item())
+ num_changed = total - num_unchanged
+
+ return num_changed, num_unchanged, total
+
+
+def compute_bf16_sparsity(
+ base_model: torch.nn.Module,
+ finetuned_model: torch.nn.Module,
+ eta: float = 1e-3,
+ include_layer_stats: bool = False
+) -> Dict[str, Any]:
+ """
+ Compute bf16-aware update sparsity between two models.
+
+ Sparsity = 1 - |θ_1 - θ_0|_{0,bf16,η} / n
+
+ Values closer to 1 mean more sparse (fewer visible updates).
+ Values closer to 0 mean more dense (more visible updates).
+
+ RLVR Table 1 reports sparsity in range 36%-92% for their experiments.
+
+ Args:
+ base_model: Original model (θ_0)
+ finetuned_model: Updated model (θ_1)
+ eta: Relative tolerance
+ include_layer_stats: If True, include per-layer statistics
+
+ Returns:
+ Dictionary with sparsity metrics
+ """
+ base_params = dict(base_model.named_parameters())
+ ft_params = dict(finetuned_model.named_parameters())
+
+ total_elements = 0
+ changed_elements = 0
+
+ layer_stats: Dict[str, Dict[str, Any]] = {}
+
+ for name, base_param in base_params.items():
+ if name not in ft_params:
+ logger.warning(f"Parameter {name} not found in finetuned model")
+ continue
+
+ ft_param = ft_params[name]
+
+ if base_param.shape != ft_param.shape:
+ logger.warning(
+ f"Shape mismatch for {name}: "
+ f"{base_param.shape} vs {ft_param.shape}"
+ )
+ continue
+
+ # Move to CPU for computation
+ w = base_param.detach().cpu().float().flatten()
+ w_hat = ft_param.detach().cpu().float().flatten()
+
+ # Compute update count
+ num_changed, num_unchanged, total = compute_bf16_update_count(
+ w, w_hat, eta=eta
+ )
+
+ total_elements += total
+ changed_elements += num_changed
+
+ if include_layer_stats:
+ layer_sparsity = 1.0 - num_changed / total if total > 0 else 1.0
+ layer_stats[name] = {
+ "num_changed": num_changed,
+ "num_unchanged": num_unchanged,
+ "total": total,
+ "sparsity": layer_sparsity,
+ "shape": list(base_param.shape)
+ }
+
+ # Compute overall sparsity
+ overall_sparsity = 1.0 - changed_elements / total_elements if total_elements > 0 else 1.0
+
+ result = {
+ "sparsity": overall_sparsity,
+ "sparsity_percent": overall_sparsity * 100,
+ "num_changed": changed_elements,
+ "num_unchanged": total_elements - changed_elements,
+ "total_parameters": total_elements,
+ "eta": eta,
+ "update_fraction": changed_elements / total_elements if total_elements > 0 else 0.0,
+ }
+
+ if include_layer_stats:
+ result["layer_stats"] = layer_stats
+
+ return result
+
+
+# ============================================================================
+# Update Magnitude Analysis
+# ============================================================================
+
+def analyze_update_magnitudes(
+ base_model: torch.nn.Module,
+ finetuned_model: torch.nn.Module
+) -> Dict[str, Any]:
+ """
+ Analyze the distribution of update magnitudes.
+
+ This helps understand which updates are below the bf16 ULP threshold.
+
+ Returns statistics about:
+ - Absolute update magnitudes
+ - Relative update magnitudes
+ - Distribution relative to bf16 ULP
+ """
+ base_params = dict(base_model.named_parameters())
+ ft_params = dict(finetuned_model.named_parameters())
+
+ all_relative_updates: List[float] = []
+ all_absolute_updates: List[float] = []
+
+ for name, base_param in base_params.items():
+ if name not in ft_params:
+ continue
+
+ ft_param = ft_params[name]
+ if base_param.shape != ft_param.shape:
+ continue
+
+ w = base_param.detach().cpu().float().flatten()
+ w_hat = ft_param.detach().cpu().float().flatten()
+
+ # Absolute updates
+ abs_updates = (w_hat - w).abs()
+
+ # Relative updates (avoid division by zero)
+ max_abs = torch.maximum(w.abs(), w_hat.abs())
+ valid_mask = max_abs > 1e-10
+
+ rel_updates = torch.zeros_like(abs_updates)
+ rel_updates[valid_mask] = abs_updates[valid_mask] / max_abs[valid_mask]
+
+ # Sample for statistics (avoid memory issues)
+ sample_size = min(10000, len(abs_updates))
+ indices = np.random.choice(len(abs_updates), sample_size, replace=False)
+
+ all_absolute_updates.extend(abs_updates[indices].tolist())
+ all_relative_updates.extend(rel_updates[indices].tolist())
+
+ abs_array = np.array(all_absolute_updates)
+ rel_array = np.array(all_relative_updates)
+
+ # bf16 ULP threshold (approximately 2^{-8} to 2^{-7}, or 0.2% to 0.4%)
+ bf16_ulp_low = 2 ** -8 # ~0.39%
+ bf16_ulp_high = 2 ** -7 # ~0.78%
+
+ # Fraction of updates below ULP threshold
+ below_low = (rel_array < bf16_ulp_low).mean()
+ below_high = (rel_array < bf16_ulp_high).mean()
+
+ result = {
+ "absolute_updates": {
+ "mean": float(np.mean(abs_array)),
+ "std": float(np.std(abs_array)),
+ "median": float(np.median(abs_array)),
+ "min": float(np.min(abs_array)),
+ "max": float(np.max(abs_array)),
+ "percentiles": {
+ "p25": float(np.percentile(abs_array, 25)),
+ "p50": float(np.percentile(abs_array, 50)),
+ "p75": float(np.percentile(abs_array, 75)),
+ "p90": float(np.percentile(abs_array, 90)),
+ "p99": float(np.percentile(abs_array, 99)),
+ }
+ },
+ "relative_updates": {
+ "mean": float(np.mean(rel_array)),
+ "std": float(np.std(rel_array)),
+ "median": float(np.median(rel_array)),
+ "min": float(np.min(rel_array)),
+ "max": float(np.max(rel_array)),
+ "percentiles": {
+ "p25": float(np.percentile(rel_array, 25)),
+ "p50": float(np.percentile(rel_array, 50)),
+ "p75": float(np.percentile(rel_array, 75)),
+ "p90": float(np.percentile(rel_array, 90)),
+ "p99": float(np.percentile(rel_array, 99)),
+ }
+ },
+ "bf16_ulp_analysis": {
+ "ulp_low_threshold": bf16_ulp_low,
+ "ulp_high_threshold": bf16_ulp_high,
+ "fraction_below_low": float(below_low),
+ "fraction_below_high": float(below_high),
+ "estimated_swallowed_fraction": float(below_low),
+ }
+ }
+
+ return result
+
+
+# ============================================================================
+# Sparsity Trajectory
+# ============================================================================
+
+def compute_sparsity_trajectory(
+ base_model_path: str,
+ checkpoint_paths: List[str],
+ eta: float = 1e-3
+) -> List[Dict[str, Any]]:
+ """
+ Compute bf16 sparsity for a sequence of checkpoints.
+
+ Useful for understanding how sparsity evolves during training.
+
+ Args:
+ base_model_path: Path to base model
+ checkpoint_paths: List of checkpoint paths (in training order)
+ eta: Relative tolerance
+
+ Returns:
+ List of sparsity results for each checkpoint
+ """
+ from transformers import AutoModelForCausalLM
+
+ # Load base model
+ logger.info(f"Loading base model from {base_model_path}")
+ base_model = AutoModelForCausalLM.from_pretrained(
+ base_model_path,
+ torch_dtype=torch.float32,
+ device_map="cpu"
+ )
+
+ trajectory = []
+
+ for ckpt_path in tqdm(checkpoint_paths, desc="Computing sparsity"):
+ # Load checkpoint
+ ckpt_model = AutoModelForCausalLM.from_pretrained(
+ ckpt_path,
+ torch_dtype=torch.float32,
+ device_map="cpu"
+ )
+
+ # Compute sparsity
+ sparsity_result = compute_bf16_sparsity(
+ base_model=base_model,
+ finetuned_model=ckpt_model,
+ eta=eta,
+ include_layer_stats=False
+ )
+
+ trajectory.append({
+ "checkpoint": ckpt_path,
+ "sparsity": sparsity_result["sparsity"],
+ "sparsity_percent": sparsity_result["sparsity_percent"],
+ "num_changed": sparsity_result["num_changed"],
+ })
+
+ # Free memory
+ del ckpt_model
+
+ return trajectory
+
+
+# ============================================================================
+# Layer-wise Sparsity Analysis
+# ============================================================================
+
+def analyze_layer_sparsity_patterns(
+ base_model: torch.nn.Module,
+ finetuned_model: torch.nn.Module,
+ eta: float = 1e-3
+) -> Dict[str, Any]:
+ """
+ Analyze sparsity patterns across different layer types.
+
+ Groups layers by type (attention, MLP, embeddings, etc.) and
+ reports aggregate sparsity statistics.
+ """
+ sparsity_result = compute_bf16_sparsity(
+ base_model=base_model,
+ finetuned_model=finetuned_model,
+ eta=eta,
+ include_layer_stats=True
+ )
+
+ layer_stats = sparsity_result.get("layer_stats", {})
+
+ # Group by layer type
+ groups: Dict[str, List[Dict[str, Any]]] = {
+ "attention": [],
+ "mlp": [],
+ "embedding": [],
+ "norm": [],
+ "other": []
+ }
+
+ for name, stats in layer_stats.items():
+ name_lower = name.lower()
+
+ if any(k in name_lower for k in ["attn", "attention", "self_attn"]):
+ groups["attention"].append(stats)
+ elif any(k in name_lower for k in ["mlp", "fc", "dense", "linear"]):
+ groups["mlp"].append(stats)
+ elif any(k in name_lower for k in ["embed", "wte", "wpe"]):
+ groups["embedding"].append(stats)
+ elif any(k in name_lower for k in ["norm", "ln", "layer_norm"]):
+ groups["norm"].append(stats)
+ else:
+ groups["other"].append(stats)
+
+ # Compute aggregate statistics per group
+ group_analysis = {}
+ for group_name, layer_list in groups.items():
+ if not layer_list:
+ continue
+
+ sparsities = [l["sparsity"] for l in layer_list]
+ total_params = sum(l["total"] for l in layer_list)
+ total_changed = sum(l["num_changed"] for l in layer_list)
+
+ group_analysis[group_name] = {
+ "num_layers": len(layer_list),
+ "total_params": total_params,
+ "mean_sparsity": float(np.mean(sparsities)),
+ "std_sparsity": float(np.std(sparsities)),
+ "min_sparsity": float(np.min(sparsities)),
+ "max_sparsity": float(np.max(sparsities)),
+ "aggregate_sparsity": 1.0 - total_changed / total_params if total_params > 0 else 1.0,
+ }
+
+ return {
+ "overall_sparsity": sparsity_result["sparsity"],
+ "group_analysis": group_analysis,
+ }
+
diff --git a/utils_kl.py b/utils_kl.py
new file mode 100644
index 0000000..2be50a0
--- /dev/null
+++ b/utils_kl.py
@@ -0,0 +1,419 @@
+# utils_kl.py
+"""
+KL Divergence Utilities for RLVR Experiments.
+
+This module provides utilities for computing KL divergence between
+policy distributions, including:
+- Token-level KL computation
+- Sequence-level KL aggregation
+- Dataset-level KL estimation
+"""
+
+import torch
+import torch.nn.functional as F
+from typing import Dict, Any, List, Tuple, Optional
+import numpy as np
+import logging
+from tqdm import tqdm
+
+logger = logging.getLogger(__name__)
+
+
+# ============================================================================
+# Token-Level KL Computation
+# ============================================================================
+
+def compute_token_log_probs(
+ model: torch.nn.Module,
+ input_ids: torch.Tensor,
+ attention_mask: torch.Tensor,
+ labels: Optional[torch.Tensor] = None
+) -> torch.Tensor:
+ """
+ Compute token-level log probabilities.
+
+ Args:
+ model: Language model
+ input_ids: Input token IDs [batch, seq_len]
+ attention_mask: Attention mask [batch, seq_len]
+ labels: Token labels for which to compute log probs (default: input_ids)
+
+ Returns:
+ Token log probabilities [batch, seq_len - 1]
+ """
+ if labels is None:
+ labels = input_ids
+
+ with torch.no_grad():
+ outputs = model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ use_cache=False
+ )
+
+ logits = outputs.logits # [batch, seq_len, vocab]
+
+ # Shift for autoregressive: predict token t from tokens 0..t-1
+ shift_logits = logits[:, :-1, :] # [batch, seq_len-1, vocab]
+ shift_labels = labels[:, 1:] # [batch, seq_len-1]
+
+ # Compute log probabilities
+ log_probs = F.log_softmax(shift_logits, dim=-1)
+
+ # Gather log probs for actual tokens
+ token_log_probs = torch.gather(
+ log_probs,
+ dim=-1,
+ index=shift_labels.unsqueeze(-1)
+ ).squeeze(-1) # [batch, seq_len-1]
+
+ return token_log_probs
+
+
+def compute_kl_per_token(
+ policy_log_probs: torch.Tensor,
+ ref_log_probs: torch.Tensor
+) -> torch.Tensor:
+ """
+ Compute per-token KL divergence.
+
+ KL(π || π_ref) at token t = log π(y_t) - log π_ref(y_t)
+
+ Note: This is the forward KL from policy to reference.
+ """
+ return policy_log_probs - ref_log_probs
+
+
+def compute_reverse_kl_per_token(
+ policy_logits: torch.Tensor,
+ ref_logits: torch.Tensor,
+ temperature: float = 1.0
+) -> torch.Tensor:
+ """
+ Compute per-token reverse KL divergence using full distributions.
+
+ KL(π || π_ref) = Σ_y π(y) [log π(y) - log π_ref(y)]
+
+ This is more expensive but gives the true KL.
+ """
+ policy_probs = F.softmax(policy_logits / temperature, dim=-1)
+ policy_log_probs = F.log_softmax(policy_logits / temperature, dim=-1)
+ ref_log_probs = F.log_softmax(ref_logits / temperature, dim=-1)
+
+ # KL = Σ p(x) log(p(x)/q(x)) = Σ p(x) [log p(x) - log q(x)]
+ kl = (policy_probs * (policy_log_probs - ref_log_probs)).sum(dim=-1)
+
+ return kl
+
+
+# ============================================================================
+# Sequence-Level KL
+# ============================================================================
+
+def compute_sequence_kl(
+ policy_model: torch.nn.Module,
+ ref_model: torch.nn.Module,
+ input_ids: torch.Tensor,
+ attention_mask: torch.Tensor,
+ response_start_idx: int = 0,
+ normalize_by_length: bool = False
+) -> Dict[str, float]:
+ """
+ Compute KL divergence for a single sequence.
+
+ Args:
+ policy_model: Finetuned policy model
+ ref_model: Reference model
+ input_ids: Full sequence (prompt + response) [1, seq_len]
+ attention_mask: Attention mask [1, seq_len]
+ response_start_idx: Index where response starts
+ normalize_by_length: If True, return average KL per token
+
+ Returns:
+ Dictionary with KL metrics
+ """
+ # Get log probs from both models
+ policy_log_probs = compute_token_log_probs(
+ policy_model, input_ids, attention_mask
+ )
+ ref_log_probs = compute_token_log_probs(
+ ref_model, input_ids, attention_mask
+ )
+
+ # Compute per-token KL
+ kl_per_token = compute_kl_per_token(policy_log_probs, ref_log_probs)
+
+ # Create mask for response tokens only
+ seq_len = kl_per_token.shape[1]
+ response_mask = torch.zeros(1, seq_len, device=input_ids.device)
+ if response_start_idx > 0:
+ response_mask[:, response_start_idx-1:] = 1.0
+ else:
+ response_mask[:, :] = 1.0
+
+ # Apply attention mask
+ valid_mask = attention_mask[:, 1:].float() * response_mask
+
+ # Compute statistics
+ masked_kl = kl_per_token * valid_mask
+ num_tokens = valid_mask.sum().item()
+ total_kl = masked_kl.sum().item()
+
+ result = {
+ "total_kl": total_kl,
+ "num_tokens": int(num_tokens),
+ "mean_kl": total_kl / num_tokens if num_tokens > 0 else 0.0,
+ "max_kl": (kl_per_token * valid_mask).max().item() if num_tokens > 0 else 0.0,
+ "min_kl": (kl_per_token * valid_mask).min().item() if num_tokens > 0 else 0.0,
+ }
+
+ if normalize_by_length:
+ result["kl"] = result["mean_kl"]
+ else:
+ result["kl"] = result["total_kl"]
+
+ return result
+
+
+# ============================================================================
+# Dataset-Level KL Estimation
+# ============================================================================
+
+def estimate_dataset_kl(
+ policy_model: torch.nn.Module,
+ ref_model: torch.nn.Module,
+ tokenizer,
+ prompts: List[str],
+ responses: List[str],
+ device: torch.device,
+ max_seq_len: int = 4096,
+ normalize_by_length: bool = False,
+ show_progress: bool = True
+) -> Dict[str, Any]:
+ """
+ Estimate KL divergence over a dataset.
+
+ Args:
+ policy_model: Finetuned policy model
+ ref_model: Reference model
+ tokenizer: Tokenizer for both models
+ prompts: List of prompts
+ responses: List of corresponding responses
+ device: Device to use
+ max_seq_len: Maximum sequence length
+ normalize_by_length: If True, use mean KL per token
+ show_progress: Show progress bar
+
+ Returns:
+ Dictionary with dataset-level KL statistics
+ """
+ assert len(prompts) == len(responses), \
+ "Number of prompts must match responses"
+
+ policy_model.eval()
+ ref_model.eval()
+
+ all_kl_values: List[float] = []
+ all_num_tokens: List[int] = []
+
+ iterator = zip(prompts, responses)
+ if show_progress:
+ iterator = tqdm(
+ list(iterator),
+ desc="Computing KL"
+ )
+
+ for prompt, response in iterator:
+ # Tokenize prompt
+ prompt_tokens = tokenizer(
+ prompt,
+ return_tensors="pt",
+ truncation=True,
+ max_length=max_seq_len // 2
+ )
+ prompt_len = prompt_tokens["input_ids"].shape[1]
+
+ # Tokenize full sequence
+ full_text = prompt + response
+ full_tokens = tokenizer(
+ full_text,
+ return_tensors="pt",
+ truncation=True,
+ max_length=max_seq_len
+ )
+
+ input_ids = full_tokens["input_ids"].to(device)
+ attention_mask = full_tokens["attention_mask"].to(device)
+
+ # Compute sequence KL
+ with torch.no_grad():
+ kl_result = compute_sequence_kl(
+ policy_model=policy_model,
+ ref_model=ref_model,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ response_start_idx=prompt_len,
+ normalize_by_length=normalize_by_length
+ )
+
+ all_kl_values.append(kl_result["kl"])
+ all_num_tokens.append(kl_result["num_tokens"])
+
+ # Aggregate statistics
+ kl_array = np.array(all_kl_values)
+
+ result = {
+ "mean_kl": float(np.mean(kl_array)),
+ "std_kl": float(np.std(kl_array)),
+ "median_kl": float(np.median(kl_array)),
+ "min_kl": float(np.min(kl_array)),
+ "max_kl": float(np.max(kl_array)),
+ "total_samples": len(prompts),
+ "total_tokens": sum(all_num_tokens),
+ "kl_values": all_kl_values,
+ }
+
+ return result
+
+
+# ============================================================================
+# On-Task vs Off-Task KL Analysis
+# ============================================================================
+
+def analyze_kl_by_task(
+ kl_results: Dict[str, Dict[str, Any]],
+ on_task_names: List[str],
+ off_task_names: List[str]
+) -> Dict[str, Any]:
+ """
+ Analyze KL divergence patterns for on-task vs off-task.
+
+ Args:
+ kl_results: Dictionary mapping task names to KL results
+ on_task_names: List of on-task (training distribution) names
+ off_task_names: List of off-task names
+
+ Returns:
+ Analysis of KL patterns
+ """
+ on_task_kl = []
+ off_task_kl = []
+
+ for name in on_task_names:
+ if name in kl_results:
+ on_task_kl.append(kl_results[name]["mean_kl"])
+
+ for name in off_task_names:
+ if name in kl_results:
+ off_task_kl.append(kl_results[name]["mean_kl"])
+
+ analysis = {
+ "on_task": {
+ "mean": float(np.mean(on_task_kl)) if on_task_kl else 0.0,
+ "std": float(np.std(on_task_kl)) if on_task_kl else 0.0,
+ "values": on_task_kl,
+ },
+ "off_task": {
+ "mean": float(np.mean(off_task_kl)) if off_task_kl else 0.0,
+ "std": float(np.std(off_task_kl)) if off_task_kl else 0.0,
+ "values": off_task_kl,
+ },
+ }
+
+ # Compute ratio
+ if analysis["on_task"]["mean"] > 0:
+ analysis["off_to_on_ratio"] = (
+ analysis["off_task"]["mean"] / analysis["on_task"]["mean"]
+ )
+ else:
+ analysis["off_to_on_ratio"] = float("inf")
+
+ return analysis
+
+
+# ============================================================================
+# KL Contribution Analysis
+# ============================================================================
+
+def analyze_kl_contribution_by_layer(
+ model: torch.nn.Module,
+ input_ids: torch.Tensor,
+ attention_mask: torch.Tensor
+) -> Dict[str, float]:
+ """
+ Analyze which layers contribute most to the final prediction.
+
+ This is a simplified analysis - for full KL attribution,
+ you would need layer-wise probing.
+ """
+ # This is a placeholder for more sophisticated analysis
+ # Full implementation would require modifying the model
+ # to output intermediate representations
+
+ return {
+ "note": "Layer-wise KL attribution not implemented",
+ }
+
+
+def compute_kl_trajectory(
+ checkpoints: List[str],
+ ref_model_path: str,
+ tokenizer_path: str,
+ prompts: List[str],
+ responses: List[str],
+ device: torch.device
+) -> List[Dict[str, Any]]:
+ """
+ Compute KL divergence trajectory over training checkpoints.
+
+ Useful for understanding how KL evolves during training.
+ """
+ from transformers import AutoModelForCausalLM, AutoTokenizer
+
+ # Load tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token
+
+ # Load reference model
+ ref_model = AutoModelForCausalLM.from_pretrained(
+ ref_model_path,
+ torch_dtype=torch.bfloat16,
+ device_map=None
+ ).to(device)
+ ref_model.eval()
+
+ trajectory = []
+
+ for ckpt_path in tqdm(checkpoints, desc="Computing KL trajectory"):
+ # Load checkpoint
+ policy_model = AutoModelForCausalLM.from_pretrained(
+ ckpt_path,
+ torch_dtype=torch.bfloat16,
+ device_map=None
+ ).to(device)
+ policy_model.eval()
+
+ # Estimate KL
+ kl_result = estimate_dataset_kl(
+ policy_model=policy_model,
+ ref_model=ref_model,
+ tokenizer=tokenizer,
+ prompts=prompts,
+ responses=responses,
+ device=device,
+ show_progress=False
+ )
+
+ trajectory.append({
+ "checkpoint": ckpt_path,
+ "mean_kl": kl_result["mean_kl"],
+ "std_kl": kl_result["std_kl"],
+ })
+
+ # Free memory
+ del policy_model
+ torch.cuda.empty_cache()
+
+ return trajectory
+
diff --git a/utils_math_eval.py b/utils_math_eval.py
new file mode 100644
index 0000000..d4a1db2
--- /dev/null
+++ b/utils_math_eval.py
@@ -0,0 +1,367 @@
+# utils_math_eval.py
+"""
+Math Evaluation Utilities for RLVR Experiments.
+
+This module provides utilities for:
+- Extracting answers from model responses
+- Verifying mathematical answers
+- Computing accuracy metrics
+"""
+
+import re
+from typing import Optional, List, Dict, Any, Tuple
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+# ============================================================================
+# Answer Extraction
+# ============================================================================
+
+def extract_boxed_content(text: str) -> List[str]:
+ """
+ Extract all content from \\boxed{} patterns.
+
+ Handles nested braces correctly.
+ """
+ results = []
+ i = 0
+
+ while i < len(text):
+ # Find \boxed{
+ idx = text.find("\\boxed{", i)
+ if idx == -1:
+ break
+
+ # Find matching closing brace
+ start = idx + 7 # len("\\boxed{")
+ depth = 1
+ j = start
+
+ while j < len(text) and depth > 0:
+ if text[j] == "{":
+ depth += 1
+ elif text[j] == "}":
+ depth -= 1
+ j += 1
+
+ if depth == 0:
+ content = text[start:j-1]
+ results.append(content.strip())
+
+ i = j
+
+ return results
+
+
+def extract_answer_from_boxed(text: str) -> Optional[str]:
+ """Extract the last boxed answer from text."""
+ boxed_contents = extract_boxed_content(text)
+ if boxed_contents:
+ return boxed_contents[-1]
+ return None
+
+
+def extract_answer_from_patterns(text: str) -> Optional[str]:
+ """
+ Extract answer using common natural language patterns.
+ """
+ # Patterns in order of priority
+ patterns = [
+ # Explicit answer statements
+ (r"[Tt]he\s+(?:final\s+)?answer\s+is\s*[:\s]*(.+?)(?:\.|,|$)", 1),
+ (r"[Aa]nswer\s*[:\s]+(.+?)(?:\.|,|$)", 1),
+
+ # Conclusion patterns
+ (r"[Tt]herefore\s*,?\s*(?:the\s+answer\s+is\s*)?(.+?)(?:\.|$)", 1),
+ (r"[Hh]ence\s*,?\s*(?:the\s+answer\s+is\s*)?(.+?)(?:\.|$)", 1),
+ (r"[Ss]o\s*,?\s*(?:the\s+answer\s+is\s*)?(.+?)(?:\.|$)", 1),
+ (r"[Tt]hus\s*,?\s*(?:the\s+answer\s+is\s*)?(.+?)(?:\.|$)", 1),
+
+ # Equation result
+ (r"=\s*(\S+)\s*$", 1),
+ ]
+
+ for pattern, group in patterns:
+ match = re.search(pattern, text, re.MULTILINE | re.IGNORECASE)
+ if match:
+ answer = match.group(group).strip()
+ # Clean up trailing punctuation
+ answer = re.sub(r"[.,;:!?]+$", "", answer).strip()
+ if answer:
+ return answer
+
+ return None
+
+
+def extract_final_answer(text: str) -> Optional[str]:
+ """
+ Extract the final answer from a model response.
+
+ Priority:
+ 1. \\boxed{} format
+ 2. Natural language patterns
+ """
+ # Try boxed format first
+ boxed = extract_answer_from_boxed(text)
+ if boxed:
+ return boxed
+
+ # Try natural language patterns
+ pattern_answer = extract_answer_from_patterns(text)
+ if pattern_answer:
+ return pattern_answer
+
+ return None
+
+
+# ============================================================================
+# Answer Normalization
+# ============================================================================
+
+def normalize_numeric_answer(answer: str) -> Optional[float]:
+ """
+ Normalize a numeric answer for comparison.
+
+ Handles:
+ - Integers and decimals
+ - Fractions (simple forms like a/b)
+ - Scientific notation
+ - Percentages
+ """
+ if not answer:
+ return None
+
+ # Clean up the string
+ cleaned = answer.strip().lower()
+ cleaned = cleaned.replace(" ", "")
+ cleaned = cleaned.replace(",", "")
+
+ # Handle percentages
+ if cleaned.endswith("%"):
+ cleaned = cleaned[:-1]
+ try:
+ return float(cleaned) / 100
+ except ValueError:
+ pass
+
+ # Handle fractions (a/b)
+ if "/" in cleaned:
+ parts = cleaned.split("/")
+ if len(parts) == 2:
+ try:
+ num = float(parts[0])
+ denom = float(parts[1])
+ if denom != 0:
+ return num / denom
+ except ValueError:
+ pass
+
+ # Handle scientific notation and regular numbers
+ try:
+ return float(cleaned)
+ except ValueError:
+ pass
+
+ return None
+
+
+def normalize_text_answer(answer: str) -> str:
+ """
+ Normalize a text answer for comparison.
+
+ - Lowercase
+ - Remove extra whitespace
+ - Remove common formatting
+ """
+ if not answer:
+ return ""
+
+ normalized = answer.strip().lower()
+
+ # Remove LaTeX formatting
+ normalized = re.sub(r"\\[a-zA-Z]+", "", normalized)
+ normalized = re.sub(r"[{}$]", "", normalized)
+
+ # Normalize whitespace
+ normalized = " ".join(normalized.split())
+
+ # Remove common punctuation
+ normalized = re.sub(r"[.,;:!?]+$", "", normalized).strip()
+
+ return normalized
+
+
+# ============================================================================
+# Answer Comparison
+# ============================================================================
+
+def compare_numeric_answers(
+ predicted: str,
+ ground_truth: str,
+ tolerance: float = 1e-6
+) -> bool:
+ """
+ Compare two answers numerically.
+
+ Returns True if both can be parsed as numbers and are within tolerance.
+ """
+ pred_num = normalize_numeric_answer(predicted)
+ gt_num = normalize_numeric_answer(ground_truth)
+
+ if pred_num is None or gt_num is None:
+ return False
+
+ # Absolute tolerance for small numbers
+ if abs(gt_num) < 1e-6:
+ return abs(pred_num - gt_num) < tolerance
+
+ # Relative tolerance for larger numbers
+ rel_diff = abs(pred_num - gt_num) / abs(gt_num)
+ return rel_diff < tolerance
+
+
+def compare_text_answers(
+ predicted: str,
+ ground_truth: str
+) -> bool:
+ """Compare two text answers after normalization."""
+ pred_norm = normalize_text_answer(predicted)
+ gt_norm = normalize_text_answer(ground_truth)
+
+ return pred_norm == gt_norm
+
+
+def verify_answer(
+ response: str,
+ ground_truth: str,
+ task_type: str = "math"
+) -> Tuple[bool, Optional[str]]:
+ """
+ Verify if the response contains the correct answer.
+
+ Args:
+ response: Model's full response
+ ground_truth: Expected answer
+ task_type: Type of task ("math", "qa", "code")
+
+ Returns:
+ Tuple of (is_correct, extracted_answer)
+ """
+ # Extract predicted answer
+ predicted = extract_final_answer(response)
+
+ if predicted is None:
+ return False, None
+
+ # Try numeric comparison first
+ if compare_numeric_answers(predicted, ground_truth):
+ return True, predicted
+
+ # Try text comparison
+ if compare_text_answers(predicted, ground_truth):
+ return True, predicted
+
+ # Check if ground truth is contained in predicted
+ gt_norm = normalize_text_answer(ground_truth)
+ pred_norm = normalize_text_answer(predicted)
+
+ if gt_norm and gt_norm in pred_norm:
+ return True, predicted
+
+ return False, predicted
+
+
+# ============================================================================
+# Batch Evaluation
+# ============================================================================
+
+def evaluate_batch(
+ responses: List[str],
+ ground_truths: List[str],
+ task_type: str = "math"
+) -> Dict[str, Any]:
+ """
+ Evaluate a batch of responses.
+
+ Args:
+ responses: List of model responses
+ ground_truths: List of expected answers
+ task_type: Type of task
+
+ Returns:
+ Dictionary with evaluation metrics
+ """
+ assert len(responses) == len(ground_truths), \
+ "Number of responses must match ground truths"
+
+ correct = 0
+ total = len(responses)
+ results = []
+
+ for response, gt in zip(responses, ground_truths):
+ is_correct, extracted = verify_answer(response, gt, task_type)
+ correct += int(is_correct)
+ results.append({
+ "is_correct": is_correct,
+ "extracted_answer": extracted,
+ "ground_truth": gt
+ })
+
+ accuracy = correct / total if total > 0 else 0.0
+
+ return {
+ "accuracy": accuracy,
+ "correct": correct,
+ "total": total,
+ "results": results
+ }
+
+
+# ============================================================================
+# Answer Format Detection
+# ============================================================================
+
+def detect_answer_format(text: str) -> str:
+ """
+ Detect the format of an answer.
+
+ Returns one of: "boxed", "numeric", "fraction", "text", "unknown"
+ """
+ if "\\boxed{" in text:
+ return "boxed"
+
+ # Check for fraction
+ if re.match(r"^-?\d+/\d+$", text.strip()):
+ return "fraction"
+
+ # Check for numeric
+ try:
+ float(text.strip().replace(",", ""))
+ return "numeric"
+ except ValueError:
+ pass
+
+ if text.strip():
+ return "text"
+
+ return "unknown"
+
+
+def format_answer_for_display(answer: str, detected_format: str) -> str:
+ """Format answer for display based on detected format."""
+ if detected_format == "fraction":
+ num = normalize_numeric_answer(answer)
+ if num is not None:
+ return f"{answer} ≈ {num:.6f}"
+
+ if detected_format == "numeric":
+ try:
+ num = float(answer.replace(",", ""))
+ return f"{num:g}"
+ except ValueError:
+ pass
+
+ return answer
+