diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-04 18:59:35 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-04 18:59:35 -0600 |
| commit | f1c2cc22d46a6976df3555391e667c7e61592fad (patch) | |
| tree | 0b37b52c8ff91042a742d3b3ec54542cb6d6e2f6 /config.py | |
Diffstat (limited to 'config.py')
| -rw-r--r-- | config.py | 328 |
1 files changed, 328 insertions, 0 deletions
diff --git a/config.py b/config.py new file mode 100644 index 0000000..898776c --- /dev/null +++ b/config.py @@ -0,0 +1,328 @@ +# config.py +""" +Configuration definitions for RLVR floating-point precision experiments. + +This module defines configurations for: +- Training parameters (DAPO algorithm, hyperparameters) +- Precision settings (FP32 vs bf16) +- Evaluation task specifications +""" + +from dataclasses import dataclass, field +from typing import List, Dict, Any, Optional +import os + + +@dataclass +class TrainingConfig: + """Configuration for RLVR training with DAPO algorithm.""" + + # Model specification + model_name: str = "Qwen/Qwen2.5-Math-7B" + + # Precision mode: "fp32" for high precision, "bf16" for default RLVR + precision_mode: str = "bf16" + + # Batch configuration (sized for single H200 GPU with gradient checkpointing) + global_batch_size: int = 32 # Reduced from 256 for single GPU + micro_batch_size: int = 4 # Reduced further for fp32 memory safety + grad_accumulation_steps: int = 8 # Increased to maintain effective batch size + + # Rollout configuration + num_rollouts_per_prompt: int = 4 # Reduced from 16 for speed + max_seq_len: int = 2048 # Reduced from 8192 (GSM8K answers are short) + + # Training steps and checkpointing + # Note: With sequential generation, each step takes ~18 min on H200 + # 150 steps ≈ 45 hours, fits in 2-day limit with buffer + num_steps: int = 150 + checkpoint_steps: List[int] = field(default_factory=lambda: [0, 50, 100, 150]) + + # Optimizer configuration (AdamW) + learning_rate: float = 1e-6 + beta1: float = 0.9 + beta2: float = 0.999 + weight_decay: float = 0.01 + + # RL algorithm + rl_algorithm: str = "dapo" + clip_ratio: float = 0.2 # DAPO clip parameter + kl_coef: float = 0.0 # DAPO uses clip-only, no explicit KL penalty + + # Reproducibility + # IMPORTANT: Keep these constant across precision modes to isolate precision effects + # We maximize determinism so the ONLY variance source is floating-point precision + seed: int = 1 + use_dropout: bool = False # Disabled to reduce stochasticity + use_deterministic_algorithms: bool = True # Enabled for reproducibility + + # Paths + output_dir: str = "./results/train_logs" + train_dataset_path: str = "./data/dm_train.json" + + # GPU configuration (single GPU for this implementation) + num_gpus: int = 1 # Current implementation is single-GPU + + def __post_init__(self): + """ + Note: We intentionally keep dropout and determinism settings CONSTANT + across precision modes to isolate the effect of floating-point precision. + + Previously this coupled dropout=False with fp32 and dropout=True with bf16, + which confounded precision effects with stochasticity effects. + + To study pure precision effects: + - Both modes use the SAME dropout setting (default: False) + - Both modes use the SAME determinism setting (default: False for speed) + + The only difference between fp32 and bf16 should be param_dtype. + """ + # Don't modify settings based on precision_mode - keep them independent + pass + + +@dataclass +class PrecisionConfig: + """Configuration for floating-point precision settings.""" + + # Parameter storage dtype + param_dtype: str = "bfloat16" # "float32" or "bfloat16" + + # Automatic mixed precision + use_amp: bool = True + amp_dtype: str = "bfloat16" # "float16" or "bfloat16" + + # Gradient and optimizer state always in FP32 + grad_dtype: str = "float32" + optimizer_dtype: str = "float32" + + # Deterministic algorithms + deterministic: bool = False + + # CUDNN settings + cudnn_benchmark: bool = True + cudnn_deterministic: bool = False + + +@dataclass +class EvalTaskConfig: + """Configuration for a single evaluation task.""" + + # Task identification + name: str = "" + task_type: str = "math" # "math", "code", "qa", "general" + + # Dataset + dataset_path: str = "" + num_samples: int = -1 # -1 means use all samples + + # Whether task has verifiable answers (math problems) + is_verifiable: bool = True + + # Metric type for non-verifiable tasks + metric_type: str = "accuracy" # "accuracy", "bleu", "rouge", "score" + + # Generation parameters + max_gen_len: int = 2048 + temperature: float = 0.7 + top_p: float = 0.8 + num_samples_per_prompt: int = 1 + + +@dataclass +class ExperimentConfig: + """Master configuration for the entire experiment.""" + + # Experiment identification + experiment_name: str = "fp_precision_rlvr" + + # Seeds for multiple runs + seeds: List[int] = field(default_factory=lambda: [1, 2, 3, 4, 5]) + + # Precision modes to compare + precision_modes: List[str] = field(default_factory=lambda: ["fp32", "bf16"]) + + # Base model checkpoint (shared starting point) + base_model_path: str = "Qwen/Qwen2.5-Math-7B" + + # Output directories + base_output_dir: str = "./results" + train_logs_dir: str = "./results/train_logs" + checkpoints_dir: str = "./results/checkpoints" + eval_metrics_dir: str = "./results/eval_metrics" + + # Evaluation configuration + eval_tasks_config_path: str = "./configs/eval_tasks_config.json" + + # bf16 sparsity analysis + bf16_sparsity_eta: float = 1e-3 + + +def make_precision_config(precision_mode: str) -> PrecisionConfig: + """ + Create precision configuration based on mode. + + IMPORTANT: Only the precision-related settings differ between modes. + All other settings (determinism, cudnn) are kept CONSTANT to isolate + the effect of floating-point precision on training outcomes. + + Args: + precision_mode: "fp32" for high precision, "bf16" for default RLVR + + Returns: + PrecisionConfig with appropriate settings + """ + # Common settings for both modes (to avoid confounds) + # Maximize determinism so precision is the ONLY source of variance + common_settings = { + "grad_dtype": "float32", + "optimizer_dtype": "float32", + "deterministic": True, # Enable deterministic algorithms + "cudnn_benchmark": False, # Disable for reproducibility + "cudnn_deterministic": True, # Enable for reproducibility + } + + if precision_mode == "fp32": + return PrecisionConfig( + param_dtype="float32", + use_amp=False, # No AMP needed for fp32 + amp_dtype="float32", + **common_settings + ) + elif precision_mode == "bf16": + return PrecisionConfig( + param_dtype="bfloat16", + use_amp=True, + amp_dtype="bfloat16", + **common_settings + ) + else: + raise ValueError(f"Unknown precision_mode: {precision_mode}") + + +def make_training_config( + precision_mode: str, + seed: int, + output_dir: str, + train_dataset_path: str, + model_name: str = "Qwen/Qwen2.5-Math-7B" +) -> TrainingConfig: + """ + Create training configuration for a specific run. + + Args: + precision_mode: "fp32" or "bf16" + seed: Random seed for this run + output_dir: Directory to save outputs + train_dataset_path: Path to training data + model_name: HuggingFace model identifier + + Returns: + TrainingConfig with all parameters set + """ + config = TrainingConfig( + model_name=model_name, + precision_mode=precision_mode, + seed=seed, + output_dir=output_dir, + train_dataset_path=train_dataset_path + ) + return config + + +def get_run_output_dir( + base_dir: str, + precision_mode: str, + seed: int +) -> str: + """Get output directory for a specific run.""" + return os.path.join(base_dir, f"{precision_mode}_seed{seed}") + + +def get_checkpoint_path( + output_dir: str, + step: Optional[int] = None +) -> str: + """Get checkpoint path for a specific step (None = final).""" + if step is None: + return os.path.join(output_dir, "final_model") + return os.path.join(output_dir, f"checkpoint_step{step}") + + +# Default evaluation tasks for the experiment +DEFAULT_EVAL_TASKS = [ + # On-task: Training distribution + EvalTaskConfig( + name="dm_val", + task_type="math", + dataset_path="./data/dm_val.json", + is_verifiable=True, + metric_type="accuracy", + max_gen_len=2048, + temperature=0.7, + top_p=0.8 + ), + # In-domain OOD: Math benchmarks + EvalTaskConfig( + name="aime24", + task_type="math", + dataset_path="./data/aime24.json", + is_verifiable=True, + metric_type="accuracy", + max_gen_len=4096, + temperature=0.7, + top_p=0.8 + ), + EvalTaskConfig( + name="aime25", + task_type="math", + dataset_path="./data/aime25.json", + is_verifiable=True, + metric_type="accuracy", + max_gen_len=4096, + temperature=0.7, + top_p=0.8 + ), + EvalTaskConfig( + name="amc23", + task_type="math", + dataset_path="./data/amc23.json", + is_verifiable=True, + metric_type="accuracy", + max_gen_len=2048, + temperature=0.7, + top_p=0.8 + ), + EvalTaskConfig( + name="math500", + task_type="math", + dataset_path="./data/math500.json", + is_verifiable=True, + metric_type="accuracy", + max_gen_len=2048, + temperature=0.7, + top_p=0.8 + ), + # Off-domain: General tasks + EvalTaskConfig( + name="mmlu_stem", + task_type="qa", + dataset_path="./data/mmlu_stem.json", + is_verifiable=True, + metric_type="accuracy", + max_gen_len=512, + temperature=0.3, + top_p=0.9 + ), + EvalTaskConfig( + name="humaneval", + task_type="code", + dataset_path="./data/humaneval.json", + is_verifiable=True, + metric_type="accuracy", + max_gen_len=1024, + temperature=0.2, + top_p=0.95 + ), +] + |
