from typing import Optional, Any, Sequence, List from dataclasses import dataclass, replace import os import math import yaml import shutil import torch import torch.distributed as dist from torch import nn from torch.utils.data import DataLoader import tqdm import wandb import coolname import hydra import pydantic from omegaconf import DictConfig from adam_atan2 import AdamATan2 from puzzle_dataset import PuzzleDataset, PuzzleDatasetConfig, PuzzleDatasetMetadata from utils.functions import load_model_class, get_model_source_path from models.sparse_embedding import CastedSparseEmbeddingSignSGD_Distributed class LossConfig(pydantic.BaseModel): model_config = pydantic.ConfigDict(extra='allow') name: str class ArchConfig(pydantic.BaseModel): model_config = pydantic.ConfigDict(extra='allow') name: str loss: LossConfig class PretrainConfig(pydantic.BaseModel): # Config arch: ArchConfig # Data data_path: str # Hyperparams global_batch_size: int epochs: int lr: float lr_min_ratio: float lr_warmup_steps: int weight_decay: float beta1: float beta2: float # Puzzle embedding puzzle_emb_lr: float puzzle_emb_weight_decay: float # Names project_name: Optional[str] = None run_name: Optional[str] = None checkpoint_path: Optional[str] = None # Extras seed: int = 0 checkpoint_every_eval: bool = False eval_interval: Optional[int] = None eval_save_outputs: List[str] = [] trajectory_augment: bool = False trajectory_n: int = 4 trajectory_noise_std: float = 1e-3 trajectory_noise_min: Optional[float] = None trajectory_noise_max: Optional[float] = None trajectory_noise_sampling: str = "loguniform" trajectory_sigma_start: Optional[float] = 0.0 trajectory_sigma_ramp_steps: int = 5000 trajectory_perturb: str = "both" trajectory_micro_batch: int = 0 trajectory_parallel: bool = False @dataclass class TrainState: model: nn.Module optimizers: Sequence[torch.optim.Optimizer] optimizer_lrs: Sequence[float] carry: Any step: int total_steps: int def create_dataloader(config: PretrainConfig, split: str, rank: int, world_size: int, **kwargs): dataset = PuzzleDataset(PuzzleDatasetConfig( seed=config.seed, dataset_path=config.data_path, rank=rank, num_replicas=world_size, **kwargs ), split=split) dataloader = DataLoader( dataset, batch_size=None, num_workers=1, prefetch_factor=8, pin_memory=True, persistent_workers=True ) return dataloader, dataset.metadata def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int): model_batch_size = config.global_batch_size // world_size if config.trajectory_augment and config.trajectory_parallel: model_batch_size *= config.trajectory_n model_cfg = dict( **config.arch.__pydantic_extra__, # type: ignore batch_size=model_batch_size, vocab_size=train_metadata.vocab_size, seq_len=train_metadata.seq_len, num_puzzle_identifiers=train_metadata.num_puzzle_identifiers, causal=False # Non-autoregressive ) # Instantiate model with loss head model_cls = load_model_class(config.arch.name) loss_head_cls = load_model_class(config.arch.loss.name) with torch.device("cuda"): model: nn.Module = model_cls(model_cfg) model = loss_head_cls(model, **config.arch.loss.__pydantic_extra__) # type: ignore if "DISABLE_COMPILE" not in os.environ: model = torch.compile(model, dynamic=False) # type: ignore # Broadcast parameters from rank 0 if world_size > 1: with torch.no_grad(): for param in list(model.parameters()) + list(model.buffers()): dist.broadcast(param, src=0) # Optimizers and lr optimizers = [ CastedSparseEmbeddingSignSGD_Distributed( model.model.puzzle_emb.buffers(), # type: ignore lr=0, # Needs to be set by scheduler weight_decay=config.puzzle_emb_weight_decay, world_size=world_size ), AdamATan2( model.parameters(), lr=0, # Needs to be set by scheduler weight_decay=config.weight_decay, betas=(config.beta1, config.beta2) ) ] optimizer_lrs = [ config.puzzle_emb_lr, config.lr ] return model, optimizers, optimizer_lrs def cosine_schedule_with_warmup_lr_lambda( current_step: int, *, base_lr: float, num_warmup_steps: int, num_training_steps: int, min_ratio: float = 0.0, num_cycles: float = 0.5 ): if current_step < num_warmup_steps: return base_lr * float(current_step) / float(max(1, num_warmup_steps)) progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) return base_lr * (min_ratio + max(0.0, (1 - min_ratio) * 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))) def init_train_state(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int): # Estimated total training steps total_steps = int(config.epochs * train_metadata.total_groups * train_metadata.mean_puzzle_examples / config.global_batch_size) # Model model, optimizers, optimizer_lrs = create_model(config, train_metadata, world_size=world_size) return TrainState( step=0, total_steps=total_steps, model=model, optimizers=optimizers, optimizer_lrs=optimizer_lrs, carry=None ) def save_train_state(config: PretrainConfig, train_state: TrainState): # FIXME: Only saved model. if config.checkpoint_path is None: return os.makedirs(config.checkpoint_path, exist_ok=True) torch.save(train_state.model.state_dict(), os.path.join(config.checkpoint_path, f"step_{train_state.step}")) def compute_lr(base_lr: float, config: PretrainConfig, train_state: TrainState): return cosine_schedule_with_warmup_lr_lambda( current_step=train_state.step, base_lr=base_lr, num_warmup_steps=round(config.lr_warmup_steps), num_training_steps=train_state.total_steps, min_ratio=config.lr_min_ratio ) def _unwrap_loss_head(model: nn.Module): return getattr(model, "_orig_mod", model) def _unit_noise_like(tensor: torch.Tensor, sampling: str): if sampling == "uniform": return ((2.0 * torch.rand(tensor.shape, device=tensor.device, dtype=torch.float32) - 1.0) * math.sqrt(3.0)).to(tensor.dtype) return torch.randn(tensor.shape, device=tensor.device, dtype=torch.float32).to(tensor.dtype) def _trajectory_noise_target(config: PretrainConfig, step: int) -> float: if config.trajectory_sigma_ramp_steps <= 0: return config.trajectory_noise_std start = config.trajectory_sigma_start if config.trajectory_sigma_start is not None else config.trajectory_noise_std frac = min(max(step / config.trajectory_sigma_ramp_steps, 0.0), 1.0) return float(start + frac * (config.trajectory_noise_std - start)) def _sample_noise_stds(config: PretrainConfig, batch_size: int, device: torch.device, step: int): target = _trajectory_noise_target(config, step) if target <= 0: return torch.zeros(batch_size, device=device, dtype=torch.float32) if config.trajectory_noise_sampling == "loguniform": scale = 1.0 if config.trajectory_noise_std <= 0 else target / config.trajectory_noise_std hi = float(config.trajectory_noise_max if config.trajectory_noise_max is not None else config.trajectory_noise_std) * scale lo = float(config.trajectory_noise_min if config.trajectory_noise_min is not None else max(hi / 10.0, 1e-8)) * scale hi = max(hi, 1e-12) lo = max(min(lo, hi), 1e-12) u = torch.rand(batch_size, device=device, dtype=torch.float32) return torch.exp(math.log(lo) + u * (math.log(hi) - math.log(lo))) return torch.full((batch_size,), target, device=device, dtype=torch.float32) def _add_initial_noise(config: PretrainConfig, inner_carry: Any, noise_stds: torch.Tensor): if noise_stds.numel() == 0 or float(noise_stds.max().item()) <= 0: return inner_carry view_shape = (noise_stds.shape[0],) + (1,) * (inner_carry.z_H.ndim - 1) scale = noise_stds.view(view_shape) perturb = config.trajectory_perturb.lower() z_h = inner_carry.z_H z_l = inner_carry.z_L if perturb == "h": z_h = z_h + scale.to(z_h.dtype) * _unit_noise_like(z_h, config.trajectory_noise_sampling) elif perturb == "l": z_l = z_l + scale.to(z_l.dtype) * _unit_noise_like(z_l, config.trajectory_noise_sampling) elif perturb == "both": z_h = z_h + scale.to(z_h.dtype) * _unit_noise_like(z_h, config.trajectory_noise_sampling) z_l = z_l + scale.to(z_l.dtype) * _unit_noise_like(z_l, config.trajectory_noise_sampling) elif perturb in ("joint", "both_norm", "joint_norm"): joint_scale = scale / math.sqrt(2.0) z_h = z_h + joint_scale.to(z_h.dtype) * _unit_noise_like(z_h, config.trajectory_noise_sampling) z_l = z_l + joint_scale.to(z_l.dtype) * _unit_noise_like(z_l, config.trajectory_noise_sampling) else: raise ValueError(f"Unknown trajectory_perturb={config.trajectory_perturb!r}; expected h, l, both, or joint") return replace(inner_carry, z_H=z_h, z_L=z_l) def _token_loss(loss_fn, logits, labels, mask): kwargs = {"ignore_index": -100} code = getattr(loss_fn, "__code__", None) arg_names = code.co_varnames[: code.co_argcount + code.co_kwonlyargcount] if code is not None else () if "valid_mask" in arg_names: kwargs["valid_mask"] = mask return loss_fn(logits, labels, **kwargs) def _fixed_unroll_branch_loss(config: PretrainConfig, head: nn.Module, batch: Any, noise_stds: Optional[torch.Tensor]): base = head.model # type: ignore[attr-defined] with torch.device("cuda"): carry = base.initial_carry(batch) reset_flag = torch.ones_like(carry.halted) inner_carry = base.inner.reset_carry(reset_flag, carry.inner_carry) if noise_stds is not None: inner_carry = _add_initial_noise(config, inner_carry, noise_stds) batch_size = batch["inputs"].shape[0] loss_sum = torch.zeros((), device=batch["inputs"].device, dtype=torch.float32) last_exact = torch.zeros((), device=batch["inputs"].device, dtype=torch.float32) for act_step in range(base.config.halt_max_steps): inner_carry, logits, (q_halt_logits, q_continue_logits) = base.inner(inner_carry, batch) labels = batch["labels"] with torch.no_grad(): mask = labels != -100 loss_counts = mask.sum(-1) loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) is_correct = mask & (torch.argmax(logits, dim=-1) == labels) seq_is_correct = is_correct.sum(-1) == loss_counts last_exact = seq_is_correct.to(torch.float32).sum() lm_loss = (_token_loss(head.loss_fn, logits, labels, mask) / loss_divisor).sum() q_halt_loss = torch.nn.functional.binary_cross_entropy_with_logits( q_halt_logits, seq_is_correct.to(q_halt_logits.dtype), reduction="sum", ) q_continue_loss = torch.zeros_like(q_halt_loss) with torch.no_grad(): next_q_halt_logits, next_q_continue_logits = base.inner(inner_carry, batch)[-1] is_last = act_step + 1 >= base.config.halt_max_steps target_q_continue = torch.sigmoid( torch.where( torch.full_like(next_q_halt_logits, is_last, dtype=torch.bool), next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits), ) ) q_continue_loss = torch.nn.functional.binary_cross_entropy_with_logits(q_continue_logits, target_q_continue, reduction="sum") loss_sum = loss_sum + lm_loss + 0.5 * (q_halt_loss + q_continue_loss) return loss_sum / max(base.config.halt_max_steps, 1), last_exact, batch_size def _prepare_noisy_stream_carry(config: PretrainConfig, base: nn.Module, carry: Any, batch: Any, step: int): reset_mask = carry.halted new_inner = base.inner.reset_carry(reset_mask, carry.inner_carry) noise_stds = _sample_noise_stds(config, batch["inputs"].shape[0], batch["inputs"].device, step) noise_stds = torch.where(reset_mask, noise_stds, torch.zeros_like(noise_stds)) new_inner = _add_initial_noise(config, new_inner, noise_stds) new_steps = torch.where(reset_mask, 0, carry.steps) view = reset_mask.view((-1,) + (1,) * (batch["inputs"].ndim - 1)) new_current_data = { k: torch.where(reset_mask.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items() } return replace( carry, inner_carry=new_inner, steps=new_steps, halted=torch.zeros_like(reset_mask), current_data=new_current_data, ) def _merge_trajectory_carries(carries: Sequence[Any]): inner_type = type(carries[0].inner_carry) inner_fields = carries[0].inner_carry.__dataclass_fields__.keys() inner_carry = inner_type(**{ name: torch.cat([getattr(c.inner_carry, name) for c in carries], dim=0) for name in inner_fields }) current_data = { k: torch.cat([c.current_data[k] for c in carries], dim=0) for k in carries[0].current_data } return replace( carries[0], inner_carry=inner_carry, steps=torch.cat([c.steps for c in carries], dim=0), halted=torch.cat([c.halted for c in carries], dim=0), current_data=current_data, ) def _split_trajectory_carry(carry: Any, branch_size: int, num_branches: int): inner_type = type(carry.inner_carry) inner_fields = carry.inner_carry.__dataclass_fields__.keys() inner_chunks = { name: [chunk.contiguous() for chunk in getattr(carry.inner_carry, name).split(branch_size, dim=0)] for name in inner_fields } steps_chunks = [chunk.contiguous() for chunk in carry.steps.split(branch_size, dim=0)] halted_chunks = [chunk.contiguous() for chunk in carry.halted.split(branch_size, dim=0)] data_chunks = { k: [chunk.contiguous() for chunk in v.split(branch_size, dim=0)] for k, v in carry.current_data.items() } return [ replace( carry, inner_carry=inner_type(**{name: inner_chunks[name][idx] for name in inner_fields}), steps=steps_chunks[idx], halted=halted_chunks[idx], current_data={k: chunks[idx] for k, chunks in data_chunks.items()}, ) for idx in range(num_branches) ] def _repeat_batch_for_trajectories(batch: Any, num_branches: int): return { k: torch.cat([v for _ in range(num_branches)], dim=0) for k, v in batch.items() } def _split_outputs(outputs: Any, branch_size: int, num_branches: int): return [ { k: v.narrow(0, idx * branch_size, branch_size) for k, v in outputs.items() } for idx in range(num_branches) ] def _branch_metrics_and_loss(head: nn.Module, carry: Any, outputs: Any): labels = carry.current_data["labels"] with torch.no_grad(): mask = labels != -100 loss_counts = mask.sum(-1) loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels) seq_is_correct = is_correct.sum(-1) == loss_counts valid_metrics = carry.halted & (loss_counts > 0) metrics = { "count": valid_metrics.sum(), "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(), "exact_accuracy": (valid_metrics & seq_is_correct).sum(), "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(), "steps": torch.where(valid_metrics, carry.steps, 0).sum(), } lm_loss = (head.loss_fn(outputs["logits"], labels, ignore_index=-100) / loss_divisor).sum() q_halt_loss = torch.nn.functional.binary_cross_entropy_with_logits( outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum", ) q_continue_loss = torch.zeros_like(q_halt_loss) if "target_q_continue" in outputs: q_continue_loss = torch.nn.functional.binary_cross_entropy_with_logits( outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum", ) metrics.update({ "lm_loss": lm_loss, "q_halt_loss": q_halt_loss, }) if "target_q_continue" in outputs: metrics["q_continue_loss"] = q_continue_loss total_loss = lm_loss + 0.5 * (q_halt_loss + q_continue_loss) return metrics, total_loss def train_batch_trajectory_aug_parallel(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int, world_size: int): if world_size != 1: raise NotImplementedError("trajectory_parallel currently expects single-GPU baseline configs") train_state.step += 1 if train_state.step > train_state.total_steps: return batch = {k: v.cuda() for k, v in batch.items()} head = _unwrap_loss_head(train_state.model) base = head.model # type: ignore[attr-defined] num_branches = max(config.trajectory_n, 1) branch_size = batch["inputs"].shape[0] if train_state.carry is None or not isinstance(train_state.carry, list): with torch.device("cuda"): train_state.carry = [train_state.model.initial_carry(batch) for _ in range(num_branches)] # type: ignore for optim in train_state.optimizers: optim.zero_grad() carries = [] for branch_idx, carry in enumerate(train_state.carry): if branch_idx > 0: carry = _prepare_noisy_stream_carry(config, base, carry, batch, train_state.step) carries.append(carry) parallel_carry = _merge_trajectory_carries(carries) parallel_batch = _repeat_batch_for_trajectories(batch, num_branches) return_keys = ["logits", "q_halt_logits", "q_continue_logits", "target_q_continue"] parallel_carry, loss, _metrics, outputs, _ = train_state.model( carry=parallel_carry, batch=parallel_batch, return_keys=return_keys, ) train_state.carry = _split_trajectory_carry(parallel_carry, branch_size, num_branches) (loss / (global_batch_size * num_branches)).backward() lr_this_step = None for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs): lr_this_step = compute_lr(base_lr, config, train_state) for param_group in optim.param_groups: param_group["lr"] = lr_this_step optim.step() optim.zero_grad() if rank == 0 and outputs is not None: split_outputs = _split_outputs(outputs, branch_size, num_branches) clean_metrics, clean_loss = _branch_metrics_and_loss(head, train_state.carry[0], split_outputs[0]) noisy_loss = 0.0 for branch_idx in range(1, num_branches): _, branch_loss = _branch_metrics_and_loss(head, train_state.carry[branch_idx], split_outputs[branch_idx]) noisy_loss += float(branch_loss.detach().cpu()) reduced_metrics = {f"train/{k}": float(v.detach().cpu()) for k, v in clean_metrics.items()} count = max(reduced_metrics.get("train/count", 1.0), 1.0) reduced_metrics = { k: v / (global_batch_size if k.endswith("loss") else count) for k, v in reduced_metrics.items() } reduced_metrics["train/lr"] = lr_this_step reduced_metrics["train/trajectory_clean_loss"] = float(clean_loss.detach().cpu()) / global_batch_size reduced_metrics["train/trajectory_noisy_loss"] = noisy_loss / max(num_branches - 1, 1) / global_batch_size reduced_metrics["train/trajectory_sigma"] = _trajectory_noise_target(config, train_state.step) reduced_metrics["train/trajectory_parallel"] = 1.0 return reduced_metrics def train_batch_trajectory_aug(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int, world_size: int): if config.trajectory_parallel: return train_batch_trajectory_aug_parallel(config, train_state, batch, global_batch_size, rank, world_size) if world_size != 1: raise NotImplementedError("trajectory_augment currently expects single-GPU baseline configs") train_state.step += 1 if train_state.step > train_state.total_steps: return batch = {k: v.cuda() for k, v in batch.items()} head = _unwrap_loss_head(train_state.model) base = head.model # type: ignore[attr-defined] if train_state.carry is None or not isinstance(train_state.carry, list): with torch.device("cuda"): train_state.carry = [train_state.model.initial_carry(batch) for _ in range(config.trajectory_n)] # type: ignore for optim in train_state.optimizers: optim.zero_grad() clean_metrics = None clean_loss_value = 0.0 noisy_loss_value = 0.0 for branch_idx in range(config.trajectory_n): carry = train_state.carry[branch_idx] if branch_idx > 0: carry = _prepare_noisy_stream_carry(config, base, carry, batch, train_state.step) carry, loss, metrics, _, _ = train_state.model(carry=carry, batch=batch, return_keys=[]) train_state.carry[branch_idx] = carry (loss / (global_batch_size * max(config.trajectory_n, 1))).backward() if branch_idx == 0: clean_metrics = metrics clean_loss_value = float(loss.detach().cpu()) else: noisy_loss_value += float(loss.detach().cpu()) lr_this_step = None for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs): lr_this_step = compute_lr(base_lr, config, train_state) for param_group in optim.param_groups: param_group["lr"] = lr_this_step optim.step() optim.zero_grad() if rank == 0 and clean_metrics is not None: reduced_metrics = {f"train/{k}": float(v.detach().cpu()) for k, v in clean_metrics.items()} count = max(reduced_metrics.get("train/count", 1.0), 1.0) reduced_metrics = { k: v / (global_batch_size if k.endswith("loss") else count) for k, v in reduced_metrics.items() } reduced_metrics["train/lr"] = lr_this_step reduced_metrics["train/trajectory_clean_loss"] = clean_loss_value / global_batch_size reduced_metrics["train/trajectory_noisy_loss"] = noisy_loss_value / max(config.trajectory_n - 1, 1) / global_batch_size reduced_metrics["train/trajectory_sigma"] = _trajectory_noise_target(config, train_state.step) return reduced_metrics def train_batch(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int, world_size: int): if config.trajectory_augment: return train_batch_trajectory_aug(config, train_state, batch, global_batch_size, rank, world_size) train_state.step += 1 if train_state.step > train_state.total_steps: # At most train_total_steps return # To device batch = {k: v.cuda() for k, v in batch.items()} # Init carry if it is None if train_state.carry is None: with torch.device("cuda"): train_state.carry = train_state.model.initial_carry(batch) # type: ignore # Forward train_state.carry, loss, metrics, _, _ = train_state.model(carry=train_state.carry, batch=batch, return_keys=[]) ((1 / global_batch_size) * loss).backward() # Allreduce if world_size > 1: for param in train_state.model.parameters(): if param.grad is not None: dist.all_reduce(param.grad) # Apply optimizer lr_this_step = None for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs): lr_this_step = compute_lr(base_lr, config, train_state) for param_group in optim.param_groups: param_group['lr'] = lr_this_step optim.step() optim.zero_grad() # Reduce metrics if len(metrics): assert not any(v.requires_grad for v in metrics.values()) metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order. # Reduce and reconstruct metric_values = torch.stack([metrics[k] for k in metric_keys]) if world_size > 1: dist.reduce(metric_values, dst=0) if rank == 0: metric_values = metric_values.cpu().numpy() reduced_metrics = {k: metric_values[i] for i, k in enumerate(metric_keys)} # Postprocess count = max(reduced_metrics["count"], 1) # Avoid NaNs reduced_metrics = {f"train/{k}": v / (global_batch_size if k.endswith("loss") else count) for k, v in reduced_metrics.items()} reduced_metrics["train/lr"] = lr_this_step return reduced_metrics def evaluate(config: PretrainConfig, train_state: TrainState, eval_loader: torch.utils.data.DataLoader, eval_metadata: PuzzleDatasetMetadata, rank: int, world_size: int): with torch.inference_mode(): set_ids = {k: idx for idx, k in enumerate(eval_metadata.sets)} all_preds = {} metric_keys = [] metric_values = None metric_global_batch_size = [0 for _ in range(len(set_ids))] carry = None for set_name, batch, global_batch_size in eval_loader: # To device batch = {k: v.cuda() for k, v in batch.items()} with torch.device("cuda"): carry = train_state.model.initial_carry(batch) # type: ignore # Forward while True: carry, _, metrics, preds, all_finish = train_state.model(carry=carry, batch=batch, return_keys=config.eval_save_outputs) if all_finish: break for collection in (batch, preds): for k, v in collection.items(): if k in config.eval_save_outputs: all_preds.setdefault(k, []) all_preds[k].append(v.cpu()) # Move to CPU for saving GPU memory del carry, preds, batch, all_finish # Aggregate set_id = set_ids[set_name] if metric_values is None: metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order. metric_values = torch.zeros((len(set_ids), len(metrics.values())), dtype=torch.float32, device="cuda") metric_values[set_id] += torch.stack([metrics[k] for k in metric_keys]) metric_global_batch_size[set_id] += global_batch_size if len(all_preds) and config.checkpoint_path is not None: all_preds = {k: torch.cat(v, dim=0) for k, v in all_preds.items()} os.makedirs(config.checkpoint_path, exist_ok=True) torch.save(all_preds, os.path.join(config.checkpoint_path, f"step_{train_state.step}_all_preds.{rank}")) # Logging # Reduce to rank 0 if metric_values is not None: if world_size > 1: dist.reduce(metric_values, dst=0) if rank == 0: reduced_metrics = metric_values.cpu().numpy() reduced_metrics = {set_name: {metric_name: reduced_metrics[set_id, metric_id] for metric_id, metric_name in enumerate(metric_keys)} for set_id, set_name in enumerate(set_ids)} # Postprocess for set_name, metrics in reduced_metrics.items(): count = metrics.pop("count") reduced_metrics[set_name] = {k: v / count for k, v in metrics.items()} return reduced_metrics def save_code_and_config(config: PretrainConfig): if config.checkpoint_path is None or wandb.run is None: return os.makedirs(config.checkpoint_path, exist_ok=True) # Copy code code_list = [ get_model_source_path(config.arch.name), get_model_source_path(config.arch.loss.name) ] for code_file in code_list: if code_file is not None: code_name = os.path.basename(code_file) shutil.copy(code_file, os.path.join(config.checkpoint_path, code_name)) # Dump config as yaml config_file = os.path.join(config.checkpoint_path, "all_config.yaml") with open(config_file, "wt") as f: yaml.dump(config.model_dump(), f) # Log code wandb.run.log_code(config.checkpoint_path) def load_synced_config(hydra_config: DictConfig, rank: int, world_size: int) -> PretrainConfig: objects = [None] if rank == 0: config = PretrainConfig(**hydra_config) # type: ignore # Naming if config.project_name is None: config.project_name = f"{os.path.basename(config.data_path).capitalize()} ACT-torch" if config.run_name is None: config.run_name = f"{config.arch.name.split('@')[-1]} {coolname.generate_slug(2)}" if config.checkpoint_path is None: config.checkpoint_path = os.path.join("checkpoints", config.project_name, config.run_name) objects = [config] if world_size > 1: dist.broadcast_object_list(objects, src=0) return objects[0] # type: ignore @hydra.main(config_path="config", config_name="cfg_pretrain", version_base=None) def launch(hydra_config: DictConfig): RANK = 0 WORLD_SIZE = 1 # Initialize distributed training if in distributed environment (e.g. torchrun) if "LOCAL_RANK" in os.environ: # Initialize distributed, default device and dtype dist.init_process_group(backend="nccl") RANK = dist.get_rank() WORLD_SIZE = dist.get_world_size() torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) # Load sync'ed config config = load_synced_config(hydra_config, rank=RANK, world_size=WORLD_SIZE) # Seed RNGs to ensure consistency torch.random.manual_seed(config.seed + RANK) # Dataset train_epochs_per_iter = config.eval_interval if config.eval_interval is not None else config.epochs total_iters = config.epochs // train_epochs_per_iter assert config.epochs % train_epochs_per_iter == 0, "Eval interval must be a divisor of total epochs." train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=train_epochs_per_iter, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE) eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE) # Train state train_state = init_train_state(config, train_metadata, world_size=WORLD_SIZE) # Progress bar and logger progress_bar = None if RANK == 0: progress_bar = tqdm.tqdm(total=train_state.total_steps) wandb.init(project=config.project_name, name=config.run_name, config=config.model_dump(), settings=wandb.Settings(_disable_stats=True)) # type: ignore wandb.log({"num_params": sum(x.numel() for x in train_state.model.parameters())}, step=0) save_code_and_config(config) # Training Loop for _iter_id in range(total_iters): print (f"[Rank {RANK}, World Size {WORLD_SIZE}]: Epoch {_iter_id * train_epochs_per_iter}") ############ Train Iter train_state.model.train() for set_name, batch, global_batch_size in train_loader: metrics = train_batch(config, train_state, batch, global_batch_size, rank=RANK, world_size=WORLD_SIZE) if RANK == 0 and metrics is not None: wandb.log(metrics, step=train_state.step) progress_bar.update(train_state.step - progress_bar.n) # type: ignore ############ Evaluation train_state.model.eval() metrics = evaluate(config, train_state, eval_loader, eval_metadata, rank=RANK, world_size=WORLD_SIZE) if RANK == 0 and metrics is not None: wandb.log(metrics, step=train_state.step) ############ Checkpointing if RANK == 0 and (config.checkpoint_every_eval or (_iter_id == total_iters - 1)): save_train_state(config, train_state) # finalize if dist.is_initialized(): dist.destroy_process_group() wandb.finish() if __name__ == "__main__": launch()