summaryrefslogtreecommitdiff
path: root/config.py
diff options
context:
space:
mode:
Diffstat (limited to 'config.py')
-rw-r--r--config.py328
1 files changed, 328 insertions, 0 deletions
diff --git a/config.py b/config.py
new file mode 100644
index 0000000..898776c
--- /dev/null
+++ b/config.py
@@ -0,0 +1,328 @@
+# config.py
+"""
+Configuration definitions for RLVR floating-point precision experiments.
+
+This module defines configurations for:
+- Training parameters (DAPO algorithm, hyperparameters)
+- Precision settings (FP32 vs bf16)
+- Evaluation task specifications
+"""
+
+from dataclasses import dataclass, field
+from typing import List, Dict, Any, Optional
+import os
+
+
+@dataclass
+class TrainingConfig:
+ """Configuration for RLVR training with DAPO algorithm."""
+
+ # Model specification
+ model_name: str = "Qwen/Qwen2.5-Math-7B"
+
+ # Precision mode: "fp32" for high precision, "bf16" for default RLVR
+ precision_mode: str = "bf16"
+
+ # Batch configuration (sized for single H200 GPU with gradient checkpointing)
+ global_batch_size: int = 32 # Reduced from 256 for single GPU
+ micro_batch_size: int = 4 # Reduced further for fp32 memory safety
+ grad_accumulation_steps: int = 8 # Increased to maintain effective batch size
+
+ # Rollout configuration
+ num_rollouts_per_prompt: int = 4 # Reduced from 16 for speed
+ max_seq_len: int = 2048 # Reduced from 8192 (GSM8K answers are short)
+
+ # Training steps and checkpointing
+ # Note: With sequential generation, each step takes ~18 min on H200
+ # 150 steps ≈ 45 hours, fits in 2-day limit with buffer
+ num_steps: int = 150
+ checkpoint_steps: List[int] = field(default_factory=lambda: [0, 50, 100, 150])
+
+ # Optimizer configuration (AdamW)
+ learning_rate: float = 1e-6
+ beta1: float = 0.9
+ beta2: float = 0.999
+ weight_decay: float = 0.01
+
+ # RL algorithm
+ rl_algorithm: str = "dapo"
+ clip_ratio: float = 0.2 # DAPO clip parameter
+ kl_coef: float = 0.0 # DAPO uses clip-only, no explicit KL penalty
+
+ # Reproducibility
+ # IMPORTANT: Keep these constant across precision modes to isolate precision effects
+ # We maximize determinism so the ONLY variance source is floating-point precision
+ seed: int = 1
+ use_dropout: bool = False # Disabled to reduce stochasticity
+ use_deterministic_algorithms: bool = True # Enabled for reproducibility
+
+ # Paths
+ output_dir: str = "./results/train_logs"
+ train_dataset_path: str = "./data/dm_train.json"
+
+ # GPU configuration (single GPU for this implementation)
+ num_gpus: int = 1 # Current implementation is single-GPU
+
+ def __post_init__(self):
+ """
+ Note: We intentionally keep dropout and determinism settings CONSTANT
+ across precision modes to isolate the effect of floating-point precision.
+
+ Previously this coupled dropout=False with fp32 and dropout=True with bf16,
+ which confounded precision effects with stochasticity effects.
+
+ To study pure precision effects:
+ - Both modes use the SAME dropout setting (default: False)
+ - Both modes use the SAME determinism setting (default: False for speed)
+
+ The only difference between fp32 and bf16 should be param_dtype.
+ """
+ # Don't modify settings based on precision_mode - keep them independent
+ pass
+
+
+@dataclass
+class PrecisionConfig:
+ """Configuration for floating-point precision settings."""
+
+ # Parameter storage dtype
+ param_dtype: str = "bfloat16" # "float32" or "bfloat16"
+
+ # Automatic mixed precision
+ use_amp: bool = True
+ amp_dtype: str = "bfloat16" # "float16" or "bfloat16"
+
+ # Gradient and optimizer state always in FP32
+ grad_dtype: str = "float32"
+ optimizer_dtype: str = "float32"
+
+ # Deterministic algorithms
+ deterministic: bool = False
+
+ # CUDNN settings
+ cudnn_benchmark: bool = True
+ cudnn_deterministic: bool = False
+
+
+@dataclass
+class EvalTaskConfig:
+ """Configuration for a single evaluation task."""
+
+ # Task identification
+ name: str = ""
+ task_type: str = "math" # "math", "code", "qa", "general"
+
+ # Dataset
+ dataset_path: str = ""
+ num_samples: int = -1 # -1 means use all samples
+
+ # Whether task has verifiable answers (math problems)
+ is_verifiable: bool = True
+
+ # Metric type for non-verifiable tasks
+ metric_type: str = "accuracy" # "accuracy", "bleu", "rouge", "score"
+
+ # Generation parameters
+ max_gen_len: int = 2048
+ temperature: float = 0.7
+ top_p: float = 0.8
+ num_samples_per_prompt: int = 1
+
+
+@dataclass
+class ExperimentConfig:
+ """Master configuration for the entire experiment."""
+
+ # Experiment identification
+ experiment_name: str = "fp_precision_rlvr"
+
+ # Seeds for multiple runs
+ seeds: List[int] = field(default_factory=lambda: [1, 2, 3, 4, 5])
+
+ # Precision modes to compare
+ precision_modes: List[str] = field(default_factory=lambda: ["fp32", "bf16"])
+
+ # Base model checkpoint (shared starting point)
+ base_model_path: str = "Qwen/Qwen2.5-Math-7B"
+
+ # Output directories
+ base_output_dir: str = "./results"
+ train_logs_dir: str = "./results/train_logs"
+ checkpoints_dir: str = "./results/checkpoints"
+ eval_metrics_dir: str = "./results/eval_metrics"
+
+ # Evaluation configuration
+ eval_tasks_config_path: str = "./configs/eval_tasks_config.json"
+
+ # bf16 sparsity analysis
+ bf16_sparsity_eta: float = 1e-3
+
+
+def make_precision_config(precision_mode: str) -> PrecisionConfig:
+ """
+ Create precision configuration based on mode.
+
+ IMPORTANT: Only the precision-related settings differ between modes.
+ All other settings (determinism, cudnn) are kept CONSTANT to isolate
+ the effect of floating-point precision on training outcomes.
+
+ Args:
+ precision_mode: "fp32" for high precision, "bf16" for default RLVR
+
+ Returns:
+ PrecisionConfig with appropriate settings
+ """
+ # Common settings for both modes (to avoid confounds)
+ # Maximize determinism so precision is the ONLY source of variance
+ common_settings = {
+ "grad_dtype": "float32",
+ "optimizer_dtype": "float32",
+ "deterministic": True, # Enable deterministic algorithms
+ "cudnn_benchmark": False, # Disable for reproducibility
+ "cudnn_deterministic": True, # Enable for reproducibility
+ }
+
+ if precision_mode == "fp32":
+ return PrecisionConfig(
+ param_dtype="float32",
+ use_amp=False, # No AMP needed for fp32
+ amp_dtype="float32",
+ **common_settings
+ )
+ elif precision_mode == "bf16":
+ return PrecisionConfig(
+ param_dtype="bfloat16",
+ use_amp=True,
+ amp_dtype="bfloat16",
+ **common_settings
+ )
+ else:
+ raise ValueError(f"Unknown precision_mode: {precision_mode}")
+
+
+def make_training_config(
+ precision_mode: str,
+ seed: int,
+ output_dir: str,
+ train_dataset_path: str,
+ model_name: str = "Qwen/Qwen2.5-Math-7B"
+) -> TrainingConfig:
+ """
+ Create training configuration for a specific run.
+
+ Args:
+ precision_mode: "fp32" or "bf16"
+ seed: Random seed for this run
+ output_dir: Directory to save outputs
+ train_dataset_path: Path to training data
+ model_name: HuggingFace model identifier
+
+ Returns:
+ TrainingConfig with all parameters set
+ """
+ config = TrainingConfig(
+ model_name=model_name,
+ precision_mode=precision_mode,
+ seed=seed,
+ output_dir=output_dir,
+ train_dataset_path=train_dataset_path
+ )
+ return config
+
+
+def get_run_output_dir(
+ base_dir: str,
+ precision_mode: str,
+ seed: int
+) -> str:
+ """Get output directory for a specific run."""
+ return os.path.join(base_dir, f"{precision_mode}_seed{seed}")
+
+
+def get_checkpoint_path(
+ output_dir: str,
+ step: Optional[int] = None
+) -> str:
+ """Get checkpoint path for a specific step (None = final)."""
+ if step is None:
+ return os.path.join(output_dir, "final_model")
+ return os.path.join(output_dir, f"checkpoint_step{step}")
+
+
+# Default evaluation tasks for the experiment
+DEFAULT_EVAL_TASKS = [
+ # On-task: Training distribution
+ EvalTaskConfig(
+ name="dm_val",
+ task_type="math",
+ dataset_path="./data/dm_val.json",
+ is_verifiable=True,
+ metric_type="accuracy",
+ max_gen_len=2048,
+ temperature=0.7,
+ top_p=0.8
+ ),
+ # In-domain OOD: Math benchmarks
+ EvalTaskConfig(
+ name="aime24",
+ task_type="math",
+ dataset_path="./data/aime24.json",
+ is_verifiable=True,
+ metric_type="accuracy",
+ max_gen_len=4096,
+ temperature=0.7,
+ top_p=0.8
+ ),
+ EvalTaskConfig(
+ name="aime25",
+ task_type="math",
+ dataset_path="./data/aime25.json",
+ is_verifiable=True,
+ metric_type="accuracy",
+ max_gen_len=4096,
+ temperature=0.7,
+ top_p=0.8
+ ),
+ EvalTaskConfig(
+ name="amc23",
+ task_type="math",
+ dataset_path="./data/amc23.json",
+ is_verifiable=True,
+ metric_type="accuracy",
+ max_gen_len=2048,
+ temperature=0.7,
+ top_p=0.8
+ ),
+ EvalTaskConfig(
+ name="math500",
+ task_type="math",
+ dataset_path="./data/math500.json",
+ is_verifiable=True,
+ metric_type="accuracy",
+ max_gen_len=2048,
+ temperature=0.7,
+ top_p=0.8
+ ),
+ # Off-domain: General tasks
+ EvalTaskConfig(
+ name="mmlu_stem",
+ task_type="qa",
+ dataset_path="./data/mmlu_stem.json",
+ is_verifiable=True,
+ metric_type="accuracy",
+ max_gen_len=512,
+ temperature=0.3,
+ top_p=0.9
+ ),
+ EvalTaskConfig(
+ name="humaneval",
+ task_type="code",
+ dataset_path="./data/humaneval.json",
+ is_verifiable=True,
+ metric_type="accuracy",
+ max_gen_len=1024,
+ temperature=0.2,
+ top_p=0.95
+ ),
+]
+