diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
| commit | 66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch) | |
| tree | c29cba61124018755a19b02c9d33e3ad5f2e05cc /trm/pretrain.py | |
Curated export for clone-and-run Maze training (2x A6000) + diagnostics.
trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible).
Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'trm/pretrain.py')
| -rw-r--r-- | trm/pretrain.py | 1277 |
1 files changed, 1277 insertions, 0 deletions
diff --git a/trm/pretrain.py b/trm/pretrain.py new file mode 100644 index 0000000..b8f9ef0 --- /dev/null +++ b/trm/pretrain.py @@ -0,0 +1,1277 @@ +from typing import Optional, Any, Sequence, List +from dataclasses import dataclass, replace +import os +import math +import yaml +import shutil +import copy + +import torch +import torch.distributed as dist +from torch import nn +from torch.utils.data import DataLoader + +import tqdm +import wandb +import coolname +import hydra +import pydantic +from omegaconf import DictConfig +from adam_atan2 import AdamATan2 + +from puzzle_dataset import PuzzleDataset, PuzzleDatasetConfig, PuzzleDatasetMetadata +from utils.functions import load_model_class, get_model_source_path +from models.sparse_embedding import CastedSparseEmbeddingSignSGD_Distributed +from models.ema import EMAHelper + + +class LossConfig(pydantic.BaseModel): + model_config = pydantic.ConfigDict(extra='allow') + name: str + + +class ArchConfig(pydantic.BaseModel): + model_config = pydantic.ConfigDict(extra='allow') + name: str + loss: LossConfig + + +class EvaluatorConfig(pydantic.BaseModel): + model_config = pydantic.ConfigDict(extra="allow") + name: str + + +class PretrainConfig(pydantic.BaseModel): + # Config + arch: ArchConfig + # Data + data_paths: List[str] + data_paths_test: List[str] = [] + # Evaluators + evaluators: List[EvaluatorConfig] = [] + + # Hyperparams + global_batch_size: int + epochs: int + + lr: float + lr_min_ratio: float + lr_warmup_steps: int + + weight_decay: float + beta1: float + beta2: float + + # Puzzle embedding + puzzle_emb_lr: float + puzzle_emb_weight_decay: float + + # Names + project_name: Optional[str] = None + run_name: Optional[str] = None + load_checkpoint: Optional[str] = None + checkpoint_path: Optional[str] = None + + # Extras + seed: int = 0 + checkpoint_every_eval: bool = False + eval_interval: Optional[int] = None + min_eval_interval: Optional[int] = 0 # when to start eval + eval_save_outputs: List[str] = [] + + ema: bool = False # use Exponential-Moving-Average + ema_rate: float = 0.999 # EMA-rate + freeze_weights: bool = False # If True, freeze weights and only learn the embeddings + + trajectory_augment: bool = False + trajectory_n: int = 4 + trajectory_noise_std: float = 1e-3 + trajectory_noise_min: Optional[float] = None + trajectory_noise_max: Optional[float] = None + trajectory_noise_sampling: str = "loguniform" + trajectory_sigma_start: Optional[float] = 0.0 + trajectory_sigma_ramp_steps: int = 5000 + trajectory_perturb: str = "both" + trajectory_micro_batch: int = 0 + trajectory_parallel: bool = False + trajectory_directional: bool = False + trajectory_directional_candidates: int = 4 + trajectory_directional_fd_eps: float = 3e-2 + trajectory_directional_horizon: int = 0 + trajectory_directional_sign: str = "alternate" + +@dataclass +class TrainState: + model: nn.Module + optimizers: Sequence[torch.optim.Optimizer] + optimizer_lrs: Sequence[float] + carry: Any + + step: int + total_steps: int + + +def create_dataloader(config: PretrainConfig, split: str, rank: int, world_size: int, **kwargs): + dataset = PuzzleDataset(PuzzleDatasetConfig( + seed=config.seed, + dataset_paths=config.data_paths_test if len(config.data_paths_test)>0 and split=="test" else config.data_paths, + rank=rank, + num_replicas=world_size, + **kwargs + ), split=split) + dataloader = DataLoader( + dataset, + batch_size=None, + num_workers=1, + prefetch_factor=8, + pin_memory=True, + persistent_workers=True + ) + return dataloader, dataset.metadata + + +def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, rank: int, world_size: int): + model_batch_size = config.global_batch_size // world_size + if config.trajectory_augment and config.trajectory_parallel: + model_batch_size *= config.trajectory_n + + model_cfg = dict( + **config.arch.__pydantic_extra__, # type: ignore + batch_size=model_batch_size, + vocab_size=train_metadata.vocab_size, + seq_len=train_metadata.seq_len, + num_puzzle_identifiers=train_metadata.num_puzzle_identifiers, + causal=False # Non-autoregressive + ) + + # Instantiate model with loss head + model_cls = load_model_class(config.arch.name) + loss_head_cls = load_model_class(config.arch.loss.name) + + with torch.device("cuda"): + model: nn.Module = model_cls(model_cfg) + print(model) + model = loss_head_cls(model, **config.arch.loss.__pydantic_extra__) # type: ignore + if "DISABLE_COMPILE" not in os.environ: + model = torch.compile(model) # type: ignore + + # Load checkpoint + if rank == 0: + load_checkpoint(model, config) + + # Broadcast parameters from rank 0 + if world_size > 1: + with torch.no_grad(): + for param in list(model.parameters()) + list(model.buffers()): + dist.broadcast(param, src=0) + + # Optimizers and lr + if config.arch.puzzle_emb_ndim == 0: + optimizers = [ + AdamATan2( + model.parameters(), + lr=0, # Needs to be set by scheduler + weight_decay=config.weight_decay, + betas=(config.beta1, config.beta2) + ) + ] + optimizer_lrs = [ + config.lr + ] + elif config.freeze_weights: + optimizers = [ + CastedSparseEmbeddingSignSGD_Distributed( + model.model.puzzle_emb.buffers(), # type: ignore + lr=0, # Needs to be set by scheduler + weight_decay=config.puzzle_emb_weight_decay, + world_size=world_size + ) + ] + optimizer_lrs = [ + config.puzzle_emb_lr + ] + else: + optimizers = [ + CastedSparseEmbeddingSignSGD_Distributed( + model.model.puzzle_emb.buffers(), # type: ignore + lr=0, # Needs to be set by scheduler + weight_decay=config.puzzle_emb_weight_decay, + world_size=world_size + ), + AdamATan2( + model.parameters(), + lr=0, # Needs to be set by scheduler + weight_decay=config.weight_decay, + betas=(config.beta1, config.beta2) + ) + ] + optimizer_lrs = [ + config.puzzle_emb_lr, + config.lr + ] + + return model, optimizers, optimizer_lrs + +def mix_weights_direct(device, alpha, net, nets): + sd = [] + for i in range(len(nets)): + sd += [nets[i].state_dict()] + sd_alpha = {} + for k in sd[0].keys(): + comb_net = alpha[0]*sd[0][k].to(device) + for i in range(1,len(nets)): + comb_net += alpha[i]*sd[i][k].to(device) + sd_alpha[k] = comb_net + net.load_state_dict(sd_alpha) + return net + +def cosine_schedule_with_warmup_lr_lambda( + current_step: int, *, base_lr: float, num_warmup_steps: int, num_training_steps: int, min_ratio: float = 0.0, num_cycles: float = 0.5 +): + if current_step < num_warmup_steps: + return base_lr * float(current_step) / float(max(1, num_warmup_steps)) + + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + return base_lr * (min_ratio + max(0.0, (1 - min_ratio) * 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))) + + +def init_train_state(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, rank: int, world_size: int): + # Estimated total training steps + total_steps = int(config.epochs * train_metadata.total_groups * train_metadata.mean_puzzle_examples / config.global_batch_size) + + # Model + model, optimizers, optimizer_lrs = create_model(config, train_metadata, rank=rank, world_size=world_size) + + return TrainState( + step=0, + total_steps=total_steps, + + model=model, + optimizers=optimizers, + optimizer_lrs=optimizer_lrs, + carry=None + ) + + +def save_train_state(config: PretrainConfig, train_state: TrainState): + # FIXME: Only saved model. + if config.checkpoint_path is None: + return + + os.makedirs(config.checkpoint_path, exist_ok=True) + torch.save(train_state.model.state_dict(), os.path.join(config.checkpoint_path, f"step_{train_state.step}")) + + +def load_checkpoint(model: nn.Module, config: PretrainConfig): + if config.load_checkpoint is not None: + print(f"Loading checkpoint {config.load_checkpoint}") + + # Load state dict + state_dict = torch.load(config.load_checkpoint, map_location="cuda") + + # Resize and reset puzzle emb if needed + puzzle_emb_name = "_orig_mod.model.inner.puzzle_emb.weights" + expected_shape: torch.Size = model.model.puzzle_emb.weights.shape # type: ignore + if puzzle_emb_name in state_dict: + puzzle_emb = state_dict[puzzle_emb_name] + if puzzle_emb.shape != expected_shape: + print(f"Resetting puzzle embedding as shape is different. Found {puzzle_emb.shape}, Expected {expected_shape}") + # Re-initialize using mean + state_dict[puzzle_emb_name] = ( + torch.mean(puzzle_emb, dim=0, keepdim=True).expand(expected_shape).contiguous() + ) + model.load_state_dict(state_dict, assign=True) + + +def compute_lr(base_lr: float, config: PretrainConfig, train_state: TrainState): + return cosine_schedule_with_warmup_lr_lambda( + current_step=train_state.step, + base_lr=base_lr, + num_warmup_steps=round(config.lr_warmup_steps), + num_training_steps=train_state.total_steps, + min_ratio=config.lr_min_ratio + ) + + + +def _unwrap_loss_head(model: nn.Module): + return getattr(model, "_orig_mod", model) + + +def _unit_noise_like(tensor: torch.Tensor, sampling: str): + if sampling == "uniform": + return ((2.0 * torch.rand(tensor.shape, device=tensor.device, dtype=torch.float32) - 1.0) * math.sqrt(3.0)).to(tensor.dtype) + return torch.randn(tensor.shape, device=tensor.device, dtype=torch.float32).to(tensor.dtype) + + +def _trajectory_noise_target(config: PretrainConfig, step: int) -> float: + if config.trajectory_sigma_ramp_steps <= 0: + return config.trajectory_noise_std + start = config.trajectory_sigma_start if config.trajectory_sigma_start is not None else config.trajectory_noise_std + frac = min(max(step / config.trajectory_sigma_ramp_steps, 0.0), 1.0) + return float(start + frac * (config.trajectory_noise_std - start)) + + +def _sample_noise_stds(config: PretrainConfig, batch_size: int, device: torch.device, step: int): + target = _trajectory_noise_target(config, step) + if target <= 0: + return torch.zeros(batch_size, device=device, dtype=torch.float32) + if config.trajectory_noise_sampling == "loguniform": + scale = 1.0 if config.trajectory_noise_std <= 0 else target / config.trajectory_noise_std + hi = float(config.trajectory_noise_max if config.trajectory_noise_max is not None else config.trajectory_noise_std) * scale + lo = float(config.trajectory_noise_min if config.trajectory_noise_min is not None else max(hi / 10.0, 1e-8)) * scale + hi = max(hi, 1e-12) + lo = max(min(lo, hi), 1e-12) + u = torch.rand(batch_size, device=device, dtype=torch.float32) + return torch.exp(math.log(lo) + u * (math.log(hi) - math.log(lo))) + return torch.full((batch_size,), target, device=device, dtype=torch.float32) + + +def _add_initial_noise(config: PretrainConfig, inner_carry: Any, noise_stds: torch.Tensor): + if noise_stds.numel() == 0 or float(noise_stds.max().item()) <= 0: + return inner_carry + view_shape = (noise_stds.shape[0],) + (1,) * (inner_carry.z_H.ndim - 1) + scale = noise_stds.view(view_shape) + z_h = inner_carry.z_H + z_l = inner_carry.z_L + if config.trajectory_perturb in ("h", "both"): + z_h = z_h + scale.to(z_h.dtype) * _unit_noise_like(z_h, config.trajectory_noise_sampling) + if config.trajectory_perturb in ("l", "both"): + z_l = z_l + scale.to(z_l.dtype) * _unit_noise_like(z_l, config.trajectory_noise_sampling) + return replace(inner_carry, z_H=z_h, z_L=z_l) + + +def _directional_fields(config: PretrainConfig): + if config.trajectory_perturb == "h": + return ("z_H",) + if config.trajectory_perturb == "l": + return ("z_L",) + return ("z_H", "z_L") + + +def _index_inner_carry(inner_carry: Any, indices: torch.Tensor): + inner_type = type(inner_carry) + return inner_type(**{ + name: getattr(inner_carry, name).index_select(0, indices).contiguous() + for name in inner_carry.__dataclass_fields__.keys() + }) + + +def _cat_inner_carries(carries: Sequence[Any]): + inner_type = type(carries[0]) + return inner_type(**{ + name: torch.cat([getattr(c, name) for c in carries], dim=0) + for name in carries[0].__dataclass_fields__.keys() + }) + + +def _split_inner_carry(inner_carry: Any, first_size: int): + inner_type = type(inner_carry) + first = {} + second = {} + for name in inner_carry.__dataclass_fields__.keys(): + value = getattr(inner_carry, name) + first[name] = value[:first_size].contiguous() + second[name] = value[first_size:].contiguous() + return inner_type(**first), inner_type(**second) + + +def _repeat_batch_interleave(batch: Any, repeats: int): + return {k: v.repeat_interleave(repeats, dim=0) for k, v in batch.items()} + + +def _cat_batches(a: Any, b: Any): + return {k: torch.cat([a[k], b[k]], dim=0) for k in a} + + +def _rand_unit_direction_candidates(config: PretrainConfig, inner_carry: Any, candidates: int): + fields = _directional_fields(config) + dirs = {} + norm_sq = None + for name in inner_carry.__dataclass_fields__.keys(): + value = getattr(inner_carry, name) + if name in fields: + direction = torch.randn( + (value.shape[0], candidates) + tuple(value.shape[1:]), + device=value.device, + dtype=torch.float32, + ) + term = direction.flatten(2).square().sum(-1) + norm_sq = term if norm_sq is None else norm_sq + term + dirs[name] = direction + else: + dirs[name] = torch.zeros( + (value.shape[0], candidates) + tuple(value.shape[1:]), + device=value.device, + dtype=torch.float32, + ) + + if norm_sq is None: + raise ValueError(f"unknown trajectory_perturb={config.trajectory_perturb!r}") + + norm = torch.sqrt(norm_sq).clamp_min(1e-30) + out = {} + for name, direction in dirs.items(): + view_shape = direction.shape[:2] + (1,) * (direction.ndim - 2) + out[name] = (direction / norm.view(view_shape)).to(getattr(inner_carry, name).dtype) + return out + + +def _make_shadow_inner(inner_carry: Any, dirs: dict[str, torch.Tensor], eps: float): + inner_type = type(inner_carry) + fields = {} + for name in inner_carry.__dataclass_fields__.keys(): + value = getattr(inner_carry, name) + perturbed = value[:, None] + eps * dirs[name] + fields[name] = perturbed.reshape((value.shape[0] * dirs[name].shape[1],) + tuple(value.shape[1:])).detach() + return inner_type(**fields) + + +def _separation(main: Any, shadow: Any, candidates: int): + sep_sq = None + batch_size = getattr(main, next(iter(main.__dataclass_fields__.keys()))).shape[0] + for name in main.__dataclass_fields__.keys(): + main_value = getattr(main, name) + shadow_value = getattr(shadow, name).reshape((batch_size, candidates) + tuple(main_value.shape[1:])) + diff_sq = (shadow_value.float() - main_value[:, None].float()).flatten(2).square().sum(-1) + sep_sq = diff_sq if sep_sq is None else sep_sq + diff_sq + return torch.sqrt(sep_sq).clamp_min(1e-30) + + +def _gather_direction(dirs: dict[str, torch.Tensor], best_idx: torch.Tensor): + batch_size = best_idx.shape[0] + row_idx = torch.arange(batch_size, device=best_idx.device) + return {name: value[row_idx, best_idx].contiguous() for name, value in dirs.items()} + + +def _directional_signs(config: PretrainConfig, branch_idx: int, count: int, device: torch.device): + mode = config.trajectory_directional_sign + if mode == "random": + return torch.where( + torch.rand(count, device=device, dtype=torch.float32) < 0.5, + torch.full((count,), -1.0, device=device, dtype=torch.float32), + torch.full((count,), 1.0, device=device, dtype=torch.float32), + ) + if mode == "negative": + return torch.full((count,), -1.0, device=device, dtype=torch.float32) + if mode == "positive": + return torch.full((count,), 1.0, device=device, dtype=torch.float32) + if mode != "alternate": + raise ValueError(f"unknown trajectory_directional_sign={mode!r}") + sign = 1.0 if branch_idx % 2 == 1 else -1.0 + return torch.full((count,), sign, device=device, dtype=torch.float32) + + +@torch.no_grad() +def _select_directional_perturbation( + config: PretrainConfig, + base: nn.Module, + inner_carry: Any, + batch: Any, + active_mask: torch.Tensor, +): + active_indices = torch.nonzero(active_mask, as_tuple=False).flatten() + active_count = int(active_indices.numel()) + if active_count == 0: + return None, None, { + "active": 0, + "growth_sum": 0.0, + "growth_count": 0, + } + + candidates = max(int(config.trajectory_directional_candidates), 1) + fd_eps = float(config.trajectory_directional_fd_eps) + horizon = int(config.trajectory_directional_horizon) + if horizon <= 0: + horizon = int(base.config.halt_max_steps) # type: ignore[attr-defined] + horizon = max(horizon, 1) + + active_inner = _index_inner_carry(inner_carry, active_indices) + active_batch = {k: v.index_select(0, active_indices).contiguous() for k, v in batch.items()} + dirs = _rand_unit_direction_candidates(config, active_inner, candidates) + main = type(active_inner)(**{ + name: getattr(active_inner, name).detach() + for name in active_inner.__dataclass_fields__.keys() + }) + shadow = _make_shadow_inner(main, dirs, fd_eps) + combined = _cat_inner_carries([main, shadow]) + combined_batch = _cat_batches(active_batch, _repeat_batch_interleave(active_batch, candidates)) + + was_training = base.inner.training # type: ignore[attr-defined] + base.inner.eval() # type: ignore[attr-defined] + try: + for _ in range(horizon): + combined, _logits, _q = base.inner(combined, combined_batch) # type: ignore[attr-defined] + finally: + if was_training: + base.inner.train() # type: ignore[attr-defined] + + main_final, shadow_final = _split_inner_carry(combined, active_count) + sep = _separation(main_final, shadow_final, candidates) + best_idx = sep.argmax(dim=1) + best_sep = sep.gather(1, best_idx[:, None]).squeeze(1) + best_dirs = _gather_direction(dirs, best_idx) + growth = torch.log(best_sep / max(fd_eps, 1e-30)).float() / horizon + return active_indices, best_dirs, { + "active": active_count, + "growth_sum": float(growth.sum().item()), + "growth_count": active_count, + } + + +def _add_directional_initial_noise( + config: PretrainConfig, + base: nn.Module, + inner_carry: Any, + batch: Any, + noise_stds: torch.Tensor, + active_mask: torch.Tensor, + branch_idx: int, +): + selected_indices, selected_dirs, stats = _select_directional_perturbation( + config, + base, + inner_carry, + batch, + active_mask, + ) + if selected_indices is None or selected_dirs is None: + return inner_carry, stats + + signs = _directional_signs(config, branch_idx, selected_indices.shape[0], noise_stds.device) + scales = noise_stds.index_select(0, selected_indices) * signs + fields = {} + for name in inner_carry.__dataclass_fields__.keys(): + value = getattr(inner_carry, name) + updated = value.clone() + direction = selected_dirs[name].to(value.dtype) + view_shape = (scales.shape[0],) + (1,) * (direction.ndim - 1) + updated.index_copy_(0, selected_indices, value.index_select(0, selected_indices) + scales.view(view_shape).to(value.dtype) * direction) + fields[name] = updated + return type(inner_carry)(**fields), stats + + +def _token_loss(loss_fn, logits, labels, mask): + kwargs = {"ignore_index": -100} + code = getattr(loss_fn, "__code__", None) + arg_names = code.co_varnames[: code.co_argcount + code.co_kwonlyargcount] if code is not None else () + if "valid_mask" in arg_names: + kwargs["valid_mask"] = mask + return loss_fn(logits, labels, **kwargs) + + +def _fixed_unroll_branch_loss(config: PretrainConfig, head: nn.Module, batch: Any, noise_stds: Optional[torch.Tensor]): + base = head.model # type: ignore[attr-defined] + with torch.device("cuda"): + carry = base.initial_carry(batch) + reset_flag = torch.ones_like(carry.halted) + inner_carry = base.inner.reset_carry(reset_flag, carry.inner_carry) + if noise_stds is not None: + inner_carry = _add_initial_noise(config, inner_carry, noise_stds) + + batch_size = batch["inputs"].shape[0] + loss_sum = torch.zeros((), device=batch["inputs"].device, dtype=torch.float32) + last_exact = torch.zeros((), device=batch["inputs"].device, dtype=torch.float32) + for act_step in range(base.config.halt_max_steps): + inner_carry, logits, (q_halt_logits, q_continue_logits) = base.inner(inner_carry, batch) + labels = batch["labels"] + with torch.no_grad(): + mask = labels != -100 + loss_counts = mask.sum(-1) + loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) + is_correct = mask & (torch.argmax(logits, dim=-1) == labels) + seq_is_correct = is_correct.sum(-1) == loss_counts + last_exact = seq_is_correct.to(torch.float32).sum() + + lm_loss = (_token_loss(head.loss_fn, logits, labels, mask) / loss_divisor).sum() + q_halt_loss = torch.nn.functional.binary_cross_entropy_with_logits( + q_halt_logits, + seq_is_correct.to(q_halt_logits.dtype), + reduction="sum", + ) + q_continue_loss = torch.zeros_like(q_halt_loss) + if not getattr(base.config, "no_ACT_continue", False): + with torch.no_grad(): + next_q_halt_logits, next_q_continue_logits = base.inner(inner_carry, batch)[-1] + is_last = act_step + 1 >= base.config.halt_max_steps + target_q_continue = torch.sigmoid( + torch.where( + torch.full_like(next_q_halt_logits, is_last, dtype=torch.bool), + next_q_halt_logits, + torch.maximum(next_q_halt_logits, next_q_continue_logits), + ) + ) + q_continue_loss = torch.nn.functional.binary_cross_entropy_with_logits(q_continue_logits, target_q_continue, reduction="sum") + loss_sum = loss_sum + lm_loss + 0.5 * (q_halt_loss + q_continue_loss) + + return loss_sum / max(base.config.halt_max_steps, 1), last_exact, batch_size + + +def _prepare_noisy_stream_carry(config: PretrainConfig, base: nn.Module, carry: Any, batch: Any, step: int, branch_idx: int): + reset_mask = carry.halted + new_inner = base.inner.reset_carry(reset_mask, carry.inner_carry) + noise_stds = _sample_noise_stds(config, batch["inputs"].shape[0], batch["inputs"].device, step) + noise_stds = torch.where(reset_mask, noise_stds, torch.zeros_like(noise_stds)) + directional_stats = { + "active": int((noise_stds > 0).sum().detach().cpu().item()), + "growth_sum": 0.0, + "growth_count": 0, + } + if config.trajectory_directional and float(noise_stds.max().item()) > 0: + new_inner, directional_stats = _add_directional_initial_noise( + config, + base, + new_inner, + batch, + noise_stds, + noise_stds > 0, + branch_idx, + ) + else: + new_inner = _add_initial_noise(config, new_inner, noise_stds) + new_steps = torch.where(reset_mask, 0, carry.steps) + new_current_data = { + k: torch.where(reset_mask.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v) + for k, v in carry.current_data.items() + } + return ( + replace( + carry, + inner_carry=new_inner, + steps=new_steps, + halted=torch.zeros_like(reset_mask), + current_data=new_current_data, + ), + directional_stats, + ) + + +def _merge_trajectory_carries(carries: Sequence[Any]): + inner_type = type(carries[0].inner_carry) + inner_fields = carries[0].inner_carry.__dataclass_fields__.keys() + inner_carry = inner_type(**{ + name: torch.cat([getattr(c.inner_carry, name) for c in carries], dim=0) + for name in inner_fields + }) + current_data = { + k: torch.cat([c.current_data[k] for c in carries], dim=0) + for k in carries[0].current_data + } + return replace( + carries[0], + inner_carry=inner_carry, + steps=torch.cat([c.steps for c in carries], dim=0), + halted=torch.cat([c.halted for c in carries], dim=0), + current_data=current_data, + ) + + +def _split_trajectory_carry(carry: Any, branch_size: int, num_branches: int): + inner_type = type(carry.inner_carry) + inner_fields = carry.inner_carry.__dataclass_fields__.keys() + inner_chunks = { + name: [chunk.contiguous() for chunk in getattr(carry.inner_carry, name).split(branch_size, dim=0)] + for name in inner_fields + } + steps_chunks = [chunk.contiguous() for chunk in carry.steps.split(branch_size, dim=0)] + halted_chunks = [chunk.contiguous() for chunk in carry.halted.split(branch_size, dim=0)] + data_chunks = { + k: [chunk.contiguous() for chunk in v.split(branch_size, dim=0)] + for k, v in carry.current_data.items() + } + return [ + replace( + carry, + inner_carry=inner_type(**{name: inner_chunks[name][idx] for name in inner_fields}), + steps=steps_chunks[idx], + halted=halted_chunks[idx], + current_data={k: chunks[idx] for k, chunks in data_chunks.items()}, + ) + for idx in range(num_branches) + ] + + +def _repeat_batch_for_trajectories(batch: Any, num_branches: int): + return { + k: torch.cat([v for _ in range(num_branches)], dim=0) + for k, v in batch.items() + } + + +def _split_outputs(outputs: Any, branch_size: int, num_branches: int): + return [ + { + k: v.narrow(0, idx * branch_size, branch_size) + for k, v in outputs.items() + } + for idx in range(num_branches) + ] + + +def _branch_metrics_and_loss(head: nn.Module, carry: Any, outputs: Any): + labels = carry.current_data["labels"] + with torch.no_grad(): + mask = labels != -100 + loss_counts = mask.sum(-1) + loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) + is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels) + seq_is_correct = is_correct.sum(-1) == loss_counts + valid_metrics = carry.halted & (loss_counts > 0) + metrics = { + "count": valid_metrics.sum(), + "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(), + "exact_accuracy": (valid_metrics & seq_is_correct).sum(), + "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(), + "steps": torch.where(valid_metrics, carry.steps, 0).sum(), + } + + lm_loss = (head.loss_fn(outputs["logits"], labels, ignore_index=-100) / loss_divisor).sum() + q_halt_loss = torch.nn.functional.binary_cross_entropy_with_logits( + outputs["q_halt_logits"], + seq_is_correct.to(outputs["q_halt_logits"].dtype), + reduction="sum", + ) + q_continue_loss = torch.zeros_like(q_halt_loss) + if "target_q_continue" in outputs: + q_continue_loss = torch.nn.functional.binary_cross_entropy_with_logits( + outputs["q_continue_logits"], + outputs["target_q_continue"], + reduction="sum", + ) + metrics.update({ + "lm_loss": lm_loss, + "q_halt_loss": q_halt_loss, + }) + if "target_q_continue" in outputs: + metrics["q_continue_loss"] = q_continue_loss + total_loss = lm_loss + 0.5 * (q_halt_loss + q_continue_loss) + return metrics, total_loss + + +def train_batch_trajectory_aug_parallel(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int, world_size: int): + if world_size != 1: + raise NotImplementedError("trajectory_parallel currently expects single-GPU baseline configs") + train_state.step += 1 + if train_state.step > train_state.total_steps: + return + + batch = {k: v.cuda() for k, v in batch.items()} + head = _unwrap_loss_head(train_state.model) + base = head.model # type: ignore[attr-defined] + num_branches = max(config.trajectory_n, 1) + branch_size = batch["inputs"].shape[0] + + if train_state.carry is None or not isinstance(train_state.carry, list): + with torch.device("cuda"): + train_state.carry = [train_state.model.initial_carry(batch) for _ in range(num_branches)] # type: ignore + + for optim in train_state.optimizers: + optim.zero_grad() + + carries = [] + directional_active = 0 + directional_growth_sum = 0.0 + directional_growth_count = 0 + for branch_idx, carry in enumerate(train_state.carry): + if branch_idx > 0: + carry, stats = _prepare_noisy_stream_carry(config, base, carry, batch, train_state.step, branch_idx) + directional_active += int(stats["active"]) + directional_growth_sum += float(stats["growth_sum"]) + directional_growth_count += int(stats["growth_count"]) + carries.append(carry) + + parallel_carry = _merge_trajectory_carries(carries) + parallel_batch = _repeat_batch_for_trajectories(batch, num_branches) + return_keys = ["logits", "q_halt_logits", "q_continue_logits", "target_q_continue"] + parallel_carry, loss, _metrics, outputs, _ = train_state.model( + carry=parallel_carry, + batch=parallel_batch, + return_keys=return_keys, + ) + train_state.carry = _split_trajectory_carry(parallel_carry, branch_size, num_branches) + (loss / (global_batch_size * num_branches)).backward() + + lr_this_step = None + for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs): + lr_this_step = compute_lr(base_lr, config, train_state) + for param_group in optim.param_groups: + param_group["lr"] = lr_this_step + optim.step() + optim.zero_grad() + + if rank == 0 and outputs is not None: + split_outputs = _split_outputs(outputs, branch_size, num_branches) + clean_metrics, clean_loss = _branch_metrics_and_loss(head, train_state.carry[0], split_outputs[0]) + noisy_loss = 0.0 + for branch_idx in range(1, num_branches): + _, branch_loss = _branch_metrics_and_loss(head, train_state.carry[branch_idx], split_outputs[branch_idx]) + noisy_loss += float(branch_loss.detach().cpu()) + + reduced_metrics = {f"train/{k}": float(v.detach().cpu()) for k, v in clean_metrics.items()} + count = max(reduced_metrics.get("train/count", 1.0), 1.0) + reduced_metrics = { + k: v / (global_batch_size if k.endswith("loss") else count) + for k, v in reduced_metrics.items() + } + reduced_metrics["train/lr"] = lr_this_step + reduced_metrics["train/trajectory_clean_loss"] = float(clean_loss.detach().cpu()) / global_batch_size + reduced_metrics["train/trajectory_noisy_loss"] = noisy_loss / max(num_branches - 1, 1) / global_batch_size + reduced_metrics["train/trajectory_sigma"] = _trajectory_noise_target(config, train_state.step) + reduced_metrics["train/trajectory_parallel"] = 1.0 + reduced_metrics["train/trajectory_directional"] = 1.0 if config.trajectory_directional else 0.0 + reduced_metrics["train/trajectory_directional_active"] = float(directional_active) + if directional_growth_count > 0: + reduced_metrics["train/trajectory_directional_growth"] = directional_growth_sum / directional_growth_count + return reduced_metrics + + +def train_batch_trajectory_aug(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int, world_size: int): + if config.trajectory_parallel: + return train_batch_trajectory_aug_parallel(config, train_state, batch, global_batch_size, rank, world_size) + if world_size != 1: + raise NotImplementedError("trajectory_augment currently expects single-GPU baseline configs") + train_state.step += 1 + if train_state.step > train_state.total_steps: + return + + batch = {k: v.cuda() for k, v in batch.items()} + head = _unwrap_loss_head(train_state.model) + base = head.model # type: ignore[attr-defined] + + if train_state.carry is None or not isinstance(train_state.carry, list): + with torch.device("cuda"): + train_state.carry = [train_state.model.initial_carry(batch) for _ in range(config.trajectory_n)] # type: ignore + + for optim in train_state.optimizers: + optim.zero_grad() + + clean_metrics = None + clean_loss_value = 0.0 + noisy_loss_value = 0.0 + directional_active = 0 + directional_growth_sum = 0.0 + directional_growth_count = 0 + for branch_idx in range(config.trajectory_n): + carry = train_state.carry[branch_idx] + if branch_idx > 0: + carry, stats = _prepare_noisy_stream_carry(config, base, carry, batch, train_state.step, branch_idx) + directional_active += int(stats["active"]) + directional_growth_sum += float(stats["growth_sum"]) + directional_growth_count += int(stats["growth_count"]) + carry, loss, metrics, _, _ = train_state.model(carry=carry, batch=batch, return_keys=[]) + train_state.carry[branch_idx] = carry + (loss / (global_batch_size * max(config.trajectory_n, 1))).backward() + if branch_idx == 0: + clean_metrics = metrics + clean_loss_value = float(loss.detach().cpu()) + else: + noisy_loss_value += float(loss.detach().cpu()) + + lr_this_step = None + for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs): + lr_this_step = compute_lr(base_lr, config, train_state) + for param_group in optim.param_groups: + param_group["lr"] = lr_this_step + optim.step() + optim.zero_grad() + + if rank == 0 and clean_metrics is not None: + reduced_metrics = {f"train/{k}": float(v.detach().cpu()) for k, v in clean_metrics.items()} + count = max(reduced_metrics.get("train/count", 1.0), 1.0) + reduced_metrics = { + k: v / (global_batch_size if k.endswith("loss") else count) + for k, v in reduced_metrics.items() + } + reduced_metrics["train/lr"] = lr_this_step + reduced_metrics["train/trajectory_clean_loss"] = clean_loss_value / global_batch_size + reduced_metrics["train/trajectory_noisy_loss"] = noisy_loss_value / max(config.trajectory_n - 1, 1) / global_batch_size + reduced_metrics["train/trajectory_sigma"] = _trajectory_noise_target(config, train_state.step) + reduced_metrics["train/trajectory_directional"] = 1.0 if config.trajectory_directional else 0.0 + reduced_metrics["train/trajectory_directional_active"] = float(directional_active) + if directional_growth_count > 0: + reduced_metrics["train/trajectory_directional_growth"] = directional_growth_sum / directional_growth_count + return reduced_metrics + + +def create_evaluators(config: PretrainConfig, eval_metadata: PuzzleDatasetMetadata) -> List[Any]: + data_paths =config.data_paths_test if len(config.data_paths_test)>0 else config.data_paths + # Initialize evaluators + evaluators = [] + for cfg in config.evaluators: + for data_path in data_paths: + cls = load_model_class(cfg.name, "evaluators.")( + data_path=data_path, eval_metadata=eval_metadata, **cfg.__pydantic_extra__ + ) # type: ignore + evaluators.append(cls) + + return evaluators + +def train_batch(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int, world_size: int): + if config.trajectory_augment: + return train_batch_trajectory_aug(config, train_state, batch, global_batch_size, rank, world_size) + + train_state.step += 1 + if train_state.step > train_state.total_steps: # At most train_total_steps + return + + # To device + batch = {k: v.cuda() for k, v in batch.items()} + + # Init carry if it is None + if train_state.carry is None: + with torch.device("cuda"): + train_state.carry = train_state.model.initial_carry(batch) # type: ignore + + # Forward + train_state.carry, loss, metrics, _, _ = train_state.model(carry=train_state.carry, batch=batch, return_keys=[]) + + ((1 / global_batch_size) * loss).backward() + + # Allreduce + if world_size > 1: + for param in train_state.model.parameters(): + if param.grad is not None: + dist.all_reduce(param.grad) + + # Apply optimizer + lr_this_step = None + for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs): + lr_this_step = compute_lr(base_lr, config, train_state) + + for param_group in optim.param_groups: + param_group['lr'] = lr_this_step + + optim.step() + optim.zero_grad() + + # Reduce metrics + if len(metrics): + assert not any(v.requires_grad for v in metrics.values()) + + metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order. + # Reduce and reconstruct + metric_values = torch.stack([metrics[k] for k in metric_keys]) + if world_size > 1: + dist.reduce(metric_values, dst=0) + + if rank == 0: + metric_values = metric_values.cpu().numpy() + reduced_metrics = {k: metric_values[i] for i, k in enumerate(metric_keys)} + + # Postprocess + count = max(reduced_metrics["count"], 1) # Avoid NaNs + reduced_metrics = {f"train/{k}": v / (global_batch_size if k.endswith("loss") else count) for k, v in reduced_metrics.items()} + + reduced_metrics["train/lr"] = lr_this_step + return reduced_metrics + +def evaluate( + config: PretrainConfig, + train_state: TrainState, + eval_loader: torch.utils.data.DataLoader, + eval_metadata: PuzzleDatasetMetadata, + evaluators: List[Any], + rank: int, + world_size: int, + cpu_group: Optional[dist.ProcessGroup], +): + reduced_metrics = None + + with torch.inference_mode(): + return_keys = set(config.eval_save_outputs) + for evaluator in evaluators: + evaluator.begin_eval() + return_keys.update(evaluator.required_outputs) + + # Run evaluation + set_ids = {k: idx for idx, k in enumerate(eval_metadata.sets)} + + save_preds = {} + + metric_keys = [] + metric_values = None + + carry = None + processed_batches = 0 + + for set_name, batch, global_batch_size in eval_loader: + processed_batches += 1 + if rank == 0: + print(f"Processing batch {processed_batches}: {set_name}") + + # To device + batch = {k: v.cuda() for k, v in batch.items()} + with torch.device("cuda"): + carry = train_state.model.initial_carry(batch) # type: ignore + + # Forward + inference_steps = 0 + while True: + carry, loss, metrics, preds, all_finish = train_state.model( + carry=carry, batch=batch, return_keys=return_keys + ) + inference_steps += 1 + + if all_finish: + break + + if rank == 0: + print(f" Completed inference in {inference_steps} steps") + + for collection in (batch, preds): + for k, v in collection.items(): + if k in config.eval_save_outputs: + save_preds.setdefault(k, []) + save_preds[k].append(v.cpu()) # Move to CPU for saving GPU memory + + for evaluator in evaluators: + evaluator.update_batch(batch, preds) + + del carry, loss, preds, batch, all_finish + + # Aggregate metrics + set_id = set_ids[set_name] + + if metric_values is None: + metric_keys = list( + sorted(metrics.keys()) + ) # Sort keys to guarantee all processes use the same order. + metric_values = torch.zeros( + (len(set_ids), len(metrics.values())), dtype=torch.float32, device="cuda" + ) + + metric_values[set_id] += torch.stack([metrics[k] for k in metric_keys]) + + del metrics + + # concatenate save preds + save_preds = {k: torch.cat(v, dim=0) for k, v in save_preds.items()} + + # Save preds + if config.checkpoint_path is not None and len(save_preds): + # Each rank save predictions independently + os.makedirs(os.path.dirname(config.checkpoint_path), exist_ok=True) + torch.save( + save_preds, os.path.join(config.checkpoint_path, f"step_{train_state.step}_all_preds.{rank}") + ) + + del save_preds + + # Reduce to rank 0 + if metric_values is not None: + if world_size > 1: + dist.reduce(metric_values, dst=0) + + if rank == 0: + reduced_metrics = metric_values.cpu().numpy() + reduced_metrics = { + set_name: { + metric_name: reduced_metrics[set_id, metric_id] + for metric_id, metric_name in enumerate(metric_keys) + } + for set_id, set_name in enumerate(set_ids) + } + + # Postprocess + for set_name, m in reduced_metrics.items(): + count = m.pop("count") + reduced_metrics[set_name] = {k: v / count for k, v in m.items()} + + # Run evaluators + if rank == 0: + print(f"\nRunning {len(evaluators)} evaluator(s)...") + + for i, evaluator in enumerate(evaluators): + if rank == 0: + print(f"Running evaluator {i+1}/{len(evaluators)}: {evaluator.__class__.__name__}") + + # Path for saving + evaluator_save_path = None + if config.checkpoint_path is not None: + evaluator_save_path = os.path.join( + config.checkpoint_path, + f"evaluator_{evaluator.__class__.__name__}_step_{train_state.step}", + ) + os.makedirs(evaluator_save_path, exist_ok=True) + + # Run and log + metrics = evaluator.result(evaluator_save_path, rank=rank, world_size=world_size, group=cpu_group) + if rank == 0 and metrics is not None: + if reduced_metrics is None: + reduced_metrics = {} + + reduced_metrics.update(metrics) + print(f" Completed {evaluator.__class__.__name__}") + + if rank == 0: + print("All evaluators completed!") + + return reduced_metrics + +def save_code_and_config(config: PretrainConfig): + if config.checkpoint_path is None or wandb.run is None: + return + + os.makedirs(config.checkpoint_path, exist_ok=True) + + # Copy code + code_list = [ + get_model_source_path(config.arch.name), + get_model_source_path(config.arch.loss.name) + ] + for code_file in code_list: + if code_file is not None: + code_name = os.path.basename(code_file) + + shutil.copy(code_file, os.path.join(config.checkpoint_path, code_name)) + + # Dump config as yaml + config_file = os.path.join(config.checkpoint_path, "all_config.yaml") + with open(config_file, "wt") as f: + yaml.dump(config.model_dump(), f) + + # Log code + wandb.run.log_code(config.checkpoint_path) + + +def load_synced_config(hydra_config: DictConfig, rank: int, world_size: int) -> PretrainConfig: + objects = [None] + if rank == 0: + config = PretrainConfig(**hydra_config) # type: ignore + + # Naming + if config.project_name is None: + config.project_name = f"{os.path.basename(config.data_paths[0]).capitalize()}-ACT-torch" + if config.run_name is None: + config.run_name = f"{config.arch.name.split('@')[-1]} {coolname.generate_slug(2)}" + if config.checkpoint_path is None: + config.checkpoint_path = os.path.join("checkpoints", config.project_name, config.run_name) + + objects = [config] + + if world_size > 1: + dist.broadcast_object_list(objects, src=0) + + return objects[0] # type: ignore + + +@hydra.main(config_path="config", config_name="cfg_pretrain", version_base=None) +def launch(hydra_config: DictConfig): + RANK = 0 + WORLD_SIZE = 1 + CPU_PROCESS_GROUP = None + + # Initialize distributed training if in distributed environment (e.g. torchrun) + if "LOCAL_RANK" in os.environ: + # Initialize distributed, default device and dtype + dist.init_process_group(backend="nccl") + + RANK = dist.get_rank() + WORLD_SIZE = dist.get_world_size() + + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + + # CPU GLOO process group + CPU_PROCESS_GROUP = dist.new_group(backend="gloo") + assert ( + dist.get_rank(CPU_PROCESS_GROUP) == RANK and dist.get_world_size(CPU_PROCESS_GROUP) == WORLD_SIZE + ) + + # Load sync'ed config + config = load_synced_config(hydra_config, rank=RANK, world_size=WORLD_SIZE) + + # Seed RNGs to ensure consistency + torch.random.manual_seed(config.seed + RANK) + + # Dataset + train_epochs_per_iter = config.eval_interval if config.eval_interval is not None else config.epochs + total_iters = config.epochs // train_epochs_per_iter + + assert config.epochs % train_epochs_per_iter == 0, "Eval interval must be a divisor of total epochs." + + train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=train_epochs_per_iter, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE) + try: + eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE) + except: + print("NO EVAL DATA FOUND") + eval_loader = eval_metadata = None + + try: + evaluators = create_evaluators(config, eval_metadata) + except: + print("No evaluator found") + evaluators = [] + + # Train state + train_state = init_train_state(config, train_metadata, rank=RANK, world_size=WORLD_SIZE) + + # Progress bar and logger + progress_bar = None + ema_helper = None + if RANK == 0: + progress_bar = tqdm.tqdm(total=train_state.total_steps) + wandb.init(project=config.project_name, name=config.run_name, config=config.model_dump(), settings=wandb.Settings(_disable_stats=True)) # type: ignore + wandb.log({"num_params": sum(x.numel() for x in train_state.model.parameters())}, step=0) + save_code_and_config(config) + if config.ema: + print('Setup EMA') + ema_helper = EMAHelper(mu=config.ema_rate) + ema_helper.register(train_state.model) + + # Training Loop + for _iter_id in range(total_iters): + print (f"[Rank {RANK}, World Size {WORLD_SIZE}]: Epoch {_iter_id * train_epochs_per_iter}") + + ############ Train Iter + if RANK == 0: + print("TRAIN") + train_state.model.train() + for set_name, batch, global_batch_size in train_loader: + metrics = train_batch(config, train_state, batch, global_batch_size, rank=RANK, world_size=WORLD_SIZE) + + if RANK == 0 and metrics is not None: + wandb.log(metrics, step=train_state.step) + progress_bar.update(train_state.step - progress_bar.n) # type: ignore + if config.ema: + ema_helper.update(train_state.model) + + if _iter_id >= config.min_eval_interval: + ############ Evaluation + if RANK == 0: + print("EVALUATE") + if config.ema: + print("SWITCH TO EMA") + train_state_eval = copy.deepcopy(train_state) + train_state_eval.model = ema_helper.ema_copy(train_state_eval.model) + else: + train_state_eval = train_state + train_state_eval.model.eval() + metrics = evaluate(config, + train_state_eval, + eval_loader, + eval_metadata, + evaluators, + rank=RANK, + world_size=WORLD_SIZE, + cpu_group=CPU_PROCESS_GROUP) + + if RANK == 0 and metrics is not None: + wandb.log(metrics, step=train_state.step) + + ############ Checkpointing + if RANK == 0: + print("SAVE CHECKPOINT") + if RANK == 0 and (config.checkpoint_every_eval or (_iter_id == total_iters - 1)): + save_train_state(config, train_state_eval) + + if config.ema: + del train_state_eval + + # finalize + if dist.is_initialized(): + dist.destroy_process_group() + wandb.finish() + + +if __name__ == "__main__": + launch() |
