from typing import Optional, Any, Sequence, List from dataclasses import dataclass, replace import os import math import yaml import shutil import copy import torch import torch.distributed as dist from torch import nn from torch.utils.data import DataLoader import tqdm import wandb import coolname import hydra import pydantic from omegaconf import DictConfig from adam_atan2 import AdamATan2 from puzzle_dataset import PuzzleDataset, PuzzleDatasetConfig, PuzzleDatasetMetadata from utils.functions import load_model_class, get_model_source_path from models.sparse_embedding import CastedSparseEmbeddingSignSGD_Distributed from models.ema import EMAHelper class LossConfig(pydantic.BaseModel): model_config = pydantic.ConfigDict(extra='allow') name: str class ArchConfig(pydantic.BaseModel): model_config = pydantic.ConfigDict(extra='allow') name: str loss: LossConfig class EvaluatorConfig(pydantic.BaseModel): model_config = pydantic.ConfigDict(extra="allow") name: str class PretrainConfig(pydantic.BaseModel): # Config arch: ArchConfig # Data data_paths: List[str] data_paths_test: List[str] = [] # Evaluators evaluators: List[EvaluatorConfig] = [] # Hyperparams global_batch_size: int epochs: int lr: float lr_min_ratio: float lr_warmup_steps: int weight_decay: float beta1: float beta2: float # Puzzle embedding puzzle_emb_lr: float puzzle_emb_weight_decay: float # Names project_name: Optional[str] = None run_name: Optional[str] = None load_checkpoint: Optional[str] = None checkpoint_path: Optional[str] = None # Extras seed: int = 0 checkpoint_every_eval: bool = False eval_interval: Optional[int] = None min_eval_interval: Optional[int] = 0 # when to start eval eval_save_outputs: List[str] = [] ema: bool = False # use Exponential-Moving-Average ema_rate: float = 0.999 # EMA-rate freeze_weights: bool = False # If True, freeze weights and only learn the embeddings trajectory_augment: bool = False trajectory_n: int = 4 trajectory_noise_std: float = 1e-3 trajectory_noise_min: Optional[float] = None trajectory_noise_max: Optional[float] = None trajectory_noise_sampling: str = "loguniform" trajectory_sigma_start: Optional[float] = 0.0 trajectory_sigma_ramp_steps: int = 5000 trajectory_perturb: str = "both" trajectory_micro_batch: int = 0 trajectory_parallel: bool = False trajectory_directional: bool = False trajectory_directional_candidates: int = 4 trajectory_directional_fd_eps: float = 3e-2 trajectory_directional_horizon: int = 0 trajectory_directional_sign: str = "alternate" @dataclass class TrainState: model: nn.Module optimizers: Sequence[torch.optim.Optimizer] optimizer_lrs: Sequence[float] carry: Any step: int total_steps: int def create_dataloader(config: PretrainConfig, split: str, rank: int, world_size: int, **kwargs): dataset = PuzzleDataset(PuzzleDatasetConfig( seed=config.seed, dataset_paths=config.data_paths_test if len(config.data_paths_test)>0 and split=="test" else config.data_paths, rank=rank, num_replicas=world_size, **kwargs ), split=split) dataloader = DataLoader( dataset, batch_size=None, num_workers=1, prefetch_factor=8, pin_memory=True, persistent_workers=True ) return dataloader, dataset.metadata def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, rank: int, world_size: int): model_batch_size = config.global_batch_size // world_size if config.trajectory_augment and config.trajectory_parallel: model_batch_size *= config.trajectory_n model_cfg = dict( **config.arch.__pydantic_extra__, # type: ignore batch_size=model_batch_size, vocab_size=train_metadata.vocab_size, seq_len=train_metadata.seq_len, num_puzzle_identifiers=train_metadata.num_puzzle_identifiers, causal=False # Non-autoregressive ) # Instantiate model with loss head model_cls = load_model_class(config.arch.name) loss_head_cls = load_model_class(config.arch.loss.name) with torch.device("cuda"): model: nn.Module = model_cls(model_cfg) print(model) model = loss_head_cls(model, **config.arch.loss.__pydantic_extra__) # type: ignore if "DISABLE_COMPILE" not in os.environ: model = torch.compile(model) # type: ignore # Load checkpoint if rank == 0: load_checkpoint(model, config) # Broadcast parameters from rank 0 if world_size > 1: with torch.no_grad(): for param in list(model.parameters()) + list(model.buffers()): dist.broadcast(param, src=0) # Optimizers and lr if config.arch.puzzle_emb_ndim == 0: optimizers = [ AdamATan2( model.parameters(), lr=0, # Needs to be set by scheduler weight_decay=config.weight_decay, betas=(config.beta1, config.beta2) ) ] optimizer_lrs = [ config.lr ] elif config.freeze_weights: optimizers = [ CastedSparseEmbeddingSignSGD_Distributed( model.model.puzzle_emb.buffers(), # type: ignore lr=0, # Needs to be set by scheduler weight_decay=config.puzzle_emb_weight_decay, world_size=world_size ) ] optimizer_lrs = [ config.puzzle_emb_lr ] else: optimizers = [ CastedSparseEmbeddingSignSGD_Distributed( model.model.puzzle_emb.buffers(), # type: ignore lr=0, # Needs to be set by scheduler weight_decay=config.puzzle_emb_weight_decay, world_size=world_size ), AdamATan2( model.parameters(), lr=0, # Needs to be set by scheduler weight_decay=config.weight_decay, betas=(config.beta1, config.beta2) ) ] optimizer_lrs = [ config.puzzle_emb_lr, config.lr ] return model, optimizers, optimizer_lrs def mix_weights_direct(device, alpha, net, nets): sd = [] for i in range(len(nets)): sd += [nets[i].state_dict()] sd_alpha = {} for k in sd[0].keys(): comb_net = alpha[0]*sd[0][k].to(device) for i in range(1,len(nets)): comb_net += alpha[i]*sd[i][k].to(device) sd_alpha[k] = comb_net net.load_state_dict(sd_alpha) return net def cosine_schedule_with_warmup_lr_lambda( current_step: int, *, base_lr: float, num_warmup_steps: int, num_training_steps: int, min_ratio: float = 0.0, num_cycles: float = 0.5 ): if current_step < num_warmup_steps: return base_lr * float(current_step) / float(max(1, num_warmup_steps)) progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) return base_lr * (min_ratio + max(0.0, (1 - min_ratio) * 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))) def init_train_state(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, rank: int, world_size: int): # Estimated total training steps total_steps = int(config.epochs * train_metadata.total_groups * train_metadata.mean_puzzle_examples / config.global_batch_size) # Model model, optimizers, optimizer_lrs = create_model(config, train_metadata, rank=rank, world_size=world_size) return TrainState( step=0, total_steps=total_steps, model=model, optimizers=optimizers, optimizer_lrs=optimizer_lrs, carry=None ) def save_train_state(config: PretrainConfig, train_state: TrainState): # FIXME: Only saved model. if config.checkpoint_path is None: return os.makedirs(config.checkpoint_path, exist_ok=True) torch.save(train_state.model.state_dict(), os.path.join(config.checkpoint_path, f"step_{train_state.step}")) def load_checkpoint(model: nn.Module, config: PretrainConfig): if config.load_checkpoint is not None: print(f"Loading checkpoint {config.load_checkpoint}") # Load state dict state_dict = torch.load(config.load_checkpoint, map_location="cuda") # Resize and reset puzzle emb if needed puzzle_emb_name = "_orig_mod.model.inner.puzzle_emb.weights" expected_shape: torch.Size = model.model.puzzle_emb.weights.shape # type: ignore if puzzle_emb_name in state_dict: puzzle_emb = state_dict[puzzle_emb_name] if puzzle_emb.shape != expected_shape: print(f"Resetting puzzle embedding as shape is different. Found {puzzle_emb.shape}, Expected {expected_shape}") # Re-initialize using mean state_dict[puzzle_emb_name] = ( torch.mean(puzzle_emb, dim=0, keepdim=True).expand(expected_shape).contiguous() ) model.load_state_dict(state_dict, assign=True) def compute_lr(base_lr: float, config: PretrainConfig, train_state: TrainState): return cosine_schedule_with_warmup_lr_lambda( current_step=train_state.step, base_lr=base_lr, num_warmup_steps=round(config.lr_warmup_steps), num_training_steps=train_state.total_steps, min_ratio=config.lr_min_ratio ) def _unwrap_loss_head(model: nn.Module): return getattr(model, "_orig_mod", model) def _unit_noise_like(tensor: torch.Tensor, sampling: str): if sampling == "uniform": return ((2.0 * torch.rand(tensor.shape, device=tensor.device, dtype=torch.float32) - 1.0) * math.sqrt(3.0)).to(tensor.dtype) return torch.randn(tensor.shape, device=tensor.device, dtype=torch.float32).to(tensor.dtype) def _trajectory_noise_target(config: PretrainConfig, step: int) -> float: if config.trajectory_sigma_ramp_steps <= 0: return config.trajectory_noise_std start = config.trajectory_sigma_start if config.trajectory_sigma_start is not None else config.trajectory_noise_std frac = min(max(step / config.trajectory_sigma_ramp_steps, 0.0), 1.0) return float(start + frac * (config.trajectory_noise_std - start)) def _sample_noise_stds(config: PretrainConfig, batch_size: int, device: torch.device, step: int): target = _trajectory_noise_target(config, step) if target <= 0: return torch.zeros(batch_size, device=device, dtype=torch.float32) if config.trajectory_noise_sampling == "loguniform": scale = 1.0 if config.trajectory_noise_std <= 0 else target / config.trajectory_noise_std hi = float(config.trajectory_noise_max if config.trajectory_noise_max is not None else config.trajectory_noise_std) * scale lo = float(config.trajectory_noise_min if config.trajectory_noise_min is not None else max(hi / 10.0, 1e-8)) * scale hi = max(hi, 1e-12) lo = max(min(lo, hi), 1e-12) u = torch.rand(batch_size, device=device, dtype=torch.float32) return torch.exp(math.log(lo) + u * (math.log(hi) - math.log(lo))) return torch.full((batch_size,), target, device=device, dtype=torch.float32) def _add_initial_noise(config: PretrainConfig, inner_carry: Any, noise_stds: torch.Tensor): if noise_stds.numel() == 0 or float(noise_stds.max().item()) <= 0: return inner_carry view_shape = (noise_stds.shape[0],) + (1,) * (inner_carry.z_H.ndim - 1) scale = noise_stds.view(view_shape) z_h = inner_carry.z_H z_l = inner_carry.z_L if config.trajectory_perturb in ("h", "both"): z_h = z_h + scale.to(z_h.dtype) * _unit_noise_like(z_h, config.trajectory_noise_sampling) if config.trajectory_perturb in ("l", "both"): z_l = z_l + scale.to(z_l.dtype) * _unit_noise_like(z_l, config.trajectory_noise_sampling) return replace(inner_carry, z_H=z_h, z_L=z_l) def _directional_fields(config: PretrainConfig): if config.trajectory_perturb == "h": return ("z_H",) if config.trajectory_perturb == "l": return ("z_L",) return ("z_H", "z_L") def _index_inner_carry(inner_carry: Any, indices: torch.Tensor): inner_type = type(inner_carry) return inner_type(**{ name: getattr(inner_carry, name).index_select(0, indices).contiguous() for name in inner_carry.__dataclass_fields__.keys() }) def _cat_inner_carries(carries: Sequence[Any]): inner_type = type(carries[0]) return inner_type(**{ name: torch.cat([getattr(c, name) for c in carries], dim=0) for name in carries[0].__dataclass_fields__.keys() }) def _split_inner_carry(inner_carry: Any, first_size: int): inner_type = type(inner_carry) first = {} second = {} for name in inner_carry.__dataclass_fields__.keys(): value = getattr(inner_carry, name) first[name] = value[:first_size].contiguous() second[name] = value[first_size:].contiguous() return inner_type(**first), inner_type(**second) def _repeat_batch_interleave(batch: Any, repeats: int): return {k: v.repeat_interleave(repeats, dim=0) for k, v in batch.items()} def _cat_batches(a: Any, b: Any): return {k: torch.cat([a[k], b[k]], dim=0) for k in a} def _rand_unit_direction_candidates(config: PretrainConfig, inner_carry: Any, candidates: int): fields = _directional_fields(config) dirs = {} norm_sq = None for name in inner_carry.__dataclass_fields__.keys(): value = getattr(inner_carry, name) if name in fields: direction = torch.randn( (value.shape[0], candidates) + tuple(value.shape[1:]), device=value.device, dtype=torch.float32, ) term = direction.flatten(2).square().sum(-1) norm_sq = term if norm_sq is None else norm_sq + term dirs[name] = direction else: dirs[name] = torch.zeros( (value.shape[0], candidates) + tuple(value.shape[1:]), device=value.device, dtype=torch.float32, ) if norm_sq is None: raise ValueError(f"unknown trajectory_perturb={config.trajectory_perturb!r}") norm = torch.sqrt(norm_sq).clamp_min(1e-30) out = {} for name, direction in dirs.items(): view_shape = direction.shape[:2] + (1,) * (direction.ndim - 2) out[name] = (direction / norm.view(view_shape)).to(getattr(inner_carry, name).dtype) return out def _make_shadow_inner(inner_carry: Any, dirs: dict[str, torch.Tensor], eps: float): inner_type = type(inner_carry) fields = {} for name in inner_carry.__dataclass_fields__.keys(): value = getattr(inner_carry, name) perturbed = value[:, None] + eps * dirs[name] fields[name] = perturbed.reshape((value.shape[0] * dirs[name].shape[1],) + tuple(value.shape[1:])).detach() return inner_type(**fields) def _separation(main: Any, shadow: Any, candidates: int): sep_sq = None batch_size = getattr(main, next(iter(main.__dataclass_fields__.keys()))).shape[0] for name in main.__dataclass_fields__.keys(): main_value = getattr(main, name) shadow_value = getattr(shadow, name).reshape((batch_size, candidates) + tuple(main_value.shape[1:])) diff_sq = (shadow_value.float() - main_value[:, None].float()).flatten(2).square().sum(-1) sep_sq = diff_sq if sep_sq is None else sep_sq + diff_sq return torch.sqrt(sep_sq).clamp_min(1e-30) def _gather_direction(dirs: dict[str, torch.Tensor], best_idx: torch.Tensor): batch_size = best_idx.shape[0] row_idx = torch.arange(batch_size, device=best_idx.device) return {name: value[row_idx, best_idx].contiguous() for name, value in dirs.items()} def _directional_signs(config: PretrainConfig, branch_idx: int, count: int, device: torch.device): mode = config.trajectory_directional_sign if mode == "random": return torch.where( torch.rand(count, device=device, dtype=torch.float32) < 0.5, torch.full((count,), -1.0, device=device, dtype=torch.float32), torch.full((count,), 1.0, device=device, dtype=torch.float32), ) if mode == "negative": return torch.full((count,), -1.0, device=device, dtype=torch.float32) if mode == "positive": return torch.full((count,), 1.0, device=device, dtype=torch.float32) if mode != "alternate": raise ValueError(f"unknown trajectory_directional_sign={mode!r}") sign = 1.0 if branch_idx % 2 == 1 else -1.0 return torch.full((count,), sign, device=device, dtype=torch.float32) @torch.no_grad() def _select_directional_perturbation( config: PretrainConfig, base: nn.Module, inner_carry: Any, batch: Any, active_mask: torch.Tensor, ): active_indices = torch.nonzero(active_mask, as_tuple=False).flatten() active_count = int(active_indices.numel()) if active_count == 0: return None, None, { "active": 0, "growth_sum": 0.0, "growth_count": 0, } candidates = max(int(config.trajectory_directional_candidates), 1) fd_eps = float(config.trajectory_directional_fd_eps) horizon = int(config.trajectory_directional_horizon) if horizon <= 0: horizon = int(base.config.halt_max_steps) # type: ignore[attr-defined] horizon = max(horizon, 1) active_inner = _index_inner_carry(inner_carry, active_indices) active_batch = {k: v.index_select(0, active_indices).contiguous() for k, v in batch.items()} dirs = _rand_unit_direction_candidates(config, active_inner, candidates) main = type(active_inner)(**{ name: getattr(active_inner, name).detach() for name in active_inner.__dataclass_fields__.keys() }) shadow = _make_shadow_inner(main, dirs, fd_eps) combined = _cat_inner_carries([main, shadow]) combined_batch = _cat_batches(active_batch, _repeat_batch_interleave(active_batch, candidates)) was_training = base.inner.training # type: ignore[attr-defined] base.inner.eval() # type: ignore[attr-defined] try: for _ in range(horizon): combined, _logits, _q = base.inner(combined, combined_batch) # type: ignore[attr-defined] finally: if was_training: base.inner.train() # type: ignore[attr-defined] main_final, shadow_final = _split_inner_carry(combined, active_count) sep = _separation(main_final, shadow_final, candidates) best_idx = sep.argmax(dim=1) best_sep = sep.gather(1, best_idx[:, None]).squeeze(1) best_dirs = _gather_direction(dirs, best_idx) growth = torch.log(best_sep / max(fd_eps, 1e-30)).float() / horizon return active_indices, best_dirs, { "active": active_count, "growth_sum": float(growth.sum().item()), "growth_count": active_count, } def _add_directional_initial_noise( config: PretrainConfig, base: nn.Module, inner_carry: Any, batch: Any, noise_stds: torch.Tensor, active_mask: torch.Tensor, branch_idx: int, ): selected_indices, selected_dirs, stats = _select_directional_perturbation( config, base, inner_carry, batch, active_mask, ) if selected_indices is None or selected_dirs is None: return inner_carry, stats signs = _directional_signs(config, branch_idx, selected_indices.shape[0], noise_stds.device) scales = noise_stds.index_select(0, selected_indices) * signs fields = {} for name in inner_carry.__dataclass_fields__.keys(): value = getattr(inner_carry, name) updated = value.clone() direction = selected_dirs[name].to(value.dtype) view_shape = (scales.shape[0],) + (1,) * (direction.ndim - 1) updated.index_copy_(0, selected_indices, value.index_select(0, selected_indices) + scales.view(view_shape).to(value.dtype) * direction) fields[name] = updated return type(inner_carry)(**fields), stats def _token_loss(loss_fn, logits, labels, mask): kwargs = {"ignore_index": -100} code = getattr(loss_fn, "__code__", None) arg_names = code.co_varnames[: code.co_argcount + code.co_kwonlyargcount] if code is not None else () if "valid_mask" in arg_names: kwargs["valid_mask"] = mask return loss_fn(logits, labels, **kwargs) def _fixed_unroll_branch_loss(config: PretrainConfig, head: nn.Module, batch: Any, noise_stds: Optional[torch.Tensor]): base = head.model # type: ignore[attr-defined] with torch.device("cuda"): carry = base.initial_carry(batch) reset_flag = torch.ones_like(carry.halted) inner_carry = base.inner.reset_carry(reset_flag, carry.inner_carry) if noise_stds is not None: inner_carry = _add_initial_noise(config, inner_carry, noise_stds) batch_size = batch["inputs"].shape[0] loss_sum = torch.zeros((), device=batch["inputs"].device, dtype=torch.float32) last_exact = torch.zeros((), device=batch["inputs"].device, dtype=torch.float32) for act_step in range(base.config.halt_max_steps): inner_carry, logits, (q_halt_logits, q_continue_logits) = base.inner(inner_carry, batch) labels = batch["labels"] with torch.no_grad(): mask = labels != -100 loss_counts = mask.sum(-1) loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) is_correct = mask & (torch.argmax(logits, dim=-1) == labels) seq_is_correct = is_correct.sum(-1) == loss_counts last_exact = seq_is_correct.to(torch.float32).sum() lm_loss = (_token_loss(head.loss_fn, logits, labels, mask) / loss_divisor).sum() q_halt_loss = torch.nn.functional.binary_cross_entropy_with_logits( q_halt_logits, seq_is_correct.to(q_halt_logits.dtype), reduction="sum", ) q_continue_loss = torch.zeros_like(q_halt_loss) if not getattr(base.config, "no_ACT_continue", False): with torch.no_grad(): next_q_halt_logits, next_q_continue_logits = base.inner(inner_carry, batch)[-1] is_last = act_step + 1 >= base.config.halt_max_steps target_q_continue = torch.sigmoid( torch.where( torch.full_like(next_q_halt_logits, is_last, dtype=torch.bool), next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits), ) ) q_continue_loss = torch.nn.functional.binary_cross_entropy_with_logits(q_continue_logits, target_q_continue, reduction="sum") loss_sum = loss_sum + lm_loss + 0.5 * (q_halt_loss + q_continue_loss) return loss_sum / max(base.config.halt_max_steps, 1), last_exact, batch_size def _prepare_noisy_stream_carry(config: PretrainConfig, base: nn.Module, carry: Any, batch: Any, step: int, branch_idx: int): reset_mask = carry.halted new_inner = base.inner.reset_carry(reset_mask, carry.inner_carry) noise_stds = _sample_noise_stds(config, batch["inputs"].shape[0], batch["inputs"].device, step) noise_stds = torch.where(reset_mask, noise_stds, torch.zeros_like(noise_stds)) directional_stats = { "active": int((noise_stds > 0).sum().detach().cpu().item()), "growth_sum": 0.0, "growth_count": 0, } if config.trajectory_directional and float(noise_stds.max().item()) > 0: new_inner, directional_stats = _add_directional_initial_noise( config, base, new_inner, batch, noise_stds, noise_stds > 0, branch_idx, ) else: new_inner = _add_initial_noise(config, new_inner, noise_stds) new_steps = torch.where(reset_mask, 0, carry.steps) new_current_data = { k: torch.where(reset_mask.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items() } return ( replace( carry, inner_carry=new_inner, steps=new_steps, halted=torch.zeros_like(reset_mask), current_data=new_current_data, ), directional_stats, ) def _merge_trajectory_carries(carries: Sequence[Any]): inner_type = type(carries[0].inner_carry) inner_fields = carries[0].inner_carry.__dataclass_fields__.keys() inner_carry = inner_type(**{ name: torch.cat([getattr(c.inner_carry, name) for c in carries], dim=0) for name in inner_fields }) current_data = { k: torch.cat([c.current_data[k] for c in carries], dim=0) for k in carries[0].current_data } return replace( carries[0], inner_carry=inner_carry, steps=torch.cat([c.steps for c in carries], dim=0), halted=torch.cat([c.halted for c in carries], dim=0), current_data=current_data, ) def _split_trajectory_carry(carry: Any, branch_size: int, num_branches: int): inner_type = type(carry.inner_carry) inner_fields = carry.inner_carry.__dataclass_fields__.keys() inner_chunks = { name: [chunk.contiguous() for chunk in getattr(carry.inner_carry, name).split(branch_size, dim=0)] for name in inner_fields } steps_chunks = [chunk.contiguous() for chunk in carry.steps.split(branch_size, dim=0)] halted_chunks = [chunk.contiguous() for chunk in carry.halted.split(branch_size, dim=0)] data_chunks = { k: [chunk.contiguous() for chunk in v.split(branch_size, dim=0)] for k, v in carry.current_data.items() } return [ replace( carry, inner_carry=inner_type(**{name: inner_chunks[name][idx] for name in inner_fields}), steps=steps_chunks[idx], halted=halted_chunks[idx], current_data={k: chunks[idx] for k, chunks in data_chunks.items()}, ) for idx in range(num_branches) ] def _repeat_batch_for_trajectories(batch: Any, num_branches: int): return { k: torch.cat([v for _ in range(num_branches)], dim=0) for k, v in batch.items() } def _split_outputs(outputs: Any, branch_size: int, num_branches: int): return [ { k: v.narrow(0, idx * branch_size, branch_size) for k, v in outputs.items() } for idx in range(num_branches) ] def _branch_metrics_and_loss(head: nn.Module, carry: Any, outputs: Any): labels = carry.current_data["labels"] with torch.no_grad(): mask = labels != -100 loss_counts = mask.sum(-1) loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels) seq_is_correct = is_correct.sum(-1) == loss_counts valid_metrics = carry.halted & (loss_counts > 0) metrics = { "count": valid_metrics.sum(), "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(), "exact_accuracy": (valid_metrics & seq_is_correct).sum(), "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(), "steps": torch.where(valid_metrics, carry.steps, 0).sum(), } lm_loss = (head.loss_fn(outputs["logits"], labels, ignore_index=-100) / loss_divisor).sum() q_halt_loss = torch.nn.functional.binary_cross_entropy_with_logits( outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum", ) q_continue_loss = torch.zeros_like(q_halt_loss) if "target_q_continue" in outputs: q_continue_loss = torch.nn.functional.binary_cross_entropy_with_logits( outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum", ) metrics.update({ "lm_loss": lm_loss, "q_halt_loss": q_halt_loss, }) if "target_q_continue" in outputs: metrics["q_continue_loss"] = q_continue_loss total_loss = lm_loss + 0.5 * (q_halt_loss + q_continue_loss) return metrics, total_loss def train_batch_trajectory_aug_parallel(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int, world_size: int): if world_size != 1: raise NotImplementedError("trajectory_parallel currently expects single-GPU baseline configs") train_state.step += 1 if train_state.step > train_state.total_steps: return batch = {k: v.cuda() for k, v in batch.items()} head = _unwrap_loss_head(train_state.model) base = head.model # type: ignore[attr-defined] num_branches = max(config.trajectory_n, 1) branch_size = batch["inputs"].shape[0] if train_state.carry is None or not isinstance(train_state.carry, list): with torch.device("cuda"): train_state.carry = [train_state.model.initial_carry(batch) for _ in range(num_branches)] # type: ignore for optim in train_state.optimizers: optim.zero_grad() carries = [] directional_active = 0 directional_growth_sum = 0.0 directional_growth_count = 0 for branch_idx, carry in enumerate(train_state.carry): if branch_idx > 0: carry, stats = _prepare_noisy_stream_carry(config, base, carry, batch, train_state.step, branch_idx) directional_active += int(stats["active"]) directional_growth_sum += float(stats["growth_sum"]) directional_growth_count += int(stats["growth_count"]) carries.append(carry) parallel_carry = _merge_trajectory_carries(carries) parallel_batch = _repeat_batch_for_trajectories(batch, num_branches) return_keys = ["logits", "q_halt_logits", "q_continue_logits", "target_q_continue"] parallel_carry, loss, _metrics, outputs, _ = train_state.model( carry=parallel_carry, batch=parallel_batch, return_keys=return_keys, ) train_state.carry = _split_trajectory_carry(parallel_carry, branch_size, num_branches) (loss / (global_batch_size * num_branches)).backward() lr_this_step = None for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs): lr_this_step = compute_lr(base_lr, config, train_state) for param_group in optim.param_groups: param_group["lr"] = lr_this_step optim.step() optim.zero_grad() if rank == 0 and outputs is not None: split_outputs = _split_outputs(outputs, branch_size, num_branches) clean_metrics, clean_loss = _branch_metrics_and_loss(head, train_state.carry[0], split_outputs[0]) noisy_loss = 0.0 for branch_idx in range(1, num_branches): _, branch_loss = _branch_metrics_and_loss(head, train_state.carry[branch_idx], split_outputs[branch_idx]) noisy_loss += float(branch_loss.detach().cpu()) reduced_metrics = {f"train/{k}": float(v.detach().cpu()) for k, v in clean_metrics.items()} count = max(reduced_metrics.get("train/count", 1.0), 1.0) reduced_metrics = { k: v / (global_batch_size if k.endswith("loss") else count) for k, v in reduced_metrics.items() } reduced_metrics["train/lr"] = lr_this_step reduced_metrics["train/trajectory_clean_loss"] = float(clean_loss.detach().cpu()) / global_batch_size reduced_metrics["train/trajectory_noisy_loss"] = noisy_loss / max(num_branches - 1, 1) / global_batch_size reduced_metrics["train/trajectory_sigma"] = _trajectory_noise_target(config, train_state.step) reduced_metrics["train/trajectory_parallel"] = 1.0 reduced_metrics["train/trajectory_directional"] = 1.0 if config.trajectory_directional else 0.0 reduced_metrics["train/trajectory_directional_active"] = float(directional_active) if directional_growth_count > 0: reduced_metrics["train/trajectory_directional_growth"] = directional_growth_sum / directional_growth_count return reduced_metrics def train_batch_trajectory_aug(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int, world_size: int): if config.trajectory_parallel: return train_batch_trajectory_aug_parallel(config, train_state, batch, global_batch_size, rank, world_size) if world_size != 1: raise NotImplementedError("trajectory_augment currently expects single-GPU baseline configs") train_state.step += 1 if train_state.step > train_state.total_steps: return batch = {k: v.cuda() for k, v in batch.items()} head = _unwrap_loss_head(train_state.model) base = head.model # type: ignore[attr-defined] if train_state.carry is None or not isinstance(train_state.carry, list): with torch.device("cuda"): train_state.carry = [train_state.model.initial_carry(batch) for _ in range(config.trajectory_n)] # type: ignore for optim in train_state.optimizers: optim.zero_grad() clean_metrics = None clean_loss_value = 0.0 noisy_loss_value = 0.0 directional_active = 0 directional_growth_sum = 0.0 directional_growth_count = 0 for branch_idx in range(config.trajectory_n): carry = train_state.carry[branch_idx] if branch_idx > 0: carry, stats = _prepare_noisy_stream_carry(config, base, carry, batch, train_state.step, branch_idx) directional_active += int(stats["active"]) directional_growth_sum += float(stats["growth_sum"]) directional_growth_count += int(stats["growth_count"]) carry, loss, metrics, _, _ = train_state.model(carry=carry, batch=batch, return_keys=[]) train_state.carry[branch_idx] = carry (loss / (global_batch_size * max(config.trajectory_n, 1))).backward() if branch_idx == 0: clean_metrics = metrics clean_loss_value = float(loss.detach().cpu()) else: noisy_loss_value += float(loss.detach().cpu()) lr_this_step = None for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs): lr_this_step = compute_lr(base_lr, config, train_state) for param_group in optim.param_groups: param_group["lr"] = lr_this_step optim.step() optim.zero_grad() if rank == 0 and clean_metrics is not None: reduced_metrics = {f"train/{k}": float(v.detach().cpu()) for k, v in clean_metrics.items()} count = max(reduced_metrics.get("train/count", 1.0), 1.0) reduced_metrics = { k: v / (global_batch_size if k.endswith("loss") else count) for k, v in reduced_metrics.items() } reduced_metrics["train/lr"] = lr_this_step reduced_metrics["train/trajectory_clean_loss"] = clean_loss_value / global_batch_size reduced_metrics["train/trajectory_noisy_loss"] = noisy_loss_value / max(config.trajectory_n - 1, 1) / global_batch_size reduced_metrics["train/trajectory_sigma"] = _trajectory_noise_target(config, train_state.step) reduced_metrics["train/trajectory_directional"] = 1.0 if config.trajectory_directional else 0.0 reduced_metrics["train/trajectory_directional_active"] = float(directional_active) if directional_growth_count > 0: reduced_metrics["train/trajectory_directional_growth"] = directional_growth_sum / directional_growth_count return reduced_metrics def create_evaluators(config: PretrainConfig, eval_metadata: PuzzleDatasetMetadata) -> List[Any]: data_paths =config.data_paths_test if len(config.data_paths_test)>0 else config.data_paths # Initialize evaluators evaluators = [] for cfg in config.evaluators: for data_path in data_paths: cls = load_model_class(cfg.name, "evaluators.")( data_path=data_path, eval_metadata=eval_metadata, **cfg.__pydantic_extra__ ) # type: ignore evaluators.append(cls) return evaluators def train_batch(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int, world_size: int): if config.trajectory_augment: return train_batch_trajectory_aug(config, train_state, batch, global_batch_size, rank, world_size) train_state.step += 1 if train_state.step > train_state.total_steps: # At most train_total_steps return # To device batch = {k: v.cuda() for k, v in batch.items()} # Init carry if it is None if train_state.carry is None: with torch.device("cuda"): train_state.carry = train_state.model.initial_carry(batch) # type: ignore # Forward train_state.carry, loss, metrics, _, _ = train_state.model(carry=train_state.carry, batch=batch, return_keys=[]) ((1 / global_batch_size) * loss).backward() # Allreduce if world_size > 1: for param in train_state.model.parameters(): if param.grad is not None: dist.all_reduce(param.grad) # Apply optimizer lr_this_step = None for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs): lr_this_step = compute_lr(base_lr, config, train_state) for param_group in optim.param_groups: param_group['lr'] = lr_this_step optim.step() optim.zero_grad() # Reduce metrics if len(metrics): assert not any(v.requires_grad for v in metrics.values()) metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order. # Reduce and reconstruct metric_values = torch.stack([metrics[k] for k in metric_keys]) if world_size > 1: dist.reduce(metric_values, dst=0) if rank == 0: metric_values = metric_values.cpu().numpy() reduced_metrics = {k: metric_values[i] for i, k in enumerate(metric_keys)} # Postprocess count = max(reduced_metrics["count"], 1) # Avoid NaNs reduced_metrics = {f"train/{k}": v / (global_batch_size if k.endswith("loss") else count) for k, v in reduced_metrics.items()} reduced_metrics["train/lr"] = lr_this_step return reduced_metrics def evaluate( config: PretrainConfig, train_state: TrainState, eval_loader: torch.utils.data.DataLoader, eval_metadata: PuzzleDatasetMetadata, evaluators: List[Any], rank: int, world_size: int, cpu_group: Optional[dist.ProcessGroup], ): reduced_metrics = None with torch.inference_mode(): return_keys = set(config.eval_save_outputs) for evaluator in evaluators: evaluator.begin_eval() return_keys.update(evaluator.required_outputs) # Run evaluation set_ids = {k: idx for idx, k in enumerate(eval_metadata.sets)} save_preds = {} metric_keys = [] metric_values = None carry = None processed_batches = 0 for set_name, batch, global_batch_size in eval_loader: processed_batches += 1 if rank == 0: print(f"Processing batch {processed_batches}: {set_name}") # To device batch = {k: v.cuda() for k, v in batch.items()} with torch.device("cuda"): carry = train_state.model.initial_carry(batch) # type: ignore # Forward inference_steps = 0 while True: carry, loss, metrics, preds, all_finish = train_state.model( carry=carry, batch=batch, return_keys=return_keys ) inference_steps += 1 if all_finish: break if rank == 0: print(f" Completed inference in {inference_steps} steps") for collection in (batch, preds): for k, v in collection.items(): if k in config.eval_save_outputs: save_preds.setdefault(k, []) save_preds[k].append(v.cpu()) # Move to CPU for saving GPU memory for evaluator in evaluators: evaluator.update_batch(batch, preds) del carry, loss, preds, batch, all_finish # Aggregate metrics set_id = set_ids[set_name] if metric_values is None: metric_keys = list( sorted(metrics.keys()) ) # Sort keys to guarantee all processes use the same order. metric_values = torch.zeros( (len(set_ids), len(metrics.values())), dtype=torch.float32, device="cuda" ) metric_values[set_id] += torch.stack([metrics[k] for k in metric_keys]) del metrics # concatenate save preds save_preds = {k: torch.cat(v, dim=0) for k, v in save_preds.items()} # Save preds if config.checkpoint_path is not None and len(save_preds): # Each rank save predictions independently os.makedirs(os.path.dirname(config.checkpoint_path), exist_ok=True) torch.save( save_preds, os.path.join(config.checkpoint_path, f"step_{train_state.step}_all_preds.{rank}") ) del save_preds # Reduce to rank 0 if metric_values is not None: if world_size > 1: dist.reduce(metric_values, dst=0) if rank == 0: reduced_metrics = metric_values.cpu().numpy() reduced_metrics = { set_name: { metric_name: reduced_metrics[set_id, metric_id] for metric_id, metric_name in enumerate(metric_keys) } for set_id, set_name in enumerate(set_ids) } # Postprocess for set_name, m in reduced_metrics.items(): count = m.pop("count") reduced_metrics[set_name] = {k: v / count for k, v in m.items()} # Run evaluators if rank == 0: print(f"\nRunning {len(evaluators)} evaluator(s)...") for i, evaluator in enumerate(evaluators): if rank == 0: print(f"Running evaluator {i+1}/{len(evaluators)}: {evaluator.__class__.__name__}") # Path for saving evaluator_save_path = None if config.checkpoint_path is not None: evaluator_save_path = os.path.join( config.checkpoint_path, f"evaluator_{evaluator.__class__.__name__}_step_{train_state.step}", ) os.makedirs(evaluator_save_path, exist_ok=True) # Run and log metrics = evaluator.result(evaluator_save_path, rank=rank, world_size=world_size, group=cpu_group) if rank == 0 and metrics is not None: if reduced_metrics is None: reduced_metrics = {} reduced_metrics.update(metrics) print(f" Completed {evaluator.__class__.__name__}") if rank == 0: print("All evaluators completed!") return reduced_metrics def save_code_and_config(config: PretrainConfig): if config.checkpoint_path is None or wandb.run is None: return os.makedirs(config.checkpoint_path, exist_ok=True) # Copy code code_list = [ get_model_source_path(config.arch.name), get_model_source_path(config.arch.loss.name) ] for code_file in code_list: if code_file is not None: code_name = os.path.basename(code_file) shutil.copy(code_file, os.path.join(config.checkpoint_path, code_name)) # Dump config as yaml config_file = os.path.join(config.checkpoint_path, "all_config.yaml") with open(config_file, "wt") as f: yaml.dump(config.model_dump(), f) # Log code wandb.run.log_code(config.checkpoint_path) def load_synced_config(hydra_config: DictConfig, rank: int, world_size: int) -> PretrainConfig: objects = [None] if rank == 0: config = PretrainConfig(**hydra_config) # type: ignore # Naming if config.project_name is None: config.project_name = f"{os.path.basename(config.data_paths[0]).capitalize()}-ACT-torch" if config.run_name is None: config.run_name = f"{config.arch.name.split('@')[-1]} {coolname.generate_slug(2)}" if config.checkpoint_path is None: config.checkpoint_path = os.path.join("checkpoints", config.project_name, config.run_name) objects = [config] if world_size > 1: dist.broadcast_object_list(objects, src=0) return objects[0] # type: ignore @hydra.main(config_path="config", config_name="cfg_pretrain", version_base=None) def launch(hydra_config: DictConfig): RANK = 0 WORLD_SIZE = 1 CPU_PROCESS_GROUP = None # Initialize distributed training if in distributed environment (e.g. torchrun) if "LOCAL_RANK" in os.environ: # Initialize distributed, default device and dtype dist.init_process_group(backend="nccl") RANK = dist.get_rank() WORLD_SIZE = dist.get_world_size() torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) # CPU GLOO process group CPU_PROCESS_GROUP = dist.new_group(backend="gloo") assert ( dist.get_rank(CPU_PROCESS_GROUP) == RANK and dist.get_world_size(CPU_PROCESS_GROUP) == WORLD_SIZE ) # Load sync'ed config config = load_synced_config(hydra_config, rank=RANK, world_size=WORLD_SIZE) # Seed RNGs to ensure consistency torch.random.manual_seed(config.seed + RANK) # Dataset train_epochs_per_iter = config.eval_interval if config.eval_interval is not None else config.epochs total_iters = config.epochs // train_epochs_per_iter assert config.epochs % train_epochs_per_iter == 0, "Eval interval must be a divisor of total epochs." train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=train_epochs_per_iter, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE) try: eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE) except: print("NO EVAL DATA FOUND") eval_loader = eval_metadata = None try: evaluators = create_evaluators(config, eval_metadata) except: print("No evaluator found") evaluators = [] # Train state train_state = init_train_state(config, train_metadata, rank=RANK, world_size=WORLD_SIZE) # Progress bar and logger progress_bar = None ema_helper = None if RANK == 0: progress_bar = tqdm.tqdm(total=train_state.total_steps) wandb.init(project=config.project_name, name=config.run_name, config=config.model_dump(), settings=wandb.Settings(_disable_stats=True)) # type: ignore wandb.log({"num_params": sum(x.numel() for x in train_state.model.parameters())}, step=0) save_code_and_config(config) if config.ema: print('Setup EMA') ema_helper = EMAHelper(mu=config.ema_rate) ema_helper.register(train_state.model) # Training Loop for _iter_id in range(total_iters): print (f"[Rank {RANK}, World Size {WORLD_SIZE}]: Epoch {_iter_id * train_epochs_per_iter}") ############ Train Iter if RANK == 0: print("TRAIN") train_state.model.train() for set_name, batch, global_batch_size in train_loader: metrics = train_batch(config, train_state, batch, global_batch_size, rank=RANK, world_size=WORLD_SIZE) if RANK == 0 and metrics is not None: wandb.log(metrics, step=train_state.step) progress_bar.update(train_state.step - progress_bar.n) # type: ignore if config.ema: ema_helper.update(train_state.model) if _iter_id >= config.min_eval_interval: ############ Evaluation if RANK == 0: print("EVALUATE") if config.ema: print("SWITCH TO EMA") train_state_eval = copy.deepcopy(train_state) train_state_eval.model = ema_helper.ema_copy(train_state_eval.model) else: train_state_eval = train_state train_state_eval.model.eval() metrics = evaluate(config, train_state_eval, eval_loader, eval_metadata, evaluators, rank=RANK, world_size=WORLD_SIZE, cpu_group=CPU_PROCESS_GROUP) if RANK == 0 and metrics is not None: wandb.log(metrics, step=train_state.step) ############ Checkpointing if RANK == 0: print("SAVE CHECKPOINT") if RANK == 0 and (config.checkpoint_every_eval or (_iter_id == total_iters - 1)): save_train_state(config, train_state_eval) if config.ema: del train_state_eval # finalize if dist.is_initialized(): dist.destroy_process_group() wandb.finish() if __name__ == "__main__": launch()