summaryrefslogtreecommitdiff
path: root/trm/pretrain.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
committerYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
commit66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch)
treec29cba61124018755a19b02c9d33e3ad5f2e05cc /trm/pretrain.py
rrm workspace: TRM/HRM/SRM code, Maze dataset, dynamical-analysis pipelineHEADmain
Curated export for clone-and-run Maze training (2x A6000) + diagnostics. trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible). Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'trm/pretrain.py')
-rw-r--r--trm/pretrain.py1277
1 files changed, 1277 insertions, 0 deletions
diff --git a/trm/pretrain.py b/trm/pretrain.py
new file mode 100644
index 0000000..b8f9ef0
--- /dev/null
+++ b/trm/pretrain.py
@@ -0,0 +1,1277 @@
+from typing import Optional, Any, Sequence, List
+from dataclasses import dataclass, replace
+import os
+import math
+import yaml
+import shutil
+import copy
+
+import torch
+import torch.distributed as dist
+from torch import nn
+from torch.utils.data import DataLoader
+
+import tqdm
+import wandb
+import coolname
+import hydra
+import pydantic
+from omegaconf import DictConfig
+from adam_atan2 import AdamATan2
+
+from puzzle_dataset import PuzzleDataset, PuzzleDatasetConfig, PuzzleDatasetMetadata
+from utils.functions import load_model_class, get_model_source_path
+from models.sparse_embedding import CastedSparseEmbeddingSignSGD_Distributed
+from models.ema import EMAHelper
+
+
+class LossConfig(pydantic.BaseModel):
+ model_config = pydantic.ConfigDict(extra='allow')
+ name: str
+
+
+class ArchConfig(pydantic.BaseModel):
+ model_config = pydantic.ConfigDict(extra='allow')
+ name: str
+ loss: LossConfig
+
+
+class EvaluatorConfig(pydantic.BaseModel):
+ model_config = pydantic.ConfigDict(extra="allow")
+ name: str
+
+
+class PretrainConfig(pydantic.BaseModel):
+ # Config
+ arch: ArchConfig
+ # Data
+ data_paths: List[str]
+ data_paths_test: List[str] = []
+ # Evaluators
+ evaluators: List[EvaluatorConfig] = []
+
+ # Hyperparams
+ global_batch_size: int
+ epochs: int
+
+ lr: float
+ lr_min_ratio: float
+ lr_warmup_steps: int
+
+ weight_decay: float
+ beta1: float
+ beta2: float
+
+ # Puzzle embedding
+ puzzle_emb_lr: float
+ puzzle_emb_weight_decay: float
+
+ # Names
+ project_name: Optional[str] = None
+ run_name: Optional[str] = None
+ load_checkpoint: Optional[str] = None
+ checkpoint_path: Optional[str] = None
+
+ # Extras
+ seed: int = 0
+ checkpoint_every_eval: bool = False
+ eval_interval: Optional[int] = None
+ min_eval_interval: Optional[int] = 0 # when to start eval
+ eval_save_outputs: List[str] = []
+
+ ema: bool = False # use Exponential-Moving-Average
+ ema_rate: float = 0.999 # EMA-rate
+ freeze_weights: bool = False # If True, freeze weights and only learn the embeddings
+
+ trajectory_augment: bool = False
+ trajectory_n: int = 4
+ trajectory_noise_std: float = 1e-3
+ trajectory_noise_min: Optional[float] = None
+ trajectory_noise_max: Optional[float] = None
+ trajectory_noise_sampling: str = "loguniform"
+ trajectory_sigma_start: Optional[float] = 0.0
+ trajectory_sigma_ramp_steps: int = 5000
+ trajectory_perturb: str = "both"
+ trajectory_micro_batch: int = 0
+ trajectory_parallel: bool = False
+ trajectory_directional: bool = False
+ trajectory_directional_candidates: int = 4
+ trajectory_directional_fd_eps: float = 3e-2
+ trajectory_directional_horizon: int = 0
+ trajectory_directional_sign: str = "alternate"
+
+@dataclass
+class TrainState:
+ model: nn.Module
+ optimizers: Sequence[torch.optim.Optimizer]
+ optimizer_lrs: Sequence[float]
+ carry: Any
+
+ step: int
+ total_steps: int
+
+
+def create_dataloader(config: PretrainConfig, split: str, rank: int, world_size: int, **kwargs):
+ dataset = PuzzleDataset(PuzzleDatasetConfig(
+ seed=config.seed,
+ dataset_paths=config.data_paths_test if len(config.data_paths_test)>0 and split=="test" else config.data_paths,
+ rank=rank,
+ num_replicas=world_size,
+ **kwargs
+ ), split=split)
+ dataloader = DataLoader(
+ dataset,
+ batch_size=None,
+ num_workers=1,
+ prefetch_factor=8,
+ pin_memory=True,
+ persistent_workers=True
+ )
+ return dataloader, dataset.metadata
+
+
+def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, rank: int, world_size: int):
+ model_batch_size = config.global_batch_size // world_size
+ if config.trajectory_augment and config.trajectory_parallel:
+ model_batch_size *= config.trajectory_n
+
+ model_cfg = dict(
+ **config.arch.__pydantic_extra__, # type: ignore
+ batch_size=model_batch_size,
+ vocab_size=train_metadata.vocab_size,
+ seq_len=train_metadata.seq_len,
+ num_puzzle_identifiers=train_metadata.num_puzzle_identifiers,
+ causal=False # Non-autoregressive
+ )
+
+ # Instantiate model with loss head
+ model_cls = load_model_class(config.arch.name)
+ loss_head_cls = load_model_class(config.arch.loss.name)
+
+ with torch.device("cuda"):
+ model: nn.Module = model_cls(model_cfg)
+ print(model)
+ model = loss_head_cls(model, **config.arch.loss.__pydantic_extra__) # type: ignore
+ if "DISABLE_COMPILE" not in os.environ:
+ model = torch.compile(model) # type: ignore
+
+ # Load checkpoint
+ if rank == 0:
+ load_checkpoint(model, config)
+
+ # Broadcast parameters from rank 0
+ if world_size > 1:
+ with torch.no_grad():
+ for param in list(model.parameters()) + list(model.buffers()):
+ dist.broadcast(param, src=0)
+
+ # Optimizers and lr
+ if config.arch.puzzle_emb_ndim == 0:
+ optimizers = [
+ AdamATan2(
+ model.parameters(),
+ lr=0, # Needs to be set by scheduler
+ weight_decay=config.weight_decay,
+ betas=(config.beta1, config.beta2)
+ )
+ ]
+ optimizer_lrs = [
+ config.lr
+ ]
+ elif config.freeze_weights:
+ optimizers = [
+ CastedSparseEmbeddingSignSGD_Distributed(
+ model.model.puzzle_emb.buffers(), # type: ignore
+ lr=0, # Needs to be set by scheduler
+ weight_decay=config.puzzle_emb_weight_decay,
+ world_size=world_size
+ )
+ ]
+ optimizer_lrs = [
+ config.puzzle_emb_lr
+ ]
+ else:
+ optimizers = [
+ CastedSparseEmbeddingSignSGD_Distributed(
+ model.model.puzzle_emb.buffers(), # type: ignore
+ lr=0, # Needs to be set by scheduler
+ weight_decay=config.puzzle_emb_weight_decay,
+ world_size=world_size
+ ),
+ AdamATan2(
+ model.parameters(),
+ lr=0, # Needs to be set by scheduler
+ weight_decay=config.weight_decay,
+ betas=(config.beta1, config.beta2)
+ )
+ ]
+ optimizer_lrs = [
+ config.puzzle_emb_lr,
+ config.lr
+ ]
+
+ return model, optimizers, optimizer_lrs
+
+def mix_weights_direct(device, alpha, net, nets):
+ sd = []
+ for i in range(len(nets)):
+ sd += [nets[i].state_dict()]
+ sd_alpha = {}
+ for k in sd[0].keys():
+ comb_net = alpha[0]*sd[0][k].to(device)
+ for i in range(1,len(nets)):
+ comb_net += alpha[i]*sd[i][k].to(device)
+ sd_alpha[k] = comb_net
+ net.load_state_dict(sd_alpha)
+ return net
+
+def cosine_schedule_with_warmup_lr_lambda(
+ current_step: int, *, base_lr: float, num_warmup_steps: int, num_training_steps: int, min_ratio: float = 0.0, num_cycles: float = 0.5
+):
+ if current_step < num_warmup_steps:
+ return base_lr * float(current_step) / float(max(1, num_warmup_steps))
+
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
+ return base_lr * (min_ratio + max(0.0, (1 - min_ratio) * 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))))
+
+
+def init_train_state(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, rank: int, world_size: int):
+ # Estimated total training steps
+ total_steps = int(config.epochs * train_metadata.total_groups * train_metadata.mean_puzzle_examples / config.global_batch_size)
+
+ # Model
+ model, optimizers, optimizer_lrs = create_model(config, train_metadata, rank=rank, world_size=world_size)
+
+ return TrainState(
+ step=0,
+ total_steps=total_steps,
+
+ model=model,
+ optimizers=optimizers,
+ optimizer_lrs=optimizer_lrs,
+ carry=None
+ )
+
+
+def save_train_state(config: PretrainConfig, train_state: TrainState):
+ # FIXME: Only saved model.
+ if config.checkpoint_path is None:
+ return
+
+ os.makedirs(config.checkpoint_path, exist_ok=True)
+ torch.save(train_state.model.state_dict(), os.path.join(config.checkpoint_path, f"step_{train_state.step}"))
+
+
+def load_checkpoint(model: nn.Module, config: PretrainConfig):
+ if config.load_checkpoint is not None:
+ print(f"Loading checkpoint {config.load_checkpoint}")
+
+ # Load state dict
+ state_dict = torch.load(config.load_checkpoint, map_location="cuda")
+
+ # Resize and reset puzzle emb if needed
+ puzzle_emb_name = "_orig_mod.model.inner.puzzle_emb.weights"
+ expected_shape: torch.Size = model.model.puzzle_emb.weights.shape # type: ignore
+ if puzzle_emb_name in state_dict:
+ puzzle_emb = state_dict[puzzle_emb_name]
+ if puzzle_emb.shape != expected_shape:
+ print(f"Resetting puzzle embedding as shape is different. Found {puzzle_emb.shape}, Expected {expected_shape}")
+ # Re-initialize using mean
+ state_dict[puzzle_emb_name] = (
+ torch.mean(puzzle_emb, dim=0, keepdim=True).expand(expected_shape).contiguous()
+ )
+ model.load_state_dict(state_dict, assign=True)
+
+
+def compute_lr(base_lr: float, config: PretrainConfig, train_state: TrainState):
+ return cosine_schedule_with_warmup_lr_lambda(
+ current_step=train_state.step,
+ base_lr=base_lr,
+ num_warmup_steps=round(config.lr_warmup_steps),
+ num_training_steps=train_state.total_steps,
+ min_ratio=config.lr_min_ratio
+ )
+
+
+
+def _unwrap_loss_head(model: nn.Module):
+ return getattr(model, "_orig_mod", model)
+
+
+def _unit_noise_like(tensor: torch.Tensor, sampling: str):
+ if sampling == "uniform":
+ return ((2.0 * torch.rand(tensor.shape, device=tensor.device, dtype=torch.float32) - 1.0) * math.sqrt(3.0)).to(tensor.dtype)
+ return torch.randn(tensor.shape, device=tensor.device, dtype=torch.float32).to(tensor.dtype)
+
+
+def _trajectory_noise_target(config: PretrainConfig, step: int) -> float:
+ if config.trajectory_sigma_ramp_steps <= 0:
+ return config.trajectory_noise_std
+ start = config.trajectory_sigma_start if config.trajectory_sigma_start is not None else config.trajectory_noise_std
+ frac = min(max(step / config.trajectory_sigma_ramp_steps, 0.0), 1.0)
+ return float(start + frac * (config.trajectory_noise_std - start))
+
+
+def _sample_noise_stds(config: PretrainConfig, batch_size: int, device: torch.device, step: int):
+ target = _trajectory_noise_target(config, step)
+ if target <= 0:
+ return torch.zeros(batch_size, device=device, dtype=torch.float32)
+ if config.trajectory_noise_sampling == "loguniform":
+ scale = 1.0 if config.trajectory_noise_std <= 0 else target / config.trajectory_noise_std
+ hi = float(config.trajectory_noise_max if config.trajectory_noise_max is not None else config.trajectory_noise_std) * scale
+ lo = float(config.trajectory_noise_min if config.trajectory_noise_min is not None else max(hi / 10.0, 1e-8)) * scale
+ hi = max(hi, 1e-12)
+ lo = max(min(lo, hi), 1e-12)
+ u = torch.rand(batch_size, device=device, dtype=torch.float32)
+ return torch.exp(math.log(lo) + u * (math.log(hi) - math.log(lo)))
+ return torch.full((batch_size,), target, device=device, dtype=torch.float32)
+
+
+def _add_initial_noise(config: PretrainConfig, inner_carry: Any, noise_stds: torch.Tensor):
+ if noise_stds.numel() == 0 or float(noise_stds.max().item()) <= 0:
+ return inner_carry
+ view_shape = (noise_stds.shape[0],) + (1,) * (inner_carry.z_H.ndim - 1)
+ scale = noise_stds.view(view_shape)
+ z_h = inner_carry.z_H
+ z_l = inner_carry.z_L
+ if config.trajectory_perturb in ("h", "both"):
+ z_h = z_h + scale.to(z_h.dtype) * _unit_noise_like(z_h, config.trajectory_noise_sampling)
+ if config.trajectory_perturb in ("l", "both"):
+ z_l = z_l + scale.to(z_l.dtype) * _unit_noise_like(z_l, config.trajectory_noise_sampling)
+ return replace(inner_carry, z_H=z_h, z_L=z_l)
+
+
+def _directional_fields(config: PretrainConfig):
+ if config.trajectory_perturb == "h":
+ return ("z_H",)
+ if config.trajectory_perturb == "l":
+ return ("z_L",)
+ return ("z_H", "z_L")
+
+
+def _index_inner_carry(inner_carry: Any, indices: torch.Tensor):
+ inner_type = type(inner_carry)
+ return inner_type(**{
+ name: getattr(inner_carry, name).index_select(0, indices).contiguous()
+ for name in inner_carry.__dataclass_fields__.keys()
+ })
+
+
+def _cat_inner_carries(carries: Sequence[Any]):
+ inner_type = type(carries[0])
+ return inner_type(**{
+ name: torch.cat([getattr(c, name) for c in carries], dim=0)
+ for name in carries[0].__dataclass_fields__.keys()
+ })
+
+
+def _split_inner_carry(inner_carry: Any, first_size: int):
+ inner_type = type(inner_carry)
+ first = {}
+ second = {}
+ for name in inner_carry.__dataclass_fields__.keys():
+ value = getattr(inner_carry, name)
+ first[name] = value[:first_size].contiguous()
+ second[name] = value[first_size:].contiguous()
+ return inner_type(**first), inner_type(**second)
+
+
+def _repeat_batch_interleave(batch: Any, repeats: int):
+ return {k: v.repeat_interleave(repeats, dim=0) for k, v in batch.items()}
+
+
+def _cat_batches(a: Any, b: Any):
+ return {k: torch.cat([a[k], b[k]], dim=0) for k in a}
+
+
+def _rand_unit_direction_candidates(config: PretrainConfig, inner_carry: Any, candidates: int):
+ fields = _directional_fields(config)
+ dirs = {}
+ norm_sq = None
+ for name in inner_carry.__dataclass_fields__.keys():
+ value = getattr(inner_carry, name)
+ if name in fields:
+ direction = torch.randn(
+ (value.shape[0], candidates) + tuple(value.shape[1:]),
+ device=value.device,
+ dtype=torch.float32,
+ )
+ term = direction.flatten(2).square().sum(-1)
+ norm_sq = term if norm_sq is None else norm_sq + term
+ dirs[name] = direction
+ else:
+ dirs[name] = torch.zeros(
+ (value.shape[0], candidates) + tuple(value.shape[1:]),
+ device=value.device,
+ dtype=torch.float32,
+ )
+
+ if norm_sq is None:
+ raise ValueError(f"unknown trajectory_perturb={config.trajectory_perturb!r}")
+
+ norm = torch.sqrt(norm_sq).clamp_min(1e-30)
+ out = {}
+ for name, direction in dirs.items():
+ view_shape = direction.shape[:2] + (1,) * (direction.ndim - 2)
+ out[name] = (direction / norm.view(view_shape)).to(getattr(inner_carry, name).dtype)
+ return out
+
+
+def _make_shadow_inner(inner_carry: Any, dirs: dict[str, torch.Tensor], eps: float):
+ inner_type = type(inner_carry)
+ fields = {}
+ for name in inner_carry.__dataclass_fields__.keys():
+ value = getattr(inner_carry, name)
+ perturbed = value[:, None] + eps * dirs[name]
+ fields[name] = perturbed.reshape((value.shape[0] * dirs[name].shape[1],) + tuple(value.shape[1:])).detach()
+ return inner_type(**fields)
+
+
+def _separation(main: Any, shadow: Any, candidates: int):
+ sep_sq = None
+ batch_size = getattr(main, next(iter(main.__dataclass_fields__.keys()))).shape[0]
+ for name in main.__dataclass_fields__.keys():
+ main_value = getattr(main, name)
+ shadow_value = getattr(shadow, name).reshape((batch_size, candidates) + tuple(main_value.shape[1:]))
+ diff_sq = (shadow_value.float() - main_value[:, None].float()).flatten(2).square().sum(-1)
+ sep_sq = diff_sq if sep_sq is None else sep_sq + diff_sq
+ return torch.sqrt(sep_sq).clamp_min(1e-30)
+
+
+def _gather_direction(dirs: dict[str, torch.Tensor], best_idx: torch.Tensor):
+ batch_size = best_idx.shape[0]
+ row_idx = torch.arange(batch_size, device=best_idx.device)
+ return {name: value[row_idx, best_idx].contiguous() for name, value in dirs.items()}
+
+
+def _directional_signs(config: PretrainConfig, branch_idx: int, count: int, device: torch.device):
+ mode = config.trajectory_directional_sign
+ if mode == "random":
+ return torch.where(
+ torch.rand(count, device=device, dtype=torch.float32) < 0.5,
+ torch.full((count,), -1.0, device=device, dtype=torch.float32),
+ torch.full((count,), 1.0, device=device, dtype=torch.float32),
+ )
+ if mode == "negative":
+ return torch.full((count,), -1.0, device=device, dtype=torch.float32)
+ if mode == "positive":
+ return torch.full((count,), 1.0, device=device, dtype=torch.float32)
+ if mode != "alternate":
+ raise ValueError(f"unknown trajectory_directional_sign={mode!r}")
+ sign = 1.0 if branch_idx % 2 == 1 else -1.0
+ return torch.full((count,), sign, device=device, dtype=torch.float32)
+
+
+@torch.no_grad()
+def _select_directional_perturbation(
+ config: PretrainConfig,
+ base: nn.Module,
+ inner_carry: Any,
+ batch: Any,
+ active_mask: torch.Tensor,
+):
+ active_indices = torch.nonzero(active_mask, as_tuple=False).flatten()
+ active_count = int(active_indices.numel())
+ if active_count == 0:
+ return None, None, {
+ "active": 0,
+ "growth_sum": 0.0,
+ "growth_count": 0,
+ }
+
+ candidates = max(int(config.trajectory_directional_candidates), 1)
+ fd_eps = float(config.trajectory_directional_fd_eps)
+ horizon = int(config.trajectory_directional_horizon)
+ if horizon <= 0:
+ horizon = int(base.config.halt_max_steps) # type: ignore[attr-defined]
+ horizon = max(horizon, 1)
+
+ active_inner = _index_inner_carry(inner_carry, active_indices)
+ active_batch = {k: v.index_select(0, active_indices).contiguous() for k, v in batch.items()}
+ dirs = _rand_unit_direction_candidates(config, active_inner, candidates)
+ main = type(active_inner)(**{
+ name: getattr(active_inner, name).detach()
+ for name in active_inner.__dataclass_fields__.keys()
+ })
+ shadow = _make_shadow_inner(main, dirs, fd_eps)
+ combined = _cat_inner_carries([main, shadow])
+ combined_batch = _cat_batches(active_batch, _repeat_batch_interleave(active_batch, candidates))
+
+ was_training = base.inner.training # type: ignore[attr-defined]
+ base.inner.eval() # type: ignore[attr-defined]
+ try:
+ for _ in range(horizon):
+ combined, _logits, _q = base.inner(combined, combined_batch) # type: ignore[attr-defined]
+ finally:
+ if was_training:
+ base.inner.train() # type: ignore[attr-defined]
+
+ main_final, shadow_final = _split_inner_carry(combined, active_count)
+ sep = _separation(main_final, shadow_final, candidates)
+ best_idx = sep.argmax(dim=1)
+ best_sep = sep.gather(1, best_idx[:, None]).squeeze(1)
+ best_dirs = _gather_direction(dirs, best_idx)
+ growth = torch.log(best_sep / max(fd_eps, 1e-30)).float() / horizon
+ return active_indices, best_dirs, {
+ "active": active_count,
+ "growth_sum": float(growth.sum().item()),
+ "growth_count": active_count,
+ }
+
+
+def _add_directional_initial_noise(
+ config: PretrainConfig,
+ base: nn.Module,
+ inner_carry: Any,
+ batch: Any,
+ noise_stds: torch.Tensor,
+ active_mask: torch.Tensor,
+ branch_idx: int,
+):
+ selected_indices, selected_dirs, stats = _select_directional_perturbation(
+ config,
+ base,
+ inner_carry,
+ batch,
+ active_mask,
+ )
+ if selected_indices is None or selected_dirs is None:
+ return inner_carry, stats
+
+ signs = _directional_signs(config, branch_idx, selected_indices.shape[0], noise_stds.device)
+ scales = noise_stds.index_select(0, selected_indices) * signs
+ fields = {}
+ for name in inner_carry.__dataclass_fields__.keys():
+ value = getattr(inner_carry, name)
+ updated = value.clone()
+ direction = selected_dirs[name].to(value.dtype)
+ view_shape = (scales.shape[0],) + (1,) * (direction.ndim - 1)
+ updated.index_copy_(0, selected_indices, value.index_select(0, selected_indices) + scales.view(view_shape).to(value.dtype) * direction)
+ fields[name] = updated
+ return type(inner_carry)(**fields), stats
+
+
+def _token_loss(loss_fn, logits, labels, mask):
+ kwargs = {"ignore_index": -100}
+ code = getattr(loss_fn, "__code__", None)
+ arg_names = code.co_varnames[: code.co_argcount + code.co_kwonlyargcount] if code is not None else ()
+ if "valid_mask" in arg_names:
+ kwargs["valid_mask"] = mask
+ return loss_fn(logits, labels, **kwargs)
+
+
+def _fixed_unroll_branch_loss(config: PretrainConfig, head: nn.Module, batch: Any, noise_stds: Optional[torch.Tensor]):
+ base = head.model # type: ignore[attr-defined]
+ with torch.device("cuda"):
+ carry = base.initial_carry(batch)
+ reset_flag = torch.ones_like(carry.halted)
+ inner_carry = base.inner.reset_carry(reset_flag, carry.inner_carry)
+ if noise_stds is not None:
+ inner_carry = _add_initial_noise(config, inner_carry, noise_stds)
+
+ batch_size = batch["inputs"].shape[0]
+ loss_sum = torch.zeros((), device=batch["inputs"].device, dtype=torch.float32)
+ last_exact = torch.zeros((), device=batch["inputs"].device, dtype=torch.float32)
+ for act_step in range(base.config.halt_max_steps):
+ inner_carry, logits, (q_halt_logits, q_continue_logits) = base.inner(inner_carry, batch)
+ labels = batch["labels"]
+ with torch.no_grad():
+ mask = labels != -100
+ loss_counts = mask.sum(-1)
+ loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1)
+ is_correct = mask & (torch.argmax(logits, dim=-1) == labels)
+ seq_is_correct = is_correct.sum(-1) == loss_counts
+ last_exact = seq_is_correct.to(torch.float32).sum()
+
+ lm_loss = (_token_loss(head.loss_fn, logits, labels, mask) / loss_divisor).sum()
+ q_halt_loss = torch.nn.functional.binary_cross_entropy_with_logits(
+ q_halt_logits,
+ seq_is_correct.to(q_halt_logits.dtype),
+ reduction="sum",
+ )
+ q_continue_loss = torch.zeros_like(q_halt_loss)
+ if not getattr(base.config, "no_ACT_continue", False):
+ with torch.no_grad():
+ next_q_halt_logits, next_q_continue_logits = base.inner(inner_carry, batch)[-1]
+ is_last = act_step + 1 >= base.config.halt_max_steps
+ target_q_continue = torch.sigmoid(
+ torch.where(
+ torch.full_like(next_q_halt_logits, is_last, dtype=torch.bool),
+ next_q_halt_logits,
+ torch.maximum(next_q_halt_logits, next_q_continue_logits),
+ )
+ )
+ q_continue_loss = torch.nn.functional.binary_cross_entropy_with_logits(q_continue_logits, target_q_continue, reduction="sum")
+ loss_sum = loss_sum + lm_loss + 0.5 * (q_halt_loss + q_continue_loss)
+
+ return loss_sum / max(base.config.halt_max_steps, 1), last_exact, batch_size
+
+
+def _prepare_noisy_stream_carry(config: PretrainConfig, base: nn.Module, carry: Any, batch: Any, step: int, branch_idx: int):
+ reset_mask = carry.halted
+ new_inner = base.inner.reset_carry(reset_mask, carry.inner_carry)
+ noise_stds = _sample_noise_stds(config, batch["inputs"].shape[0], batch["inputs"].device, step)
+ noise_stds = torch.where(reset_mask, noise_stds, torch.zeros_like(noise_stds))
+ directional_stats = {
+ "active": int((noise_stds > 0).sum().detach().cpu().item()),
+ "growth_sum": 0.0,
+ "growth_count": 0,
+ }
+ if config.trajectory_directional and float(noise_stds.max().item()) > 0:
+ new_inner, directional_stats = _add_directional_initial_noise(
+ config,
+ base,
+ new_inner,
+ batch,
+ noise_stds,
+ noise_stds > 0,
+ branch_idx,
+ )
+ else:
+ new_inner = _add_initial_noise(config, new_inner, noise_stds)
+ new_steps = torch.where(reset_mask, 0, carry.steps)
+ new_current_data = {
+ k: torch.where(reset_mask.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v)
+ for k, v in carry.current_data.items()
+ }
+ return (
+ replace(
+ carry,
+ inner_carry=new_inner,
+ steps=new_steps,
+ halted=torch.zeros_like(reset_mask),
+ current_data=new_current_data,
+ ),
+ directional_stats,
+ )
+
+
+def _merge_trajectory_carries(carries: Sequence[Any]):
+ inner_type = type(carries[0].inner_carry)
+ inner_fields = carries[0].inner_carry.__dataclass_fields__.keys()
+ inner_carry = inner_type(**{
+ name: torch.cat([getattr(c.inner_carry, name) for c in carries], dim=0)
+ for name in inner_fields
+ })
+ current_data = {
+ k: torch.cat([c.current_data[k] for c in carries], dim=0)
+ for k in carries[0].current_data
+ }
+ return replace(
+ carries[0],
+ inner_carry=inner_carry,
+ steps=torch.cat([c.steps for c in carries], dim=0),
+ halted=torch.cat([c.halted for c in carries], dim=0),
+ current_data=current_data,
+ )
+
+
+def _split_trajectory_carry(carry: Any, branch_size: int, num_branches: int):
+ inner_type = type(carry.inner_carry)
+ inner_fields = carry.inner_carry.__dataclass_fields__.keys()
+ inner_chunks = {
+ name: [chunk.contiguous() for chunk in getattr(carry.inner_carry, name).split(branch_size, dim=0)]
+ for name in inner_fields
+ }
+ steps_chunks = [chunk.contiguous() for chunk in carry.steps.split(branch_size, dim=0)]
+ halted_chunks = [chunk.contiguous() for chunk in carry.halted.split(branch_size, dim=0)]
+ data_chunks = {
+ k: [chunk.contiguous() for chunk in v.split(branch_size, dim=0)]
+ for k, v in carry.current_data.items()
+ }
+ return [
+ replace(
+ carry,
+ inner_carry=inner_type(**{name: inner_chunks[name][idx] for name in inner_fields}),
+ steps=steps_chunks[idx],
+ halted=halted_chunks[idx],
+ current_data={k: chunks[idx] for k, chunks in data_chunks.items()},
+ )
+ for idx in range(num_branches)
+ ]
+
+
+def _repeat_batch_for_trajectories(batch: Any, num_branches: int):
+ return {
+ k: torch.cat([v for _ in range(num_branches)], dim=0)
+ for k, v in batch.items()
+ }
+
+
+def _split_outputs(outputs: Any, branch_size: int, num_branches: int):
+ return [
+ {
+ k: v.narrow(0, idx * branch_size, branch_size)
+ for k, v in outputs.items()
+ }
+ for idx in range(num_branches)
+ ]
+
+
+def _branch_metrics_and_loss(head: nn.Module, carry: Any, outputs: Any):
+ labels = carry.current_data["labels"]
+ with torch.no_grad():
+ mask = labels != -100
+ loss_counts = mask.sum(-1)
+ loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1)
+ is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
+ seq_is_correct = is_correct.sum(-1) == loss_counts
+ valid_metrics = carry.halted & (loss_counts > 0)
+ metrics = {
+ "count": valid_metrics.sum(),
+ "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
+ "exact_accuracy": (valid_metrics & seq_is_correct).sum(),
+ "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
+ "steps": torch.where(valid_metrics, carry.steps, 0).sum(),
+ }
+
+ lm_loss = (head.loss_fn(outputs["logits"], labels, ignore_index=-100) / loss_divisor).sum()
+ q_halt_loss = torch.nn.functional.binary_cross_entropy_with_logits(
+ outputs["q_halt_logits"],
+ seq_is_correct.to(outputs["q_halt_logits"].dtype),
+ reduction="sum",
+ )
+ q_continue_loss = torch.zeros_like(q_halt_loss)
+ if "target_q_continue" in outputs:
+ q_continue_loss = torch.nn.functional.binary_cross_entropy_with_logits(
+ outputs["q_continue_logits"],
+ outputs["target_q_continue"],
+ reduction="sum",
+ )
+ metrics.update({
+ "lm_loss": lm_loss,
+ "q_halt_loss": q_halt_loss,
+ })
+ if "target_q_continue" in outputs:
+ metrics["q_continue_loss"] = q_continue_loss
+ total_loss = lm_loss + 0.5 * (q_halt_loss + q_continue_loss)
+ return metrics, total_loss
+
+
+def train_batch_trajectory_aug_parallel(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int, world_size: int):
+ if world_size != 1:
+ raise NotImplementedError("trajectory_parallel currently expects single-GPU baseline configs")
+ train_state.step += 1
+ if train_state.step > train_state.total_steps:
+ return
+
+ batch = {k: v.cuda() for k, v in batch.items()}
+ head = _unwrap_loss_head(train_state.model)
+ base = head.model # type: ignore[attr-defined]
+ num_branches = max(config.trajectory_n, 1)
+ branch_size = batch["inputs"].shape[0]
+
+ if train_state.carry is None or not isinstance(train_state.carry, list):
+ with torch.device("cuda"):
+ train_state.carry = [train_state.model.initial_carry(batch) for _ in range(num_branches)] # type: ignore
+
+ for optim in train_state.optimizers:
+ optim.zero_grad()
+
+ carries = []
+ directional_active = 0
+ directional_growth_sum = 0.0
+ directional_growth_count = 0
+ for branch_idx, carry in enumerate(train_state.carry):
+ if branch_idx > 0:
+ carry, stats = _prepare_noisy_stream_carry(config, base, carry, batch, train_state.step, branch_idx)
+ directional_active += int(stats["active"])
+ directional_growth_sum += float(stats["growth_sum"])
+ directional_growth_count += int(stats["growth_count"])
+ carries.append(carry)
+
+ parallel_carry = _merge_trajectory_carries(carries)
+ parallel_batch = _repeat_batch_for_trajectories(batch, num_branches)
+ return_keys = ["logits", "q_halt_logits", "q_continue_logits", "target_q_continue"]
+ parallel_carry, loss, _metrics, outputs, _ = train_state.model(
+ carry=parallel_carry,
+ batch=parallel_batch,
+ return_keys=return_keys,
+ )
+ train_state.carry = _split_trajectory_carry(parallel_carry, branch_size, num_branches)
+ (loss / (global_batch_size * num_branches)).backward()
+
+ lr_this_step = None
+ for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs):
+ lr_this_step = compute_lr(base_lr, config, train_state)
+ for param_group in optim.param_groups:
+ param_group["lr"] = lr_this_step
+ optim.step()
+ optim.zero_grad()
+
+ if rank == 0 and outputs is not None:
+ split_outputs = _split_outputs(outputs, branch_size, num_branches)
+ clean_metrics, clean_loss = _branch_metrics_and_loss(head, train_state.carry[0], split_outputs[0])
+ noisy_loss = 0.0
+ for branch_idx in range(1, num_branches):
+ _, branch_loss = _branch_metrics_and_loss(head, train_state.carry[branch_idx], split_outputs[branch_idx])
+ noisy_loss += float(branch_loss.detach().cpu())
+
+ reduced_metrics = {f"train/{k}": float(v.detach().cpu()) for k, v in clean_metrics.items()}
+ count = max(reduced_metrics.get("train/count", 1.0), 1.0)
+ reduced_metrics = {
+ k: v / (global_batch_size if k.endswith("loss") else count)
+ for k, v in reduced_metrics.items()
+ }
+ reduced_metrics["train/lr"] = lr_this_step
+ reduced_metrics["train/trajectory_clean_loss"] = float(clean_loss.detach().cpu()) / global_batch_size
+ reduced_metrics["train/trajectory_noisy_loss"] = noisy_loss / max(num_branches - 1, 1) / global_batch_size
+ reduced_metrics["train/trajectory_sigma"] = _trajectory_noise_target(config, train_state.step)
+ reduced_metrics["train/trajectory_parallel"] = 1.0
+ reduced_metrics["train/trajectory_directional"] = 1.0 if config.trajectory_directional else 0.0
+ reduced_metrics["train/trajectory_directional_active"] = float(directional_active)
+ if directional_growth_count > 0:
+ reduced_metrics["train/trajectory_directional_growth"] = directional_growth_sum / directional_growth_count
+ return reduced_metrics
+
+
+def train_batch_trajectory_aug(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int, world_size: int):
+ if config.trajectory_parallel:
+ return train_batch_trajectory_aug_parallel(config, train_state, batch, global_batch_size, rank, world_size)
+ if world_size != 1:
+ raise NotImplementedError("trajectory_augment currently expects single-GPU baseline configs")
+ train_state.step += 1
+ if train_state.step > train_state.total_steps:
+ return
+
+ batch = {k: v.cuda() for k, v in batch.items()}
+ head = _unwrap_loss_head(train_state.model)
+ base = head.model # type: ignore[attr-defined]
+
+ if train_state.carry is None or not isinstance(train_state.carry, list):
+ with torch.device("cuda"):
+ train_state.carry = [train_state.model.initial_carry(batch) for _ in range(config.trajectory_n)] # type: ignore
+
+ for optim in train_state.optimizers:
+ optim.zero_grad()
+
+ clean_metrics = None
+ clean_loss_value = 0.0
+ noisy_loss_value = 0.0
+ directional_active = 0
+ directional_growth_sum = 0.0
+ directional_growth_count = 0
+ for branch_idx in range(config.trajectory_n):
+ carry = train_state.carry[branch_idx]
+ if branch_idx > 0:
+ carry, stats = _prepare_noisy_stream_carry(config, base, carry, batch, train_state.step, branch_idx)
+ directional_active += int(stats["active"])
+ directional_growth_sum += float(stats["growth_sum"])
+ directional_growth_count += int(stats["growth_count"])
+ carry, loss, metrics, _, _ = train_state.model(carry=carry, batch=batch, return_keys=[])
+ train_state.carry[branch_idx] = carry
+ (loss / (global_batch_size * max(config.trajectory_n, 1))).backward()
+ if branch_idx == 0:
+ clean_metrics = metrics
+ clean_loss_value = float(loss.detach().cpu())
+ else:
+ noisy_loss_value += float(loss.detach().cpu())
+
+ lr_this_step = None
+ for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs):
+ lr_this_step = compute_lr(base_lr, config, train_state)
+ for param_group in optim.param_groups:
+ param_group["lr"] = lr_this_step
+ optim.step()
+ optim.zero_grad()
+
+ if rank == 0 and clean_metrics is not None:
+ reduced_metrics = {f"train/{k}": float(v.detach().cpu()) for k, v in clean_metrics.items()}
+ count = max(reduced_metrics.get("train/count", 1.0), 1.0)
+ reduced_metrics = {
+ k: v / (global_batch_size if k.endswith("loss") else count)
+ for k, v in reduced_metrics.items()
+ }
+ reduced_metrics["train/lr"] = lr_this_step
+ reduced_metrics["train/trajectory_clean_loss"] = clean_loss_value / global_batch_size
+ reduced_metrics["train/trajectory_noisy_loss"] = noisy_loss_value / max(config.trajectory_n - 1, 1) / global_batch_size
+ reduced_metrics["train/trajectory_sigma"] = _trajectory_noise_target(config, train_state.step)
+ reduced_metrics["train/trajectory_directional"] = 1.0 if config.trajectory_directional else 0.0
+ reduced_metrics["train/trajectory_directional_active"] = float(directional_active)
+ if directional_growth_count > 0:
+ reduced_metrics["train/trajectory_directional_growth"] = directional_growth_sum / directional_growth_count
+ return reduced_metrics
+
+
+def create_evaluators(config: PretrainConfig, eval_metadata: PuzzleDatasetMetadata) -> List[Any]:
+ data_paths =config.data_paths_test if len(config.data_paths_test)>0 else config.data_paths
+ # Initialize evaluators
+ evaluators = []
+ for cfg in config.evaluators:
+ for data_path in data_paths:
+ cls = load_model_class(cfg.name, "evaluators.")(
+ data_path=data_path, eval_metadata=eval_metadata, **cfg.__pydantic_extra__
+ ) # type: ignore
+ evaluators.append(cls)
+
+ return evaluators
+
+def train_batch(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int, world_size: int):
+ if config.trajectory_augment:
+ return train_batch_trajectory_aug(config, train_state, batch, global_batch_size, rank, world_size)
+
+ train_state.step += 1
+ if train_state.step > train_state.total_steps: # At most train_total_steps
+ return
+
+ # To device
+ batch = {k: v.cuda() for k, v in batch.items()}
+
+ # Init carry if it is None
+ if train_state.carry is None:
+ with torch.device("cuda"):
+ train_state.carry = train_state.model.initial_carry(batch) # type: ignore
+
+ # Forward
+ train_state.carry, loss, metrics, _, _ = train_state.model(carry=train_state.carry, batch=batch, return_keys=[])
+
+ ((1 / global_batch_size) * loss).backward()
+
+ # Allreduce
+ if world_size > 1:
+ for param in train_state.model.parameters():
+ if param.grad is not None:
+ dist.all_reduce(param.grad)
+
+ # Apply optimizer
+ lr_this_step = None
+ for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs):
+ lr_this_step = compute_lr(base_lr, config, train_state)
+
+ for param_group in optim.param_groups:
+ param_group['lr'] = lr_this_step
+
+ optim.step()
+ optim.zero_grad()
+
+ # Reduce metrics
+ if len(metrics):
+ assert not any(v.requires_grad for v in metrics.values())
+
+ metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order.
+ # Reduce and reconstruct
+ metric_values = torch.stack([metrics[k] for k in metric_keys])
+ if world_size > 1:
+ dist.reduce(metric_values, dst=0)
+
+ if rank == 0:
+ metric_values = metric_values.cpu().numpy()
+ reduced_metrics = {k: metric_values[i] for i, k in enumerate(metric_keys)}
+
+ # Postprocess
+ count = max(reduced_metrics["count"], 1) # Avoid NaNs
+ reduced_metrics = {f"train/{k}": v / (global_batch_size if k.endswith("loss") else count) for k, v in reduced_metrics.items()}
+
+ reduced_metrics["train/lr"] = lr_this_step
+ return reduced_metrics
+
+def evaluate(
+ config: PretrainConfig,
+ train_state: TrainState,
+ eval_loader: torch.utils.data.DataLoader,
+ eval_metadata: PuzzleDatasetMetadata,
+ evaluators: List[Any],
+ rank: int,
+ world_size: int,
+ cpu_group: Optional[dist.ProcessGroup],
+):
+ reduced_metrics = None
+
+ with torch.inference_mode():
+ return_keys = set(config.eval_save_outputs)
+ for evaluator in evaluators:
+ evaluator.begin_eval()
+ return_keys.update(evaluator.required_outputs)
+
+ # Run evaluation
+ set_ids = {k: idx for idx, k in enumerate(eval_metadata.sets)}
+
+ save_preds = {}
+
+ metric_keys = []
+ metric_values = None
+
+ carry = None
+ processed_batches = 0
+
+ for set_name, batch, global_batch_size in eval_loader:
+ processed_batches += 1
+ if rank == 0:
+ print(f"Processing batch {processed_batches}: {set_name}")
+
+ # To device
+ batch = {k: v.cuda() for k, v in batch.items()}
+ with torch.device("cuda"):
+ carry = train_state.model.initial_carry(batch) # type: ignore
+
+ # Forward
+ inference_steps = 0
+ while True:
+ carry, loss, metrics, preds, all_finish = train_state.model(
+ carry=carry, batch=batch, return_keys=return_keys
+ )
+ inference_steps += 1
+
+ if all_finish:
+ break
+
+ if rank == 0:
+ print(f" Completed inference in {inference_steps} steps")
+
+ for collection in (batch, preds):
+ for k, v in collection.items():
+ if k in config.eval_save_outputs:
+ save_preds.setdefault(k, [])
+ save_preds[k].append(v.cpu()) # Move to CPU for saving GPU memory
+
+ for evaluator in evaluators:
+ evaluator.update_batch(batch, preds)
+
+ del carry, loss, preds, batch, all_finish
+
+ # Aggregate metrics
+ set_id = set_ids[set_name]
+
+ if metric_values is None:
+ metric_keys = list(
+ sorted(metrics.keys())
+ ) # Sort keys to guarantee all processes use the same order.
+ metric_values = torch.zeros(
+ (len(set_ids), len(metrics.values())), dtype=torch.float32, device="cuda"
+ )
+
+ metric_values[set_id] += torch.stack([metrics[k] for k in metric_keys])
+
+ del metrics
+
+ # concatenate save preds
+ save_preds = {k: torch.cat(v, dim=0) for k, v in save_preds.items()}
+
+ # Save preds
+ if config.checkpoint_path is not None and len(save_preds):
+ # Each rank save predictions independently
+ os.makedirs(os.path.dirname(config.checkpoint_path), exist_ok=True)
+ torch.save(
+ save_preds, os.path.join(config.checkpoint_path, f"step_{train_state.step}_all_preds.{rank}")
+ )
+
+ del save_preds
+
+ # Reduce to rank 0
+ if metric_values is not None:
+ if world_size > 1:
+ dist.reduce(metric_values, dst=0)
+
+ if rank == 0:
+ reduced_metrics = metric_values.cpu().numpy()
+ reduced_metrics = {
+ set_name: {
+ metric_name: reduced_metrics[set_id, metric_id]
+ for metric_id, metric_name in enumerate(metric_keys)
+ }
+ for set_id, set_name in enumerate(set_ids)
+ }
+
+ # Postprocess
+ for set_name, m in reduced_metrics.items():
+ count = m.pop("count")
+ reduced_metrics[set_name] = {k: v / count for k, v in m.items()}
+
+ # Run evaluators
+ if rank == 0:
+ print(f"\nRunning {len(evaluators)} evaluator(s)...")
+
+ for i, evaluator in enumerate(evaluators):
+ if rank == 0:
+ print(f"Running evaluator {i+1}/{len(evaluators)}: {evaluator.__class__.__name__}")
+
+ # Path for saving
+ evaluator_save_path = None
+ if config.checkpoint_path is not None:
+ evaluator_save_path = os.path.join(
+ config.checkpoint_path,
+ f"evaluator_{evaluator.__class__.__name__}_step_{train_state.step}",
+ )
+ os.makedirs(evaluator_save_path, exist_ok=True)
+
+ # Run and log
+ metrics = evaluator.result(evaluator_save_path, rank=rank, world_size=world_size, group=cpu_group)
+ if rank == 0 and metrics is not None:
+ if reduced_metrics is None:
+ reduced_metrics = {}
+
+ reduced_metrics.update(metrics)
+ print(f" Completed {evaluator.__class__.__name__}")
+
+ if rank == 0:
+ print("All evaluators completed!")
+
+ return reduced_metrics
+
+def save_code_and_config(config: PretrainConfig):
+ if config.checkpoint_path is None or wandb.run is None:
+ return
+
+ os.makedirs(config.checkpoint_path, exist_ok=True)
+
+ # Copy code
+ code_list = [
+ get_model_source_path(config.arch.name),
+ get_model_source_path(config.arch.loss.name)
+ ]
+ for code_file in code_list:
+ if code_file is not None:
+ code_name = os.path.basename(code_file)
+
+ shutil.copy(code_file, os.path.join(config.checkpoint_path, code_name))
+
+ # Dump config as yaml
+ config_file = os.path.join(config.checkpoint_path, "all_config.yaml")
+ with open(config_file, "wt") as f:
+ yaml.dump(config.model_dump(), f)
+
+ # Log code
+ wandb.run.log_code(config.checkpoint_path)
+
+
+def load_synced_config(hydra_config: DictConfig, rank: int, world_size: int) -> PretrainConfig:
+ objects = [None]
+ if rank == 0:
+ config = PretrainConfig(**hydra_config) # type: ignore
+
+ # Naming
+ if config.project_name is None:
+ config.project_name = f"{os.path.basename(config.data_paths[0]).capitalize()}-ACT-torch"
+ if config.run_name is None:
+ config.run_name = f"{config.arch.name.split('@')[-1]} {coolname.generate_slug(2)}"
+ if config.checkpoint_path is None:
+ config.checkpoint_path = os.path.join("checkpoints", config.project_name, config.run_name)
+
+ objects = [config]
+
+ if world_size > 1:
+ dist.broadcast_object_list(objects, src=0)
+
+ return objects[0] # type: ignore
+
+
+@hydra.main(config_path="config", config_name="cfg_pretrain", version_base=None)
+def launch(hydra_config: DictConfig):
+ RANK = 0
+ WORLD_SIZE = 1
+ CPU_PROCESS_GROUP = None
+
+ # Initialize distributed training if in distributed environment (e.g. torchrun)
+ if "LOCAL_RANK" in os.environ:
+ # Initialize distributed, default device and dtype
+ dist.init_process_group(backend="nccl")
+
+ RANK = dist.get_rank()
+ WORLD_SIZE = dist.get_world_size()
+
+ torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
+
+ # CPU GLOO process group
+ CPU_PROCESS_GROUP = dist.new_group(backend="gloo")
+ assert (
+ dist.get_rank(CPU_PROCESS_GROUP) == RANK and dist.get_world_size(CPU_PROCESS_GROUP) == WORLD_SIZE
+ )
+
+ # Load sync'ed config
+ config = load_synced_config(hydra_config, rank=RANK, world_size=WORLD_SIZE)
+
+ # Seed RNGs to ensure consistency
+ torch.random.manual_seed(config.seed + RANK)
+
+ # Dataset
+ train_epochs_per_iter = config.eval_interval if config.eval_interval is not None else config.epochs
+ total_iters = config.epochs // train_epochs_per_iter
+
+ assert config.epochs % train_epochs_per_iter == 0, "Eval interval must be a divisor of total epochs."
+
+ train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=train_epochs_per_iter, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
+ try:
+ eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
+ except:
+ print("NO EVAL DATA FOUND")
+ eval_loader = eval_metadata = None
+
+ try:
+ evaluators = create_evaluators(config, eval_metadata)
+ except:
+ print("No evaluator found")
+ evaluators = []
+
+ # Train state
+ train_state = init_train_state(config, train_metadata, rank=RANK, world_size=WORLD_SIZE)
+
+ # Progress bar and logger
+ progress_bar = None
+ ema_helper = None
+ if RANK == 0:
+ progress_bar = tqdm.tqdm(total=train_state.total_steps)
+ wandb.init(project=config.project_name, name=config.run_name, config=config.model_dump(), settings=wandb.Settings(_disable_stats=True)) # type: ignore
+ wandb.log({"num_params": sum(x.numel() for x in train_state.model.parameters())}, step=0)
+ save_code_and_config(config)
+ if config.ema:
+ print('Setup EMA')
+ ema_helper = EMAHelper(mu=config.ema_rate)
+ ema_helper.register(train_state.model)
+
+ # Training Loop
+ for _iter_id in range(total_iters):
+ print (f"[Rank {RANK}, World Size {WORLD_SIZE}]: Epoch {_iter_id * train_epochs_per_iter}")
+
+ ############ Train Iter
+ if RANK == 0:
+ print("TRAIN")
+ train_state.model.train()
+ for set_name, batch, global_batch_size in train_loader:
+ metrics = train_batch(config, train_state, batch, global_batch_size, rank=RANK, world_size=WORLD_SIZE)
+
+ if RANK == 0 and metrics is not None:
+ wandb.log(metrics, step=train_state.step)
+ progress_bar.update(train_state.step - progress_bar.n) # type: ignore
+ if config.ema:
+ ema_helper.update(train_state.model)
+
+ if _iter_id >= config.min_eval_interval:
+ ############ Evaluation
+ if RANK == 0:
+ print("EVALUATE")
+ if config.ema:
+ print("SWITCH TO EMA")
+ train_state_eval = copy.deepcopy(train_state)
+ train_state_eval.model = ema_helper.ema_copy(train_state_eval.model)
+ else:
+ train_state_eval = train_state
+ train_state_eval.model.eval()
+ metrics = evaluate(config,
+ train_state_eval,
+ eval_loader,
+ eval_metadata,
+ evaluators,
+ rank=RANK,
+ world_size=WORLD_SIZE,
+ cpu_group=CPU_PROCESS_GROUP)
+
+ if RANK == 0 and metrics is not None:
+ wandb.log(metrics, step=train_state.step)
+
+ ############ Checkpointing
+ if RANK == 0:
+ print("SAVE CHECKPOINT")
+ if RANK == 0 and (config.checkpoint_every_eval or (_iter_id == total_iters - 1)):
+ save_train_state(config, train_state_eval)
+
+ if config.ema:
+ del train_state_eval
+
+ # finalize
+ if dist.is_initialized():
+ dist.destroy_process_group()
+ wandb.finish()
+
+
+if __name__ == "__main__":
+ launch()