""" The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. """ from __future__ import annotations import copy import glob import io import math import os import random import subprocess import sys import time import uuid import zlib from pathlib import Path import numpy as np import sentencepiece as spm import torch import torch.distributed as dist import torch.nn.functional as F from torch import Tensor, nn from torch.nn.parallel import DistributedDataParallel as DDP # ----------------------------- # HYPERPARAMETERS # ----------------------------- # Default Simple Baseline run: # - 9 transformer blocks at width 512 # - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion # - vocab size 1024, sequence length 1024, tied embeddings # - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap class Hyperparameters: # Data paths are shard globs produced by the existing preprocessing pipeline. data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") train_files = os.path.join(data_path, "fineweb_train_*.bin") val_files = os.path.join(data_path, "fineweb_val_*.bin") tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) seed = int(os.environ.get("SEED", 1337)) # Validation cadence and batch size. Validation always uses the full fineweb_val split. val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) prune_ratio = float(os.environ.get("PRUNE_RATIO", 0.0)) # fraction of int8 range to prune (e.g. 0.1 = zero out |val| <= 12) int4_layers = os.environ.get("INT4_LAYERS", "") # comma-separated layer indices for reduced precision (e.g. "3,4,5,6") int4_step = int(os.environ.get("INT4_STEP", 16)) # rounding step: 2=int7, 4=int6, 8=int5, 16=int4 # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) num_layers = int(os.environ.get("NUM_LAYERS", 9)) num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) model_dim = int(os.environ.get("MODEL_DIM", 512)) num_heads = int(os.environ.get("NUM_HEADS", 8)) mlp_mult = int(os.environ.get("MLP_MULT", 2)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) # Optimizer hyperparameters. embed_lr = float(os.environ.get("EMBED_LR", 0.6)) head_lr = float(os.environ.get("HEAD_LR", 0.008)) tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) beta1 = float(os.environ.get("BETA1", 0.9)) beta2 = float(os.environ.get("BETA2", 0.95)) adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) # ----------------------------- # MUON OPTIMIZER # ----------------------------- # # As borrowed from modded-nanogpt # Background on Muon: https://kellerjordan.github.io/posts/muon/ def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. # Muon uses this to normalize matrix-shaped gradients before applying them. a, b, c = (3.4445, -4.7750, 2.0315) X = G.bfloat16() X /= X.norm() + eps transposed = G.size(0) > G.size(1) if transposed: X = X.T for _ in range(steps): A = X @ X.T B = b * A + c * A @ A X = a * X + B @ X return X.T if transposed else X class Muon(torch.optim.Optimizer): def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): super().__init__( params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), ) @torch.no_grad() def step(self, closure=None): loss = None if closure is not None: with torch.enable_grad(): loss = closure() distributed = dist.is_available() and dist.is_initialized() world_size = dist.get_world_size() if distributed else 1 rank = dist.get_rank() if distributed else 0 for group in self.param_groups: params = group["params"] if not params: continue lr = group["lr"] momentum = group["momentum"] backend_steps = group["backend_steps"] nesterov = group["nesterov"] total_params = sum(int(p.numel()) for p in params) updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) curr = 0 for i, p in enumerate(params): if i % world_size == rank and p.grad is not None: g = p.grad state = self.state[p] if "momentum_buffer" not in state: state["momentum_buffer"] = torch.zeros_like(g) buf = state["momentum_buffer"] buf.mul_(momentum).add_(g) if nesterov: g = g.add(buf, alpha=momentum) g = zeropower_via_newtonschulz5(g, steps=backend_steps) # Scale correction from Muon reference implementations. g *= max(1, g.size(0) / g.size(1)) ** 0.5 updates_flat[curr : curr + p.numel()] = g.reshape(-1) curr += p.numel() if distributed: dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) curr = 0 for p in params: g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) p.add_(g, alpha=-lr) curr += p.numel() return loss # ----------------------------- # TOKENIZER-AGNOSTIC EVALUATION SETUP # ----------------------------- # # It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. # Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. # We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. # Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. def build_sentencepiece_luts( sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device ) -> tuple[Tensor, Tensor, Tensor]: sp_vocab_size = int(sp.vocab_size()) table_size = max(sp_vocab_size, vocab_size) base_bytes_np = np.zeros((table_size,), dtype=np.int16) has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) for token_id in range(sp_vocab_size): if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): continue is_boundary_token_np[token_id] = False if sp.is_byte(token_id): base_bytes_np[token_id] = 1 continue piece = sp.id_to_piece(token_id) if piece.startswith("▁"): has_leading_space_np[token_id] = True piece = piece[1:] base_bytes_np[token_id] = len(piece.encode("utf-8")) return ( torch.tensor(base_bytes_np, dtype=torch.int16, device=device), torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), ) def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: files = [Path(p) for p in sorted(glob.glob(pattern))] if not files: raise FileNotFoundError(f"No files found for pattern: {pattern}") # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() usable = ((tokens.numel() - 1) // seq_len) * seq_len if usable <= 0: raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") return tokens[: usable + 1] def eval_val( args: Hyperparameters, model: nn.Module, rank: int, world_size: int, device: torch.device, grad_accum_steps: int, val_tokens: Tensor, base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, ) -> tuple[float, float]: # Validation computes two metrics: # - val_loss: token cross-entropy (natural log) # - val_bpb: tokenizer-agnostic compression metric used by the challenge local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) if local_batch_tokens < args.train_seq_len: raise ValueError( "VAL_BATCH_SIZE must provide at least one sequence per rank; " f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" ) local_batch_seqs = local_batch_tokens // args.train_seq_len total_seqs = (val_tokens.numel() - 1) // args.train_seq_len seq_start = (total_seqs * rank) // world_size seq_end = (total_seqs * (rank + 1)) // world_size val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) val_token_count = torch.zeros((), device=device, dtype=torch.float64) val_byte_count = torch.zeros((), device=device, dtype=torch.float64) model.eval() with torch.inference_mode(): for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) raw_start = batch_seq_start * args.train_seq_len raw_end = batch_seq_end * args.train_seq_len + 1 local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) x = local[:-1].reshape(-1, args.train_seq_len) y = local[1:].reshape(-1, args.train_seq_len) with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): batch_loss = model(x, y).detach() batch_token_count = float(y.numel()) val_loss_sum += batch_loss.to(torch.float64) * batch_token_count val_token_count += batch_token_count prev_ids = x.reshape(-1) tgt_ids = y.reshape(-1) token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) val_byte_count += token_bytes.to(torch.float64).sum() if dist.is_available() and dist.is_initialized(): dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) val_loss = val_loss_sum / val_token_count bits_per_token = val_loss.item() / math.log(2.0) tokens_per_byte = val_token_count.item() / val_byte_count.item() model.train() return float(val_loss.item()), float(bits_per_token * tokens_per_byte) # ----------------------------- # POST-TRAINING QUANTIZATION # ----------------------------- # # It's silly to export our model, which is trained in bf16 and fp32, at that same precision. # Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. # We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. CONTROL_TENSOR_NAME_PATTERNS = tuple( pattern for pattern in os.environ.get( "CONTROL_TENSOR_NAME_PATTERNS", "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", ).split(",") if pattern ) INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( pattern for pattern in os.environ.get( "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", ",".join(CONTROL_TENSOR_NAME_PATTERNS), ).split(",") if pattern ) INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 INT8_PER_ROW_SCALE_DTYPE = torch.float16 INT8_CLIP_PERCENTILE = 99.99984 INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 def tensor_nbytes(t: Tensor) -> int: return int(t.numel()) * int(t.element_size()) def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): return t.float().contiguous() if t.dtype in {torch.float32, torch.bfloat16}: passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() return t def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: t32 = t.float() if t32.ndim == 2: # Matrices get one scale per row, which usually tracks output-channel # ranges much better than a single tensor-wide scale. clip_abs = ( torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() else torch.empty((t32.shape[0],), dtype=torch.float32) ) clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() # Vectors / scalars use a simpler per-tensor scale. clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() return q, scale def quantize_state_dict_int8(state_dict: dict[str, Tensor]): # Single supported clean-script export format: # - per-row int8 for 2D float tensors # - per-tensor int8 for other float tensors # - exact passthrough for non-floats # - passthrough for small float tensors, stored as fp16 to save bytes quantized: dict[str, Tensor] = {} scales: dict[str, Tensor] = {} dtypes: dict[str, str] = {} passthrough: dict[str, Tensor] = {} passthrough_orig_dtypes: dict[str, str] = {} qmeta: dict[str, dict[str, object]] = {} stats = dict.fromkeys( ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), 0, ) for name, tensor in state_dict.items(): t = tensor.detach().to("cpu").contiguous() stats["param_count"] += int(t.numel()) stats["num_tensors"] += 1 stats["baseline_tensor_bytes"] += tensor_nbytes(t) if not t.is_floating_point(): stats["num_nonfloat_tensors"] += 1 passthrough[name] = t stats["int8_payload_bytes"] += tensor_nbytes(t) continue # Small float tensors are cheap enough to keep directly. We still downcast # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: kept = keep_float_tensor(name, t, passthrough_orig_dtypes) passthrough[name] = kept stats["int8_payload_bytes"] += tensor_nbytes(kept) continue stats["num_float_tensors"] += 1 q, s = quantize_float_tensor(t) if s.ndim > 0: qmeta[name] = {"scheme": "per_row", "axis": 0} quantized[name] = q scales[name] = s dtypes[name] = str(t.dtype).removeprefix("torch.") stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) obj: dict[str, object] = { "__quant_format__": "int8_clean_per_row_v1", "quantized": quantized, "scales": scales, "dtypes": dtypes, "passthrough": passthrough, } if qmeta: obj["qmeta"] = qmeta if passthrough_orig_dtypes: obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes return obj, stats def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: out: dict[str, Tensor] = {} qmeta = obj.get("qmeta", {}) passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) for name, q in obj["quantized"].items(): dtype = getattr(torch, obj["dtypes"][name]) s = obj["scales"][name] if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: s = s.to(dtype=torch.float32) # Broadcast the saved row scale back across trailing dimensions. out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() else: scale = float(s.item()) out[name] = (q.float() * scale).to(dtype=dtype).contiguous() for name, t in obj["passthrough"].items(): # Restore small tensors, undoing the temporary fp16 storage cast if needed. out_t = t.detach().to("cpu").contiguous() orig_dtype = passthrough_orig_dtypes.get(name) if isinstance(orig_dtype, str): out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() out[name] = out_t return out # ----------------------------- # DATA LOADING # ----------------------------- def load_data_shard(file: Path) -> Tensor: header_bytes = 256 * np.dtype(" None: self.file_idx = (self.file_idx + 1) % len(self.files) self.tokens = load_data_shard(self.files[self.file_idx]) self.pos = 0 def take(self, n: int) -> Tensor: chunks: list[Tensor] = [] remaining = n while remaining > 0: avail = self.tokens.numel() - self.pos if avail <= 0: self._advance_file() continue k = min(remaining, avail) chunks.append(self.tokens[self.pos : self.pos + k]) self.pos += k remaining -= k return chunks[0] if len(chunks) == 1 else torch.cat(chunks) class DistributedTokenLoader: # Each call consumes a contiguous chunk from the shared token stream, then slices out # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): self.rank = rank self.world_size = world_size self.device = device self.stream = TokenStream(pattern) def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: local_tokens = global_tokens // (self.world_size * grad_accum_steps) per_rank_span = local_tokens + 1 chunk = self.stream.take(per_rank_span * self.world_size) start = self.rank * per_rank_span local = chunk[start : start + per_rank_span].to(dtype=torch.int64) x = local[:-1].reshape(-1, seq_len) y = local[1:].reshape(-1, seq_len) return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) # ----------------------------- # TRANSFORMER MODULES # ----------------------------- class RMSNorm(nn.Module): def __init__(self, eps: float | None = None): super().__init__() self.eps = eps def forward(self, x: Tensor) -> Tensor: return F.rms_norm(x, (x.size(-1),), eps=self.eps) class CastedLinear(nn.Linear): # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. def forward(self, x: Tensor) -> Tensor: bias = self.bias.to(x.dtype) if self.bias is not None else None return F.linear(x, self.weight.to(x.dtype), bias) def restore_low_dim_params_to_fp32(module: nn.Module) -> None: # Keep small/control parameters in fp32 even when the model body runs in bf16. with torch.no_grad(): for name, param in module.named_parameters(): if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: param.data = param.data.float() class Rotary(nn.Module): # Caches cos/sin tables per sequence length on the current device. def __init__(self, dim: int, base: float = 10000.0): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._seq_len_cached = 0 self._cos_cached: Tensor | None = None self._sin_cached: Tensor | None = None def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: if ( self._cos_cached is None or self._sin_cached is None or self._seq_len_cached != seq_len or self._cos_cached.device != device ): t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) freqs = torch.outer(t, self.inv_freq.to(device)) self._cos_cached = freqs.cos()[None, None, :, :] self._sin_cached = freqs.sin()[None, None, :, :] self._seq_len_cached = seq_len return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: half = x.size(-1) // 2 x1, x2 = x[..., :half], x[..., half:] return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) class CausalSelfAttention(nn.Module): def __init__( self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float, ): super().__init__() if dim % num_heads != 0: raise ValueError("model_dim must be divisible by num_heads") if num_heads % num_kv_heads != 0: raise ValueError("num_heads must be divisible by num_kv_heads") self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.head_dim = dim // num_heads if self.head_dim % 2 != 0: raise ValueError("head_dim must be even for RoPE") kv_dim = self.num_kv_heads * self.head_dim self.c_q = CastedLinear(dim, dim, bias=False) self.c_k = CastedLinear(dim, kv_dim, bias=False) self.c_v = CastedLinear(dim, kv_dim, bias=False) self.proj = CastedLinear(dim, dim, bias=False) self.proj._zero_init = True self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) self.rotary = Rotary(self.head_dim, base=rope_base) def forward(self, x: Tensor) -> Tensor: bsz, seqlen, dim = x.shape q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) q = F.rms_norm(q, (q.size(-1),)) k = F.rms_norm(k, (k.size(-1),)) cos, sin = self.rotary(seqlen, x.device, q.dtype) q = apply_rotary_emb(q, cos, sin) k = apply_rotary_emb(k, cos, sin) q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] y = F.scaled_dot_product_attention( q, k, v, attn_mask=None, is_causal=True, enable_gqa=(self.num_kv_heads != self.num_heads), ) y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) return self.proj(y) class MLP(nn.Module): # relu^2 MLP from the original modded-nanogpt setup def __init__(self, dim: int, mlp_mult: int): super().__init__() hidden = mlp_mult * dim self.fc = CastedLinear(dim, hidden, bias=False) self.proj = CastedLinear(hidden, dim, bias=False) self.proj._zero_init = True def forward(self, x: Tensor) -> Tensor: x = torch.relu(self.fc(x)) return self.proj(x.square()) class Block(nn.Module): def __init__( self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, rope_base: float, qk_gain_init: float, ): super().__init__() self.attn_norm = RMSNorm() self.mlp_norm = RMSNorm() self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) self.mlp = MLP(dim, mlp_mult) self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) def forward(self, x: Tensor, x0: Tensor) -> Tensor: mix = self.resid_mix.to(dtype=x.dtype) x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 attn_out = self.attn(self.attn_norm(x)) x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) return x class GPT(nn.Module): def __init__( self, vocab_size: int, num_layers: int, model_dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, tie_embeddings: bool, tied_embed_init_std: float, logit_softcap: float, rope_base: float, qk_gain_init: float, ): super().__init__() if logit_softcap <= 0.0: raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") self.tie_embeddings = tie_embeddings self.tied_embed_init_std = tied_embed_init_std self.logit_softcap = logit_softcap self.tok_emb = nn.Embedding(vocab_size, model_dim) self.num_encoder_layers = num_layers // 2 self.num_decoder_layers = num_layers - self.num_encoder_layers self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) self.blocks = nn.ModuleList( [ Block( model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, ) for i in range(num_layers) ] ) self.final_norm = RMSNorm() self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) if self.lm_head is not None: self.lm_head._zero_init = True self._init_weights() def _init_weights(self) -> None: if self.tie_embeddings: nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) for module in self.modules(): if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): nn.init.zeros_(module.weight) def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: x = self.tok_emb(input_ids) x = F.rms_norm(x, (x.size(-1),)) x0 = x skips: list[Tensor] = [] # First half stores skips; second half reuses them in reverse order. for i in range(self.num_encoder_layers): x = self.blocks[i](x, x0) skips.append(x) for i in range(self.num_decoder_layers): if skips: x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() x = self.blocks[self.num_encoder_layers + i](x, x0) x = self.final_norm(x).reshape(-1, x.size(-1)) targets = target_ids.reshape(-1) if self.tie_embeddings: logits_proj = F.linear(x, self.tok_emb.weight) else: if self.lm_head is None: raise RuntimeError("lm_head is required when tie_embeddings=False") logits_proj = self.lm_head(x) logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) return F.cross_entropy(logits.float(), targets, reduction="mean") # ----------------------------- # TRAINING # ----------------------------- def main() -> None: global zeropower_via_newtonschulz5 code = Path(__file__).read_text(encoding="utf-8") args = Hyperparameters() zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) # ----------------------------- # DISTRIBUTED + CUDA SETUP # ----------------------------- distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ rank = int(os.environ.get("RANK", "0")) world_size = int(os.environ.get("WORLD_SIZE", "1")) local_rank = int(os.environ.get("LOCAL_RANK", "0")) if world_size <= 0: raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") if 8 % world_size != 0: raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") grad_accum_steps = 8 // world_size grad_scale = 1.0 / grad_accum_steps if not torch.cuda.is_available(): raise RuntimeError("CUDA is required") device = torch.device("cuda", local_rank) torch.cuda.set_device(device) if distributed: dist.init_process_group(backend="nccl", device_id=device) dist.barrier() master_process = rank == 0 # Fast math knobs torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp enable_cudnn_sdp(False) enable_flash_sdp(True) enable_mem_efficient_sdp(False) enable_math_sdp(False) logfile = None if master_process: os.makedirs("logs", exist_ok=True) logfile = f"logs/{args.run_id}.txt" print(logfile) def log0(msg: str, console: bool = True) -> None: if not master_process: return if console: print(msg) if logfile is not None: with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) log0(code, console=False) log0("=" * 100, console=False) log0(f"Running Python {sys.version}", console=False) log0(f"Running PyTorch {torch.__version__}", console=False) log0( subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, console=False, ) log0("=" * 100, console=False) # ----------------------------- # TOKENIZER + VALIDATION METRIC SETUP # ----------------------------- random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) if not args.tokenizer_path.endswith(".model"): raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) if int(sp.vocab_size()) != args.vocab_size: raise ValueError( f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" ) dataset_dir = Path(args.data_path).resolve() actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( sp, args.vocab_size, device ) log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") # ----------------------------- # MODEL + OPTIMIZER SETUP # ----------------------------- base_model = GPT( vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, ).to(device).bfloat16() for module in base_model.modules(): if isinstance(module, CastedLinear): module.float() restore_low_dim_params_to_fp32(base_model) compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model # Optimizer split: # - token embedding (Adam) uses EMBED_LR # - untied lm_head (Adam) uses HEAD_LR # - matrix params in transformer blocks use MATRIX_LR via Muon # - vectors/scalars use SCALAR_LR via Adam block_named_params = list(base_model.blocks.named_parameters()) matrix_params = [ p for name, p in block_named_params if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) ] scalar_params = [ p for name, p in block_named_params if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) ] if base_model.skip_weights.numel() > 0: scalar_params.append(base_model.skip_weights) token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr optimizer_tok = torch.optim.Adam( [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True, ) optimizer_muon = Muon( matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, backend_steps=args.muon_backend_steps, ) for group in optimizer_muon.param_groups: group["base_lr"] = args.matrix_lr optimizer_scalar = torch.optim.Adam( [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True, ) optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] if base_model.lm_head is not None: optimizer_head = torch.optim.Adam( [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True, ) optimizers.insert(1, optimizer_head) n_params = sum(p.numel() for p in base_model.parameters()) log0(f"model_params:{n_params}") log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") log0( f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" ) log0( f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" ) log0(f"seed:{args.seed}") # ----------------------------- # DATA LOADER & MODEL WARMUP # ----------------------------- train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) def zero_grad_all() -> None: for opt in optimizers: opt.zero_grad(set_to_none=True) max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None def lr_mul(step: int, elapsed_ms: float) -> float: if args.warmdown_iters <= 0: return 1.0 if max_wallclock_ms is None: warmdown_start = max(args.iterations - args.warmdown_iters, 0) return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 step_ms = elapsed_ms / max(step, 1) warmdown_ms = args.warmdown_iters * step_ms remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 # Warmup primes the compiled forward/backward/optimizer paths, then we restore the # initial weights/optimizer state so measured training starts from the true init. if args.warmup_steps > 0: initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] model.train() for warmup_step in range(args.warmup_steps): zero_grad_all() for micro_step in range(grad_accum_steps): if distributed: model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): warmup_loss = model(x, y) (warmup_loss * grad_scale).backward() for opt in optimizers: opt.step() zero_grad_all() if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") base_model.load_state_dict(initial_model_state, strict=True) for opt, state in zip(optimizers, initial_optimizer_states, strict=True): opt.load_state_dict(state) zero_grad_all() if distributed: model.require_backward_grad_sync = True train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) # ----------------------------- # MAIN TRAINING LOOP # ----------------------------- training_time_ms = 0.0 stop_after_step: int | None = None torch.cuda.synchronize() t0 = time.perf_counter() step = 0 while True: last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) if should_validate: torch.cuda.synchronize() training_time_ms += 1000.0 * (time.perf_counter() - t0) val_loss, val_bpb = eval_val( args, model, rank, world_size, device, grad_accum_steps, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, ) log0( f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" ) torch.cuda.synchronize() t0 = time.perf_counter() if last_step: if stop_after_step is not None and step < args.iterations: log0( f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " f"step:{step}/{args.iterations}" ) break elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) scale = lr_mul(step, elapsed_ms) zero_grad_all() train_loss = torch.zeros((), device=device) for micro_step in range(grad_accum_steps): if distributed: model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): loss = model(x, y) train_loss += loss.detach() (loss * grad_scale).backward() train_loss /= grad_accum_steps frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum for group in optimizer_muon.param_groups: group["momentum"] = muon_momentum for opt in optimizers: for group in opt.param_groups: group["lr"] = group["base_lr"] * scale if args.grad_clip_norm > 0: torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) for opt in optimizers: opt.step() zero_grad_all() step += 1 approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) should_log_train = ( args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) ) if should_log_train: log0( f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" ) # Needed to sync whether we've reached the wallclock cap. reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms if distributed and max_wallclock_ms is not None: reached_cap_tensor = torch.tensor(int(reached_cap), device=device) dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) reached_cap = bool(reached_cap_tensor.item()) if stop_after_step is None and reached_cap: stop_after_step = step log0( f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" ) # ----------------------------- # SERIALIZATION + ROUNDTRIP VALIDATION # ----------------------------- # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce # the compressed int8+zlib artifact and validate the round-tripped weights. if master_process: torch.save(base_model.state_dict(), "final_model.pt") model_bytes = os.path.getsize("final_model.pt") code_bytes = len(code.encode("utf-8")) log0(f"Serialized model: {model_bytes} bytes") log0(f"Code size: {code_bytes} bytes") log0(f"Total submission size: {model_bytes + code_bytes} bytes") quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) # Optional post-quantization pruning: zero out small int8 values for better compression if args.prune_ratio > 0: threshold = int(127 * args.prune_ratio) for name in list(quant_obj.get("quantized", {}).keys()): t = quant_obj["quantized"][name] t[t.abs() <= threshold] = 0 # Optional mixed-precision: round middle layers to int4 (16 levels) for better compression if args.int4_layers: int4_set = set(int(x) for x in args.int4_layers.split(",") if x.strip()) for name in list(quant_obj.get("quantized", {}).keys()): layer_num = -1 if "blocks." in name: try: layer_num = int(name.split("blocks.")[1].split(".")[0]) except (ValueError, IndexError): pass if layer_num in int4_set: t = quant_obj["quantized"][name] step = args.int4_step quant_obj["quantized"][name] = ((t.float() / step).round() * step).clamp(-127, 127).to(torch.int8) quant_buf = io.BytesIO() torch.save(quant_obj, quant_buf) quant_raw = quant_buf.getvalue() quant_blob = zlib.compress(quant_raw, level=9) quant_raw_bytes = len(quant_raw) if master_process: with open("final_model.int8.ptz", "wb") as f: f.write(quant_blob) quant_file_bytes = os.path.getsize("final_model.int8.ptz") code_bytes = len(code.encode("utf-8")) ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) log0( f"Serialized model int8+zlib: {quant_file_bytes} bytes " f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" ) log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") if distributed: dist.barrier() with open("final_model.int8.ptz", "rb") as f: quant_blob_disk = f.read() quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) torch.cuda.synchronize() t_qeval = time.perf_counter() q_val_loss, q_val_bpb = eval_val( args, model, rank, world_size, device, grad_accum_steps, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, ) torch.cuda.synchronize() log0( f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" ) log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") if distributed: dist.destroy_process_group() if __name__ == "__main__": main() ==================================================================================================== Running Python 3.11.9 (main, Nov 10 2025, 02:08:09) [GCC 11.4.0] Running PyTorch 2.8.0+cu128 Thu Mar 19 02:57:29 2026 +-----------------------------------------------------------------------------------------+ | NVIDIA-SMI 580.95.05 Driver Version: 580.95.05 CUDA Version: 13.0 | +-----------------------------------------+------------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+========================+======================| | 0 NVIDIA H200 Off | 00000002:00:01.0 Off | 0 | | N/A 32C P0 122W / 700W | 1516MiB / 143771MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+ | 1 NVIDIA H200 Off | 00000002:00:02.0 Off | 0 | | N/A 34C P0 118W / 700W | 1516MiB / 143771MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+ | 2 NVIDIA H200 Off | 00000002:00:03.0 Off | 0 | | N/A 35C P0 121W / 700W | 1516MiB / 143771MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+ | 3 NVIDIA H200 Off | 00000002:00:04.0 Off | 0 | | N/A 32C P0 121W / 700W | 1516MiB / 143771MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+ | 4 NVIDIA H200 Off | 00000003:00:01.0 Off | 0 | | N/A 32C P0 120W / 700W | 1516MiB / 143771MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+ | 5 NVIDIA H200 Off | 00000003:00:02.0 Off | 0 | | N/A 37C P0 121W / 700W | 1516MiB / 143771MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+ | 6 NVIDIA H200 Off | 00000003:00:03.0 Off | 0 | | N/A 35C P0 117W / 700W | 1516MiB / 143771MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+ | 7 NVIDIA H200 Off | 00000003:00:04.0 Off | 0 | | N/A 31C P0 119W / 700W | 1516MiB / 143771MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+ +-----------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=========================================================================================| | 0 N/A N/A 301282 C ...nv/versions/3.11.9/bin/python 1506MiB | | 1 N/A N/A 301283 C ...nv/versions/3.11.9/bin/python 1506MiB | | 2 N/A N/A 301284 C ...nv/versions/3.11.9/bin/python 1506MiB | | 3 N/A N/A 301285 C ...nv/versions/3.11.9/bin/python 1530MiB | | 4 N/A N/A 301286 C ...nv/versions/3.11.9/bin/python 1534MiB | | 5 N/A N/A 301287 C ...nv/versions/3.11.9/bin/python 1506MiB | | 6 N/A N/A 301288 C ...nv/versions/3.11.9/bin/python 1506MiB | | 7 N/A N/A 301289 C ...nv/versions/3.11.9/bin/python 1506MiB | +-----------------------------------------------------------------------------------------+ ==================================================================================================== val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model train_loader:dataset:fineweb10B_sp1024 train_shards:180 val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 model_params:18897488 world_size:8 grad_accum_steps:1 sdp_backends:cudnn=False flash=True mem_efficient=False math=False attention_mode:gqa num_heads:8 num_kv_heads:4 tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 seed:1337 warmup_step:1/20 warmup_step:2/20 warmup_step:3/20 warmup_step:4/20 warmup_step:5/20 warmup_step:6/20 warmup_step:7/20 warmup_step:8/20 warmup_step:9/20 warmup_step:10/20 warmup_step:11/20 warmup_step:12/20 warmup_step:13/20 warmup_step:14/20 warmup_step:15/20 warmup_step:16/20 warmup_step:17/20 warmup_step:18/20 warmup_step:19/20 warmup_step:20/20 step:0/20000 val_loss:6.9363 val_bpb:4.1080 train_time:0ms step_avg:0.02ms step:1/20000 train_loss:6.9355 train_time:31ms step_avg:31.28ms step:2/20000 train_loss:12.1231 train_time:73ms step_avg:36.68ms step:3/20000 train_loss:7.2165 train_time:125ms step_avg:41.81ms step:4/20000 train_loss:6.4176 train_time:165ms step_avg:41.19ms step:5/20000 train_loss:6.8785 train_time:213ms step_avg:42.51ms step:6/20000 train_loss:7.6248 train_time:256ms step_avg:42.66ms step:7/20000 train_loss:6.8402 train_time:303ms step_avg:43.24ms step:8/20000 train_loss:6.4526 train_time:348ms step_avg:43.54ms step:9/20000 train_loss:6.2574 train_time:394ms step_avg:43.79ms step:10/20000 train_loss:6.1774 train_time:441ms step_avg:44.06ms step:50/20000 train_loss:4.0508 train_time:2226ms step_avg:44.51ms step:100/20000 train_loss:3.2461 train_time:4465ms step_avg:44.65ms step:150/20000 train_loss:2.9349 train_time:6704ms step_avg:44.69ms step:200/20000 train_loss:2.7576 train_time:9077ms step_avg:45.38ms step:200/20000 val_loss:2.7303 val_bpb:1.6170 train_time:9103ms step_avg:45.52ms step:250/20000 train_loss:2.6695 train_time:11324ms step_avg:45.30ms step:300/20000 train_loss:2.4316 train_time:13571ms step_avg:45.24ms step:350/20000 train_loss:2.6215 train_time:15818ms step_avg:45.19ms step:400/20000 train_loss:2.3094 train_time:18198ms step_avg:45.49ms step:400/20000 val_loss:2.5186 val_bpb:1.4917 train_time:18221ms step_avg:45.55ms step:450/20000 train_loss:2.4622 train_time:20441ms step_avg:45.42ms step:500/20000 train_loss:2.4606 train_time:22696ms step_avg:45.39ms step:550/20000 train_loss:2.3666 train_time:24931ms step_avg:45.33ms step:600/20000 train_loss:2.5197 train_time:27345ms step_avg:45.57ms step:600/20000 val_loss:2.4169 val_bpb:1.4315 train_time:27361ms step_avg:45.60ms step:650/20000 train_loss:2.3557 train_time:29586ms step_avg:45.52ms step:700/20000 train_loss:2.4120 train_time:31830ms step_avg:45.47ms step:750/20000 train_loss:2.2413 train_time:34073ms step_avg:45.43ms step:800/20000 train_loss:2.2639 train_time:36456ms step_avg:45.57ms step:800/20000 val_loss:2.3534 val_bpb:1.3938 train_time:36482ms step_avg:45.60ms step:850/20000 train_loss:2.6951 train_time:38715ms step_avg:45.55ms step:900/20000 train_loss:2.3127 train_time:40946ms step_avg:45.50ms step:950/20000 train_loss:2.3720 train_time:43193ms step_avg:45.47ms step:1000/20000 train_loss:2.3507 train_time:45593ms step_avg:45.59ms step:1000/20000 val_loss:2.3092 val_bpb:1.3677 train_time:45617ms step_avg:45.62ms step:1050/20000 train_loss:2.4618 train_time:47845ms step_avg:45.57ms step:1100/20000 train_loss:2.2360 train_time:50109ms step_avg:45.55ms step:1150/20000 train_loss:2.2286 train_time:52507ms step_avg:45.66ms step:1200/20000 train_loss:2.3629 train_time:54759ms step_avg:45.63ms step:1200/20000 val_loss:2.2795 val_bpb:1.3501 train_time:54781ms step_avg:45.65ms step:1250/20000 train_loss:2.1861 train_time:57017ms step_avg:45.61ms step:1300/20000 train_loss:2.3350 train_time:59267ms step_avg:45.59ms step:1350/20000 train_loss:2.2513 train_time:61658ms step_avg:45.67ms step:1400/20000 train_loss:2.4046 train_time:63916ms step_avg:45.65ms step:1400/20000 val_loss:2.2574 val_bpb:1.3370 train_time:63932ms step_avg:45.67ms step:1450/20000 train_loss:2.2156 train_time:66161ms step_avg:45.63ms step:1500/20000 train_loss:2.1999 train_time:68424ms step_avg:45.62ms step:1550/20000 train_loss:2.1331 train_time:70797ms step_avg:45.68ms step:1600/20000 train_loss:2.0744 train_time:73040ms step_avg:45.65ms step:1600/20000 val_loss:2.2430 val_bpb:1.3284 train_time:73065ms step_avg:45.67ms step:1650/20000 train_loss:2.2082 train_time:75294ms step_avg:45.63ms step:1700/20000 train_loss:2.1512 train_time:77538ms step_avg:45.61ms step:1750/20000 train_loss:2.2288 train_time:79928ms step_avg:45.67ms step:1800/20000 train_loss:2.1758 train_time:82184ms step_avg:45.66ms step:1800/20000 val_loss:2.2278 val_bpb:1.3194 train_time:82201ms step_avg:45.67ms step:1850/20000 train_loss:2.2862 train_time:84429ms step_avg:45.64ms step:1900/20000 train_loss:2.1669 train_time:86684ms step_avg:45.62ms step:1950/20000 train_loss:2.1909 train_time:89069ms step_avg:45.68ms step:2000/20000 train_loss:2.2319 train_time:91324ms step_avg:45.66ms step:2000/20000 val_loss:2.2129 val_bpb:1.3106 train_time:91351ms step_avg:45.68ms step:2050/20000 train_loss:2.2279 train_time:93580ms step_avg:45.65ms step:2100/20000 train_loss:2.2434 train_time:95987ms step_avg:45.71ms step:2150/20000 train_loss:2.1649 train_time:98238ms step_avg:45.69ms step:2200/20000 train_loss:2.0522 train_time:100494ms step_avg:45.68ms step:2200/20000 val_loss:2.2044 val_bpb:1.3055 train_time:100513ms step_avg:45.69ms step:2250/20000 train_loss:2.1400 train_time:102739ms step_avg:45.66ms step:2300/20000 train_loss:2.3533 train_time:105126ms step_avg:45.71ms step:2350/20000 train_loss:2.1788 train_time:107379ms step_avg:45.69ms step:2400/20000 train_loss:2.1812 train_time:109631ms step_avg:45.68ms step:2400/20000 val_loss:2.1944 val_bpb:1.2997 train_time:109653ms step_avg:45.69ms step:2450/20000 train_loss:2.1842 train_time:111889ms step_avg:45.67ms step:2500/20000 train_loss:2.0987 train_time:114285ms step_avg:45.71ms step:2550/20000 train_loss:2.1142 train_time:116542ms step_avg:45.70ms step:2600/20000 train_loss:2.3890 train_time:118794ms step_avg:45.69ms step:2600/20000 val_loss:2.1926 val_bpb:1.2986 train_time:118817ms step_avg:45.70ms step:2650/20000 train_loss:2.2227 train_time:121046ms step_avg:45.68ms step:2700/20000 train_loss:2.1354 train_time:123439ms step_avg:45.72ms step:2750/20000 train_loss:2.3391 train_time:125689ms step_avg:45.71ms step:2800/20000 train_loss:2.2132 train_time:127942ms step_avg:45.69ms step:2800/20000 val_loss:2.1801 val_bpb:1.2912 train_time:127965ms step_avg:45.70ms step:2850/20000 train_loss:2.1652 train_time:130191ms step_avg:45.68ms step:2900/20000 train_loss:2.1553 train_time:132603ms step_avg:45.73ms step:2950/20000 train_loss:2.2221 train_time:134847ms step_avg:45.71ms step:3000/20000 train_loss:2.2030 train_time:137103ms step_avg:45.70ms step:3000/20000 val_loss:2.1722 val_bpb:1.2865 train_time:137127ms step_avg:45.71ms step:3050/20000 train_loss:2.1472 train_time:139356ms step_avg:45.69ms step:3100/20000 train_loss:2.1888 train_time:141742ms step_avg:45.72ms step:3150/20000 train_loss:2.1400 train_time:143993ms step_avg:45.71ms step:3200/20000 train_loss:2.1645 train_time:146242ms step_avg:45.70ms step:3200/20000 val_loss:2.1671 val_bpb:1.2835 train_time:146268ms step_avg:45.71ms step:3250/20000 train_loss:2.0671 train_time:148637ms step_avg:45.73ms step:3300/20000 train_loss:2.2137 train_time:150885ms step_avg:45.72ms step:3350/20000 train_loss:2.0731 train_time:153128ms step_avg:45.71ms step:3400/20000 train_loss:2.1360 train_time:155380ms step_avg:45.70ms step:3400/20000 val_loss:2.1642 val_bpb:1.2818 train_time:155406ms step_avg:45.71ms step:3450/20000 train_loss:2.0890 train_time:157762ms step_avg:45.73ms step:3500/20000 train_loss:2.2285 train_time:160022ms step_avg:45.72ms step:3550/20000 train_loss:2.3672 train_time:162275ms step_avg:45.71ms step:3600/20000 train_loss:2.0920 train_time:164524ms step_avg:45.70ms step:3600/20000 val_loss:2.1558 val_bpb:1.2768 train_time:164550ms step_avg:45.71ms step:3650/20000 train_loss:2.2009 train_time:166952ms step_avg:45.74ms step:3700/20000 train_loss:2.1262 train_time:169204ms step_avg:45.73ms step:3750/20000 train_loss:2.1242 train_time:171454ms step_avg:45.72ms step:3800/20000 train_loss:2.1979 train_time:173690ms step_avg:45.71ms step:3800/20000 val_loss:2.1522 val_bpb:1.2747 train_time:173716ms step_avg:45.71ms step:3850/20000 train_loss:2.1522 train_time:176098ms step_avg:45.74ms step:3900/20000 train_loss:1.9648 train_time:178363ms step_avg:45.73ms step:3950/20000 train_loss:2.1063 train_time:180605ms step_avg:45.72ms step:4000/20000 train_loss:2.1404 train_time:182855ms step_avg:45.71ms step:4000/20000 val_loss:2.1483 val_bpb:1.2723 train_time:182880ms step_avg:45.72ms step:4050/20000 train_loss:2.0791 train_time:185248ms step_avg:45.74ms step:4100/20000 train_loss:2.1671 train_time:187504ms step_avg:45.73ms step:4150/20000 train_loss:2.3022 train_time:189754ms step_avg:45.72ms step:4200/20000 train_loss:2.1544 train_time:192160ms step_avg:45.75ms step:4200/20000 val_loss:2.1438 val_bpb:1.2697 train_time:192179ms step_avg:45.76ms step:4250/20000 train_loss:2.1045 train_time:194411ms step_avg:45.74ms step:4300/20000 train_loss:2.0035 train_time:196665ms step_avg:45.74ms step:4350/20000 train_loss:2.1856 train_time:198915ms step_avg:45.73ms step:4400/20000 train_loss:2.0900 train_time:201320ms step_avg:45.75ms step:4400/20000 val_loss:2.1441 val_bpb:1.2698 train_time:201344ms step_avg:45.76ms step:4450/20000 train_loss:2.0453 train_time:203578ms step_avg:45.75ms step:4500/20000 train_loss:2.2388 train_time:205817ms step_avg:45.74ms step:4550/20000 train_loss:2.0374 train_time:208069ms step_avg:45.73ms step:4600/20000 train_loss:1.9508 train_time:210467ms step_avg:45.75ms step:4600/20000 val_loss:2.1397 val_bpb:1.2672 train_time:210490ms step_avg:45.76ms step:4650/20000 train_loss:2.0559 train_time:212730ms step_avg:45.75ms step:4700/20000 train_loss:2.2390 train_time:214973ms step_avg:45.74ms step:4750/20000 train_loss:1.9568 train_time:217228ms step_avg:45.73ms step:4800/20000 train_loss:2.2412 train_time:219613ms step_avg:45.75ms step:4800/20000 val_loss:2.1358 val_bpb:1.2649 train_time:219633ms step_avg:45.76ms step:4850/20000 train_loss:2.1316 train_time:221864ms step_avg:45.75ms step:4900/20000 train_loss:2.1431 train_time:224117ms step_avg:45.74ms step:4950/20000 train_loss:2.3209 train_time:226367ms step_avg:45.73ms step:5000/20000 train_loss:2.0024 train_time:228775ms step_avg:45.75ms step:5000/20000 val_loss:2.1310 val_bpb:1.2621 train_time:228799ms step_avg:45.76ms step:5050/20000 train_loss:2.1805 train_time:231025ms step_avg:45.75ms step:5100/20000 train_loss:2.0028 train_time:233279ms step_avg:45.74ms step:5150/20000 train_loss:2.2555 train_time:235682ms step_avg:45.76ms step:5200/20000 train_loss:2.1495 train_time:237949ms step_avg:45.76ms step:5200/20000 val_loss:2.1311 val_bpb:1.2621 train_time:237972ms step_avg:45.76ms step:5250/20000 train_loss:2.1005 train_time:240194ms step_avg:45.75ms step:5300/20000 train_loss:2.1919 train_time:242450ms step_avg:45.75ms step:5350/20000 train_loss:2.1187 train_time:244835ms step_avg:45.76ms step:5400/20000 train_loss:2.1622 train_time:247089ms step_avg:45.76ms step:5400/20000 val_loss:2.1266 val_bpb:1.2595 train_time:247105ms step_avg:45.76ms step:5450/20000 train_loss:2.1771 train_time:249340ms step_avg:45.75ms step:5500/20000 train_loss:2.1191 train_time:251580ms step_avg:45.74ms step:5550/20000 train_loss:2.0819 train_time:253964ms step_avg:45.76ms step:5600/20000 train_loss:2.1589 train_time:256223ms step_avg:45.75ms step:5600/20000 val_loss:2.1262 val_bpb:1.2593 train_time:256248ms step_avg:45.76ms step:5650/20000 train_loss:2.0342 train_time:258475ms step_avg:45.75ms step:5700/20000 train_loss:2.1558 train_time:260721ms step_avg:45.74ms step:5750/20000 train_loss:2.1967 train_time:263132ms step_avg:45.76ms step:5800/20000 train_loss:2.1182 train_time:265385ms step_avg:45.76ms step:5800/20000 val_loss:2.1233 val_bpb:1.2575 train_time:265407ms step_avg:45.76ms step:5850/20000 train_loss:2.1585 train_time:267633ms step_avg:45.75ms step:5900/20000 train_loss:2.0729 train_time:269888ms step_avg:45.74ms step:5950/20000 train_loss:2.1111 train_time:272265ms step_avg:45.76ms step:6000/20000 train_loss:2.2017 train_time:274533ms step_avg:45.76ms step:6000/20000 val_loss:2.1214 val_bpb:1.2564 train_time:274549ms step_avg:45.76ms step:6050/20000 train_loss:2.1040 train_time:276778ms step_avg:45.75ms step:6100/20000 train_loss:2.0972 train_time:279041ms step_avg:45.74ms step:6150/20000 train_loss:2.0780 train_time:281419ms step_avg:45.76ms step:6200/20000 train_loss:2.0625 train_time:283675ms step_avg:45.75ms step:6200/20000 val_loss:2.1192 val_bpb:1.2551 train_time:283696ms step_avg:45.76ms step:6250/20000 train_loss:2.1312 train_time:285929ms step_avg:45.75ms step:6300/20000 train_loss:2.0109 train_time:288345ms step_avg:45.77ms step:6350/20000 train_loss:1.9996 train_time:290597ms step_avg:45.76ms step:6400/20000 train_loss:2.1393 train_time:292855ms step_avg:45.76ms step:6400/20000 val_loss:2.1165 val_bpb:1.2535 train_time:292870ms step_avg:45.76ms step:6450/20000 train_loss:2.0561 train_time:295104ms step_avg:45.75ms step:6500/20000 train_loss:2.0563 train_time:297478ms step_avg:45.77ms step:6550/20000 train_loss:2.1868 train_time:299733ms step_avg:45.76ms step:6600/20000 train_loss:2.1016 train_time:301990ms step_avg:45.76ms step:6600/20000 val_loss:2.1126 val_bpb:1.2512 train_time:302015ms step_avg:45.76ms step:6650/20000 train_loss:2.2676 train_time:304249ms step_avg:45.75ms step:6700/20000 train_loss:2.1344 train_time:306638ms step_avg:45.77ms step:6750/20000 train_loss:2.3069 train_time:308884ms step_avg:45.76ms step:6800/20000 train_loss:2.1721 train_time:311133ms step_avg:45.75ms step:6800/20000 val_loss:2.1120 val_bpb:1.2508 train_time:311158ms step_avg:45.76ms step:6850/20000 train_loss:2.0027 train_time:313387ms step_avg:45.75ms step:6900/20000 train_loss:2.0708 train_time:315782ms step_avg:45.77ms step:6950/20000 train_loss:2.1526 train_time:318017ms step_avg:45.76ms step:7000/20000 train_loss:2.2026 train_time:320274ms step_avg:45.75ms step:7000/20000 val_loss:2.1100 val_bpb:1.2497 train_time:320294ms step_avg:45.76ms step:7050/20000 train_loss:2.2315 train_time:322525ms step_avg:45.75ms step:7100/20000 train_loss:2.0491 train_time:325043ms step_avg:45.78ms step:7150/20000 train_loss:2.1278 train_time:327297ms step_avg:45.78ms step:7200/20000 train_loss:2.1779 train_time:329554ms step_avg:45.77ms step:7200/20000 val_loss:2.1091 val_bpb:1.2492 train_time:329578ms step_avg:45.77ms step:7250/20000 train_loss:2.0836 train_time:331941ms step_avg:45.78ms step:7300/20000 train_loss:2.0657 train_time:334194ms step_avg:45.78ms step:7350/20000 train_loss:2.1627 train_time:336441ms step_avg:45.77ms step:7400/20000 train_loss:2.0957 train_time:338691ms step_avg:45.77ms step:7400/20000 val_loss:2.1063 val_bpb:1.2475 train_time:338718ms step_avg:45.77ms step:7450/20000 train_loss:2.0930 train_time:341085ms step_avg:45.78ms step:7500/20000 train_loss:2.0893 train_time:343339ms step_avg:45.78ms step:7550/20000 train_loss:2.1512 train_time:345594ms step_avg:45.77ms step:7600/20000 train_loss:1.9759 train_time:347849ms step_avg:45.77ms step:7600/20000 val_loss:2.1051 val_bpb:1.2467 train_time:347873ms step_avg:45.77ms step:7650/20000 train_loss:2.2603 train_time:350250ms step_avg:45.78ms step:7700/20000 train_loss:2.0707 train_time:352506ms step_avg:45.78ms step:7750/20000 train_loss:2.0895 train_time:354756ms step_avg:45.78ms step:7800/20000 train_loss:2.1256 train_time:357009ms step_avg:45.77ms step:7800/20000 val_loss:2.1026 val_bpb:1.2453 train_time:357034ms step_avg:45.77ms step:7850/20000 train_loss:1.9783 train_time:359400ms step_avg:45.78ms step:7900/20000 train_loss:2.1132 train_time:361649ms step_avg:45.78ms step:7950/20000 train_loss:2.0734 train_time:363898ms step_avg:45.77ms step:8000/20000 train_loss:2.0955 train_time:366149ms step_avg:45.77ms step:8000/20000 val_loss:2.1000 val_bpb:1.2437 train_time:366174ms step_avg:45.77ms step:8050/20000 train_loss:2.0608 train_time:368554ms step_avg:45.78ms step:8100/20000 train_loss:2.1254 train_time:370808ms step_avg:45.78ms step:8150/20000 train_loss:2.2333 train_time:373056ms step_avg:45.77ms step:8200/20000 train_loss:2.1677 train_time:375301ms step_avg:45.77ms step:8200/20000 val_loss:2.0991 val_bpb:1.2432 train_time:375326ms step_avg:45.77ms step:8250/20000 train_loss:2.1263 train_time:377716ms step_avg:45.78ms step:8300/20000 train_loss:2.0939 train_time:379966ms step_avg:45.78ms step:8350/20000 train_loss:2.2056 train_time:382213ms step_avg:45.77ms step:8400/20000 train_loss:2.1145 train_time:384602ms step_avg:45.79ms step:8400/20000 val_loss:2.0986 val_bpb:1.2429 train_time:384627ms step_avg:45.79ms step:8450/20000 train_loss:2.2060 train_time:386859ms step_avg:45.78ms step:8500/20000 train_loss:2.1064 train_time:389125ms step_avg:45.78ms step:8550/20000 train_loss:2.1724 train_time:391359ms step_avg:45.77ms step:8600/20000 train_loss:2.1130 train_time:393767ms step_avg:45.79ms step:8600/20000 val_loss:2.0957 val_bpb:1.2412 train_time:393779ms step_avg:45.79ms step:8650/20000 train_loss:2.0786 train_time:396002ms step_avg:45.78ms step:8700/20000 train_loss:2.0084 train_time:398259ms step_avg:45.78ms step:8750/20000 train_loss:2.1722 train_time:400508ms step_avg:45.77ms step:8800/20000 train_loss:2.0755 train_time:402921ms step_avg:45.79ms step:8800/20000 val_loss:2.0949 val_bpb:1.2407 train_time:402947ms step_avg:45.79ms step:8850/20000 train_loss:2.2856 train_time:405174ms step_avg:45.78ms step:8900/20000 train_loss:2.1794 train_time:407424ms step_avg:45.78ms step:8950/20000 train_loss:2.1362 train_time:409675ms step_avg:45.77ms step:9000/20000 train_loss:2.0033 train_time:412068ms step_avg:45.79ms step:9000/20000 val_loss:2.0952 val_bpb:1.2409 train_time:412090ms step_avg:45.79ms step:9050/20000 train_loss:2.0378 train_time:414315ms step_avg:45.78ms step:9100/20000 train_loss:2.2843 train_time:416568ms step_avg:45.78ms step:9150/20000 train_loss:1.9771 train_time:418810ms step_avg:45.77ms step:9200/20000 train_loss:2.0639 train_time:421205ms step_avg:45.78ms step:9200/20000 val_loss:2.0933 val_bpb:1.2398 train_time:421229ms step_avg:45.79ms step:9250/20000 train_loss:2.1723 train_time:423454ms step_avg:45.78ms step:9300/20000 train_loss:2.1050 train_time:425706ms step_avg:45.77ms step:9350/20000 train_loss:2.2069 train_time:428101ms step_avg:45.79ms step:9400/20000 train_loss:2.1070 train_time:430349ms step_avg:45.78ms step:9400/20000 val_loss:2.0910 val_bpb:1.2384 train_time:430374ms step_avg:45.78ms step:9450/20000 train_loss:2.1418 train_time:432606ms step_avg:45.78ms step:9500/20000 train_loss:2.2405 train_time:434854ms step_avg:45.77ms step:9550/20000 train_loss:2.1759 train_time:437256ms step_avg:45.79ms step:9600/20000 train_loss:2.1233 train_time:439511ms step_avg:45.78ms step:9600/20000 val_loss:2.0904 val_bpb:1.2381 train_time:439538ms step_avg:45.79ms step:9650/20000 train_loss:2.0687 train_time:441764ms step_avg:45.78ms step:9700/20000 train_loss:2.0842 train_time:444016ms step_avg:45.77ms step:9750/20000 train_loss:2.0422 train_time:446408ms step_avg:45.79ms step:9800/20000 train_loss:2.0480 train_time:448658ms step_avg:45.78ms step:9800/20000 val_loss:2.0920 val_bpb:1.2390 train_time:448685ms step_avg:45.78ms step:9850/20000 train_loss:2.0115 train_time:450915ms step_avg:45.78ms step:9900/20000 train_loss:2.1271 train_time:453166ms step_avg:45.77ms step:9950/20000 train_loss:2.0026 train_time:455566ms step_avg:45.79ms step:10000/20000 train_loss:2.0928 train_time:457811ms step_avg:45.78ms step:10000/20000 val_loss:2.0901 val_bpb:1.2379 train_time:457836ms step_avg:45.78ms step:10050/20000 train_loss:2.0821 train_time:460064ms step_avg:45.78ms step:10100/20000 train_loss:2.0749 train_time:462314ms step_avg:45.77ms step:10150/20000 train_loss:2.0414 train_time:464692ms step_avg:45.78ms step:10200/20000 train_loss:2.0427 train_time:466956ms step_avg:45.78ms step:10200/20000 val_loss:2.0864 val_bpb:1.2357 train_time:466975ms step_avg:45.78ms step:10250/20000 train_loss:2.0468 train_time:469202ms step_avg:45.78ms step:10300/20000 train_loss:2.1732 train_time:471628ms step_avg:45.79ms step:10350/20000 train_loss:2.1053 train_time:473869ms step_avg:45.78ms step:10400/20000 train_loss:2.0771 train_time:476113ms step_avg:45.78ms step:10400/20000 val_loss:2.0866 val_bpb:1.2358 train_time:476129ms step_avg:45.78ms step:10450/20000 train_loss:2.0673 train_time:478356ms step_avg:45.78ms step:10500/20000 train_loss:1.9587 train_time:480740ms step_avg:45.78ms step:10550/20000 train_loss:1.9904 train_time:482992ms step_avg:45.78ms step:10600/20000 train_loss:1.9568 train_time:485256ms step_avg:45.78ms step:10600/20000 val_loss:2.0864 val_bpb:1.2357 train_time:485269ms step_avg:45.78ms step:10650/20000 train_loss:2.1711 train_time:487496ms step_avg:45.77ms step:10700/20000 train_loss:2.0553 train_time:489875ms step_avg:45.78ms step:10750/20000 train_loss:2.1130 train_time:492156ms step_avg:45.78ms step:10800/20000 train_loss:2.1674 train_time:494388ms step_avg:45.78ms step:10800/20000 val_loss:2.0843 val_bpb:1.2345 train_time:494408ms step_avg:45.78ms step:10850/20000 train_loss:2.1129 train_time:496643ms step_avg:45.77ms step:10900/20000 train_loss:2.1275 train_time:499051ms step_avg:45.78ms step:10950/20000 train_loss:2.0914 train_time:501302ms step_avg:45.78ms step:11000/20000 train_loss:2.0934 train_time:503554ms step_avg:45.78ms step:11000/20000 val_loss:2.0832 val_bpb:1.2338 train_time:503578ms step_avg:45.78ms step:11050/20000 train_loss:2.0554 train_time:505806ms step_avg:45.77ms step:11100/20000 train_loss:2.0394 train_time:508195ms step_avg:45.78ms step:11150/20000 train_loss:2.1367 train_time:510449ms step_avg:45.78ms step:11200/20000 train_loss:2.0480 train_time:512705ms step_avg:45.78ms step:11200/20000 val_loss:2.0825 val_bpb:1.2333 train_time:512727ms step_avg:45.78ms step:11250/20000 train_loss:1.9284 train_time:514958ms step_avg:45.77ms step:11300/20000 train_loss:1.9753 train_time:517364ms step_avg:45.78ms step:11350/20000 train_loss:1.9595 train_time:519613ms step_avg:45.78ms step:11400/20000 train_loss:2.0328 train_time:521860ms step_avg:45.78ms step:11400/20000 val_loss:2.0824 val_bpb:1.2333 train_time:521885ms step_avg:45.78ms step:11450/20000 train_loss:2.0225 train_time:524252ms step_avg:45.79ms step:11500/20000 train_loss:2.0859 train_time:526498ms step_avg:45.78ms step:11550/20000 train_loss:2.0876 train_time:528748ms step_avg:45.78ms step:11600/20000 train_loss:2.0380 train_time:531003ms step_avg:45.78ms step:11600/20000 val_loss:2.0805 val_bpb:1.2322 train_time:531027ms step_avg:45.78ms step:11650/20000 train_loss:2.1590 train_time:533404ms step_avg:45.79ms step:11700/20000 train_loss:2.1872 train_time:535658ms step_avg:45.78ms step:11750/20000 train_loss:2.0971 train_time:537908ms step_avg:45.78ms step:11800/20000 train_loss:2.0735 train_time:540163ms step_avg:45.78ms step:11800/20000 val_loss:2.0792 val_bpb:1.2314 train_time:540177ms step_avg:45.78ms step:11850/20000 train_loss:2.1114 train_time:542551ms step_avg:45.78ms step:11900/20000 train_loss:2.0381 train_time:544803ms step_avg:45.78ms step:11950/20000 train_loss:2.0598 train_time:547043ms step_avg:45.78ms step:12000/20000 train_loss:2.0470 train_time:549297ms step_avg:45.77ms step:12000/20000 val_loss:2.0757 val_bpb:1.2293 train_time:549321ms step_avg:45.78ms step:12050/20000 train_loss:2.0646 train_time:551692ms step_avg:45.78ms step:12100/20000 train_loss:2.0815 train_time:553944ms step_avg:45.78ms step:12150/20000 train_loss:2.2381 train_time:556195ms step_avg:45.78ms step:12200/20000 train_loss:2.1889 train_time:558450ms step_avg:45.77ms step:12200/20000 val_loss:2.0693 val_bpb:1.2256 train_time:558473ms step_avg:45.78ms step:12250/20000 train_loss:1.8816 train_time:560835ms step_avg:45.78ms step:12300/20000 train_loss:2.0743 train_time:563087ms step_avg:45.78ms step:12350/20000 train_loss:2.1383 train_time:565338ms step_avg:45.78ms step:12400/20000 train_loss:1.8282 train_time:567731ms step_avg:45.78ms step:12400/20000 val_loss:2.0625 val_bpb:1.2215 train_time:567756ms step_avg:45.79ms step:12450/20000 train_loss:1.9936 train_time:569994ms step_avg:45.78ms step:12500/20000 train_loss:2.3264 train_time:572244ms step_avg:45.78ms step:12550/20000 train_loss:2.1008 train_time:574498ms step_avg:45.78ms step:12600/20000 train_loss:2.0522 train_time:576894ms step_avg:45.79ms step:12600/20000 val_loss:2.0554 val_bpb:1.2174 train_time:576920ms step_avg:45.79ms step:12650/20000 train_loss:2.0168 train_time:579149ms step_avg:45.78ms step:12700/20000 train_loss:2.0485 train_time:581399ms step_avg:45.78ms step:12750/20000 train_loss:2.0600 train_time:583649ms step_avg:45.78ms step:12800/20000 train_loss:2.0687 train_time:586059ms step_avg:45.79ms step:12800/20000 val_loss:2.0480 val_bpb:1.2129 train_time:586089ms step_avg:45.79ms step:12850/20000 train_loss:1.9659 train_time:588321ms step_avg:45.78ms step:12900/20000 train_loss:2.0918 train_time:590577ms step_avg:45.78ms step:12950/20000 train_loss:1.9700 train_time:592831ms step_avg:45.78ms step:13000/20000 train_loss:2.1380 train_time:595226ms step_avg:45.79ms step:13000/20000 val_loss:2.0416 val_bpb:1.2092 train_time:595253ms step_avg:45.79ms step:13050/20000 train_loss:2.0707 train_time:597483ms step_avg:45.78ms step:13100/20000 train_loss:1.9696 train_time:599748ms step_avg:45.78ms step:13101/20000 val_loss:2.0396 val_bpb:1.2080 train_time:599812ms step_avg:45.78ms stopping_early: wallclock_cap train_time:599812ms step:13101/20000 peak memory allocated: 11389 MiB reserved: 11704 MiB Serialized model: 74578915 bytes Code size: 49058 bytes Total submission size: 74627973 bytes Serialized model int8+zlib: 15879916 bytes (payload:19030336 raw_torch:19080377 payload_ratio:3.92x) Total submission size int8+zlib: 15928974 bytes final_int8_zlib_roundtrip val_loss:2.0510 val_bpb:1.2147 eval_time:1432ms final_int8_zlib_roundtrip_exact val_loss:2.05104604 val_bpb:1.21474500