From bd6222774edcec1608a6842d0b06a637a4acef59 Mon Sep 17 00:00:00 2001 From: One Date: Wed, 9 Jul 2025 10:13:51 +0800 Subject: Release --- pretrain.py | 454 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 454 insertions(+) create mode 100644 pretrain.py (limited to 'pretrain.py') diff --git a/pretrain.py b/pretrain.py new file mode 100644 index 0000000..b939318 --- /dev/null +++ b/pretrain.py @@ -0,0 +1,454 @@ +from typing import Optional, Any, Sequence, List +from dataclasses import dataclass +import os +import math +import yaml +import shutil + +import torch +import torch.distributed as dist +from torch import nn +from torch.utils.data import DataLoader + +import tqdm +import wandb +import coolname +import hydra +import pydantic +from omegaconf import DictConfig +from wandb.util import make_artifact_name_safe +from adam_atan2 import AdamATan2 + +from puzzle_dataset import PuzzleDataset, PuzzleDatasetConfig, PuzzleDatasetMetadata +from utils.functions import load_model_class, get_model_source_path +from models.sparse_embedding import CastedSparseEmbeddingSignSGD_Distributed + + +class LossConfig(pydantic.BaseModel): + model_config = pydantic.ConfigDict(extra='allow') + + name: str + + +class ArchConfig(pydantic.BaseModel): + model_config = pydantic.ConfigDict(extra='allow') + + name: str + loss: LossConfig + + +class PretrainConfig(pydantic.BaseModel): + # Config + arch: ArchConfig + # Data + data_path: str + + # Hyperparams + global_batch_size: int + epochs: int + + lr: float + lr_min_ratio: float + lr_warmup_steps: int + + weight_decay: float + beta1: float + beta2: float + + # Puzzle embedding + puzzle_emb_lr: float + puzzle_emb_weight_decay: float + + # Names + project_name: Optional[str] = None + run_name: Optional[str] = None + checkpoint_path: Optional[str] = None + + # Extras + seed: int = 0 + checkpoint_every_eval: bool = False + eval_interval: Optional[int] = None + eval_save_outputs: List[str] = [] + + +@dataclass +class TrainState: + model: nn.Module + optimizers: Sequence[torch.optim.Optimizer] + optimizer_lrs: Sequence[float] + carry: Any + + step: int + total_steps: int + + +def create_dataloader(config: PretrainConfig, split: str, rank: int, world_size: int, **kwargs): + dataset = PuzzleDataset(PuzzleDatasetConfig( + seed=config.seed, + + dataset_path=config.data_path, + + rank=rank, + num_replicas=world_size, + + **kwargs + ), split=split) + dataloader = DataLoader( + dataset, + batch_size=None, + + num_workers=1, + prefetch_factor=8, + + pin_memory=True, + persistent_workers=True + ) + return dataloader, dataset.metadata + + +def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int): + model_cfg = dict( + **config.arch.__pydantic_extra__, # type: ignore + + batch_size=config.global_batch_size // world_size, + + vocab_size=train_metadata.vocab_size, + seq_len=train_metadata.seq_len, + num_puzzle_identifiers=train_metadata.num_puzzle_identifiers, + causal=False # Non-autoregressive + ) + + # Instantiate model with loss head + model_cls = load_model_class(config.arch.name) + loss_head_cls = load_model_class(config.arch.loss.name) + + with torch.device("cuda"): + model: nn.Module = model_cls(model_cfg) + model = loss_head_cls(model, **config.arch.loss.__pydantic_extra__) # type: ignore + if "DISABLE_COMPILE" not in os.environ: + model = torch.compile(model, dynamic=False, fullgraph=True) # type: ignore + + # Broadcast parameters from rank 0 + if world_size > 1: + with torch.no_grad(): + for param in list(model.parameters()) + list(model.buffers()): + dist.broadcast(param, src=0) + + # Optimizers and lr + optimizers = [ + CastedSparseEmbeddingSignSGD_Distributed( + model.model.puzzle_emb.buffers(), # type: ignore + + lr=0, # Needs to be set by scheduler + weight_decay=config.puzzle_emb_weight_decay, + + world_size=world_size + ), + AdamATan2( + model.parameters(), + + lr=0, # Needs to be set by scheduler + weight_decay=config.weight_decay, + betas=(config.beta1, config.beta2) + ) + ] + optimizer_lrs = [ + config.puzzle_emb_lr, + config.lr + ] + + return model, optimizers, optimizer_lrs + + +def cosine_schedule_with_warmup_lr_lambda( + current_step: int, *, base_lr: float, num_warmup_steps: int, num_training_steps: int, min_ratio: float = 0.0, num_cycles: float = 0.5 +): + if current_step < num_warmup_steps: + return base_lr * float(current_step) / float(max(1, num_warmup_steps)) + + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + return base_lr * (min_ratio + max(0.0, (1 - min_ratio) * 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))) + + +def init_train_state(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int): + # Estimated total training steps + total_steps = int(config.epochs * train_metadata.total_groups * train_metadata.mean_puzzle_examples / config.global_batch_size) + + # Model + model, optimizers, optimizer_lrs = create_model(config, train_metadata, world_size=world_size) + + return TrainState( + step=0, + total_steps=total_steps, + + model=model, + optimizers=optimizers, + optimizer_lrs=optimizer_lrs, + carry=None + ) + + +def save_train_state(config: PretrainConfig, train_state: TrainState): + # FIXME: Only saved model. + if config.checkpoint_path is None: + return + + os.makedirs(config.checkpoint_path, exist_ok=True) + torch.save(train_state.model.state_dict(), os.path.join(config.checkpoint_path, f"step_{train_state.step}")) + + +def compute_lr(base_lr: float, config: PretrainConfig, train_state: TrainState): + return cosine_schedule_with_warmup_lr_lambda( + current_step=train_state.step, + base_lr=base_lr, + num_warmup_steps=round(config.lr_warmup_steps), + num_training_steps=train_state.total_steps, + min_ratio=config.lr_min_ratio + ) + + +def train_batch(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int, world_size: int): + train_state.step += 1 + if train_state.step > train_state.total_steps: # At most train_total_steps + return + + # To device + batch = {k: v.cuda() for k, v in batch.items()} + + # Init carry if it is None + if train_state.carry is None: + with torch.device("cuda"): + train_state.carry = train_state.model.initial_carry(batch) # type: ignore + + # Forward + train_state.carry, loss, metrics, _, _ = train_state.model(carry=train_state.carry, batch=batch, return_keys=[]) + + ((1 / global_batch_size) * loss).backward() + + # Allreduce + if world_size > 1: + for param in train_state.model.parameters(): + if param.grad is not None: + dist.all_reduce(param.grad) + + # Apply optimizer + lr_this_step = None + for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs): + lr_this_step = compute_lr(base_lr, config, train_state) + + for param_group in optim.param_groups: + param_group['lr'] = lr_this_step + + optim.step() + optim.zero_grad() + + # Reduce metrics + if len(metrics): + assert not any(v.requires_grad for v in metrics.values()) + + metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order. + # Reduce and reconstruct + metric_values = torch.stack([metrics[k] for k in metric_keys]) + if world_size > 1: + dist.reduce(metric_values, dst=0) + + if rank == 0: + metric_values = metric_values.cpu().numpy() + reduced_metrics = {k: metric_values[i] for i, k in enumerate(metric_keys)} + + # Postprocess + count = max(reduced_metrics["count"], 1) # Avoid NaNs + reduced_metrics = {f"train/{k}": v / (global_batch_size if k.endswith("loss") else count) for k, v in reduced_metrics.items()} + + reduced_metrics["train/lr"] = lr_this_step + return reduced_metrics + + +def evaluate(config: PretrainConfig, train_state: TrainState, eval_loader: torch.utils.data.DataLoader, eval_metadata: PuzzleDatasetMetadata, rank: int, world_size: int): + with torch.inference_mode(): + set_ids = {k: idx for idx, k in enumerate(eval_metadata.sets)} + + all_preds = {} + + metric_keys = [] + metric_values = None + metric_global_batch_size = [0 for _ in range(len(set_ids))] + + carry = None + for set_name, batch, global_batch_size in eval_loader: + # To device + batch = {k: v.cuda() for k, v in batch.items()} + with torch.device("cuda"): + carry = train_state.model.initial_carry(batch) # type: ignore + + # Forward + while True: + carry, _, metrics, preds, all_finish = train_state.model(carry=carry, batch=batch, return_keys=config.eval_save_outputs) + + if all_finish: + break + + for collection in (batch, preds): + for k, v in collection.items(): + if k in config.eval_save_outputs: + all_preds.setdefault(k, []) + all_preds[k].append(v.cpu()) # Move to CPU for saving GPU memory + + del carry, preds, batch, all_finish + + # Aggregate + set_id = set_ids[set_name] + + if metric_values is None: + metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order. + metric_values = torch.zeros((len(set_ids), len(metrics.values())), dtype=torch.float32, device="cuda") + + metric_values[set_id] += torch.stack([metrics[k] for k in metric_keys]) + metric_global_batch_size[set_id] += global_batch_size + + if len(all_preds) and config.checkpoint_path is not None: + all_preds = {k: torch.cat(v, dim=0) for k, v in all_preds.items()} + + os.makedirs(config.checkpoint_path, exist_ok=True) + torch.save(all_preds, os.path.join(config.checkpoint_path, f"step_{train_state.step}_all_preds.{rank}")) + + # Logging + # Reduce to rank 0 + if metric_values is not None: + if world_size > 1: + dist.reduce(metric_values, dst=0) + + if rank == 0: + reduced_metrics = metric_values.cpu().numpy() + reduced_metrics = {set_name: {metric_name: reduced_metrics[set_id, metric_id] for metric_id, metric_name in enumerate(metric_keys)} + for set_id, set_name in enumerate(set_ids)} + + # Postprocess + for set_name, metrics in reduced_metrics.items(): + count = metrics.pop("count") + reduced_metrics[set_name] = {k: v / count for k, v in metrics.items()} + + return reduced_metrics + + +def save_code_and_config(config: PretrainConfig): + if config.checkpoint_path is None or wandb.run is None: + return + + os.makedirs(config.checkpoint_path, exist_ok=True) + + # Copy code + code_list = [ + get_model_source_path(config.arch.name), + get_model_source_path(config.arch.loss.name) + ] + for code_file in code_list: + if code_file is not None: + code_name = os.path.basename(code_file) + + shutil.copy(code_file, os.path.join(config.checkpoint_path, code_name)) + + # Dump config as yaml + config_file = os.path.join(config.checkpoint_path, "all_config.yaml") + with open(config_file, "wt") as f: + yaml.dump(config.model_dump(), f) + + # Log code + wandb.run.log_code(config.checkpoint_path) + + +def load_synced_config(hydra_config: DictConfig, rank: int, world_size: int) -> PretrainConfig: + objects = [None] + if rank == 0: + config = PretrainConfig(**hydra_config) # type: ignore + + # Naming + if config.project_name is None: + config.project_name = f"{os.path.basename(config.data_path).capitalize()} ACT-torch" + if config.run_name is None: + config.run_name = f"{config.arch.name.split('@')[-1]} {coolname.generate_slug(2)}" + if config.checkpoint_path is None: + config.checkpoint_path = os.path.join("checkpoints", config.project_name, config.run_name) + + objects = [config] + + if world_size > 1: + dist.broadcast_object_list(objects, src=0) + + return objects[0] # type: ignore + + +@hydra.main(config_path="config", config_name="cfg_pretrain", version_base=None) +def launch(hydra_config: DictConfig): + RANK = 0 + WORLD_SIZE = 1 + + # Initialize distributed training if in distributed environment (e.g. torchrun) + if "LOCAL_RANK" in os.environ: + # Initialize distributed, default device and dtype + dist.init_process_group(backend="nccl") + + RANK = dist.get_rank() + WORLD_SIZE = dist.get_world_size() + + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + + # Load sync'ed config + config = load_synced_config(hydra_config, rank=RANK, world_size=WORLD_SIZE) + + # Seed RNGs to ensure consistency + torch.random.manual_seed(config.seed + RANK) + + # Dataset + train_epochs_per_iter = config.eval_interval if config.eval_interval is not None else config.epochs + total_iters = config.epochs // train_epochs_per_iter + + assert config.epochs % train_epochs_per_iter == 0, "Eval interval must be a divisor of total epochs." + + train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=train_epochs_per_iter, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE) + eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE) + + # Train state + train_state = init_train_state(config, train_metadata, world_size=WORLD_SIZE) + + # Progress bar and logger + progress_bar = None + if RANK == 0: + progress_bar = tqdm.tqdm(total=train_state.total_steps) + + wandb.init(project=config.project_name, name=config.run_name, config=config.model_dump(), settings=wandb.Settings(_disable_stats=True)) # type: ignore + wandb.log({"num_params": sum(x.numel() for x in train_state.model.parameters())}, step=0) + save_code_and_config(config) + + # Training Loop + for _iter_id in range(total_iters): + print (f"[Rank {RANK}, World Size {WORLD_SIZE}]: Epoch {_iter_id * train_epochs_per_iter}") + + ############ Train Iter + train_state.model.train() + for set_name, batch, global_batch_size in train_loader: + metrics = train_batch(config, train_state, batch, global_batch_size, rank=RANK, world_size=WORLD_SIZE) + + if RANK == 0 and metrics is not None: + wandb.log(metrics, step=train_state.step) + progress_bar.update(train_state.step - progress_bar.n) # type: ignore + + ############ Evaluation + train_state.model.eval() + metrics = evaluate(config, train_state, eval_loader, eval_metadata, rank=RANK, world_size=WORLD_SIZE) + + if RANK == 0 and metrics is not None: + wandb.log(metrics, step=train_state.step) + + ############ Checkpointing + if RANK == 0 and (config.checkpoint_every_eval or (_iter_id == total_iters - 1)): + save_train_state(config, train_state) + + # finalize + if dist.is_initialized(): + dist.destroy_process_group() + wandb.finish() + + +if __name__ == "__main__": + launch() -- cgit v1.2.3