#!/usr/bin/env python3 # train_rlvr.py """ RLVR Training Script with DAPO Algorithm. This script implements the training loop for reinforcement learning with verifiable rewards (RLVR) using the DAPO algorithm. It supports two precision configurations: - P-high (fp32): High precision master weights for low numerical noise - P-bf16: Default RLVR configuration with bf16 master weights The script integrates with VeRL framework for distributed RL training. Usage: python train_rlvr.py \ --precision_mode fp32 \ --seed 1 \ --output_dir results/train_logs/fp32_seed1 \ --train_dataset_path data/dm_train.json """ import argparse import json import os import random import logging from typing import Dict, Any, List, Optional, Tuple from dataclasses import asdict import numpy as np import torch import torch.distributed as dist from torch.cuda.amp import autocast, GradScaler from transformers import AutoModelForCausalLM, AutoTokenizer import deepspeed from config import ( make_training_config, make_precision_config, TrainingConfig, PrecisionConfig, ) # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) # ============================================================================ # Seed and Determinism Utilities # ============================================================================ def set_seed(seed: int) -> None: """Set random seeds for reproducibility.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) # For hash-based operations os.environ["PYTHONHASHSEED"] = str(seed) logger.info(f"Set random seed to {seed}") def configure_torch_deterministic(deterministic: bool) -> None: """Configure PyTorch deterministic algorithms.""" if deterministic: torch.use_deterministic_algorithms(True, warn_only=True) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" logger.info("Enabled deterministic algorithms") else: torch.use_deterministic_algorithms(False) torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = False logger.info("Using non-deterministic algorithms (default)") # ============================================================================ # Model Utilities # ============================================================================ def get_torch_dtype(dtype_str: str) -> torch.dtype: """Convert string dtype to torch.dtype.""" dtype_map = { "float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16, } if dtype_str not in dtype_map: raise ValueError(f"Unknown dtype: {dtype_str}") return dtype_map[dtype_str] def cast_model_param_dtype( model: torch.nn.Module, param_dtype: str ) -> torch.nn.Module: """Cast model parameters to specified dtype.""" dtype = get_torch_dtype(param_dtype) model.to(dtype=dtype) logger.info(f"Cast model parameters to {param_dtype}") return model def disable_dropout(model: torch.nn.Module) -> None: """Disable all dropout layers in the model.""" for module in model.modules(): if isinstance(module, torch.nn.Dropout): module.p = 0.0 logger.info("Disabled all dropout layers") def count_parameters(model: torch.nn.Module) -> int: """Count trainable parameters.""" return sum(p.numel() for p in model.parameters() if p.requires_grad) # ============================================================================ # Data Loading # ============================================================================ def load_training_data(dataset_path: str) -> List[Dict[str, Any]]: """Load training dataset from JSON file.""" with open(dataset_path, "r", encoding="utf-8") as f: data = json.load(f) logger.info(f"Loaded {len(data)} training examples from {dataset_path}") return data def sample_batch( data: List[Dict[str, Any]], batch_size: int, rng: np.random.Generator ) -> List[Dict[str, Any]]: """Sample a batch of prompts from the dataset.""" indices = rng.choice(len(data), size=min(batch_size, len(data)), replace=False) return [data[i] for i in indices] # ============================================================================ # Reward Function (Math Verifier) # ============================================================================ def compute_math_reward( prompt: str, response: str, ground_truth: Optional[str] = None ) -> float: """ Compute reward for a math problem response. Uses a simple rule-based verifier. In production, this should be replaced with Eval-Chemy or similar math verification system. Args: prompt: The math problem prompt response: The model's generated response ground_truth: The expected answer (if available) Returns: +1.0 if correct, -1.0 if incorrect """ # TODO: Replace with actual math verifier (Eval-Chemy) # This is a placeholder implementation if ground_truth is None: # Cannot verify without ground truth return 0.0 # Extract final answer from response (simple heuristic) response_lower = response.lower().strip() gt_lower = ground_truth.lower().strip() # Check for common answer formats answer_markers = ["the answer is", "therefore", "=", "\\boxed{"] for marker in answer_markers: if marker in response_lower: idx = response_lower.rfind(marker) potential_answer = response_lower[idx:].strip() if gt_lower in potential_answer: return 1.0 # Direct containment check as fallback if gt_lower in response_lower: return 1.0 return -1.0 # ============================================================================ # DAPO Algorithm Implementation # ============================================================================ class DAPOTrainer: """ DAPO (Direct Alignment from Preferences Optimization) Trainer. Implements clip-only DAPO with implicit KL constraint through ratio clipping. This is a simplified implementation - for production use VeRL's DapoTrainer. """ def __init__( self, model_engine, # DeepSpeed engine or raw model ref_model: torch.nn.Module, tokenizer, train_config: TrainingConfig, precision_config: PrecisionConfig, device: torch.device, ref_device: Optional[torch.device] = None, use_deepspeed: bool = False ) -> None: self.use_deepspeed = use_deepspeed if use_deepspeed: self.model_engine = model_engine self.model = model_engine.module else: self.model = model_engine self.model_engine = None self.ref_model = ref_model self.tokenizer = tokenizer self.train_config = train_config self.precision_config = precision_config self.device = device self.ref_device = ref_device if ref_device is not None else device # Freeze reference model for param in self.ref_model.parameters(): param.requires_grad = False self.ref_model.eval() # Setup optimizer (only if not using DeepSpeed) if not use_deepspeed: self.optimizer = torch.optim.AdamW( self.model.parameters(), lr=train_config.learning_rate, betas=(train_config.beta1, train_config.beta2), weight_decay=train_config.weight_decay ) else: self.optimizer = None # DeepSpeed manages optimizer # Setup AMP scaler if using fp16 (not needed with DeepSpeed) self.scaler = None if not use_deepspeed and precision_config.use_amp and precision_config.amp_dtype == "float16": self.scaler = GradScaler() # Training state self.global_step = 0 self.rng = np.random.default_rng(train_config.seed) # Metrics tracking self.metrics_history: List[Dict[str, float]] = [] def compute_log_probs( self, model: torch.nn.Module, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor ) -> torch.Tensor: """Compute token-level log probabilities.""" with autocast( enabled=self.precision_config.use_amp, dtype=get_torch_dtype(self.precision_config.amp_dtype) ): outputs = model( input_ids=input_ids, attention_mask=attention_mask, use_cache=False ) logits = outputs.logits log_probs = torch.log_softmax(logits, dim=-1) # Gather log probs for actual tokens # Shift for autoregressive: predict next token shift_log_probs = log_probs[:, :-1, :] shift_labels = labels[:, 1:] token_log_probs = torch.gather( shift_log_probs, dim=-1, index=shift_labels.unsqueeze(-1) ).squeeze(-1) return token_log_probs def compute_dapo_loss( self, policy_log_probs: torch.Tensor, ref_log_probs: torch.Tensor, rewards: torch.Tensor, response_mask: torch.Tensor ) -> Tuple[torch.Tensor, Dict[str, float]]: """ Compute DAPO clip-only loss. DAPO uses ratio clipping without explicit KL penalty (beta=0). The clipping provides implicit KL constraint (Gate I). """ # Compute log ratios log_ratios = policy_log_probs - ref_log_probs # Sum log ratios over response tokens masked_log_ratios = log_ratios * response_mask sequence_log_ratios = masked_log_ratios.sum(dim=-1) # Compute importance sampling ratios ratios = torch.exp(sequence_log_ratios) # DAPO objective with clipping clip_ratio = self.train_config.clip_ratio # Advantage estimation (simplified: just use rewards) advantages = rewards # Clipped surrogate objective unclipped = ratios * advantages clipped = torch.clamp(ratios, 1 - clip_ratio, 1 + clip_ratio) * advantages # Take minimum for pessimistic update loss = -torch.min(unclipped, clipped).mean() # Compute metrics with torch.no_grad(): approx_kl = (log_ratios * response_mask).sum() / response_mask.sum() clip_fraction = ((ratios - 1).abs() > clip_ratio).float().mean() metrics = { "loss": loss.item(), "approx_kl": approx_kl.item(), "clip_fraction": clip_fraction.item(), "mean_ratio": ratios.mean().item(), "mean_reward": rewards.mean().item(), } return loss, metrics def generate_rollouts( self, prompts: List[str], num_samples: int ) -> List[Dict[str, Any]]: """Generate rollouts for a batch of prompts.""" rollouts = [] self.model.eval() with torch.no_grad(): for prompt in prompts: inputs = self.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=self.train_config.max_seq_len // 2 ).to(self.device) for _ in range(num_samples): with autocast( enabled=self.precision_config.use_amp, dtype=get_torch_dtype(self.precision_config.amp_dtype) ): outputs = self.model.generate( **inputs, max_new_tokens=self.train_config.max_seq_len // 2, do_sample=True, temperature=0.7, top_p=0.8, pad_token_id=self.tokenizer.eos_token_id ) response_ids = outputs[0, inputs["input_ids"].shape[1]:] response_text = self.tokenizer.decode( response_ids, skip_special_tokens=True ) rollouts.append({ "prompt": prompt, "response": response_text, "input_ids": inputs["input_ids"][0], "response_ids": response_ids, "full_ids": outputs[0] }) self.model.train() return rollouts def train_step( self, batch: List[Dict[str, Any]] ) -> Dict[str, float]: """Execute one training step on a batch.""" self.model.train() # Generate rollouts prompts = [ex["prompt"] for ex in batch] ground_truths = [ex.get("answer", None) for ex in batch] rollouts = self.generate_rollouts( prompts, self.train_config.num_rollouts_per_prompt ) # Compute rewards rewards = [] for i, rollout in enumerate(rollouts): prompt_idx = i // self.train_config.num_rollouts_per_prompt gt = ground_truths[prompt_idx] if prompt_idx < len(ground_truths) else None reward = compute_math_reward(rollout["prompt"], rollout["response"], gt) rewards.append(reward) rewards_tensor = torch.tensor(rewards, device=self.device, dtype=torch.float32) # Skip if all rewards are the same (no learning signal) if rewards_tensor.std() < 1e-6: return {"skipped": 1.0} # Normalize rewards per prompt (advantage estimation) rewards_per_prompt = rewards_tensor.view(-1, self.train_config.num_rollouts_per_prompt) normalized_rewards = (rewards_per_prompt - rewards_per_prompt.mean(dim=1, keepdim=True)) normalized_rewards = normalized_rewards / (rewards_per_prompt.std(dim=1, keepdim=True) + 1e-8) normalized_rewards = normalized_rewards.view(-1) # Prepare for training (DeepSpeed handles zero_grad internally) if not self.use_deepspeed: self.optimizer.zero_grad() total_loss = 0.0 all_metrics: Dict[str, List[float]] = {} # Process rollouts in micro-batches num_rollouts = len(rollouts) micro_batch_size = self.train_config.micro_batch_size for mb_start in range(0, num_rollouts, micro_batch_size): mb_end = min(mb_start + micro_batch_size, num_rollouts) mb_rollouts = rollouts[mb_start:mb_end] mb_rewards = normalized_rewards[mb_start:mb_end] # Prepare batch tensors max_len = max(len(r["full_ids"]) for r in mb_rollouts) batch_input_ids = torch.zeros(len(mb_rollouts), max_len, dtype=torch.long, device=self.device) batch_attention_mask = torch.zeros(len(mb_rollouts), max_len, dtype=torch.long, device=self.device) batch_response_mask = torch.zeros(len(mb_rollouts), max_len - 1, dtype=torch.float32, device=self.device) for i, rollout in enumerate(mb_rollouts): seq_len = len(rollout["full_ids"]) batch_input_ids[i, :seq_len] = rollout["full_ids"] batch_attention_mask[i, :seq_len] = 1 prompt_len = len(rollout["input_ids"]) batch_response_mask[i, prompt_len-1:seq_len-1] = 1 # Compute log probs for policy and reference policy_log_probs = self.compute_log_probs( self.model, batch_input_ids, batch_attention_mask, batch_input_ids ) with torch.no_grad(): # Move tensors to ref_device if different from training device if self.ref_device != self.device: ref_input_ids = batch_input_ids.to(self.ref_device) ref_attention_mask = batch_attention_mask.to(self.ref_device) else: ref_input_ids = batch_input_ids ref_attention_mask = batch_attention_mask ref_log_probs = self.compute_log_probs( self.ref_model, ref_input_ids, ref_attention_mask, ref_input_ids ) # Move ref_log_probs back to training device if self.ref_device != self.device: ref_log_probs = ref_log_probs.to(self.device) # Compute DAPO loss loss, metrics = self.compute_dapo_loss( policy_log_probs, ref_log_probs, mb_rewards, batch_response_mask ) # Scale loss for gradient accumulation (DeepSpeed handles this internally) if self.use_deepspeed: scaled_loss = loss else: scaled_loss = loss / self.train_config.grad_accumulation_steps # Backward pass if self.use_deepspeed: self.model_engine.backward(scaled_loss) elif self.scaler is not None: self.scaler.scale(scaled_loss).backward() else: scaled_loss.backward() total_loss += loss.item() for k, v in metrics.items(): if k not in all_metrics: all_metrics[k] = [] all_metrics[k].append(v) # Optimizer step if self.use_deepspeed: self.model_engine.step() elif self.scaler is not None: self.scaler.step(self.optimizer) self.scaler.update() else: self.optimizer.step() self.global_step += 1 # Aggregate metrics step_metrics = {k: np.mean(v) for k, v in all_metrics.items()} step_metrics["total_loss"] = total_loss step_metrics["step"] = self.global_step self.metrics_history.append(step_metrics) return step_metrics def train( self, train_data: List[Dict[str, Any]], save_checkpoints: bool = True ) -> None: """Run the full training loop.""" logger.info(f"Starting training for {self.train_config.num_steps} steps") checkpoint_steps = set(self.train_config.checkpoint_steps) for step in range(self.train_config.num_steps): # Sample batch batch = sample_batch( train_data, self.train_config.global_batch_size // self.train_config.num_rollouts_per_prompt, self.rng ) # Training step metrics = self.train_step(batch) # Logging if step % 10 == 0: logger.info( f"Step {step}/{self.train_config.num_steps} | " f"Loss: {metrics.get('total_loss', 0):.4f} | " f"KL: {metrics.get('approx_kl', 0):.4f} | " f"Reward: {metrics.get('mean_reward', 0):.4f}" ) # Checkpointing if save_checkpoints and (step + 1) in checkpoint_steps: self.save_checkpoint(step + 1) def save_checkpoint(self, step: int) -> None: """Save model checkpoint.""" ckpt_dir = os.path.join( self.train_config.output_dir, f"checkpoint_step{step}" ) os.makedirs(ckpt_dir, exist_ok=True) if self.use_deepspeed: # Use DeepSpeed's save_checkpoint for proper ZeRO handling self.model_engine.save_checkpoint(ckpt_dir, tag=f"step{step}") else: self.model.save_pretrained(ckpt_dir) # Save training state state = { "step": step, "optimizer_state_dict": self.optimizer.state_dict(), "metrics_history": self.metrics_history, } torch.save(state, os.path.join(ckpt_dir, "training_state.pt")) self.tokenizer.save_pretrained(ckpt_dir) logger.info(f"Saved checkpoint at step {step} to {ckpt_dir}") def save_final(self) -> str: """Save final model.""" final_dir = os.path.join(self.train_config.output_dir, "final_model") os.makedirs(final_dir, exist_ok=True) if self.use_deepspeed: # Use DeepSpeed's save_checkpoint for proper ZeRO-3 weight gathering self.model_engine.save_checkpoint(final_dir, tag="final") else: self.model.save_pretrained(final_dir) self.tokenizer.save_pretrained(final_dir) # Save metrics with open(os.path.join(final_dir, "metrics_history.json"), "w") as f: json.dump(self.metrics_history, f, indent=2) logger.info(f"Saved final model to {final_dir}") return final_dir # ============================================================================ # Main Entry Point # ============================================================================ def parse_args() -> argparse.Namespace: """Parse command line arguments.""" parser = argparse.ArgumentParser( description="RLVR Training with DAPO Algorithm" ) parser.add_argument( "--precision_mode", type=str, required=True, choices=["fp32", "bf16"], help="Precision mode: fp32 (high precision) or bf16 (default RLVR)" ) parser.add_argument( "--seed", type=int, required=True, help="Random seed for reproducibility" ) parser.add_argument( "--output_dir", type=str, required=True, help="Directory to save outputs" ) parser.add_argument( "--train_dataset_path", type=str, required=True, help="Path to training dataset JSON" ) parser.add_argument( "--model_name", type=str, default="Qwen/Qwen2.5-Math-7B", help="HuggingFace model identifier" ) parser.add_argument( "--num_steps", type=int, default=300, help="Number of training steps" ) parser.add_argument( "--device", type=str, default="cuda", help="Device to use for training" ) parser.add_argument( "--deepspeed", type=str, default=None, help="Path to DeepSpeed config file" ) parser.add_argument( "--local_rank", type=int, default=-1, help="Local rank for distributed training (set by DeepSpeed)" ) return parser.parse_args() def main() -> None: """Main training function.""" args = parse_args() # Create configurations train_config = make_training_config( precision_mode=args.precision_mode, seed=args.seed, output_dir=args.output_dir, train_dataset_path=args.train_dataset_path, model_name=args.model_name ) train_config.num_steps = args.num_steps precision_config = make_precision_config(args.precision_mode) # Setup output directory os.makedirs(train_config.output_dir, exist_ok=True) # Save configurations with open(os.path.join(train_config.output_dir, "train_config.json"), "w") as f: json.dump(asdict(train_config), f, indent=2) with open(os.path.join(train_config.output_dir, "precision_config.json"), "w") as f: json.dump(asdict(precision_config), f, indent=2) # Set seeds and determinism set_seed(train_config.seed) configure_torch_deterministic(precision_config.deterministic) # Check if using DeepSpeed use_deepspeed = args.deepspeed is not None # Setup devices num_gpus = torch.cuda.device_count() if use_deepspeed: # Initialize DeepSpeed distributed deepspeed.init_distributed() local_rank = args.local_rank if args.local_rank >= 0 else int(os.environ.get("LOCAL_RANK", 0)) device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(device) # Put reference model on same GPU as training (each rank has its own copy) # ZeRO-2 shards optimizer states, so there's room for the bf16 ref model (~14GB) ref_device = device logger.info(f"DeepSpeed: rank {local_rank}, training on {device}, ref model on {ref_device}") else: device = torch.device(args.device if torch.cuda.is_available() else "cpu") if num_gpus >= 2: ref_device = torch.device("cuda:1") logger.info(f"Using device: {device} for training, {ref_device} for reference model") else: ref_device = device logger.info(f"Using device: {device} for both training and reference model") # Load tokenizer tokenizer = AutoTokenizer.from_pretrained( train_config.model_name, use_fast=True, trust_remote_code=True ) tokenizer.padding_side = "left" if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load model logger.info(f"Loading model: {train_config.model_name}") model = AutoModelForCausalLM.from_pretrained( train_config.model_name, torch_dtype=torch.float32, # Load in FP32 first device_map=None, trust_remote_code=True ) # Cast to target precision model = cast_model_param_dtype(model, precision_config.param_dtype) # Enable gradient checkpointing to save memory (trades compute for memory) model.gradient_checkpointing_enable() logger.info("Enabled gradient checkpointing to reduce memory usage") # Disable dropout if needed if not train_config.use_dropout: disable_dropout(model) logger.info(f"Model loaded with {count_parameters(model):,} trainable parameters") # Initialize DeepSpeed or move model to device if use_deepspeed: # Create optimizer for DeepSpeed optimizer = torch.optim.AdamW( model.parameters(), lr=train_config.learning_rate, betas=(train_config.beta1, train_config.beta2), weight_decay=train_config.weight_decay ) # Load DeepSpeed config with open(args.deepspeed, 'r') as f: ds_config = json.load(f) # Compute batch sizes compatible with DeepSpeed # DeepSpeed requires: train_batch_size = micro_batch * grad_acc * world_size world_size = int(os.environ.get("WORLD_SIZE", num_gpus - 1)) # -1 for ref model GPU micro_batch = train_config.micro_batch_size # Compute grad_acc to get closest to desired global batch size desired_global = train_config.global_batch_size grad_acc = max(1, desired_global // (micro_batch * world_size)) actual_global = micro_batch * grad_acc * world_size ds_config["train_batch_size"] = actual_global ds_config["train_micro_batch_size_per_gpu"] = micro_batch ds_config["gradient_accumulation_steps"] = grad_acc logger.info(f"DeepSpeed batch config: global={actual_global}, micro={micro_batch}, grad_acc={grad_acc}, world_size={world_size}") # Initialize DeepSpeed engine model_engine, optimizer, _, _ = deepspeed.initialize( model=model, optimizer=optimizer, config=ds_config ) logger.info(f"DeepSpeed ZeRO-2 initialized with {world_size} GPUs") else: model = model.to(device) model_engine = model # Load reference model (frozen copy) # Always use bf16 for reference model - it's frozen and only used for KL computation. # This is a controlled variable: same precision for both fp32 and bf16 training modes. # The experiment tests training precision effects, not reference model precision. logger.info("Loading reference model (bf16 for both modes - controlled variable)") ref_model = AutoModelForCausalLM.from_pretrained( train_config.model_name, torch_dtype=torch.bfloat16, # Always bf16 to save memory & control variable device_map=None, trust_remote_code=True ) ref_model = ref_model.to(ref_device) ref_model.eval() # Load training data train_data = load_training_data(train_config.train_dataset_path) # Initialize trainer trainer = DAPOTrainer( model_engine=model_engine, ref_model=ref_model, tokenizer=tokenizer, train_config=train_config, precision_config=precision_config, device=device, ref_device=ref_device, use_deepspeed=use_deepspeed ) # Run training trainer.train(train_data, save_checkpoints=True) # Save final model final_path = trainer.save_final() logger.info(f"Training complete. Final model saved to: {final_path}") if __name__ == "__main__": main()