summaryrefslogtreecommitdiff
path: root/train_rlvr.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-04 18:59:35 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-04 18:59:35 -0600
commitf1c2cc22d46a6976df3555391e667c7e61592fad (patch)
tree0b37b52c8ff91042a742d3b3ec54542cb6d6e2f6 /train_rlvr.py
Initial commit: RL floating-point noise projectHEADmain
Diffstat (limited to 'train_rlvr.py')
-rw-r--r--train_rlvr.py849
1 files changed, 849 insertions, 0 deletions
diff --git a/train_rlvr.py b/train_rlvr.py
new file mode 100644
index 0000000..1076df8
--- /dev/null
+++ b/train_rlvr.py
@@ -0,0 +1,849 @@
+#!/usr/bin/env python3
+# train_rlvr.py
+"""
+RLVR Training Script with DAPO Algorithm.
+
+This script implements the training loop for reinforcement learning with
+verifiable rewards (RLVR) using the DAPO algorithm. It supports two precision
+configurations:
+
+- P-high (fp32): High precision master weights for low numerical noise
+- P-bf16: Default RLVR configuration with bf16 master weights
+
+The script integrates with VeRL framework for distributed RL training.
+
+Usage:
+ python train_rlvr.py \
+ --precision_mode fp32 \
+ --seed 1 \
+ --output_dir results/train_logs/fp32_seed1 \
+ --train_dataset_path data/dm_train.json
+"""
+
+import argparse
+import json
+import os
+import random
+import logging
+from typing import Dict, Any, List, Optional, Tuple
+from dataclasses import asdict
+
+import numpy as np
+import torch
+import torch.distributed as dist
+from torch.cuda.amp import autocast, GradScaler
+from transformers import AutoModelForCausalLM, AutoTokenizer
+import deepspeed
+
+from config import (
+ make_training_config,
+ make_precision_config,
+ TrainingConfig,
+ PrecisionConfig,
+)
+
+# Configure logging
+logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+)
+logger = logging.getLogger(__name__)
+
+
+# ============================================================================
+# Seed and Determinism Utilities
+# ============================================================================
+
+def set_seed(seed: int) -> None:
+ """Set random seeds for reproducibility."""
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ # For hash-based operations
+ os.environ["PYTHONHASHSEED"] = str(seed)
+ logger.info(f"Set random seed to {seed}")
+
+
+def configure_torch_deterministic(deterministic: bool) -> None:
+ """Configure PyTorch deterministic algorithms."""
+ if deterministic:
+ torch.use_deterministic_algorithms(True, warn_only=True)
+ torch.backends.cudnn.benchmark = False
+ torch.backends.cudnn.deterministic = True
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
+ logger.info("Enabled deterministic algorithms")
+ else:
+ torch.use_deterministic_algorithms(False)
+ torch.backends.cudnn.benchmark = True
+ torch.backends.cudnn.deterministic = False
+ logger.info("Using non-deterministic algorithms (default)")
+
+
+# ============================================================================
+# Model Utilities
+# ============================================================================
+
+def get_torch_dtype(dtype_str: str) -> torch.dtype:
+ """Convert string dtype to torch.dtype."""
+ dtype_map = {
+ "float32": torch.float32,
+ "float16": torch.float16,
+ "bfloat16": torch.bfloat16,
+ }
+ if dtype_str not in dtype_map:
+ raise ValueError(f"Unknown dtype: {dtype_str}")
+ return dtype_map[dtype_str]
+
+
+def cast_model_param_dtype(
+ model: torch.nn.Module,
+ param_dtype: str
+) -> torch.nn.Module:
+ """Cast model parameters to specified dtype."""
+ dtype = get_torch_dtype(param_dtype)
+ model.to(dtype=dtype)
+ logger.info(f"Cast model parameters to {param_dtype}")
+ return model
+
+
+def disable_dropout(model: torch.nn.Module) -> None:
+ """Disable all dropout layers in the model."""
+ for module in model.modules():
+ if isinstance(module, torch.nn.Dropout):
+ module.p = 0.0
+ logger.info("Disabled all dropout layers")
+
+
+def count_parameters(model: torch.nn.Module) -> int:
+ """Count trainable parameters."""
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+
+# ============================================================================
+# Data Loading
+# ============================================================================
+
+def load_training_data(dataset_path: str) -> List[Dict[str, Any]]:
+ """Load training dataset from JSON file."""
+ with open(dataset_path, "r", encoding="utf-8") as f:
+ data = json.load(f)
+ logger.info(f"Loaded {len(data)} training examples from {dataset_path}")
+ return data
+
+
+def sample_batch(
+ data: List[Dict[str, Any]],
+ batch_size: int,
+ rng: np.random.Generator
+) -> List[Dict[str, Any]]:
+ """Sample a batch of prompts from the dataset."""
+ indices = rng.choice(len(data), size=min(batch_size, len(data)), replace=False)
+ return [data[i] for i in indices]
+
+
+# ============================================================================
+# Reward Function (Math Verifier)
+# ============================================================================
+
+def compute_math_reward(
+ prompt: str,
+ response: str,
+ ground_truth: Optional[str] = None
+) -> float:
+ """
+ Compute reward for a math problem response.
+
+ Uses a simple rule-based verifier. In production, this should be replaced
+ with Eval-Chemy or similar math verification system.
+
+ Args:
+ prompt: The math problem prompt
+ response: The model's generated response
+ ground_truth: The expected answer (if available)
+
+ Returns:
+ +1.0 if correct, -1.0 if incorrect
+ """
+ # TODO: Replace with actual math verifier (Eval-Chemy)
+ # This is a placeholder implementation
+
+ if ground_truth is None:
+ # Cannot verify without ground truth
+ return 0.0
+
+ # Extract final answer from response (simple heuristic)
+ response_lower = response.lower().strip()
+ gt_lower = ground_truth.lower().strip()
+
+ # Check for common answer formats
+ answer_markers = ["the answer is", "therefore", "=", "\\boxed{"]
+
+ for marker in answer_markers:
+ if marker in response_lower:
+ idx = response_lower.rfind(marker)
+ potential_answer = response_lower[idx:].strip()
+ if gt_lower in potential_answer:
+ return 1.0
+
+ # Direct containment check as fallback
+ if gt_lower in response_lower:
+ return 1.0
+
+ return -1.0
+
+
+# ============================================================================
+# DAPO Algorithm Implementation
+# ============================================================================
+
+class DAPOTrainer:
+ """
+ DAPO (Direct Alignment from Preferences Optimization) Trainer.
+
+ Implements clip-only DAPO with implicit KL constraint through ratio clipping.
+ This is a simplified implementation - for production use VeRL's DapoTrainer.
+ """
+
+ def __init__(
+ self,
+ model_engine, # DeepSpeed engine or raw model
+ ref_model: torch.nn.Module,
+ tokenizer,
+ train_config: TrainingConfig,
+ precision_config: PrecisionConfig,
+ device: torch.device,
+ ref_device: Optional[torch.device] = None,
+ use_deepspeed: bool = False
+ ) -> None:
+ self.use_deepspeed = use_deepspeed
+ if use_deepspeed:
+ self.model_engine = model_engine
+ self.model = model_engine.module
+ else:
+ self.model = model_engine
+ self.model_engine = None
+ self.ref_model = ref_model
+ self.tokenizer = tokenizer
+ self.train_config = train_config
+ self.precision_config = precision_config
+ self.device = device
+ self.ref_device = ref_device if ref_device is not None else device
+
+ # Freeze reference model
+ for param in self.ref_model.parameters():
+ param.requires_grad = False
+ self.ref_model.eval()
+
+ # Setup optimizer (only if not using DeepSpeed)
+ if not use_deepspeed:
+ self.optimizer = torch.optim.AdamW(
+ self.model.parameters(),
+ lr=train_config.learning_rate,
+ betas=(train_config.beta1, train_config.beta2),
+ weight_decay=train_config.weight_decay
+ )
+ else:
+ self.optimizer = None # DeepSpeed manages optimizer
+
+ # Setup AMP scaler if using fp16 (not needed with DeepSpeed)
+ self.scaler = None
+ if not use_deepspeed and precision_config.use_amp and precision_config.amp_dtype == "float16":
+ self.scaler = GradScaler()
+
+ # Training state
+ self.global_step = 0
+ self.rng = np.random.default_rng(train_config.seed)
+
+ # Metrics tracking
+ self.metrics_history: List[Dict[str, float]] = []
+
+ def compute_log_probs(
+ self,
+ model: torch.nn.Module,
+ input_ids: torch.Tensor,
+ attention_mask: torch.Tensor,
+ labels: torch.Tensor
+ ) -> torch.Tensor:
+ """Compute token-level log probabilities."""
+ with autocast(
+ enabled=self.precision_config.use_amp,
+ dtype=get_torch_dtype(self.precision_config.amp_dtype)
+ ):
+ outputs = model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ use_cache=False
+ )
+
+ logits = outputs.logits
+ log_probs = torch.log_softmax(logits, dim=-1)
+
+ # Gather log probs for actual tokens
+ # Shift for autoregressive: predict next token
+ shift_log_probs = log_probs[:, :-1, :]
+ shift_labels = labels[:, 1:]
+
+ token_log_probs = torch.gather(
+ shift_log_probs,
+ dim=-1,
+ index=shift_labels.unsqueeze(-1)
+ ).squeeze(-1)
+
+ return token_log_probs
+
+ def compute_dapo_loss(
+ self,
+ policy_log_probs: torch.Tensor,
+ ref_log_probs: torch.Tensor,
+ rewards: torch.Tensor,
+ response_mask: torch.Tensor
+ ) -> Tuple[torch.Tensor, Dict[str, float]]:
+ """
+ Compute DAPO clip-only loss.
+
+ DAPO uses ratio clipping without explicit KL penalty (beta=0).
+ The clipping provides implicit KL constraint (Gate I).
+ """
+ # Compute log ratios
+ log_ratios = policy_log_probs - ref_log_probs
+
+ # Sum log ratios over response tokens
+ masked_log_ratios = log_ratios * response_mask
+ sequence_log_ratios = masked_log_ratios.sum(dim=-1)
+
+ # Compute importance sampling ratios
+ ratios = torch.exp(sequence_log_ratios)
+
+ # DAPO objective with clipping
+ clip_ratio = self.train_config.clip_ratio
+
+ # Advantage estimation (simplified: just use rewards)
+ advantages = rewards
+
+ # Clipped surrogate objective
+ unclipped = ratios * advantages
+ clipped = torch.clamp(ratios, 1 - clip_ratio, 1 + clip_ratio) * advantages
+
+ # Take minimum for pessimistic update
+ loss = -torch.min(unclipped, clipped).mean()
+
+ # Compute metrics
+ with torch.no_grad():
+ approx_kl = (log_ratios * response_mask).sum() / response_mask.sum()
+ clip_fraction = ((ratios - 1).abs() > clip_ratio).float().mean()
+
+ metrics = {
+ "loss": loss.item(),
+ "approx_kl": approx_kl.item(),
+ "clip_fraction": clip_fraction.item(),
+ "mean_ratio": ratios.mean().item(),
+ "mean_reward": rewards.mean().item(),
+ }
+
+ return loss, metrics
+
+ def generate_rollouts(
+ self,
+ prompts: List[str],
+ num_samples: int
+ ) -> List[Dict[str, Any]]:
+ """Generate rollouts for a batch of prompts."""
+ rollouts = []
+
+ self.model.eval()
+ with torch.no_grad():
+ for prompt in prompts:
+ inputs = self.tokenizer(
+ prompt,
+ return_tensors="pt",
+ truncation=True,
+ max_length=self.train_config.max_seq_len // 2
+ ).to(self.device)
+
+ for _ in range(num_samples):
+ with autocast(
+ enabled=self.precision_config.use_amp,
+ dtype=get_torch_dtype(self.precision_config.amp_dtype)
+ ):
+ outputs = self.model.generate(
+ **inputs,
+ max_new_tokens=self.train_config.max_seq_len // 2,
+ do_sample=True,
+ temperature=0.7,
+ top_p=0.8,
+ pad_token_id=self.tokenizer.eos_token_id
+ )
+
+ response_ids = outputs[0, inputs["input_ids"].shape[1]:]
+ response_text = self.tokenizer.decode(
+ response_ids,
+ skip_special_tokens=True
+ )
+
+ rollouts.append({
+ "prompt": prompt,
+ "response": response_text,
+ "input_ids": inputs["input_ids"][0],
+ "response_ids": response_ids,
+ "full_ids": outputs[0]
+ })
+
+ self.model.train()
+ return rollouts
+
+ def train_step(
+ self,
+ batch: List[Dict[str, Any]]
+ ) -> Dict[str, float]:
+ """Execute one training step on a batch."""
+ self.model.train()
+
+ # Generate rollouts
+ prompts = [ex["prompt"] for ex in batch]
+ ground_truths = [ex.get("answer", None) for ex in batch]
+
+ rollouts = self.generate_rollouts(
+ prompts,
+ self.train_config.num_rollouts_per_prompt
+ )
+
+ # Compute rewards
+ rewards = []
+ for i, rollout in enumerate(rollouts):
+ prompt_idx = i // self.train_config.num_rollouts_per_prompt
+ gt = ground_truths[prompt_idx] if prompt_idx < len(ground_truths) else None
+ reward = compute_math_reward(rollout["prompt"], rollout["response"], gt)
+ rewards.append(reward)
+
+ rewards_tensor = torch.tensor(rewards, device=self.device, dtype=torch.float32)
+
+ # Skip if all rewards are the same (no learning signal)
+ if rewards_tensor.std() < 1e-6:
+ return {"skipped": 1.0}
+
+ # Normalize rewards per prompt (advantage estimation)
+ rewards_per_prompt = rewards_tensor.view(-1, self.train_config.num_rollouts_per_prompt)
+ normalized_rewards = (rewards_per_prompt - rewards_per_prompt.mean(dim=1, keepdim=True))
+ normalized_rewards = normalized_rewards / (rewards_per_prompt.std(dim=1, keepdim=True) + 1e-8)
+ normalized_rewards = normalized_rewards.view(-1)
+
+ # Prepare for training (DeepSpeed handles zero_grad internally)
+ if not self.use_deepspeed:
+ self.optimizer.zero_grad()
+
+ total_loss = 0.0
+ all_metrics: Dict[str, List[float]] = {}
+
+ # Process rollouts in micro-batches
+ num_rollouts = len(rollouts)
+ micro_batch_size = self.train_config.micro_batch_size
+
+ for mb_start in range(0, num_rollouts, micro_batch_size):
+ mb_end = min(mb_start + micro_batch_size, num_rollouts)
+ mb_rollouts = rollouts[mb_start:mb_end]
+ mb_rewards = normalized_rewards[mb_start:mb_end]
+
+ # Prepare batch tensors
+ max_len = max(len(r["full_ids"]) for r in mb_rollouts)
+ batch_input_ids = torch.zeros(len(mb_rollouts), max_len, dtype=torch.long, device=self.device)
+ batch_attention_mask = torch.zeros(len(mb_rollouts), max_len, dtype=torch.long, device=self.device)
+ batch_response_mask = torch.zeros(len(mb_rollouts), max_len - 1, dtype=torch.float32, device=self.device)
+
+ for i, rollout in enumerate(mb_rollouts):
+ seq_len = len(rollout["full_ids"])
+ batch_input_ids[i, :seq_len] = rollout["full_ids"]
+ batch_attention_mask[i, :seq_len] = 1
+ prompt_len = len(rollout["input_ids"])
+ batch_response_mask[i, prompt_len-1:seq_len-1] = 1
+
+ # Compute log probs for policy and reference
+ policy_log_probs = self.compute_log_probs(
+ self.model,
+ batch_input_ids,
+ batch_attention_mask,
+ batch_input_ids
+ )
+
+ with torch.no_grad():
+ # Move tensors to ref_device if different from training device
+ if self.ref_device != self.device:
+ ref_input_ids = batch_input_ids.to(self.ref_device)
+ ref_attention_mask = batch_attention_mask.to(self.ref_device)
+ else:
+ ref_input_ids = batch_input_ids
+ ref_attention_mask = batch_attention_mask
+
+ ref_log_probs = self.compute_log_probs(
+ self.ref_model,
+ ref_input_ids,
+ ref_attention_mask,
+ ref_input_ids
+ )
+
+ # Move ref_log_probs back to training device
+ if self.ref_device != self.device:
+ ref_log_probs = ref_log_probs.to(self.device)
+
+ # Compute DAPO loss
+ loss, metrics = self.compute_dapo_loss(
+ policy_log_probs,
+ ref_log_probs,
+ mb_rewards,
+ batch_response_mask
+ )
+
+ # Scale loss for gradient accumulation (DeepSpeed handles this internally)
+ if self.use_deepspeed:
+ scaled_loss = loss
+ else:
+ scaled_loss = loss / self.train_config.grad_accumulation_steps
+
+ # Backward pass
+ if self.use_deepspeed:
+ self.model_engine.backward(scaled_loss)
+ elif self.scaler is not None:
+ self.scaler.scale(scaled_loss).backward()
+ else:
+ scaled_loss.backward()
+
+ total_loss += loss.item()
+
+ for k, v in metrics.items():
+ if k not in all_metrics:
+ all_metrics[k] = []
+ all_metrics[k].append(v)
+
+ # Optimizer step
+ if self.use_deepspeed:
+ self.model_engine.step()
+ elif self.scaler is not None:
+ self.scaler.step(self.optimizer)
+ self.scaler.update()
+ else:
+ self.optimizer.step()
+
+ self.global_step += 1
+
+ # Aggregate metrics
+ step_metrics = {k: np.mean(v) for k, v in all_metrics.items()}
+ step_metrics["total_loss"] = total_loss
+ step_metrics["step"] = self.global_step
+
+ self.metrics_history.append(step_metrics)
+
+ return step_metrics
+
+ def train(
+ self,
+ train_data: List[Dict[str, Any]],
+ save_checkpoints: bool = True
+ ) -> None:
+ """Run the full training loop."""
+ logger.info(f"Starting training for {self.train_config.num_steps} steps")
+
+ checkpoint_steps = set(self.train_config.checkpoint_steps)
+
+ for step in range(self.train_config.num_steps):
+ # Sample batch
+ batch = sample_batch(
+ train_data,
+ self.train_config.global_batch_size // self.train_config.num_rollouts_per_prompt,
+ self.rng
+ )
+
+ # Training step
+ metrics = self.train_step(batch)
+
+ # Logging
+ if step % 10 == 0:
+ logger.info(
+ f"Step {step}/{self.train_config.num_steps} | "
+ f"Loss: {metrics.get('total_loss', 0):.4f} | "
+ f"KL: {metrics.get('approx_kl', 0):.4f} | "
+ f"Reward: {metrics.get('mean_reward', 0):.4f}"
+ )
+
+ # Checkpointing
+ if save_checkpoints and (step + 1) in checkpoint_steps:
+ self.save_checkpoint(step + 1)
+
+ def save_checkpoint(self, step: int) -> None:
+ """Save model checkpoint."""
+ ckpt_dir = os.path.join(
+ self.train_config.output_dir,
+ f"checkpoint_step{step}"
+ )
+ os.makedirs(ckpt_dir, exist_ok=True)
+
+ if self.use_deepspeed:
+ # Use DeepSpeed's save_checkpoint for proper ZeRO handling
+ self.model_engine.save_checkpoint(ckpt_dir, tag=f"step{step}")
+ else:
+ self.model.save_pretrained(ckpt_dir)
+ # Save training state
+ state = {
+ "step": step,
+ "optimizer_state_dict": self.optimizer.state_dict(),
+ "metrics_history": self.metrics_history,
+ }
+ torch.save(state, os.path.join(ckpt_dir, "training_state.pt"))
+
+ self.tokenizer.save_pretrained(ckpt_dir)
+ logger.info(f"Saved checkpoint at step {step} to {ckpt_dir}")
+
+ def save_final(self) -> str:
+ """Save final model."""
+ final_dir = os.path.join(self.train_config.output_dir, "final_model")
+ os.makedirs(final_dir, exist_ok=True)
+
+ if self.use_deepspeed:
+ # Use DeepSpeed's save_checkpoint for proper ZeRO-3 weight gathering
+ self.model_engine.save_checkpoint(final_dir, tag="final")
+ else:
+ self.model.save_pretrained(final_dir)
+
+ self.tokenizer.save_pretrained(final_dir)
+
+ # Save metrics
+ with open(os.path.join(final_dir, "metrics_history.json"), "w") as f:
+ json.dump(self.metrics_history, f, indent=2)
+
+ logger.info(f"Saved final model to {final_dir}")
+ return final_dir
+
+
+# ============================================================================
+# Main Entry Point
+# ============================================================================
+
+def parse_args() -> argparse.Namespace:
+ """Parse command line arguments."""
+ parser = argparse.ArgumentParser(
+ description="RLVR Training with DAPO Algorithm"
+ )
+ parser.add_argument(
+ "--precision_mode",
+ type=str,
+ required=True,
+ choices=["fp32", "bf16"],
+ help="Precision mode: fp32 (high precision) or bf16 (default RLVR)"
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ required=True,
+ help="Random seed for reproducibility"
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ required=True,
+ help="Directory to save outputs"
+ )
+ parser.add_argument(
+ "--train_dataset_path",
+ type=str,
+ required=True,
+ help="Path to training dataset JSON"
+ )
+ parser.add_argument(
+ "--model_name",
+ type=str,
+ default="Qwen/Qwen2.5-Math-7B",
+ help="HuggingFace model identifier"
+ )
+ parser.add_argument(
+ "--num_steps",
+ type=int,
+ default=300,
+ help="Number of training steps"
+ )
+ parser.add_argument(
+ "--device",
+ type=str,
+ default="cuda",
+ help="Device to use for training"
+ )
+ parser.add_argument(
+ "--deepspeed",
+ type=str,
+ default=None,
+ help="Path to DeepSpeed config file"
+ )
+ parser.add_argument(
+ "--local_rank",
+ type=int,
+ default=-1,
+ help="Local rank for distributed training (set by DeepSpeed)"
+ )
+ return parser.parse_args()
+
+
+def main() -> None:
+ """Main training function."""
+ args = parse_args()
+
+ # Create configurations
+ train_config = make_training_config(
+ precision_mode=args.precision_mode,
+ seed=args.seed,
+ output_dir=args.output_dir,
+ train_dataset_path=args.train_dataset_path,
+ model_name=args.model_name
+ )
+ train_config.num_steps = args.num_steps
+
+ precision_config = make_precision_config(args.precision_mode)
+
+ # Setup output directory
+ os.makedirs(train_config.output_dir, exist_ok=True)
+
+ # Save configurations
+ with open(os.path.join(train_config.output_dir, "train_config.json"), "w") as f:
+ json.dump(asdict(train_config), f, indent=2)
+ with open(os.path.join(train_config.output_dir, "precision_config.json"), "w") as f:
+ json.dump(asdict(precision_config), f, indent=2)
+
+ # Set seeds and determinism
+ set_seed(train_config.seed)
+ configure_torch_deterministic(precision_config.deterministic)
+
+ # Check if using DeepSpeed
+ use_deepspeed = args.deepspeed is not None
+
+ # Setup devices
+ num_gpus = torch.cuda.device_count()
+
+ if use_deepspeed:
+ # Initialize DeepSpeed distributed
+ deepspeed.init_distributed()
+ local_rank = args.local_rank if args.local_rank >= 0 else int(os.environ.get("LOCAL_RANK", 0))
+ device = torch.device(f"cuda:{local_rank}")
+ torch.cuda.set_device(device)
+ # Put reference model on same GPU as training (each rank has its own copy)
+ # ZeRO-2 shards optimizer states, so there's room for the bf16 ref model (~14GB)
+ ref_device = device
+ logger.info(f"DeepSpeed: rank {local_rank}, training on {device}, ref model on {ref_device}")
+ else:
+ device = torch.device(args.device if torch.cuda.is_available() else "cpu")
+ if num_gpus >= 2:
+ ref_device = torch.device("cuda:1")
+ logger.info(f"Using device: {device} for training, {ref_device} for reference model")
+ else:
+ ref_device = device
+ logger.info(f"Using device: {device} for both training and reference model")
+
+ # Load tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(
+ train_config.model_name,
+ use_fast=True,
+ trust_remote_code=True
+ )
+ tokenizer.padding_side = "left"
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token
+
+ # Load model
+ logger.info(f"Loading model: {train_config.model_name}")
+ model = AutoModelForCausalLM.from_pretrained(
+ train_config.model_name,
+ torch_dtype=torch.float32, # Load in FP32 first
+ device_map=None,
+ trust_remote_code=True
+ )
+
+ # Cast to target precision
+ model = cast_model_param_dtype(model, precision_config.param_dtype)
+
+ # Enable gradient checkpointing to save memory (trades compute for memory)
+ model.gradient_checkpointing_enable()
+ logger.info("Enabled gradient checkpointing to reduce memory usage")
+
+ # Disable dropout if needed
+ if not train_config.use_dropout:
+ disable_dropout(model)
+
+ logger.info(f"Model loaded with {count_parameters(model):,} trainable parameters")
+
+ # Initialize DeepSpeed or move model to device
+ if use_deepspeed:
+ # Create optimizer for DeepSpeed
+ optimizer = torch.optim.AdamW(
+ model.parameters(),
+ lr=train_config.learning_rate,
+ betas=(train_config.beta1, train_config.beta2),
+ weight_decay=train_config.weight_decay
+ )
+
+ # Load DeepSpeed config
+ with open(args.deepspeed, 'r') as f:
+ ds_config = json.load(f)
+
+ # Compute batch sizes compatible with DeepSpeed
+ # DeepSpeed requires: train_batch_size = micro_batch * grad_acc * world_size
+ world_size = int(os.environ.get("WORLD_SIZE", num_gpus - 1)) # -1 for ref model GPU
+ micro_batch = train_config.micro_batch_size
+
+ # Compute grad_acc to get closest to desired global batch size
+ desired_global = train_config.global_batch_size
+ grad_acc = max(1, desired_global // (micro_batch * world_size))
+ actual_global = micro_batch * grad_acc * world_size
+
+ ds_config["train_batch_size"] = actual_global
+ ds_config["train_micro_batch_size_per_gpu"] = micro_batch
+ ds_config["gradient_accumulation_steps"] = grad_acc
+
+ logger.info(f"DeepSpeed batch config: global={actual_global}, micro={micro_batch}, grad_acc={grad_acc}, world_size={world_size}")
+
+ # Initialize DeepSpeed engine
+ model_engine, optimizer, _, _ = deepspeed.initialize(
+ model=model,
+ optimizer=optimizer,
+ config=ds_config
+ )
+ logger.info(f"DeepSpeed ZeRO-2 initialized with {world_size} GPUs")
+ else:
+ model = model.to(device)
+ model_engine = model
+
+ # Load reference model (frozen copy)
+ # Always use bf16 for reference model - it's frozen and only used for KL computation.
+ # This is a controlled variable: same precision for both fp32 and bf16 training modes.
+ # The experiment tests training precision effects, not reference model precision.
+ logger.info("Loading reference model (bf16 for both modes - controlled variable)")
+ ref_model = AutoModelForCausalLM.from_pretrained(
+ train_config.model_name,
+ torch_dtype=torch.bfloat16, # Always bf16 to save memory & control variable
+ device_map=None,
+ trust_remote_code=True
+ )
+ ref_model = ref_model.to(ref_device)
+ ref_model.eval()
+
+ # Load training data
+ train_data = load_training_data(train_config.train_dataset_path)
+
+ # Initialize trainer
+ trainer = DAPOTrainer(
+ model_engine=model_engine,
+ ref_model=ref_model,
+ tokenizer=tokenizer,
+ train_config=train_config,
+ precision_config=precision_config,
+ device=device,
+ ref_device=ref_device,
+ use_deepspeed=use_deepspeed
+ )
+
+ # Run training
+ trainer.train(train_data, save_checkpoints=True)
+
+ # Save final model
+ final_path = trainer.save_final()
+ logger.info(f"Training complete. Final model saved to: {final_path}")
+
+
+if __name__ == "__main__":
+ main()
+