""" The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. """ from __future__ import annotations import copy import glob import io import math import os import random import subprocess import sys import time import uuid import zlib from pathlib import Path import numpy as np import sentencepiece as spm import torch import torch.distributed as dist import torch.nn.functional as F from torch import Tensor, nn from torch.nn.parallel import DistributedDataParallel as DDP # ----------------------------- # HYPERPARAMETERS # ----------------------------- # Default Simple Baseline run: # - 9 transformer blocks at width 512 # - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion # - vocab size 1024, sequence length 1024, tied embeddings # - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap class Hyperparameters: # Data paths are shard globs produced by the existing preprocessing pipeline. data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") train_files = os.path.join(data_path, "fineweb_train_*.bin") val_files = os.path.join(data_path, "fineweb_val_*.bin") tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) seed = int(os.environ.get("SEED", 1337)) resume_from = os.environ.get("RESUME_FROM", "") # Validation cadence and batch size. Validation always uses the full fineweb_val split. val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) num_layers = int(os.environ.get("NUM_LAYERS", 9)) num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) model_dim = int(os.environ.get("MODEL_DIM", 512)) num_heads = int(os.environ.get("NUM_HEADS", 8)) mlp_mult = int(os.environ.get("MLP_MULT", 2)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) # Optimizer hyperparameters. embed_lr = float(os.environ.get("EMBED_LR", 0.6)) head_lr = float(os.environ.get("HEAD_LR", 0.008)) tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) beta1 = float(os.environ.get("BETA1", 0.9)) beta2 = float(os.environ.get("BETA2", 0.95)) adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) # Test-time training (LoRA) hyperparameters. ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.01)) ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 256)) ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 1024)) ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) # ----------------------------- # MUON OPTIMIZER # ----------------------------- # # As borrowed from modded-nanogpt # Background on Muon: https://kellerjordan.github.io/posts/muon/ def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. # Muon uses this to normalize matrix-shaped gradients before applying them. a, b, c = (3.4445, -4.7750, 2.0315) X = G.bfloat16() X /= X.norm() + eps transposed = G.size(0) > G.size(1) if transposed: X = X.T for _ in range(steps): A = X @ X.T B = b * A + c * A @ A X = a * X + B @ X return X.T if transposed else X class Muon(torch.optim.Optimizer): def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): super().__init__( params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), ) @torch.no_grad() def step(self, closure=None): loss = None if closure is not None: with torch.enable_grad(): loss = closure() distributed = dist.is_available() and dist.is_initialized() world_size = dist.get_world_size() if distributed else 1 rank = dist.get_rank() if distributed else 0 for group in self.param_groups: params = group["params"] if not params: continue lr = group["lr"] momentum = group["momentum"] backend_steps = group["backend_steps"] nesterov = group["nesterov"] total_params = sum(int(p.numel()) for p in params) updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) curr = 0 for i, p in enumerate(params): if i % world_size == rank and p.grad is not None: g = p.grad state = self.state[p] if "momentum_buffer" not in state: state["momentum_buffer"] = torch.zeros_like(g) buf = state["momentum_buffer"] buf.mul_(momentum).add_(g) if nesterov: g = g.add(buf, alpha=momentum) g = zeropower_via_newtonschulz5(g, steps=backend_steps) # Scale correction from Muon reference implementations. g *= max(1, g.size(0) / g.size(1)) ** 0.5 updates_flat[curr : curr + p.numel()] = g.reshape(-1) curr += p.numel() if distributed: dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) curr = 0 for p in params: g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) p.add_(g, alpha=-lr) curr += p.numel() return loss # ----------------------------- # TOKENIZER-AGNOSTIC EVALUATION SETUP # ----------------------------- # # It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. # Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. # We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. # Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. def build_sentencepiece_luts( sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device ) -> tuple[Tensor, Tensor, Tensor]: sp_vocab_size = int(sp.vocab_size()) table_size = max(sp_vocab_size, vocab_size) base_bytes_np = np.zeros((table_size,), dtype=np.int16) has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) for token_id in range(sp_vocab_size): if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): continue is_boundary_token_np[token_id] = False if sp.is_byte(token_id): base_bytes_np[token_id] = 1 continue piece = sp.id_to_piece(token_id) if piece.startswith("▁"): has_leading_space_np[token_id] = True piece = piece[1:] base_bytes_np[token_id] = len(piece.encode("utf-8")) return ( torch.tensor(base_bytes_np, dtype=torch.int16, device=device), torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), ) def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: files = [Path(p) for p in sorted(glob.glob(pattern))] if not files: raise FileNotFoundError(f"No files found for pattern: {pattern}") # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() usable = ((tokens.numel() - 1) // seq_len) * seq_len if usable <= 0: raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") return tokens[: usable + 1] def eval_val( args: Hyperparameters, model: nn.Module, rank: int, world_size: int, device: torch.device, grad_accum_steps: int, val_tokens: Tensor, base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, ) -> tuple[float, float]: # Validation computes two metrics: # - val_loss: token cross-entropy (natural log) # - val_bpb: tokenizer-agnostic compression metric used by the challenge local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) if local_batch_tokens < args.train_seq_len: raise ValueError( "VAL_BATCH_SIZE must provide at least one sequence per rank; " f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" ) local_batch_seqs = local_batch_tokens // args.train_seq_len total_seqs = (val_tokens.numel() - 1) // args.train_seq_len seq_start = (total_seqs * rank) // world_size seq_end = (total_seqs * (rank + 1)) // world_size val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) val_token_count = torch.zeros((), device=device, dtype=torch.float64) val_byte_count = torch.zeros((), device=device, dtype=torch.float64) model.eval() with torch.inference_mode(): for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) raw_start = batch_seq_start * args.train_seq_len raw_end = batch_seq_end * args.train_seq_len + 1 local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) x = local[:-1].reshape(-1, args.train_seq_len) y = local[1:].reshape(-1, args.train_seq_len) with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): batch_loss = model(x, y).detach() batch_token_count = float(y.numel()) val_loss_sum += batch_loss.to(torch.float64) * batch_token_count val_token_count += batch_token_count prev_ids = x.reshape(-1) tgt_ids = y.reshape(-1) token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) val_byte_count += token_bytes.to(torch.float64).sum() if dist.is_available() and dist.is_initialized(): dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) val_loss = val_loss_sum / val_token_count bits_per_token = val_loss.item() / math.log(2.0) tokens_per_byte = val_token_count.item() / val_byte_count.item() model.train() return float(val_loss.item()), float(bits_per_token * tokens_per_byte) # ----------------------------- # POST-TRAINING QUANTIZATION # ----------------------------- # # It's silly to export our model, which is trained in bf16 and fp32, at that same precision. # Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. # We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. CONTROL_TENSOR_NAME_PATTERNS = tuple( pattern for pattern in os.environ.get( "CONTROL_TENSOR_NAME_PATTERNS", "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", ).split(",") if pattern ) INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( pattern for pattern in os.environ.get( "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", ",".join(CONTROL_TENSOR_NAME_PATTERNS), ).split(",") if pattern ) INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 INT8_PER_ROW_SCALE_DTYPE = torch.float16 INT8_CLIP_PERCENTILE = 99.99984 INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 def tensor_nbytes(t: Tensor) -> int: return int(t.numel()) * int(t.element_size()) def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): return t.float().contiguous() if t.dtype in {torch.float32, torch.bfloat16}: passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() return t def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: t32 = t.float() if t32.ndim == 2: # Matrices get one scale per row, which usually tracks output-channel # ranges much better than a single tensor-wide scale. clip_abs = ( torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() else torch.empty((t32.shape[0],), dtype=torch.float32) ) clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() # Vectors / scalars use a simpler per-tensor scale. clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() return q, scale def quantize_state_dict_int8(state_dict: dict[str, Tensor]): # Single supported clean-script export format: # - per-row int8 for 2D float tensors # - per-tensor int8 for other float tensors # - exact passthrough for non-floats # - passthrough for small float tensors, stored as fp16 to save bytes quantized: dict[str, Tensor] = {} scales: dict[str, Tensor] = {} dtypes: dict[str, str] = {} passthrough: dict[str, Tensor] = {} passthrough_orig_dtypes: dict[str, str] = {} qmeta: dict[str, dict[str, object]] = {} stats = dict.fromkeys( ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), 0, ) for name, tensor in state_dict.items(): t = tensor.detach().to("cpu").contiguous() stats["param_count"] += int(t.numel()) stats["num_tensors"] += 1 stats["baseline_tensor_bytes"] += tensor_nbytes(t) if not t.is_floating_point(): stats["num_nonfloat_tensors"] += 1 passthrough[name] = t stats["int8_payload_bytes"] += tensor_nbytes(t) continue # Small float tensors are cheap enough to keep directly. We still downcast # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: kept = keep_float_tensor(name, t, passthrough_orig_dtypes) passthrough[name] = kept stats["int8_payload_bytes"] += tensor_nbytes(kept) continue stats["num_float_tensors"] += 1 q, s = quantize_float_tensor(t) if s.ndim > 0: qmeta[name] = {"scheme": "per_row", "axis": 0} quantized[name] = q scales[name] = s dtypes[name] = str(t.dtype).removeprefix("torch.") stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) obj: dict[str, object] = { "__quant_format__": "int8_clean_per_row_v1", "quantized": quantized, "scales": scales, "dtypes": dtypes, "passthrough": passthrough, } if qmeta: obj["qmeta"] = qmeta if passthrough_orig_dtypes: obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes return obj, stats def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: out: dict[str, Tensor] = {} qmeta = obj.get("qmeta", {}) passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) for name, q in obj["quantized"].items(): dtype = getattr(torch, obj["dtypes"][name]) s = obj["scales"][name] if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: s = s.to(dtype=torch.float32) # Broadcast the saved row scale back across trailing dimensions. out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() else: scale = float(s.item()) out[name] = (q.float() * scale).to(dtype=dtype).contiguous() for name, t in obj["passthrough"].items(): # Restore small tensors, undoing the temporary fp16 storage cast if needed. out_t = t.detach().to("cpu").contiguous() orig_dtype = passthrough_orig_dtypes.get(name) if isinstance(orig_dtype, str): out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() out[name] = out_t return out # ----------------------------- # DATA LOADING # ----------------------------- def load_data_shard(file: Path) -> Tensor: header_bytes = 256 * np.dtype(" None: self.file_idx = (self.file_idx + 1) % len(self.files) self.tokens = load_data_shard(self.files[self.file_idx]) self.pos = 0 def take(self, n: int) -> Tensor: chunks: list[Tensor] = [] remaining = n while remaining > 0: avail = self.tokens.numel() - self.pos if avail <= 0: self._advance_file() continue k = min(remaining, avail) chunks.append(self.tokens[self.pos : self.pos + k]) self.pos += k remaining -= k return chunks[0] if len(chunks) == 1 else torch.cat(chunks) class DistributedTokenLoader: # Each call consumes a contiguous chunk from the shared token stream, then slices out # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): self.rank = rank self.world_size = world_size self.device = device self.stream = TokenStream(pattern) def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: local_tokens = global_tokens // (self.world_size * grad_accum_steps) per_rank_span = local_tokens + 1 chunk = self.stream.take(per_rank_span * self.world_size) start = self.rank * per_rank_span local = chunk[start : start + per_rank_span].to(dtype=torch.int64) x = local[:-1].reshape(-1, seq_len) y = local[1:].reshape(-1, seq_len) return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) # ----------------------------- # TRANSFORMER MODULES # ----------------------------- class RMSNorm(nn.Module): def __init__(self, eps: float | None = None): super().__init__() self.eps = eps def forward(self, x: Tensor) -> Tensor: return F.rms_norm(x, (x.size(-1),), eps=self.eps) class CastedLinear(nn.Linear): # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. def forward(self, x: Tensor) -> Tensor: bias = self.bias.to(x.dtype) if self.bias is not None else None return F.linear(x, self.weight.to(x.dtype), bias) def restore_low_dim_params_to_fp32(module: nn.Module) -> None: # Keep small/control parameters in fp32 even when the model body runs in bf16. with torch.no_grad(): for name, param in module.named_parameters(): if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: param.data = param.data.float() class Rotary(nn.Module): # Caches cos/sin tables per sequence length on the current device. def __init__(self, dim: int, base: float = 10000.0): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._seq_len_cached = 0 self._cos_cached: Tensor | None = None self._sin_cached: Tensor | None = None def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: if ( self._cos_cached is None or self._sin_cached is None or self._seq_len_cached != seq_len or self._cos_cached.device != device ): t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) freqs = torch.outer(t, self.inv_freq.to(device)) self._cos_cached = freqs.cos()[None, None, :, :] self._sin_cached = freqs.sin()[None, None, :, :] self._seq_len_cached = seq_len return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: half = x.size(-1) // 2 x1, x2 = x[..., :half], x[..., half:] return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) class CausalSelfAttention(nn.Module): def __init__( self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float, ): super().__init__() if dim % num_heads != 0: raise ValueError("model_dim must be divisible by num_heads") if num_heads % num_kv_heads != 0: raise ValueError("num_heads must be divisible by num_kv_heads") self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.head_dim = dim // num_heads if self.head_dim % 2 != 0: raise ValueError("head_dim must be even for RoPE") kv_dim = self.num_kv_heads * self.head_dim self.c_q = CastedLinear(dim, dim, bias=False) self.c_k = CastedLinear(dim, kv_dim, bias=False) self.c_v = CastedLinear(dim, kv_dim, bias=False) self.proj = CastedLinear(dim, dim, bias=False) self.proj._zero_init = True self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) self.rotary = Rotary(self.head_dim, base=rope_base) def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: bsz, seqlen, dim = x.shape q = self.c_q(x) + (q_delta if q_delta is not None else 0) k = self.c_k(x) v = self.c_v(x) + (v_delta if v_delta is not None else 0) q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) q = F.rms_norm(q, (q.size(-1),)) k = F.rms_norm(k, (k.size(-1),)) cos, sin = self.rotary(seqlen, x.device, q.dtype) q = apply_rotary_emb(q, cos, sin) k = apply_rotary_emb(k, cos, sin) q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] y = F.scaled_dot_product_attention( q, k, v, attn_mask=None, is_causal=True, enable_gqa=(self.num_kv_heads != self.num_heads), ) y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) return self.proj(y) class MLP(nn.Module): # relu^2 MLP from the original modded-nanogpt setup def __init__(self, dim: int, mlp_mult: int): super().__init__() hidden = mlp_mult * dim self.fc = CastedLinear(dim, hidden, bias=False) self.proj = CastedLinear(hidden, dim, bias=False) self.proj._zero_init = True def forward(self, x: Tensor) -> Tensor: x = torch.relu(self.fc(x)) return self.proj(x.square()) class Block(nn.Module): def __init__( self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, rope_base: float, qk_gain_init: float, ): super().__init__() self.attn_norm = RMSNorm() self.mlp_norm = RMSNorm() self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) self.mlp = MLP(dim, mlp_mult) self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) def forward(self, x: Tensor, x0: Tensor, q_delta_fn=None, v_delta_fn=None) -> Tensor: mix = self.resid_mix.to(dtype=x.dtype) x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 n = self.attn_norm(x) qd = q_delta_fn(n) if q_delta_fn is not None else None vd = v_delta_fn(n) if v_delta_fn is not None else None attn_out = self.attn(n, qd, vd) x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) return x class GPT(nn.Module): def __init__( self, vocab_size: int, num_layers: int, model_dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, tie_embeddings: bool, tied_embed_init_std: float, logit_softcap: float, rope_base: float, qk_gain_init: float, ): super().__init__() if logit_softcap <= 0.0: raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") self.tie_embeddings = tie_embeddings self.tied_embed_init_std = tied_embed_init_std self.logit_softcap = logit_softcap self.tok_emb = nn.Embedding(vocab_size, model_dim) self.num_encoder_layers = num_layers // 2 self.num_decoder_layers = num_layers - self.num_encoder_layers self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) self.blocks = nn.ModuleList( [ Block( model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, ) for i in range(num_layers) ] ) self.final_norm = RMSNorm() self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) if self.lm_head is not None: self.lm_head._zero_init = True self._init_weights() def _init_weights(self) -> None: if self.tie_embeddings: nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) for module in self.modules(): if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): nn.init.zeros_(module.weight) def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: x = self.tok_emb(input_ids) x = F.rms_norm(x, (x.size(-1),)) x0 = x skips: list[Tensor] = [] # First half stores skips; second half reuses them in reverse order. for i in range(self.num_encoder_layers): qd = lora.q_loras[i] if lora else None vd = lora.v_loras[i] if lora else None x = self.blocks[i](x, x0, qd, vd) skips.append(x) for i in range(self.num_decoder_layers): bi = self.num_encoder_layers + i if skips: x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() qd = lora.q_loras[bi] if lora else None vd = lora.v_loras[bi] if lora else None x = self.blocks[bi](x, x0, qd, vd) x = self.final_norm(x) if self.tie_embeddings: logits = F.linear(x, self.tok_emb.weight) else: logits = self.lm_head(x) logits = logits + (lora.lm_head_lora(x) if lora else 0) logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) if lora: bsz, sl, V = logits.shape return F.cross_entropy( logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none").reshape(bsz, sl) return F.cross_entropy(logits.float().reshape(-1, logits.size(-1)), target_ids.reshape(-1), reduction="mean") # ----------------------------- # TEST-TIME TRAINING (LoRA) # ----------------------------- # # At evaluation time, we adapt per-document low-rank adapters on the validation data. # Each document gets its own adapter, so there is no inter-document dependency. BOS_ID = 1 class BatchedLinearLoRA(nn.Module): """LoRA for a linear layer, with independent weights per batch element. Computes x @ Aᵀ @ Bᵀ = x @ (BA)ᵀ, i.e. the LoRA delta is ΔW = BA.""" def __init__(self, bsz: int, in_features: int, out_features: int, rank: int): super().__init__() self.in_features = in_features self.A = nn.Parameter(torch.empty(bsz, rank, in_features)) # down-projection self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) # up-projection self.reset() def forward(self, x: Tensor) -> Tensor: return (x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2) # (bsz, T, out) def reset(self) -> None: bound = 1.0 / math.sqrt(self.in_features) with torch.no_grad(): self.A.uniform_(-bound, bound) # kaiming-uniform self.B.zero_() class BatchedTTTLoRA(nn.Module): """All LoRA adapters for one batch: LM head and Q/V per block.""" def __init__(self, bsz: int, model: GPT, rank: int): super().__init__() dim = model.tok_emb.embedding_dim vocab = model.tok_emb.num_embeddings self.lm_head_lora = BatchedLinearLoRA(bsz, dim, vocab, rank) self.q_loras = nn.ModuleList() self.v_loras = nn.ModuleList() for block in model.blocks: self.q_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_q.weight.shape[0], rank)) self.v_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_v.weight.shape[0], rank)) def reset(self) -> None: for m in self.modules(): if isinstance(m, BatchedLinearLoRA): m.reset() def _reset_ttt_optimizer(opt): for group in opt.param_groups: for p in group['params']: s = opt.state.get(p) if not s: # Fresh state. continue s['exp_avg'].zero_() s['exp_avg_sq'].zero_() s['step'].fill_(0) def _build_ttt_optimizer(lora, args: Hyperparameters): return torch.optim.Adam(lora.parameters(), lr=args.ttt_lora_lr, betas=(args.beta1, args.beta2), eps=1e-10) def _find_docs(all_tokens: Tensor, include_next_bos: bool = True) -> list[tuple[int, int]]: """Return (start_offset, length) for each document, identified by BOS boundaries. If include_next_bos is True, include next document's BOS (to match continuous-stream eval token count exactly). """ bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() docs = [] for i in range(len(bos_positions)): start = int(bos_positions[i]) end = int(bos_positions[i + 1]) if i + 1 < len(bos_positions) else all_tokens.numel() if include_next_bos and i + 1 < len(bos_positions): end += 1 assert end - start >= 2 docs.append((start, end - start)) return docs def _compute_chunk_window(ci: int, pred_len: int, num_chunks: int, chunk_size: int, eval_seq_len: int): """Return (win_start, win_len, chunk_offset, chunk_len) for chunk `ci` of a doc.""" chunk_start = ci * chunk_size chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size win_start = max(0, chunk_end - eval_seq_len) win_len = chunk_end - win_start chunk_offset = chunk_start - win_start chunk_len = chunk_end - chunk_start return win_start, win_len, chunk_offset, chunk_len def _accumulate_bpb( ptl: Tensor, x: Tensor, y: Tensor, batch_i: int, chunk_offset: int, chunk_len: int, base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, loss_sum: Tensor, byte_sum: Tensor, token_count: Tensor, ): """Add one doc-chunk's contribution to the running BPB accumulators.""" lbl = ptl[batch_i, chunk_offset:chunk_offset + chunk_len].to(torch.float64) prev = x[batch_i, chunk_offset:chunk_offset + chunk_len] tgt = y[batch_i, chunk_offset:chunk_offset + chunk_len] tok_bytes = base_bytes_lut[tgt].to(torch.float64) tok_bytes += has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev] loss_sum += lbl.sum() byte_sum += tok_bytes.sum() token_count += chunk_len def eval_val_ttt_lora( args: Hyperparameters, base_model: GPT, rank: int, world_size: int, device: torch.device, base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, ) -> tuple[float, float]: """Evaluate with batched LoRA test-time training. Returns (val_loss, val_bpb).""" # Load validation tokens and find document boundaries files = sorted(glob.glob(args.val_files)) all_tokens = torch.cat([load_data_shard(Path(f)) for f in files]) docs = _find_docs(all_tokens) # Each rank takes a contiguous slice of documents rank_docs = docs[(len(docs) * rank) // world_size : (len(docs) * (rank + 1)) // world_size] chunk_size = args.ttt_chunk_size eval_seq_len = args.ttt_eval_seq_len batch_size = args.ttt_batch_size lora_rank = args.ttt_lora_rank rank_docs.sort(key=lambda d: (d[1] - 2) // chunk_size) base_model.eval() for p in base_model.parameters(): p.requires_grad_(False) lora = BatchedTTTLoRA(batch_size, base_model, lora_rank).to(device) opt = _build_ttt_optimizer(lora, args) loss_sum = torch.zeros((), device=device, dtype=torch.float64) byte_sum = torch.zeros((), device=device, dtype=torch.float64) token_count = torch.zeros((), device=device, dtype=torch.float64) for bi in range(0, len(rank_docs), batch_size): batch = rank_docs[bi:bi + batch_size] bsz = len(batch) if bsz == batch_size: cur_lora, cur_opt = lora, opt cur_lora.reset() _reset_ttt_optimizer(cur_opt) else: cur_lora = BatchedTTTLoRA(bsz, base_model, lora_rank).to(device) cur_opt = _build_ttt_optimizer(cur_lora, args) pred_lens = [doc_len - 1 for _, doc_len in batch] num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] max_nc = max(num_chunks) for ci in range(max_nc): chunk_stats = _compute_chunk_window(ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len) context_size, chunk_offset = chunk_stats[1], chunk_stats[2] active = [ci < nc for nc in num_chunks] needs_train = any(ci < nc - 1 for nc in num_chunks) x = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) y = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) doc_info = [] # (chunk_offset, chunk_len) per doc for b in range(bsz): if not active[b]: doc_info.append((0, 0)) continue ds, dl = batch[b] ws, wl, co, cl = _compute_chunk_window(ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len) chunk = all_tokens[ds + ws: ds + ws + wl + 1] toks = chunk.to(dtype=torch.int64, device=device) x[b, :wl] = toks[:-1] y[b, :wl] = toks[1:] doc_info.append((co, cl)) # Forward pass (keep grad graph alive only when we need to train) if needs_train: with torch.autocast(device_type="cuda", dtype=torch.bfloat16): ptl = base_model(x, y, lora=cur_lora) else: with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): ptl = base_model(x, y, lora=cur_lora) # Score: accumulate loss and byte counts for BPB (before training on chunk) with torch.no_grad(): for b in range(bsz): if not active[b]: continue co, cl = doc_info[b] _accumulate_bpb( ptl, x, y, b, co, cl, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, loss_sum, byte_sum, token_count) # Train: one Adam step on the LoRA params using this chunk's loss if needs_train: mask = torch.tensor([float(ci < num_chunks[b] - 1) for b in range(bsz)], device=device) per_doc = ptl[:, chunk_offset:chunk_offset + chunk_size].mean(dim=-1) cur_opt.zero_grad() (per_doc * mask).sum().backward() cur_opt.step() if dist.is_available() and dist.is_initialized(): dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) dist.all_reduce(token_count, op=dist.ReduceOp.SUM) val_loss = float(loss_sum.item() / token_count.item()) val_bpb = float((loss_sum.item() / math.log(2.0)) / byte_sum.item()) return val_loss, val_bpb # ----------------------------- # TRAINING # ----------------------------- def main() -> None: global zeropower_via_newtonschulz5 code = Path(__file__).read_text(encoding="utf-8") args = Hyperparameters() zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) # ----------------------------- # DISTRIBUTED + CUDA SETUP # ----------------------------- distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ rank = int(os.environ.get("RANK", "0")) world_size = int(os.environ.get("WORLD_SIZE", "1")) local_rank = int(os.environ.get("LOCAL_RANK", "0")) if world_size <= 0: raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") if 8 % world_size != 0: raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") grad_accum_steps = 8 // world_size grad_scale = 1.0 / grad_accum_steps if not torch.cuda.is_available(): raise RuntimeError("CUDA is required") device = torch.device("cuda", local_rank) torch.cuda.set_device(device) if distributed: dist.init_process_group(backend="nccl", device_id=device) dist.barrier() master_process = rank == 0 # Fast math knobs torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp enable_cudnn_sdp(False) enable_flash_sdp(True) enable_mem_efficient_sdp(False) enable_math_sdp(False) logfile = None if master_process: os.makedirs("logs", exist_ok=True) logfile = f"logs/{args.run_id}.txt" print(logfile) def log0(msg: str, console: bool = True) -> None: if not master_process: return if console: print(msg) if logfile is not None: with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) log0(code, console=False) log0("=" * 100, console=False) log0(f"Running Python {sys.version}", console=False) log0(f"Running PyTorch {torch.__version__}", console=False) log0( subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, console=False, ) log0("=" * 100, console=False) # ----------------------------- # TOKENIZER + VALIDATION METRIC SETUP # ----------------------------- random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) if not args.tokenizer_path.endswith(".model"): raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) if int(sp.vocab_size()) != args.vocab_size: raise ValueError( f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" ) dataset_dir = Path(args.data_path).resolve() actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( sp, args.vocab_size, device ) log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") # ----------------------------- # MODEL + OPTIMIZER SETUP # ----------------------------- base_model = GPT( vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, ).to(device).bfloat16() for module in base_model.modules(): if isinstance(module, CastedLinear): module.float() if isinstance(module, Rotary): module.inv_freq.data = module.inv_freq.data.float() restore_low_dim_params_to_fp32(base_model) compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model # Optimizer split: # - token embedding (Adam) uses EMBED_LR # - untied lm_head (Adam) uses HEAD_LR # - matrix params in transformer blocks use MATRIX_LR via Muon # - vectors/scalars use SCALAR_LR via Adam block_named_params = list(base_model.blocks.named_parameters()) matrix_params = [ p for name, p in block_named_params if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) ] scalar_params = [ p for name, p in block_named_params if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) ] if base_model.skip_weights.numel() > 0: scalar_params.append(base_model.skip_weights) token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr optimizer_tok = torch.optim.Adam( [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True, ) optimizer_muon = Muon( matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, backend_steps=args.muon_backend_steps, ) for group in optimizer_muon.param_groups: group["base_lr"] = args.matrix_lr optimizer_scalar = torch.optim.Adam( [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True, ) optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] if base_model.lm_head is not None: optimizer_head = torch.optim.Adam( [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True, ) optimizers.insert(1, optimizer_head) n_params = sum(p.numel() for p in base_model.parameters()) log0(f"model_params:{n_params}") log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") log0( f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" ) log0( f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" ) log0(f"seed:{args.seed}") # ----------------------------- # DATA LOADER & MODEL WARMUP # ----------------------------- train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) def zero_grad_all() -> None: for opt in optimizers: opt.zero_grad(set_to_none=True) max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None def lr_mul(step: int, elapsed_ms: float) -> float: if args.warmdown_iters <= 0: return 1.0 if max_wallclock_ms is None: warmdown_start = max(args.iterations - args.warmdown_iters, 0) return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 step_ms = elapsed_ms / max(step, 1) warmdown_ms = args.warmdown_iters * step_ms remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 # Warmup primes the compiled forward/backward/optimizer paths, then we restore the # initial weights/optimizer state so measured training starts from the true init. if args.warmup_steps > 0: initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] model.train() for warmup_step in range(args.warmup_steps): zero_grad_all() for micro_step in range(grad_accum_steps): if distributed: model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): warmup_loss = model(x, y) (warmup_loss * grad_scale).backward() for opt in optimizers: opt.step() zero_grad_all() if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") base_model.load_state_dict(initial_model_state, strict=True) for opt, state in zip(optimizers, initial_optimizer_states, strict=True): opt.load_state_dict(state) zero_grad_all() if distributed: model.require_backward_grad_sync = True train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) # ----------------------------- # MAIN TRAINING LOOP # ----------------------------- training_time_ms = 0.0 stop_after_step: int | None = None torch.cuda.synchronize() t0 = time.perf_counter() step = 0 while True: last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) if should_validate: torch.cuda.synchronize() training_time_ms += 1000.0 * (time.perf_counter() - t0) val_loss, val_bpb = eval_val( args, model, rank, world_size, device, grad_accum_steps, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, ) log0( f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" ) torch.cuda.synchronize() t0 = time.perf_counter() if last_step: if stop_after_step is not None and step < args.iterations: log0( f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " f"step:{step}/{args.iterations}" ) break elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) scale = lr_mul(step, elapsed_ms) zero_grad_all() train_loss = torch.zeros((), device=device) for micro_step in range(grad_accum_steps): if distributed: model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): loss = model(x, y) train_loss += loss.detach() (loss * grad_scale).backward() train_loss /= grad_accum_steps frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum for group in optimizer_muon.param_groups: group["momentum"] = muon_momentum for opt in optimizers: for group in opt.param_groups: group["lr"] = group["base_lr"] * scale if args.grad_clip_norm > 0: torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) for opt in optimizers: opt.step() zero_grad_all() step += 1 approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) should_log_train = ( args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) ) if should_log_train: log0( f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" ) # Needed to sync whether we've reached the wallclock cap. reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms if distributed and max_wallclock_ms is not None: reached_cap_tensor = torch.tensor(int(reached_cap), device=device) dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) reached_cap = bool(reached_cap_tensor.item()) if stop_after_step is None and reached_cap: stop_after_step = step log0( f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" ) # ----------------------------- # SERIALIZATION + ROUNDTRIP VALIDATION # ----------------------------- # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce # the compressed int8+zlib artifact and validate the round-tripped weights. if master_process: torch.save(base_model.state_dict(), "final_model.pt") model_bytes = os.path.getsize("final_model.pt") code_bytes = len(code.encode("utf-8")) log0(f"Serialized model: {model_bytes} bytes") log0(f"Code size: {code_bytes} bytes") log0(f"Total submission size: {model_bytes + code_bytes} bytes") quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) quant_buf = io.BytesIO() torch.save(quant_obj, quant_buf) quant_raw = quant_buf.getvalue() quant_blob = zlib.compress(quant_raw, level=9) quant_raw_bytes = len(quant_raw) if master_process: with open("final_model.int8.ptz", "wb") as f: f.write(quant_blob) quant_file_bytes = os.path.getsize("final_model.int8.ptz") code_bytes = len(code.encode("utf-8")) ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) log0( f"Serialized model int8+zlib: {quant_file_bytes} bytes " f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" ) log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") if distributed: dist.barrier() with open("final_model.int8.ptz", "rb") as f: quant_blob_disk = f.read() quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) torch.cuda.synchronize() t_qeval = time.perf_counter() q_val_loss, q_val_bpb = eval_val( args, model, rank, world_size, device, grad_accum_steps, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, ) torch.cuda.synchronize() log0( f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" ) log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") # LoRA test-time training evaluation (the competition score) torch._dynamo.reset() torch.cuda.synchronize() t_ttt = time.perf_counter() ttt_val_loss, ttt_val_bpb = eval_val_ttt_lora( args, base_model, rank, world_size, device, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, ) torch.cuda.synchronize() log0( f"final_int8_ttt_lora val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f} " f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" ) if distributed: dist.destroy_process_group() if __name__ == "__main__": main() ==================================================================================================== Running Python 3.12.13 (main, Mar 10 2026, 18:17:25) [Clang 21.1.4 ] Running PyTorch 2.10.0+cu128 Thu Mar 19 11:15:42 2026 +-----------------------------------------------------------------------------------------+ | NVIDIA-SMI 570.211.01 Driver Version: 570.211.01 CUDA Version: 12.8 | |-----------------------------------------+------------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+========================+======================| | 0 NVIDIA H100 80GB HBM3 On | 00000000:04:00.0 Off | 0 | | N/A 41C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+ | 1 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | 0 | | N/A 35C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+ | 2 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | | N/A 40C P0 123W / 700W | 1519MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+ | 3 NVIDIA H100 80GB HBM3 On | 00000000:0B:00.0 Off | 0 | | N/A 35C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+ | 4 NVIDIA H100 80GB HBM3 On | 00000000:84:00.0 Off | 0 | | N/A 41C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+ | 5 NVIDIA H100 80GB HBM3 On | 00000000:85:00.0 Off | 0 | | N/A 36C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+ | 6 NVIDIA H100 80GB HBM3 On | 00000000:8A:00.0 Off | 0 | | N/A 37C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+ | 7 NVIDIA H100 80GB HBM3 On | 00000000:8B:00.0 Off | 0 | | N/A 34C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+ +-----------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=========================================================================================| | 0 N/A N/A 55839 C ...ai-codegolf/.venv/bin/python3 1510MiB | | 1 N/A N/A 55840 C ...ai-codegolf/.venv/bin/python3 1510MiB | | 2 N/A N/A 55841 C ...ai-codegolf/.venv/bin/python3 1510MiB | | 3 N/A N/A 55842 C ...ai-codegolf/.venv/bin/python3 1510MiB | | 4 N/A N/A 55843 C ...ai-codegolf/.venv/bin/python3 1510MiB | | 5 N/A N/A 55844 C ...ai-codegolf/.venv/bin/python3 1510MiB | | 6 N/A N/A 55845 C ...ai-codegolf/.venv/bin/python3 1510MiB | | 7 N/A N/A 55846 C ...ai-codegolf/.venv/bin/python3 1510MiB | +-----------------------------------------------------------------------------------------+ ==================================================================================================== val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model train_loader:dataset:fineweb10B_sp1024 train_shards:25 val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 model_params:17059912 world_size:8 grad_accum_steps:1 sdp_backends:cudnn=False flash=True mem_efficient=False math=False attention_mode:gqa num_heads:8 num_kv_heads:4 tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 seed:1337 warmup_step:1/20 warmup_step:2/20 warmup_step:3/20 warmup_step:4/20 warmup_step:5/20 warmup_step:6/20 warmup_step:7/20 warmup_step:8/20 warmup_step:9/20 warmup_step:10/20 warmup_step:11/20 warmup_step:12/20 warmup_step:13/20 warmup_step:14/20 warmup_step:15/20 warmup_step:16/20 warmup_step:17/20 warmup_step:18/20 warmup_step:19/20 warmup_step:20/20 step:0/20000 val_loss:6.9357 val_bpb:4.1077 train_time:0ms step_avg:0.01ms step:1/20000 train_loss:6.9370 train_time:24ms step_avg:24.04ms step:2/20000 train_loss:16.8367 train_time:66ms step_avg:32.99ms step:3/20000 train_loss:8.7608 train_time:110ms step_avg:36.59ms step:4/20000 train_loss:6.6385 train_time:153ms step_avg:38.28ms step:5/20000 train_loss:6.6121 train_time:197ms step_avg:39.34ms step:6/20000 train_loss:7.4220 train_time:241ms step_avg:40.08ms step:7/20000 train_loss:6.3502 train_time:284ms step_avg:40.58ms step:8/20000 train_loss:6.1582 train_time:328ms step_avg:40.98ms step:9/20000 train_loss:6.0679 train_time:371ms step_avg:41.26ms step:10/20000 train_loss:5.9745 train_time:415ms step_avg:41.52ms step:50/20000 train_loss:4.0980 train_time:2156ms step_avg:43.12ms step:100/20000 train_loss:3.4059 train_time:4333ms step_avg:43.33ms step:150/20000 train_loss:3.0556 train_time:6509ms step_avg:43.39ms step:200/20000 train_loss:2.8570 train_time:8746ms step_avg:43.73ms step:200/20000 val_loss:2.8396 val_bpb:1.6817 train_time:8772ms step_avg:43.86ms step:250/20000 train_loss:2.7420 train_time:10923ms step_avg:43.69ms step:300/20000 train_loss:2.4765 train_time:13097ms step_avg:43.66ms step:350/20000 train_loss:2.6724 train_time:15280ms step_avg:43.66ms step:400/20000 train_loss:2.3554 train_time:17522ms step_avg:43.80ms step:400/20000 val_loss:2.5660 val_bpb:1.5197 train_time:17548ms step_avg:43.87ms step:450/20000 train_loss:2.5112 train_time:19694ms step_avg:43.76ms step:500/20000 train_loss:2.4970 train_time:21865ms step_avg:43.73ms step:550/20000 train_loss:2.4029 train_time:24036ms step_avg:43.70ms step:600/20000 train_loss:2.5472 train_time:26269ms step_avg:43.78ms step:600/20000 val_loss:2.4526 val_bpb:1.4526 train_time:26295ms step_avg:43.83ms step:650/20000 train_loss:2.3897 train_time:28441ms step_avg:43.76ms step:700/20000 train_loss:2.4431 train_time:30612ms step_avg:43.73ms step:750/20000 train_loss:2.2780 train_time:32782ms step_avg:43.71ms step:800/20000 train_loss:2.2972 train_time:35024ms step_avg:43.78ms step:800/20000 val_loss:2.3835 val_bpb:1.4116 train_time:35051ms step_avg:43.81ms step:850/20000 train_loss:2.7247 train_time:37196ms step_avg:43.76ms step:900/20000 train_loss:2.3416 train_time:39367ms step_avg:43.74ms step:950/20000 train_loss:2.4075 train_time:41538ms step_avg:43.72ms step:1000/20000 train_loss:2.3741 train_time:43774ms step_avg:43.77ms step:1000/20000 val_loss:2.3353 val_bpb:1.3831 train_time:43800ms step_avg:43.80ms step:1050/20000 train_loss:2.4838 train_time:45945ms step_avg:43.76ms step:1100/20000 train_loss:2.2593 train_time:48115ms step_avg:43.74ms step:1150/20000 train_loss:2.2540 train_time:50353ms step_avg:43.79ms step:1200/20000 train_loss:2.3866 train_time:52523ms step_avg:43.77ms step:1200/20000 val_loss:2.3042 val_bpb:1.3647 train_time:52549ms step_avg:43.79ms step:1250/20000 train_loss:2.2102 train_time:54693ms step_avg:43.75ms step:1300/20000 train_loss:2.3626 train_time:56863ms step_avg:43.74ms step:1350/20000 train_loss:2.2732 train_time:59102ms step_avg:43.78ms step:1400/20000 train_loss:2.4311 train_time:61273ms step_avg:43.77ms step:1400/20000 val_loss:2.2830 val_bpb:1.3521 train_time:61299ms step_avg:43.79ms step:1450/20000 train_loss:2.2370 train_time:63445ms step_avg:43.76ms step:1500/20000 train_loss:2.2247 train_time:65619ms step_avg:43.75ms step:1550/20000 train_loss:2.1566 train_time:67854ms step_avg:43.78ms step:1600/20000 train_loss:2.0967 train_time:70025ms step_avg:43.77ms step:1600/20000 val_loss:2.2676 val_bpb:1.3430 train_time:70051ms step_avg:43.78ms step:1650/20000 train_loss:2.2340 train_time:72198ms step_avg:43.76ms step:1700/20000 train_loss:2.1750 train_time:74368ms step_avg:43.75ms step:1750/20000 train_loss:2.2492 train_time:76604ms step_avg:43.77ms step:1800/20000 train_loss:2.1997 train_time:78778ms step_avg:43.77ms step:1800/20000 val_loss:2.2522 val_bpb:1.3339 train_time:78804ms step_avg:43.78ms step:1850/20000 train_loss:2.3072 train_time:80950ms step_avg:43.76ms step:1900/20000 train_loss:2.1950 train_time:83120ms step_avg:43.75ms step:1950/20000 train_loss:2.2103 train_time:85359ms step_avg:43.77ms step:2000/20000 train_loss:2.2531 train_time:87532ms step_avg:43.77ms step:2000/20000 val_loss:2.2376 val_bpb:1.3252 train_time:87558ms step_avg:43.78ms step:2050/20000 train_loss:2.2545 train_time:89706ms step_avg:43.76ms step:2100/20000 train_loss:2.2682 train_time:91951ms step_avg:43.79ms step:2150/20000 train_loss:2.1910 train_time:94122ms step_avg:43.78ms step:2200/20000 train_loss:2.0775 train_time:96292ms step_avg:43.77ms step:2200/20000 val_loss:2.2287 val_bpb:1.3200 train_time:96318ms step_avg:43.78ms step:2250/20000 train_loss:2.1628 train_time:98463ms step_avg:43.76ms step:2300/20000 train_loss:2.3838 train_time:100706ms step_avg:43.79ms step:2350/20000 train_loss:2.1988 train_time:102878ms step_avg:43.78ms step:2400/20000 train_loss:2.2034 train_time:105048ms step_avg:43.77ms step:2400/20000 val_loss:2.2178 val_bpb:1.3135 train_time:105074ms step_avg:43.78ms step:2450/20000 train_loss:2.2071 train_time:107220ms step_avg:43.76ms step:2500/20000 train_loss:2.1254 train_time:109456ms step_avg:43.78ms step:2550/20000 train_loss:2.1383 train_time:111628ms step_avg:43.78ms step:2600/20000 train_loss:2.4129 train_time:113800ms step_avg:43.77ms step:2600/20000 val_loss:2.2204 val_bpb:1.3150 train_time:113826ms step_avg:43.78ms step:2650/20000 train_loss:2.2454 train_time:115974ms step_avg:43.76ms step:2700/20000 train_loss:2.1558 train_time:118222ms step_avg:43.79ms step:2750/20000 train_loss:2.3613 train_time:120392ms step_avg:43.78ms step:2800/20000 train_loss:2.2343 train_time:122566ms step_avg:43.77ms step:2800/20000 val_loss:2.2031 val_bpb:1.3048 train_time:122592ms step_avg:43.78ms step:2850/20000 train_loss:2.1913 train_time:124738ms step_avg:43.77ms step:2900/20000 train_loss:2.1802 train_time:127001ms step_avg:43.79ms step:2950/20000 train_loss:2.2385 train_time:129172ms step_avg:43.79ms step:3000/20000 train_loss:2.2329 train_time:131345ms step_avg:43.78ms step:3000/20000 val_loss:2.1966 val_bpb:1.3009 train_time:131372ms step_avg:43.79ms step:3050/20000 train_loss:2.1716 train_time:133517ms step_avg:43.78ms step:3100/20000 train_loss:2.2063 train_time:135763ms step_avg:43.79ms step:3150/20000 train_loss:2.1626 train_time:137938ms step_avg:43.79ms step:3200/20000 train_loss:2.1921 train_time:140108ms step_avg:43.78ms step:3200/20000 val_loss:2.1905 val_bpb:1.2973 train_time:140134ms step_avg:43.79ms step:3250/20000 train_loss:2.0906 train_time:142349ms step_avg:43.80ms step:3300/20000 train_loss:2.2413 train_time:144521ms step_avg:43.79ms step:3350/20000 train_loss:2.0963 train_time:146692ms step_avg:43.79ms step:3400/20000 train_loss:2.1598 train_time:148864ms step_avg:43.78ms step:3400/20000 val_loss:2.1874 val_bpb:1.2955 train_time:148890ms step_avg:43.79ms step:3450/20000 train_loss:2.1116 train_time:151109ms step_avg:43.80ms step:3500/20000 train_loss:2.2545 train_time:153278ms step_avg:43.79ms step:3550/20000 train_loss:2.3901 train_time:155449ms step_avg:43.79ms step:3600/20000 train_loss:2.1151 train_time:157620ms step_avg:43.78ms step:3600/20000 val_loss:2.1798 val_bpb:1.2910 train_time:157646ms step_avg:43.79ms step:3650/20000 train_loss:2.2218 train_time:159858ms step_avg:43.80ms step:3700/20000 train_loss:2.1565 train_time:162029ms step_avg:43.79ms step:3750/20000 train_loss:2.1475 train_time:164200ms step_avg:43.79ms step:3800/20000 train_loss:2.2253 train_time:166371ms step_avg:43.78ms step:3800/20000 val_loss:2.1760 val_bpb:1.2888 train_time:166398ms step_avg:43.79ms step:3850/20000 train_loss:2.1791 train_time:168620ms step_avg:43.80ms step:3900/20000 train_loss:1.9959 train_time:170789ms step_avg:43.79ms step:3950/20000 train_loss:2.1253 train_time:172959ms step_avg:43.79ms step:4000/20000 train_loss:2.1645 train_time:175133ms step_avg:43.78ms step:4000/20000 val_loss:2.1710 val_bpb:1.2858 train_time:175159ms step_avg:43.79ms step:4050/20000 train_loss:2.1017 train_time:177373ms step_avg:43.80ms step:4100/20000 train_loss:2.1904 train_time:179543ms step_avg:43.79ms step:4150/20000 train_loss:2.3270 train_time:181715ms step_avg:43.79ms step:4200/20000 train_loss:2.1751 train_time:183961ms step_avg:43.80ms step:4200/20000 val_loss:2.1674 val_bpb:1.2836 train_time:183987ms step_avg:43.81ms step:4250/20000 train_loss:2.1313 train_time:186133ms step_avg:43.80ms step:4300/20000 train_loss:2.0313 train_time:188304ms step_avg:43.79ms step:4350/20000 train_loss:2.2108 train_time:190473ms step_avg:43.79ms step:4400/20000 train_loss:2.1132 train_time:192710ms step_avg:43.80ms step:4400/20000 val_loss:2.1677 val_bpb:1.2839 train_time:192737ms step_avg:43.80ms step:4450/20000 train_loss:2.0662 train_time:194882ms step_avg:43.79ms step:4500/20000 train_loss:2.2548 train_time:197054ms step_avg:43.79ms step:4550/20000 train_loss:2.0614 train_time:199226ms step_avg:43.79ms step:4600/20000 train_loss:1.9724 train_time:201465ms step_avg:43.80ms step:4600/20000 val_loss:2.1629 val_bpb:1.2810 train_time:201492ms step_avg:43.80ms step:4650/20000 train_loss:2.0723 train_time:203638ms step_avg:43.79ms step:4700/20000 train_loss:2.2682 train_time:205809ms step_avg:43.79ms step:4750/20000 train_loss:1.9806 train_time:207980ms step_avg:43.79ms step:4800/20000 train_loss:2.1260 train_time:210219ms step_avg:43.80ms step:4800/20000 val_loss:2.1585 val_bpb:1.2784 train_time:210245ms step_avg:43.80ms step:4850/20000 train_loss:2.2142 train_time:212393ms step_avg:43.79ms step:4900/20000 train_loss:2.4104 train_time:214564ms step_avg:43.79ms step:4950/20000 train_loss:2.1696 train_time:216737ms step_avg:43.79ms step:5000/20000 train_loss:2.1399 train_time:218974ms step_avg:43.79ms step:5000/20000 val_loss:2.1552 val_bpb:1.2765 train_time:219000ms step_avg:43.80ms step:5050/20000 train_loss:2.0838 train_time:221146ms step_avg:43.79ms step:5100/20000 train_loss:2.0876 train_time:223315ms step_avg:43.79ms step:5150/20000 train_loss:2.1514 train_time:225550ms step_avg:43.80ms step:5200/20000 train_loss:2.2388 train_time:227721ms step_avg:43.79ms step:5200/20000 val_loss:2.1522 val_bpb:1.2746 train_time:227747ms step_avg:43.80ms step:5250/20000 train_loss:2.0890 train_time:229893ms step_avg:43.79ms step:5300/20000 train_loss:2.2180 train_time:232064ms step_avg:43.79ms step:5350/20000 train_loss:2.5612 train_time:234307ms step_avg:43.80ms step:5400/20000 train_loss:2.2834 train_time:236479ms step_avg:43.79ms step:5400/20000 val_loss:2.1502 val_bpb:1.2735 train_time:236505ms step_avg:43.80ms step:5450/20000 train_loss:2.1625 train_time:238651ms step_avg:43.79ms step:5500/20000 train_loss:2.1710 train_time:240821ms step_avg:43.79ms step:5550/20000 train_loss:2.1861 train_time:243062ms step_avg:43.79ms step:5600/20000 train_loss:2.1630 train_time:245232ms step_avg:43.79ms step:5600/20000 val_loss:2.1468 val_bpb:1.2715 train_time:245258ms step_avg:43.80ms step:5650/20000 train_loss:2.1339 train_time:247403ms step_avg:43.79ms step:5700/20000 train_loss:2.2647 train_time:249575ms step_avg:43.79ms step:5750/20000 train_loss:2.0960 train_time:251821ms step_avg:43.80ms step:5800/20000 train_loss:2.2417 train_time:253991ms step_avg:43.79ms step:5800/20000 val_loss:2.1449 val_bpb:1.2703 train_time:254018ms step_avg:43.80ms step:5850/20000 train_loss:2.3108 train_time:256163ms step_avg:43.79ms step:5900/20000 train_loss:2.1466 train_time:258333ms step_avg:43.79ms step:5950/20000 train_loss:2.0357 train_time:260573ms step_avg:43.79ms step:6000/20000 train_loss:2.2133 train_time:262743ms step_avg:43.79ms step:6000/20000 val_loss:2.1415 val_bpb:1.2683 train_time:262769ms step_avg:43.79ms step:6050/20000 train_loss:2.0175 train_time:264915ms step_avg:43.79ms step:6100/20000 train_loss:2.2965 train_time:267085ms step_avg:43.78ms step:6150/20000 train_loss:1.9643 train_time:269337ms step_avg:43.79ms step:6200/20000 train_loss:2.1052 train_time:271507ms step_avg:43.79ms step:6200/20000 val_loss:2.1403 val_bpb:1.2676 train_time:271533ms step_avg:43.80ms step:6250/20000 train_loss:2.1371 train_time:273679ms step_avg:43.79ms step:6300/20000 train_loss:1.9363 train_time:275911ms step_avg:43.80ms step:6350/20000 train_loss:2.1653 train_time:278081ms step_avg:43.79ms step:6400/20000 train_loss:2.1143 train_time:280252ms step_avg:43.79ms step:6400/20000 val_loss:2.1402 val_bpb:1.2675 train_time:280279ms step_avg:43.79ms step:6450/20000 train_loss:2.1194 train_time:282423ms step_avg:43.79ms step:6500/20000 train_loss:2.1142 train_time:284668ms step_avg:43.80ms step:6550/20000 train_loss:2.0981 train_time:286840ms step_avg:43.79ms step:6600/20000 train_loss:2.0068 train_time:289010ms step_avg:43.79ms step:6600/20000 val_loss:2.1372 val_bpb:1.2657 train_time:289036ms step_avg:43.79ms step:6650/20000 train_loss:2.2201 train_time:291182ms step_avg:43.79ms step:6700/20000 train_loss:2.1445 train_time:293420ms step_avg:43.79ms step:6750/20000 train_loss:2.1571 train_time:295591ms step_avg:43.79ms step:6800/20000 train_loss:1.9528 train_time:297763ms step_avg:43.79ms step:6800/20000 val_loss:2.1372 val_bpb:1.2658 train_time:297789ms step_avg:43.79ms step:6850/20000 train_loss:2.0690 train_time:299936ms step_avg:43.79ms step:6900/20000 train_loss:2.1382 train_time:302195ms step_avg:43.80ms step:6950/20000 train_loss:2.0363 train_time:304368ms step_avg:43.79ms step:7000/20000 train_loss:2.1949 train_time:306538ms step_avg:43.79ms step:7000/20000 val_loss:2.1330 val_bpb:1.2633 train_time:306564ms step_avg:43.79ms step:7050/20000 train_loss:2.0663 train_time:308710ms step_avg:43.79ms step:7100/20000 train_loss:2.2323 train_time:310948ms step_avg:43.80ms step:7150/20000 train_loss:2.1158 train_time:313119ms step_avg:43.79ms step:7200/20000 train_loss:2.0351 train_time:315289ms step_avg:43.79ms step:7200/20000 val_loss:2.1327 val_bpb:1.2631 train_time:315315ms step_avg:43.79ms step:7250/20000 train_loss:2.0840 train_time:317544ms step_avg:43.80ms step:7300/20000 train_loss:2.1822 train_time:319715ms step_avg:43.80ms step:7350/20000 train_loss:2.2138 train_time:321888ms step_avg:43.79ms step:7400/20000 train_loss:2.1412 train_time:324057ms step_avg:43.79ms step:7400/20000 val_loss:2.1300 val_bpb:1.2615 train_time:324083ms step_avg:43.80ms step:7450/20000 train_loss:2.1683 train_time:326297ms step_avg:43.80ms step:7500/20000 train_loss:2.1322 train_time:328467ms step_avg:43.80ms step:7550/20000 train_loss:2.1391 train_time:330637ms step_avg:43.79ms step:7600/20000 train_loss:2.1593 train_time:332809ms step_avg:43.79ms step:7600/20000 val_loss:2.1284 val_bpb:1.2606 train_time:332835ms step_avg:43.79ms step:7650/20000 train_loss:2.1257 train_time:335067ms step_avg:43.80ms step:7700/20000 train_loss:2.1509 train_time:337238ms step_avg:43.80ms step:7750/20000 train_loss:2.2348 train_time:339409ms step_avg:43.79ms step:7800/20000 train_loss:2.0846 train_time:341580ms step_avg:43.79ms step:7800/20000 val_loss:2.1283 val_bpb:1.2605 train_time:341606ms step_avg:43.80ms step:7850/20000 train_loss:2.1334 train_time:343814ms step_avg:43.80ms step:7900/20000 train_loss:2.0901 train_time:345984ms step_avg:43.80ms step:7950/20000 train_loss:2.1335 train_time:348154ms step_avg:43.79ms step:8000/20000 train_loss:2.1574 train_time:350323ms step_avg:43.79ms step:8000/20000 val_loss:2.1254 val_bpb:1.2588 train_time:350349ms step_avg:43.79ms step:8050/20000 train_loss:2.1555 train_time:352566ms step_avg:43.80ms step:8100/20000 train_loss:2.1871 train_time:354738ms step_avg:43.79ms step:8150/20000 train_loss:2.0576 train_time:356908ms step_avg:43.79ms step:8200/20000 train_loss:2.0271 train_time:359078ms step_avg:43.79ms step:8200/20000 val_loss:2.1271 val_bpb:1.2598 train_time:359105ms step_avg:43.79ms step:8250/20000 train_loss:2.1031 train_time:361316ms step_avg:43.80ms step:8300/20000 train_loss:2.0508 train_time:363485ms step_avg:43.79ms step:8350/20000 train_loss:2.1812 train_time:365655ms step_avg:43.79ms step:8400/20000 train_loss:2.2092 train_time:367890ms step_avg:43.80ms step:8400/20000 val_loss:2.1231 val_bpb:1.2574 train_time:367916ms step_avg:43.80ms step:8450/20000 train_loss:2.1858 train_time:370061ms step_avg:43.79ms step:8500/20000 train_loss:2.1313 train_time:372231ms step_avg:43.79ms step:8550/20000 train_loss:2.1930 train_time:374401ms step_avg:43.79ms step:8600/20000 train_loss:2.1264 train_time:376636ms step_avg:43.79ms step:8600/20000 val_loss:2.1198 val_bpb:1.2555 train_time:376663ms step_avg:43.80ms step:8650/20000 train_loss:2.0138 train_time:378807ms step_avg:43.79ms step:8700/20000 train_loss:2.0725 train_time:380977ms step_avg:43.79ms step:8750/20000 train_loss:2.1241 train_time:383147ms step_avg:43.79ms step:8800/20000 train_loss:2.0577 train_time:385386ms step_avg:43.79ms step:8800/20000 val_loss:2.1179 val_bpb:1.2543 train_time:385413ms step_avg:43.80ms step:8850/20000 train_loss:2.0572 train_time:387558ms step_avg:43.79ms step:8900/20000 train_loss:2.1148 train_time:389729ms step_avg:43.79ms step:8950/20000 train_loss:2.1656 train_time:391900ms step_avg:43.79ms step:9000/20000 train_loss:2.3235 train_time:394135ms step_avg:43.79ms step:9000/20000 val_loss:2.1174 val_bpb:1.2541 train_time:394161ms step_avg:43.80ms step:9050/20000 train_loss:2.1874 train_time:396307ms step_avg:43.79ms step:9100/20000 train_loss:2.0096 train_time:398478ms step_avg:43.79ms step:9150/20000 train_loss:2.2272 train_time:400649ms step_avg:43.79ms step:9200/20000 train_loss:2.2863 train_time:402906ms step_avg:43.79ms step:9200/20000 val_loss:2.1162 val_bpb:1.2533 train_time:402932ms step_avg:43.80ms step:9250/20000 train_loss:2.1488 train_time:405078ms step_avg:43.79ms step:9300/20000 train_loss:2.3716 train_time:407251ms step_avg:43.79ms step:9350/20000 train_loss:2.1750 train_time:409481ms step_avg:43.79ms step:9400/20000 train_loss:1.9241 train_time:411651ms step_avg:43.79ms step:9400/20000 val_loss:2.1168 val_bpb:1.2537 train_time:411677ms step_avg:43.80ms step:9450/20000 train_loss:2.0555 train_time:413821ms step_avg:43.79ms step:9500/20000 train_loss:2.1634 train_time:415993ms step_avg:43.79ms step:9550/20000 train_loss:2.2153 train_time:418244ms step_avg:43.80ms step:9600/20000 train_loss:2.0163 train_time:420415ms step_avg:43.79ms step:9600/20000 val_loss:2.1162 val_bpb:1.2533 train_time:420441ms step_avg:43.80ms step:9650/20000 train_loss:2.0629 train_time:422589ms step_avg:43.79ms step:9700/20000 train_loss:2.1591 train_time:424759ms step_avg:43.79ms step:9750/20000 train_loss:2.1551 train_time:426994ms step_avg:43.79ms step:9800/20000 train_loss:2.0674 train_time:429164ms step_avg:43.79ms step:9800/20000 val_loss:2.1121 val_bpb:1.2509 train_time:429190ms step_avg:43.79ms step:9850/20000 train_loss:2.1399 train_time:431336ms step_avg:43.79ms step:9900/20000 train_loss:2.0134 train_time:433508ms step_avg:43.79ms step:9950/20000 train_loss:2.1507 train_time:435750ms step_avg:43.79ms step:10000/20000 train_loss:2.0270 train_time:437919ms step_avg:43.79ms step:10000/20000 val_loss:2.1148 val_bpb:1.2525 train_time:437945ms step_avg:43.79ms step:10050/20000 train_loss:2.0438 train_time:440091ms step_avg:43.79ms step:10100/20000 train_loss:2.1068 train_time:442261ms step_avg:43.79ms step:10150/20000 train_loss:2.1505 train_time:444498ms step_avg:43.79ms step:10200/20000 train_loss:2.1344 train_time:446670ms step_avg:43.79ms step:10200/20000 val_loss:2.1120 val_bpb:1.2509 train_time:446696ms step_avg:43.79ms step:10250/20000 train_loss:2.0986 train_time:448842ms step_avg:43.79ms step:10300/20000 train_loss:2.0060 train_time:451093ms step_avg:43.80ms step:10350/20000 train_loss:1.9689 train_time:453261ms step_avg:43.79ms step:10400/20000 train_loss:2.1015 train_time:455431ms step_avg:43.79ms step:10400/20000 val_loss:2.1101 val_bpb:1.2497 train_time:455458ms step_avg:43.79ms step:10450/20000 train_loss:1.9867 train_time:457604ms step_avg:43.79ms step:10500/20000 train_loss:2.0564 train_time:459840ms step_avg:43.79ms step:10550/20000 train_loss:2.1288 train_time:462011ms step_avg:43.79ms step:10600/20000 train_loss:2.0751 train_time:464184ms step_avg:43.79ms step:10600/20000 val_loss:2.1095 val_bpb:1.2494 train_time:464210ms step_avg:43.79ms step:10650/20000 train_loss:2.1335 train_time:466355ms step_avg:43.79ms step:10700/20000 train_loss:2.2599 train_time:468597ms step_avg:43.79ms step:10750/20000 train_loss:2.0068 train_time:470768ms step_avg:43.79ms step:10800/20000 train_loss:2.1260 train_time:472938ms step_avg:43.79ms step:10800/20000 val_loss:2.1087 val_bpb:1.2489 train_time:472964ms step_avg:43.79ms step:10850/20000 train_loss:2.2389 train_time:475109ms step_avg:43.79ms step:10900/20000 train_loss:2.1941 train_time:477345ms step_avg:43.79ms step:10950/20000 train_loss:2.0311 train_time:479516ms step_avg:43.79ms step:11000/20000 train_loss:2.1000 train_time:481687ms step_avg:43.79ms step:11000/20000 val_loss:2.1090 val_bpb:1.2491 train_time:481713ms step_avg:43.79ms step:11050/20000 train_loss:2.1310 train_time:483857ms step_avg:43.79ms step:11100/20000 train_loss:2.1062 train_time:486104ms step_avg:43.79ms step:11150/20000 train_loss:2.0740 train_time:488274ms step_avg:43.79ms step:11200/20000 train_loss:2.1338 train_time:490445ms step_avg:43.79ms step:11200/20000 val_loss:2.1066 val_bpb:1.2476 train_time:490471ms step_avg:43.79ms step:11250/20000 train_loss:2.1177 train_time:492616ms step_avg:43.79ms step:11300/20000 train_loss:2.0860 train_time:494849ms step_avg:43.79ms step:11350/20000 train_loss:2.0760 train_time:497021ms step_avg:43.79ms step:11400/20000 train_loss:2.2242 train_time:499192ms step_avg:43.79ms step:11400/20000 val_loss:2.1057 val_bpb:1.2471 train_time:499218ms step_avg:43.79ms step:11450/20000 train_loss:2.1219 train_time:501431ms step_avg:43.79ms step:11500/20000 train_loss:2.1016 train_time:503600ms step_avg:43.79ms step:11550/20000 train_loss:2.1222 train_time:505770ms step_avg:43.79ms step:11600/20000 train_loss:2.1468 train_time:507939ms step_avg:43.79ms step:11600/20000 val_loss:2.1056 val_bpb:1.2471 train_time:507966ms step_avg:43.79ms step:11650/20000 train_loss:2.1549 train_time:510177ms step_avg:43.79ms step:11700/20000 train_loss:2.0195 train_time:512348ms step_avg:43.79ms step:11750/20000 train_loss:2.0654 train_time:514518ms step_avg:43.79ms step:11800/20000 train_loss:2.0257 train_time:516690ms step_avg:43.79ms step:11800/20000 val_loss:2.1059 val_bpb:1.2472 train_time:516716ms step_avg:43.79ms step:11850/20000 train_loss:2.1168 train_time:518940ms step_avg:43.79ms step:11900/20000 train_loss:2.2975 train_time:521110ms step_avg:43.79ms step:11950/20000 train_loss:2.1150 train_time:523282ms step_avg:43.79ms step:12000/20000 train_loss:2.0739 train_time:525452ms step_avg:43.79ms step:12000/20000 val_loss:2.1045 val_bpb:1.2464 train_time:525478ms step_avg:43.79ms step:12050/20000 train_loss:2.1966 train_time:527686ms step_avg:43.79ms step:12100/20000 train_loss:2.0933 train_time:529857ms step_avg:43.79ms step:12150/20000 train_loss:2.1247 train_time:532028ms step_avg:43.79ms step:12200/20000 train_loss:2.1644 train_time:534200ms step_avg:43.79ms step:12200/20000 val_loss:2.1027 val_bpb:1.2453 train_time:534226ms step_avg:43.79ms step:12250/20000 train_loss:2.0929 train_time:536449ms step_avg:43.79ms step:12300/20000 train_loss:1.9745 train_time:538621ms step_avg:43.79ms step:12350/20000 train_loss:2.0896 train_time:540792ms step_avg:43.79ms step:12400/20000 train_loss:2.1576 train_time:543026ms step_avg:43.79ms step:12400/20000 val_loss:2.1025 val_bpb:1.2452 train_time:543052ms step_avg:43.79ms step:12450/20000 train_loss:2.1487 train_time:545198ms step_avg:43.79ms step:12500/20000 train_loss:2.1525 train_time:547368ms step_avg:43.79ms step:12550/20000 train_loss:2.1040 train_time:549538ms step_avg:43.79ms step:12600/20000 train_loss:2.3639 train_time:551777ms step_avg:43.79ms step:12600/20000 val_loss:2.0988 val_bpb:1.2430 train_time:551803ms step_avg:43.79ms step:12650/20000 train_loss:1.9861 train_time:553950ms step_avg:43.79ms step:12700/20000 train_loss:2.0478 train_time:556120ms step_avg:43.79ms step:12750/20000 train_loss:2.0901 train_time:558294ms step_avg:43.79ms step:12800/20000 train_loss:2.1683 train_time:560532ms step_avg:43.79ms step:12800/20000 val_loss:2.0935 val_bpb:1.2399 train_time:560558ms step_avg:43.79ms step:12850/20000 train_loss:1.9931 train_time:562704ms step_avg:43.79ms step:12900/20000 train_loss:2.4707 train_time:564875ms step_avg:43.79ms step:12950/20000 train_loss:2.0368 train_time:567047ms step_avg:43.79ms step:13000/20000 train_loss:2.1259 train_time:569296ms step_avg:43.79ms step:13000/20000 val_loss:2.0847 val_bpb:1.2347 train_time:569323ms step_avg:43.79ms step:13050/20000 train_loss:2.0280 train_time:571468ms step_avg:43.79ms step:13100/20000 train_loss:1.9683 train_time:573638ms step_avg:43.79ms step:13150/20000 train_loss:2.0988 train_time:575809ms step_avg:43.79ms step:13200/20000 train_loss:2.1256 train_time:578042ms step_avg:43.79ms step:13200/20000 val_loss:2.0771 val_bpb:1.2302 train_time:578068ms step_avg:43.79ms step:13250/20000 train_loss:2.1152 train_time:580213ms step_avg:43.79ms step:13300/20000 train_loss:2.1388 train_time:582383ms step_avg:43.79ms step:13350/20000 train_loss:2.1556 train_time:584554ms step_avg:43.79ms step:13400/20000 train_loss:2.1470 train_time:586802ms step_avg:43.79ms step:13400/20000 val_loss:2.0698 val_bpb:1.2259 train_time:586828ms step_avg:43.79ms step:13450/20000 train_loss:2.1367 train_time:588974ms step_avg:43.79ms step:13500/20000 train_loss:1.9808 train_time:591144ms step_avg:43.79ms step:13550/20000 train_loss:2.0810 train_time:593382ms step_avg:43.79ms step:13600/20000 train_loss:2.0428 train_time:595553ms step_avg:43.79ms step:13600/20000 val_loss:2.0624 val_bpb:1.2215 train_time:595579ms step_avg:43.79ms step:13650/20000 train_loss:2.0705 train_time:597724ms step_avg:43.79ms step:13700/20000 train_loss:2.0979 train_time:599895ms step_avg:43.79ms step:13703/20000 val_loss:2.0601 val_bpb:1.2201 train_time:600051ms step_avg:43.79ms stopping_early: wallclock_cap train_time:600051ms step:13703/20000 peak memory allocated: 10185 MiB reserved: 10572 MiB Serialized model: 67224983 bytes Code size: 60906 bytes Total submission size: 67285889 bytes Serialized model int8+zlib: 15823937 bytes (payload:17178912 raw_torch:17224025 payload_ratio:3.91x) Total submission size int8+zlib: 15884843 bytes final_int8_zlib_roundtrip val_loss:2.0724 val_bpb:1.2274 eval_time:1374ms final_int8_zlib_roundtrip_exact val_loss:2.07240908 val_bpb:1.22739739 final_int8_ttt_lora val_loss:2.0142 val_bpb:1.1929 eval_time:60184ms