diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-04 18:59:35 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-04 18:59:35 -0600 |
| commit | f1c2cc22d46a6976df3555391e667c7e61592fad (patch) | |
| tree | 0b37b52c8ff91042a742d3b3ec54542cb6d6e2f6 /train_rlvr.py | |
Diffstat (limited to 'train_rlvr.py')
| -rw-r--r-- | train_rlvr.py | 849 |
1 files changed, 849 insertions, 0 deletions
diff --git a/train_rlvr.py b/train_rlvr.py new file mode 100644 index 0000000..1076df8 --- /dev/null +++ b/train_rlvr.py @@ -0,0 +1,849 @@ +#!/usr/bin/env python3 +# train_rlvr.py +""" +RLVR Training Script with DAPO Algorithm. + +This script implements the training loop for reinforcement learning with +verifiable rewards (RLVR) using the DAPO algorithm. It supports two precision +configurations: + +- P-high (fp32): High precision master weights for low numerical noise +- P-bf16: Default RLVR configuration with bf16 master weights + +The script integrates with VeRL framework for distributed RL training. + +Usage: + python train_rlvr.py \ + --precision_mode fp32 \ + --seed 1 \ + --output_dir results/train_logs/fp32_seed1 \ + --train_dataset_path data/dm_train.json +""" + +import argparse +import json +import os +import random +import logging +from typing import Dict, Any, List, Optional, Tuple +from dataclasses import asdict + +import numpy as np +import torch +import torch.distributed as dist +from torch.cuda.amp import autocast, GradScaler +from transformers import AutoModelForCausalLM, AutoTokenizer +import deepspeed + +from config import ( + make_training_config, + make_precision_config, + TrainingConfig, + PrecisionConfig, +) + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Seed and Determinism Utilities +# ============================================================================ + +def set_seed(seed: int) -> None: + """Set random seeds for reproducibility.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # For hash-based operations + os.environ["PYTHONHASHSEED"] = str(seed) + logger.info(f"Set random seed to {seed}") + + +def configure_torch_deterministic(deterministic: bool) -> None: + """Configure PyTorch deterministic algorithms.""" + if deterministic: + torch.use_deterministic_algorithms(True, warn_only=True) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + logger.info("Enabled deterministic algorithms") + else: + torch.use_deterministic_algorithms(False) + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + logger.info("Using non-deterministic algorithms (default)") + + +# ============================================================================ +# Model Utilities +# ============================================================================ + +def get_torch_dtype(dtype_str: str) -> torch.dtype: + """Convert string dtype to torch.dtype.""" + dtype_map = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + } + if dtype_str not in dtype_map: + raise ValueError(f"Unknown dtype: {dtype_str}") + return dtype_map[dtype_str] + + +def cast_model_param_dtype( + model: torch.nn.Module, + param_dtype: str +) -> torch.nn.Module: + """Cast model parameters to specified dtype.""" + dtype = get_torch_dtype(param_dtype) + model.to(dtype=dtype) + logger.info(f"Cast model parameters to {param_dtype}") + return model + + +def disable_dropout(model: torch.nn.Module) -> None: + """Disable all dropout layers in the model.""" + for module in model.modules(): + if isinstance(module, torch.nn.Dropout): + module.p = 0.0 + logger.info("Disabled all dropout layers") + + +def count_parameters(model: torch.nn.Module) -> int: + """Count trainable parameters.""" + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +# ============================================================================ +# Data Loading +# ============================================================================ + +def load_training_data(dataset_path: str) -> List[Dict[str, Any]]: + """Load training dataset from JSON file.""" + with open(dataset_path, "r", encoding="utf-8") as f: + data = json.load(f) + logger.info(f"Loaded {len(data)} training examples from {dataset_path}") + return data + + +def sample_batch( + data: List[Dict[str, Any]], + batch_size: int, + rng: np.random.Generator +) -> List[Dict[str, Any]]: + """Sample a batch of prompts from the dataset.""" + indices = rng.choice(len(data), size=min(batch_size, len(data)), replace=False) + return [data[i] for i in indices] + + +# ============================================================================ +# Reward Function (Math Verifier) +# ============================================================================ + +def compute_math_reward( + prompt: str, + response: str, + ground_truth: Optional[str] = None +) -> float: + """ + Compute reward for a math problem response. + + Uses a simple rule-based verifier. In production, this should be replaced + with Eval-Chemy or similar math verification system. + + Args: + prompt: The math problem prompt + response: The model's generated response + ground_truth: The expected answer (if available) + + Returns: + +1.0 if correct, -1.0 if incorrect + """ + # TODO: Replace with actual math verifier (Eval-Chemy) + # This is a placeholder implementation + + if ground_truth is None: + # Cannot verify without ground truth + return 0.0 + + # Extract final answer from response (simple heuristic) + response_lower = response.lower().strip() + gt_lower = ground_truth.lower().strip() + + # Check for common answer formats + answer_markers = ["the answer is", "therefore", "=", "\\boxed{"] + + for marker in answer_markers: + if marker in response_lower: + idx = response_lower.rfind(marker) + potential_answer = response_lower[idx:].strip() + if gt_lower in potential_answer: + return 1.0 + + # Direct containment check as fallback + if gt_lower in response_lower: + return 1.0 + + return -1.0 + + +# ============================================================================ +# DAPO Algorithm Implementation +# ============================================================================ + +class DAPOTrainer: + """ + DAPO (Direct Alignment from Preferences Optimization) Trainer. + + Implements clip-only DAPO with implicit KL constraint through ratio clipping. + This is a simplified implementation - for production use VeRL's DapoTrainer. + """ + + def __init__( + self, + model_engine, # DeepSpeed engine or raw model + ref_model: torch.nn.Module, + tokenizer, + train_config: TrainingConfig, + precision_config: PrecisionConfig, + device: torch.device, + ref_device: Optional[torch.device] = None, + use_deepspeed: bool = False + ) -> None: + self.use_deepspeed = use_deepspeed + if use_deepspeed: + self.model_engine = model_engine + self.model = model_engine.module + else: + self.model = model_engine + self.model_engine = None + self.ref_model = ref_model + self.tokenizer = tokenizer + self.train_config = train_config + self.precision_config = precision_config + self.device = device + self.ref_device = ref_device if ref_device is not None else device + + # Freeze reference model + for param in self.ref_model.parameters(): + param.requires_grad = False + self.ref_model.eval() + + # Setup optimizer (only if not using DeepSpeed) + if not use_deepspeed: + self.optimizer = torch.optim.AdamW( + self.model.parameters(), + lr=train_config.learning_rate, + betas=(train_config.beta1, train_config.beta2), + weight_decay=train_config.weight_decay + ) + else: + self.optimizer = None # DeepSpeed manages optimizer + + # Setup AMP scaler if using fp16 (not needed with DeepSpeed) + self.scaler = None + if not use_deepspeed and precision_config.use_amp and precision_config.amp_dtype == "float16": + self.scaler = GradScaler() + + # Training state + self.global_step = 0 + self.rng = np.random.default_rng(train_config.seed) + + # Metrics tracking + self.metrics_history: List[Dict[str, float]] = [] + + def compute_log_probs( + self, + model: torch.nn.Module, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + labels: torch.Tensor + ) -> torch.Tensor: + """Compute token-level log probabilities.""" + with autocast( + enabled=self.precision_config.use_amp, + dtype=get_torch_dtype(self.precision_config.amp_dtype) + ): + outputs = model( + input_ids=input_ids, + attention_mask=attention_mask, + use_cache=False + ) + + logits = outputs.logits + log_probs = torch.log_softmax(logits, dim=-1) + + # Gather log probs for actual tokens + # Shift for autoregressive: predict next token + shift_log_probs = log_probs[:, :-1, :] + shift_labels = labels[:, 1:] + + token_log_probs = torch.gather( + shift_log_probs, + dim=-1, + index=shift_labels.unsqueeze(-1) + ).squeeze(-1) + + return token_log_probs + + def compute_dapo_loss( + self, + policy_log_probs: torch.Tensor, + ref_log_probs: torch.Tensor, + rewards: torch.Tensor, + response_mask: torch.Tensor + ) -> Tuple[torch.Tensor, Dict[str, float]]: + """ + Compute DAPO clip-only loss. + + DAPO uses ratio clipping without explicit KL penalty (beta=0). + The clipping provides implicit KL constraint (Gate I). + """ + # Compute log ratios + log_ratios = policy_log_probs - ref_log_probs + + # Sum log ratios over response tokens + masked_log_ratios = log_ratios * response_mask + sequence_log_ratios = masked_log_ratios.sum(dim=-1) + + # Compute importance sampling ratios + ratios = torch.exp(sequence_log_ratios) + + # DAPO objective with clipping + clip_ratio = self.train_config.clip_ratio + + # Advantage estimation (simplified: just use rewards) + advantages = rewards + + # Clipped surrogate objective + unclipped = ratios * advantages + clipped = torch.clamp(ratios, 1 - clip_ratio, 1 + clip_ratio) * advantages + + # Take minimum for pessimistic update + loss = -torch.min(unclipped, clipped).mean() + + # Compute metrics + with torch.no_grad(): + approx_kl = (log_ratios * response_mask).sum() / response_mask.sum() + clip_fraction = ((ratios - 1).abs() > clip_ratio).float().mean() + + metrics = { + "loss": loss.item(), + "approx_kl": approx_kl.item(), + "clip_fraction": clip_fraction.item(), + "mean_ratio": ratios.mean().item(), + "mean_reward": rewards.mean().item(), + } + + return loss, metrics + + def generate_rollouts( + self, + prompts: List[str], + num_samples: int + ) -> List[Dict[str, Any]]: + """Generate rollouts for a batch of prompts.""" + rollouts = [] + + self.model.eval() + with torch.no_grad(): + for prompt in prompts: + inputs = self.tokenizer( + prompt, + return_tensors="pt", + truncation=True, + max_length=self.train_config.max_seq_len // 2 + ).to(self.device) + + for _ in range(num_samples): + with autocast( + enabled=self.precision_config.use_amp, + dtype=get_torch_dtype(self.precision_config.amp_dtype) + ): + outputs = self.model.generate( + **inputs, + max_new_tokens=self.train_config.max_seq_len // 2, + do_sample=True, + temperature=0.7, + top_p=0.8, + pad_token_id=self.tokenizer.eos_token_id + ) + + response_ids = outputs[0, inputs["input_ids"].shape[1]:] + response_text = self.tokenizer.decode( + response_ids, + skip_special_tokens=True + ) + + rollouts.append({ + "prompt": prompt, + "response": response_text, + "input_ids": inputs["input_ids"][0], + "response_ids": response_ids, + "full_ids": outputs[0] + }) + + self.model.train() + return rollouts + + def train_step( + self, + batch: List[Dict[str, Any]] + ) -> Dict[str, float]: + """Execute one training step on a batch.""" + self.model.train() + + # Generate rollouts + prompts = [ex["prompt"] for ex in batch] + ground_truths = [ex.get("answer", None) for ex in batch] + + rollouts = self.generate_rollouts( + prompts, + self.train_config.num_rollouts_per_prompt + ) + + # Compute rewards + rewards = [] + for i, rollout in enumerate(rollouts): + prompt_idx = i // self.train_config.num_rollouts_per_prompt + gt = ground_truths[prompt_idx] if prompt_idx < len(ground_truths) else None + reward = compute_math_reward(rollout["prompt"], rollout["response"], gt) + rewards.append(reward) + + rewards_tensor = torch.tensor(rewards, device=self.device, dtype=torch.float32) + + # Skip if all rewards are the same (no learning signal) + if rewards_tensor.std() < 1e-6: + return {"skipped": 1.0} + + # Normalize rewards per prompt (advantage estimation) + rewards_per_prompt = rewards_tensor.view(-1, self.train_config.num_rollouts_per_prompt) + normalized_rewards = (rewards_per_prompt - rewards_per_prompt.mean(dim=1, keepdim=True)) + normalized_rewards = normalized_rewards / (rewards_per_prompt.std(dim=1, keepdim=True) + 1e-8) + normalized_rewards = normalized_rewards.view(-1) + + # Prepare for training (DeepSpeed handles zero_grad internally) + if not self.use_deepspeed: + self.optimizer.zero_grad() + + total_loss = 0.0 + all_metrics: Dict[str, List[float]] = {} + + # Process rollouts in micro-batches + num_rollouts = len(rollouts) + micro_batch_size = self.train_config.micro_batch_size + + for mb_start in range(0, num_rollouts, micro_batch_size): + mb_end = min(mb_start + micro_batch_size, num_rollouts) + mb_rollouts = rollouts[mb_start:mb_end] + mb_rewards = normalized_rewards[mb_start:mb_end] + + # Prepare batch tensors + max_len = max(len(r["full_ids"]) for r in mb_rollouts) + batch_input_ids = torch.zeros(len(mb_rollouts), max_len, dtype=torch.long, device=self.device) + batch_attention_mask = torch.zeros(len(mb_rollouts), max_len, dtype=torch.long, device=self.device) + batch_response_mask = torch.zeros(len(mb_rollouts), max_len - 1, dtype=torch.float32, device=self.device) + + for i, rollout in enumerate(mb_rollouts): + seq_len = len(rollout["full_ids"]) + batch_input_ids[i, :seq_len] = rollout["full_ids"] + batch_attention_mask[i, :seq_len] = 1 + prompt_len = len(rollout["input_ids"]) + batch_response_mask[i, prompt_len-1:seq_len-1] = 1 + + # Compute log probs for policy and reference + policy_log_probs = self.compute_log_probs( + self.model, + batch_input_ids, + batch_attention_mask, + batch_input_ids + ) + + with torch.no_grad(): + # Move tensors to ref_device if different from training device + if self.ref_device != self.device: + ref_input_ids = batch_input_ids.to(self.ref_device) + ref_attention_mask = batch_attention_mask.to(self.ref_device) + else: + ref_input_ids = batch_input_ids + ref_attention_mask = batch_attention_mask + + ref_log_probs = self.compute_log_probs( + self.ref_model, + ref_input_ids, + ref_attention_mask, + ref_input_ids + ) + + # Move ref_log_probs back to training device + if self.ref_device != self.device: + ref_log_probs = ref_log_probs.to(self.device) + + # Compute DAPO loss + loss, metrics = self.compute_dapo_loss( + policy_log_probs, + ref_log_probs, + mb_rewards, + batch_response_mask + ) + + # Scale loss for gradient accumulation (DeepSpeed handles this internally) + if self.use_deepspeed: + scaled_loss = loss + else: + scaled_loss = loss / self.train_config.grad_accumulation_steps + + # Backward pass + if self.use_deepspeed: + self.model_engine.backward(scaled_loss) + elif self.scaler is not None: + self.scaler.scale(scaled_loss).backward() + else: + scaled_loss.backward() + + total_loss += loss.item() + + for k, v in metrics.items(): + if k not in all_metrics: + all_metrics[k] = [] + all_metrics[k].append(v) + + # Optimizer step + if self.use_deepspeed: + self.model_engine.step() + elif self.scaler is not None: + self.scaler.step(self.optimizer) + self.scaler.update() + else: + self.optimizer.step() + + self.global_step += 1 + + # Aggregate metrics + step_metrics = {k: np.mean(v) for k, v in all_metrics.items()} + step_metrics["total_loss"] = total_loss + step_metrics["step"] = self.global_step + + self.metrics_history.append(step_metrics) + + return step_metrics + + def train( + self, + train_data: List[Dict[str, Any]], + save_checkpoints: bool = True + ) -> None: + """Run the full training loop.""" + logger.info(f"Starting training for {self.train_config.num_steps} steps") + + checkpoint_steps = set(self.train_config.checkpoint_steps) + + for step in range(self.train_config.num_steps): + # Sample batch + batch = sample_batch( + train_data, + self.train_config.global_batch_size // self.train_config.num_rollouts_per_prompt, + self.rng + ) + + # Training step + metrics = self.train_step(batch) + + # Logging + if step % 10 == 0: + logger.info( + f"Step {step}/{self.train_config.num_steps} | " + f"Loss: {metrics.get('total_loss', 0):.4f} | " + f"KL: {metrics.get('approx_kl', 0):.4f} | " + f"Reward: {metrics.get('mean_reward', 0):.4f}" + ) + + # Checkpointing + if save_checkpoints and (step + 1) in checkpoint_steps: + self.save_checkpoint(step + 1) + + def save_checkpoint(self, step: int) -> None: + """Save model checkpoint.""" + ckpt_dir = os.path.join( + self.train_config.output_dir, + f"checkpoint_step{step}" + ) + os.makedirs(ckpt_dir, exist_ok=True) + + if self.use_deepspeed: + # Use DeepSpeed's save_checkpoint for proper ZeRO handling + self.model_engine.save_checkpoint(ckpt_dir, tag=f"step{step}") + else: + self.model.save_pretrained(ckpt_dir) + # Save training state + state = { + "step": step, + "optimizer_state_dict": self.optimizer.state_dict(), + "metrics_history": self.metrics_history, + } + torch.save(state, os.path.join(ckpt_dir, "training_state.pt")) + + self.tokenizer.save_pretrained(ckpt_dir) + logger.info(f"Saved checkpoint at step {step} to {ckpt_dir}") + + def save_final(self) -> str: + """Save final model.""" + final_dir = os.path.join(self.train_config.output_dir, "final_model") + os.makedirs(final_dir, exist_ok=True) + + if self.use_deepspeed: + # Use DeepSpeed's save_checkpoint for proper ZeRO-3 weight gathering + self.model_engine.save_checkpoint(final_dir, tag="final") + else: + self.model.save_pretrained(final_dir) + + self.tokenizer.save_pretrained(final_dir) + + # Save metrics + with open(os.path.join(final_dir, "metrics_history.json"), "w") as f: + json.dump(self.metrics_history, f, indent=2) + + logger.info(f"Saved final model to {final_dir}") + return final_dir + + +# ============================================================================ +# Main Entry Point +# ============================================================================ + +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="RLVR Training with DAPO Algorithm" + ) + parser.add_argument( + "--precision_mode", + type=str, + required=True, + choices=["fp32", "bf16"], + help="Precision mode: fp32 (high precision) or bf16 (default RLVR)" + ) + parser.add_argument( + "--seed", + type=int, + required=True, + help="Random seed for reproducibility" + ) + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Directory to save outputs" + ) + parser.add_argument( + "--train_dataset_path", + type=str, + required=True, + help="Path to training dataset JSON" + ) + parser.add_argument( + "--model_name", + type=str, + default="Qwen/Qwen2.5-Math-7B", + help="HuggingFace model identifier" + ) + parser.add_argument( + "--num_steps", + type=int, + default=300, + help="Number of training steps" + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device to use for training" + ) + parser.add_argument( + "--deepspeed", + type=str, + default=None, + help="Path to DeepSpeed config file" + ) + parser.add_argument( + "--local_rank", + type=int, + default=-1, + help="Local rank for distributed training (set by DeepSpeed)" + ) + return parser.parse_args() + + +def main() -> None: + """Main training function.""" + args = parse_args() + + # Create configurations + train_config = make_training_config( + precision_mode=args.precision_mode, + seed=args.seed, + output_dir=args.output_dir, + train_dataset_path=args.train_dataset_path, + model_name=args.model_name + ) + train_config.num_steps = args.num_steps + + precision_config = make_precision_config(args.precision_mode) + + # Setup output directory + os.makedirs(train_config.output_dir, exist_ok=True) + + # Save configurations + with open(os.path.join(train_config.output_dir, "train_config.json"), "w") as f: + json.dump(asdict(train_config), f, indent=2) + with open(os.path.join(train_config.output_dir, "precision_config.json"), "w") as f: + json.dump(asdict(precision_config), f, indent=2) + + # Set seeds and determinism + set_seed(train_config.seed) + configure_torch_deterministic(precision_config.deterministic) + + # Check if using DeepSpeed + use_deepspeed = args.deepspeed is not None + + # Setup devices + num_gpus = torch.cuda.device_count() + + if use_deepspeed: + # Initialize DeepSpeed distributed + deepspeed.init_distributed() + local_rank = args.local_rank if args.local_rank >= 0 else int(os.environ.get("LOCAL_RANK", 0)) + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + # Put reference model on same GPU as training (each rank has its own copy) + # ZeRO-2 shards optimizer states, so there's room for the bf16 ref model (~14GB) + ref_device = device + logger.info(f"DeepSpeed: rank {local_rank}, training on {device}, ref model on {ref_device}") + else: + device = torch.device(args.device if torch.cuda.is_available() else "cpu") + if num_gpus >= 2: + ref_device = torch.device("cuda:1") + logger.info(f"Using device: {device} for training, {ref_device} for reference model") + else: + ref_device = device + logger.info(f"Using device: {device} for both training and reference model") + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained( + train_config.model_name, + use_fast=True, + trust_remote_code=True + ) + tokenizer.padding_side = "left" + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # Load model + logger.info(f"Loading model: {train_config.model_name}") + model = AutoModelForCausalLM.from_pretrained( + train_config.model_name, + torch_dtype=torch.float32, # Load in FP32 first + device_map=None, + trust_remote_code=True + ) + + # Cast to target precision + model = cast_model_param_dtype(model, precision_config.param_dtype) + + # Enable gradient checkpointing to save memory (trades compute for memory) + model.gradient_checkpointing_enable() + logger.info("Enabled gradient checkpointing to reduce memory usage") + + # Disable dropout if needed + if not train_config.use_dropout: + disable_dropout(model) + + logger.info(f"Model loaded with {count_parameters(model):,} trainable parameters") + + # Initialize DeepSpeed or move model to device + if use_deepspeed: + # Create optimizer for DeepSpeed + optimizer = torch.optim.AdamW( + model.parameters(), + lr=train_config.learning_rate, + betas=(train_config.beta1, train_config.beta2), + weight_decay=train_config.weight_decay + ) + + # Load DeepSpeed config + with open(args.deepspeed, 'r') as f: + ds_config = json.load(f) + + # Compute batch sizes compatible with DeepSpeed + # DeepSpeed requires: train_batch_size = micro_batch * grad_acc * world_size + world_size = int(os.environ.get("WORLD_SIZE", num_gpus - 1)) # -1 for ref model GPU + micro_batch = train_config.micro_batch_size + + # Compute grad_acc to get closest to desired global batch size + desired_global = train_config.global_batch_size + grad_acc = max(1, desired_global // (micro_batch * world_size)) + actual_global = micro_batch * grad_acc * world_size + + ds_config["train_batch_size"] = actual_global + ds_config["train_micro_batch_size_per_gpu"] = micro_batch + ds_config["gradient_accumulation_steps"] = grad_acc + + logger.info(f"DeepSpeed batch config: global={actual_global}, micro={micro_batch}, grad_acc={grad_acc}, world_size={world_size}") + + # Initialize DeepSpeed engine + model_engine, optimizer, _, _ = deepspeed.initialize( + model=model, + optimizer=optimizer, + config=ds_config + ) + logger.info(f"DeepSpeed ZeRO-2 initialized with {world_size} GPUs") + else: + model = model.to(device) + model_engine = model + + # Load reference model (frozen copy) + # Always use bf16 for reference model - it's frozen and only used for KL computation. + # This is a controlled variable: same precision for both fp32 and bf16 training modes. + # The experiment tests training precision effects, not reference model precision. + logger.info("Loading reference model (bf16 for both modes - controlled variable)") + ref_model = AutoModelForCausalLM.from_pretrained( + train_config.model_name, + torch_dtype=torch.bfloat16, # Always bf16 to save memory & control variable + device_map=None, + trust_remote_code=True + ) + ref_model = ref_model.to(ref_device) + ref_model.eval() + + # Load training data + train_data = load_training_data(train_config.train_dataset_path) + + # Initialize trainer + trainer = DAPOTrainer( + model_engine=model_engine, + ref_model=ref_model, + tokenizer=tokenizer, + train_config=train_config, + precision_config=precision_config, + device=device, + ref_device=ref_device, + use_deepspeed=use_deepspeed + ) + + # Run training + trainer.train(train_data, save_checkpoints=True) + + # Save final model + final_path = trainer.save_final() + logger.info(f"Training complete. Final model saved to: {final_path}") + + +if __name__ == "__main__": + main() + |
