summaryrefslogtreecommitdiff
path: root/train_gpt_mlx.py
diff options
context:
space:
mode:
authorWill DePue <williamd@openai.com>2026-03-18 09:32:01 -0700
committerWill DePue <williamd@openai.com>2026-03-18 09:32:01 -0700
commita15093adad328a650d421e53c078cbd2c45beb0e (patch)
treee054c4bde12b89e6d3b39d611d9caadabc7f7234 /train_gpt_mlx.py
Launch snapshot
Diffstat (limited to 'train_gpt_mlx.py')
-rw-r--r--train_gpt_mlx.py1088
1 files changed, 1088 insertions, 0 deletions
diff --git a/train_gpt_mlx.py b/train_gpt_mlx.py
new file mode 100644
index 0000000..bf7c7d1
--- /dev/null
+++ b/train_gpt_mlx.py
@@ -0,0 +1,1088 @@
+#!/usr/bin/env python3
+"""
+The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder.
+
+Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines.
+"""
+from __future__ import annotations
+
+import glob
+import json
+import math
+import os
+import pickle
+import sys
+import time
+import uuid
+import zlib
+from collections.abc import Callable
+from pathlib import Path
+
+import numpy as np
+import sentencepiece as spm
+
+import mlx.core as mx
+import mlx.nn as nn
+import mlx.optimizers as optim
+from mlx.utils import tree_flatten, tree_unflatten
+
+# ==============================================================================
+# SHARD FORMAT + COMPUTE DTYPE
+# ==============================================================================
+
+COMPUTE_DTYPE = mx.bfloat16
+
+# ==============================================================================
+# HYPERPARAMETERS
+# ==============================================================================
+# Default Simple Baseline run:
+# - 9 transformer blocks at width 512
+# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion
+# - vocab size 1024, sequence length 1024, tied embeddings
+# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap
+class Hyperparameters:
+ # Data / tokenizer.
+ data_path: str = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024")
+ tokenizer_path: str = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model")
+ run_id: str = os.environ.get("RUN_ID", str(uuid.uuid4()))
+ seed: int = int(os.environ.get("SEED", 1337))
+
+ # Training loop. These defaults now mirror train_gpt.py on a single process.
+ iterations: int = int(os.environ.get("ITERATIONS", 20_000))
+ val_loss_every: int = int(os.environ.get("VAL_LOSS_EVERY", 0))
+ # Validation always uses the full fineweb_val split.
+ val_batch_size: int = int(os.environ.get("VAL_BATCH_SIZE", 524_288))
+ train_log_every: int = int(os.environ.get("TRAIN_LOG_EVERY", 200))
+ train_batch_tokens: int = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288))
+ grad_accum_steps: int = int(os.environ.get("GRAD_ACCUM_STEPS", 8))
+ train_seq_len: int = int(os.environ.get("TRAIN_SEQ_LEN", os.environ.get("TRAIN_MAX_SEQ_LEN", 1024)))
+ # Chunk each logical MLX microbatch into smaller sub-batches to reduce peak
+ # memory pressure without changing the effective optimizer batch.
+ mlx_max_microbatch_tokens: int = int(os.environ.get("MLX_MAX_MICROBATCH_TOKENS", 8_192))
+ warmup_steps: int = int(os.environ.get("WARMUP_STEPS", 20))
+ warmdown_iters: int = int(os.environ.get("WARMDOWN_ITERS", 1200))
+ max_wallclock_seconds: float = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0))
+
+ # Model (defaults match the current baseline setup).
+ vocab_size: int = int(os.environ.get("VOCAB_SIZE", 1024))
+ num_layers: int = int(os.environ.get("NUM_LAYERS", 9))
+ model_dim: int = int(os.environ.get("MODEL_DIM", 512))
+ num_heads: int = int(os.environ.get("NUM_HEADS", 8))
+ num_kv_heads: int = int(os.environ.get("NUM_KV_HEADS", 4))
+ mlp_mult: int = int(os.environ.get("MLP_MULT", 2))
+ tie_embeddings: bool = bool(int(os.environ.get("TIE_EMBEDDINGS", "1")))
+ tied_embed_init_std: float = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005))
+ logit_chunk_tokens: int = int(os.environ.get("LOGIT_CHUNK_TOKENS", 0))
+ logit_softcap: float = float(os.environ.get("LOGIT_SOFTCAP", 30.0))
+ rope_base: float = float(os.environ.get("ROPE_BASE", 10000.0))
+ qk_gain_init: float = float(os.environ.get("QK_GAIN_INIT", 1.5))
+
+ # Optimizer. We keep the same per-group defaults as train_gpt.py.
+ beta1: float = float(os.environ.get("BETA1", 0.9))
+ beta2: float = float(os.environ.get("BETA2", 0.95))
+ adam_eps: float = float(os.environ.get("ADAM_EPS", 1e-8))
+ tied_embed_lr: float = float(os.environ.get("TIED_EMBED_LR", 0.05))
+ matrix_lr: float = float(os.environ.get("MATRIX_LR", 0.04))
+ scalar_lr: float = float(os.environ.get("SCALAR_LR", 0.04))
+ muon_momentum: float = float(os.environ.get("MUON_MOMENTUM", 0.95))
+ muon_backend_steps: int = int(os.environ.get("MUON_BACKEND_STEPS", 5))
+ muon_momentum_warmup_start: float = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85))
+ muon_momentum_warmup_steps: int = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500))
+ grad_clip_norm: float = float(os.environ.get("GRAD_CLIP_NORM", 0.0))
+
+ out_dir: str = os.environ.get("OUT_DIR", "logs")
+
+ @property
+ def train_files(self) -> str:
+ return f"{self.data_path}/fineweb_train_*.bin"
+
+ @property
+ def val_files(self) -> str:
+ return f"{self.data_path}/fineweb_val_*.bin"
+
+ @property
+ def microbatch_tokens(self) -> int:
+ return self.train_batch_tokens // self.grad_accum_steps
+
+ def lr_mul(self, step: int, elapsed_ms: float) -> float:
+ if self.warmdown_iters <= 0:
+ return 1.0
+ if self.max_wallclock_seconds <= 0:
+ warmdown_start = max(self.iterations - self.warmdown_iters, 0)
+ return max((self.iterations - step) / max(self.warmdown_iters, 1), 0.0) if warmdown_start <= step < self.iterations else 1.0
+ step_ms = elapsed_ms / max(step, 1)
+ warmdown_ms = self.warmdown_iters * step_ms
+ remaining_ms = max(1000.0 * self.max_wallclock_seconds - elapsed_ms, 0.0)
+ return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0
+
+
+CONTROL_TENSOR_NAME_PATTERNS = tuple(
+ pattern
+ for pattern in os.environ.get(
+ "CONTROL_TENSOR_NAME_PATTERNS",
+ "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights",
+ ).split(",")
+ if pattern
+)
+INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple(
+ pattern
+ for pattern in os.environ.get(
+ "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS",
+ ",".join(CONTROL_TENSOR_NAME_PATTERNS),
+ ).split(",")
+ if pattern
+)
+
+
+def token_chunks(total_tokens: int, seq_len: int, max_chunk_tokens: int) -> list[int]:
+ usable_total = (total_tokens // seq_len) * seq_len
+ if usable_total <= 0:
+ raise ValueError(f"token budget too small for seq_len={seq_len}")
+ usable_chunk = max((max_chunk_tokens // seq_len) * seq_len, seq_len)
+ chunks: list[int] = []
+ remaining = usable_total
+ while remaining > 0:
+ chunk = min(remaining, usable_chunk)
+ chunks.append(chunk)
+ remaining -= chunk
+ return chunks
+
+
+def accumulate_flat_grads(
+ accum: dict[str, mx.array] | None,
+ grads_tree: dict,
+ scale: float,
+) -> dict[str, mx.array]:
+ flat = dict(tree_flatten(grads_tree))
+ if accum is None:
+ return {k: g * scale for k, g in flat.items()}
+ for k, g in flat.items():
+ accum[k] = accum[k] + g * scale
+ return accum
+
+
+# ==============================================================================
+# MATH HELPERS
+# ==============================================================================
+
+def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array:
+ return (x * mx.rsqrt(mx.mean(x * x, axis=-1, keepdims=True) + eps)).astype(x.dtype)
+
+
+def zeropower_newtonschulz5(g: mx.array, steps: int, eps: float = 1e-7) -> mx.array:
+ # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration.
+ # Muon uses this to normalize matrix-shaped gradients before applying them.
+ # Background on Muon: https://kellerjordan.github.io/posts/muon/
+ a, b, c = 3.4445, -4.7750, 2.0315
+ x = g.astype(mx.float32)
+ x = x / (mx.sqrt(mx.sum(x * x)) + eps)
+ transposed = x.shape[0] > x.shape[1]
+ if transposed:
+ x = x.T
+ for _ in range(steps):
+ a_mat = x @ x.T
+ b_mat = b * a_mat + c * (a_mat @ a_mat)
+ x = a * x + b_mat @ x
+ if transposed:
+ x = x.T
+ return x.astype(g.dtype)
+
+
+def load_data_shard(path: Path) -> np.ndarray:
+ header_bytes = 256 * np.dtype("<i4").itemsize
+ token_bytes = np.dtype("<u2").itemsize
+ header = np.fromfile(path, dtype="<i4", count=256)
+ if header.size != 256 or int(header[0]) != 20240520 or int(header[1]) != 1:
+ raise ValueError(f"Unexpected shard header for {path}")
+ num_tokens = int(header[2])
+ if path.stat().st_size != header_bytes + num_tokens * token_bytes:
+ raise ValueError(f"Shard size mismatch for {path}")
+ tokens = np.fromfile(path, dtype="<u2", count=num_tokens, offset=header_bytes)
+ if tokens.size != num_tokens:
+ raise ValueError(f"Short read for {path}")
+ return tokens.astype(np.int32, copy=False)
+
+
+# ==============================================================================
+# TOKEN STREAMING / BATCHING
+# ==============================================================================
+
+
+class TokenStream:
+ def __init__(
+ self,
+ pattern: str,
+ log_fn: Callable[[str], None] | None = None,
+ dataset_name: str = "",
+ ):
+ self.files = [Path(p) for p in sorted(glob.glob(pattern))]
+ if not self.files:
+ raise FileNotFoundError(f"No files found for pattern: {pattern}")
+ self.epoch = 1
+ self.file_idx = 0
+ self.log_fn = log_fn
+ self.dataset_name = dataset_name
+ self.tokens = load_data_shard(self.files[0])
+ self.pos = 0
+
+ def next_file(self) -> None:
+ self.file_idx = (self.file_idx + 1) % len(self.files)
+ if self.file_idx == 0:
+ self.epoch += 1
+ if self.log_fn is not None:
+ self.log_fn(
+ f"WARNING: starting epoch:{self.epoch} "
+ f"dataset:{self.dataset_name} train_shards:{len(self.files)}"
+ )
+ self.tokens = load_data_shard(self.files[self.file_idx])
+ self.pos = 0
+
+ def take(self, n: int) -> np.ndarray:
+ chunks: list[np.ndarray] = []
+ left = n
+ while left > 0:
+ if self.pos >= self.tokens.size:
+ self.next_file()
+ k = min(left, int(self.tokens.size - self.pos))
+ chunks.append(self.tokens[self.pos : self.pos + k])
+ self.pos += k
+ left -= k
+ return chunks[0] if len(chunks) == 1 else np.concatenate(chunks, axis=0)
+
+
+class TokenLoader:
+ def __init__(
+ self,
+ pattern: str,
+ log_fn: Callable[[str], None] | None = None,
+ dataset_name: str = "",
+ ):
+ self.stream = TokenStream(pattern, log_fn=log_fn, dataset_name=dataset_name)
+
+ def next_batch(self, batch_tokens: int, seq_len: int) -> tuple[mx.array, mx.array]:
+ usable = (batch_tokens // seq_len) * seq_len
+ if usable <= 0:
+ raise ValueError(f"token budget too small for seq_len={seq_len}")
+ chunk = self.stream.take(usable + 1)
+ x = chunk[:-1].reshape(-1, seq_len)
+ y = chunk[1:].reshape(-1, seq_len)
+ return mx.array(x, dtype=mx.int32), mx.array(y, dtype=mx.int32)
+
+
+# ==============================================================================
+# MODEL BLOCKS
+# ==============================================================================
+
+class CastedLinear(nn.Module):
+ def __init__(self, in_dim: int, out_dim: int):
+ super().__init__()
+ self.weight = nn.Linear(in_dim, out_dim, bias=False).weight.astype(mx.float32)
+
+ def __call__(self, x: mx.array) -> mx.array:
+ return x @ self.weight.astype(x.dtype).T
+
+
+class RMSNormNoWeight(nn.Module):
+ # MLX module wrapper around the functional RMSNorm helper so it composes nicely in blocks.
+ def __call__(self, x: mx.array) -> mx.array:
+ return rms_norm(x)
+
+
+class CausalSelfAttention(nn.Module):
+ # - separate q/k/v projections
+ # - RMSNorm on q and k before attention
+ # - RoPE on q and k
+ # - causal masked SDPA
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ num_kv_heads: int,
+ rope_base: float,
+ qk_gain_init: float,
+ ):
+ super().__init__()
+ if dim % num_heads != 0:
+ raise ValueError("model_dim must be divisible by num_heads")
+ if num_heads % num_kv_heads != 0:
+ raise ValueError("num_heads must be divisible by num_kv_heads")
+ self.num_heads = num_heads
+ self.num_kv_heads = num_kv_heads
+ self.head_dim = dim // num_heads
+ if self.head_dim % 2 != 0:
+ raise ValueError("head_dim must be even for RoPE")
+ kv_dim = self.num_kv_heads * self.head_dim
+ self.c_q = CastedLinear(dim, dim)
+ self.c_k = CastedLinear(dim, kv_dim)
+ self.c_v = CastedLinear(dim, kv_dim)
+ self.proj = CastedLinear(dim, dim)
+ self.q_gain = mx.ones((num_heads,), dtype=mx.float32) * qk_gain_init
+ self.rope = nn.RoPE(self.head_dim, traditional=False, base=rope_base)
+ self.scale = self.head_dim ** -0.5
+
+ def __call__(self, x: mx.array) -> mx.array:
+ bsz, seqlen, dim = x.shape
+ q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
+ k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
+ v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
+
+ q = self.rope(rms_norm(q).astype(COMPUTE_DTYPE))
+ k = self.rope(rms_norm(k).astype(COMPUTE_DTYPE))
+ q = q * self.q_gain.astype(q.dtype)[None, :, None, None]
+ y = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask="causal")
+ y = y.transpose(0, 2, 1, 3).reshape(bsz, seqlen, dim)
+ return self.proj(y)
+
+
+class MLP(nn.Module):
+ # Baseline MLP uses relu^2 instead of GELU/SiLU. It is cheap and works well in this setup.
+ def __init__(self, dim: int, mlp_mult: int):
+ super().__init__()
+ hidden = dim * mlp_mult
+ self.fc = CastedLinear(dim, hidden)
+ self.proj = CastedLinear(hidden, dim)
+
+ def __call__(self, x: mx.array) -> mx.array:
+ x = nn.relu(self.fc(x))
+ return self.proj(x * x)
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ num_kv_heads: int,
+ mlp_mult: int,
+ rope_base: float,
+ qk_gain_init: float,
+ ):
+ super().__init__()
+ self.attn_norm = RMSNormNoWeight()
+ self.mlp_norm = RMSNormNoWeight()
+ self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init)
+ self.mlp = MLP(dim, mlp_mult)
+ self.attn_scale = mx.ones((dim,), dtype=mx.float32)
+ self.mlp_scale = mx.ones((dim,), dtype=mx.float32)
+ self.resid_mix = mx.array(np.stack((np.ones((dim,), dtype=np.float32), np.zeros((dim,), dtype=np.float32))))
+
+ def __call__(self, x: mx.array, x0: mx.array) -> mx.array:
+ mix = self.resid_mix.astype(x.dtype)
+ x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0
+ attn_out = self.attn(self.attn_norm(x))
+ x = x + self.attn_scale.astype(x.dtype)[None, None, :] * attn_out
+ x = x + self.mlp_scale.astype(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x))
+ return x
+
+
+class GPT(nn.Module):
+ # - token embedding + RMSNorm
+ # - encoder half accumulates skip tensors
+ # - decoder half consumes reversed skips with learned skip_weights
+ # - tied embeddings for the LM head (the baseline default setup)
+ def __init__(self, vocab_size: int, num_layers: int, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int,
+ logit_chunk_tokens: int, logit_softcap: float, rope_base: float, tied_embed_init_std: float,
+ qk_gain_init: float):
+ super().__init__()
+ if logit_softcap <= 0.0:
+ raise ValueError(f"logit_softcap must be positive, got {logit_softcap}")
+ self.logit_chunk_tokens = logit_chunk_tokens
+ self.logit_softcap = logit_softcap
+
+ self.tok_emb = nn.Embedding(vocab_size, dim)
+ self.num_encoder_layers = num_layers // 2
+ self.num_decoder_layers = num_layers - self.num_encoder_layers
+ self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers)
+ self.skip_weights = mx.ones((self.num_skip_weights, dim), dtype=mx.float32)
+ self.blocks = [
+ Block(dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init)
+ for i in range(num_layers)
+ ]
+ self.final_norm = RMSNormNoWeight()
+
+ for b in self.blocks:
+ b.attn.proj.weight = mx.zeros_like(b.attn.proj.weight)
+ b.mlp.proj.weight = mx.zeros_like(b.mlp.proj.weight)
+ self.tok_emb.weight = (
+ mx.random.normal(self.tok_emb.weight.shape, dtype=mx.float32) * tied_embed_init_std
+ ).astype(COMPUTE_DTYPE)
+
+ def softcap(self, logits: mx.array) -> mx.array:
+ c = self.logit_softcap
+ return c * mx.tanh(logits / c)
+
+ def __call__(self, input_ids: mx.array) -> mx.array:
+ x = rms_norm(self.tok_emb(input_ids).astype(COMPUTE_DTYPE))
+ x0 = x
+ skips: list[mx.array] = []
+
+ for i in range(self.num_encoder_layers):
+ x = self.blocks[i](x, x0)
+ skips.append(x)
+ for i in range(self.num_decoder_layers):
+ # Odd layer counts have one more decoder block than encoder block. The baseline only
+ # applies a skip connection when one exists, then runs the remaining decoder block(s)
+ # without an added skip.
+ if skips:
+ x = x + self.skip_weights[i].astype(x.dtype)[None, None, :] * skips.pop()
+ x = self.blocks[self.num_encoder_layers + i](x, x0)
+ return self.final_norm(x)
+
+ def loss(self, input_ids: mx.array, target_ids: mx.array) -> mx.array:
+ # Cross-entropy over flattened tokens. We keep optional logit chunking because it is a useful
+ # memory knob on Macs, but the common path is chunk_tokens=0 (single matmul + CE).
+ x = self(input_ids).reshape(-1, self.tok_emb.weight.shape[1])
+ y = target_ids.reshape(-1)
+ if self.logit_chunk_tokens <= 0 or x.shape[0] <= self.logit_chunk_tokens:
+ logits_proj = x @ self.tok_emb.weight.astype(x.dtype).T
+ logits = self.softcap(logits_proj)
+ return nn.losses.cross_entropy(logits.astype(mx.float32), y, reduction="mean")
+
+ loss_sum = mx.array(0.0, dtype=mx.float32)
+ n = int(x.shape[0])
+ for s in range(0, n, self.logit_chunk_tokens):
+ e = min(s + self.logit_chunk_tokens, n)
+ logits_proj = x[s:e] @ self.tok_emb.weight.astype(x.dtype).T
+ logits = self.softcap(logits_proj)
+ loss_sum = loss_sum + nn.losses.cross_entropy(logits.astype(mx.float32), y[s:e], reduction="sum")
+ return loss_sum / float(n)
+
+# ==============================================================================
+# OPTIMIZERS (MUON + ADAM SPLIT)
+# ==============================================================================
+class Muon:
+ # Muon applies SGD-momentum to matrix gradients, then orthogonalizes the result before the
+ # parameter update.
+ def __init__(self, keys: list[str], params: dict[str, mx.array], args: Hyperparameters):
+ self.keys = keys
+ self.args = args
+ self.buffers = {k: mx.zeros_like(params[k]) for k in keys}
+
+ def step(self, params: dict[str, mx.array], grads: dict[str, mx.array], step: int, lr_mul: float) -> dict[str, mx.array]:
+ if self.args.muon_momentum_warmup_steps:
+ t = min(step / self.args.muon_momentum_warmup_steps, 1.0)
+ momentum = (1.0 - t) * self.args.muon_momentum_warmup_start + t * self.args.muon_momentum
+ else:
+ momentum = self.args.muon_momentum
+ lr = self.args.matrix_lr * lr_mul
+ out: dict[str, mx.array] = {}
+ for k in self.keys:
+ p = params[k]
+ g = grads[k]
+ buf = momentum * self.buffers[k] + g
+ self.buffers[k] = buf
+ g_eff = g + momentum * buf
+ g_ortho = zeropower_newtonschulz5(g_eff, self.args.muon_backend_steps)
+ scale = math.sqrt(max(1.0, float(p.shape[0]) / float(p.shape[1])))
+ out[k] = p - lr * (g_ortho * scale).astype(p.dtype)
+ return out
+
+
+class SplitOptimizers:
+ # - embeddings: Adam with the tied-embedding LR
+ # - block matrices (2D): Muon
+ # - block scalars + skip weights: Adam
+ # This preserves the high-level optimization behavior even though MLX internals differ.
+ def __init__(self, model: GPT, args: Hyperparameters):
+ self.args = args
+ params = dict(tree_flatten(model.parameters()))
+ self.embed_key = "tok_emb.weight"
+ self.matrix_keys = [
+ k
+ for k, p in params.items()
+ if k.startswith("blocks.") and p.ndim == 2 and not any(pattern in k for pattern in CONTROL_TENSOR_NAME_PATTERNS)
+ ]
+ self.scalar_keys = [
+ k
+ for k, p in params.items()
+ if k == "skip_weights" or (k.startswith("blocks.") and (p.ndim < 2 or any(pattern in k for pattern in CONTROL_TENSOR_NAME_PATTERNS)))
+ ]
+
+ self.muon = Muon(self.matrix_keys, params, args)
+ self.adam_embed = optim.Adam(
+ learning_rate=args.tied_embed_lr,
+ betas=[args.beta1, args.beta2],
+ eps=args.adam_eps,
+ bias_correction=True,
+ )
+ self.adam_scalar = optim.Adam(
+ learning_rate=args.scalar_lr,
+ betas=[args.beta1, args.beta2],
+ eps=args.adam_eps,
+ bias_correction=True,
+ )
+
+ def step(self, model: GPT, grads_tree: dict, step: int, lr_mul: float) -> None:
+ params = dict(tree_flatten(model.parameters()))
+ grads = dict(tree_flatten(grads_tree))
+ updated = dict(params)
+
+ updated.update(self.muon.step(params, grads, step=step, lr_mul=lr_mul))
+
+ self.adam_embed.learning_rate = self.args.tied_embed_lr * lr_mul
+ updated.update(
+ self.adam_embed.apply_gradients(
+ {self.embed_key: grads[self.embed_key]},
+ {self.embed_key: params[self.embed_key]},
+ )
+ )
+
+ self.adam_scalar.learning_rate = self.args.scalar_lr * lr_mul
+ scalar_grads = {k: grads[k] for k in self.scalar_keys}
+ scalar_params = {k: params[k] for k in self.scalar_keys}
+ updated.update(self.adam_scalar.apply_gradients(scalar_grads, scalar_params))
+
+ model.update(tree_unflatten(list(updated.items())))
+
+# ==============================================================================
+# QUANTIZATION (INT8 + ZLIB)
+# ==============================================================================
+# - per-row int8 for 2D float tensors
+# - per-tensor int8 for other float tensors
+# - fp16 passthrough for small float tensors
+# - exact passthrough for non-floats
+
+MX_DTYPE_FROM_NAME = {
+ "float32": mx.float32,
+ "float16": mx.float16,
+ "bfloat16": mx.bfloat16,
+}
+
+INT8_KEEP_FLOAT_MAX_NUMEL = 65_536
+INT8_KEEP_FLOAT_STORE_DTYPE = np.float16
+INT8_PER_ROW_SCALE_DTYPE = np.float16
+INT8_CLIP_PERCENTILE = 99.99984
+INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0
+
+
+def _np_float32(arr: mx.array) -> np.ndarray:
+ return np.array(arr.astype(mx.float32), dtype=np.float32, copy=False)
+
+
+def keep_float_array(name: str, arr: mx.array, passthrough_orig_dtypes: dict[str, str]) -> np.ndarray:
+ if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS):
+ return np.ascontiguousarray(_np_float32(arr))
+ if arr.dtype in {mx.float32, mx.bfloat16}:
+ passthrough_orig_dtypes[name] = str(arr.dtype).split(".")[-1]
+ return np.ascontiguousarray(np.array(arr.astype(mx.float16), dtype=INT8_KEEP_FLOAT_STORE_DTYPE, copy=False))
+ return np.ascontiguousarray(np.array(arr, copy=True))
+
+
+def quantize_float_array(arr: mx.array) -> tuple[np.ndarray, np.ndarray]:
+ f32 = _np_float32(arr)
+ if f32.ndim == 2:
+ # Matrices get one scale per row, which usually tracks output-channel
+ # ranges much better than a single tensor-wide scale.
+ clip_abs = np.quantile(np.abs(f32), INT8_CLIP_Q, axis=1) if f32.size else np.empty((f32.shape[0],), dtype=np.float32)
+ clipped = np.clip(f32, -clip_abs[:, None], clip_abs[:, None])
+ scale = np.maximum(clip_abs / 127.0, 1.0 / 127.0).astype(np.float32, copy=False)
+ q = np.clip(np.round(clipped / scale[:, None]), -127, 127).astype(np.int8, copy=False)
+ return np.ascontiguousarray(q), np.ascontiguousarray(scale.astype(INT8_PER_ROW_SCALE_DTYPE, copy=False))
+
+ # Vectors / scalars use a simpler per-tensor scale.
+ clip_abs = float(np.quantile(np.abs(f32).reshape(-1), INT8_CLIP_Q)) if f32.size else 0.0
+ scale = np.array(clip_abs / 127.0 if clip_abs > 0.0 else 1.0, dtype=np.float32)
+ q = np.clip(np.round(np.clip(f32, -clip_abs, clip_abs) / scale), -127, 127).astype(np.int8, copy=False)
+ return np.ascontiguousarray(q), scale
+
+
+def quantize_state_dict_int8(flat_state: dict[str, mx.array]) -> tuple[dict[str, object], dict[str, int]]:
+ quantized: dict[str, np.ndarray] = {}
+ scales: dict[str, np.ndarray] = {}
+ dtypes: dict[str, str] = {}
+ passthrough: dict[str, np.ndarray] = {}
+ passthrough_orig_dtypes: dict[str, str] = {}
+ qmeta: dict[str, dict[str, object]] = {}
+ stats = dict.fromkeys(
+ ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"),
+ 0,
+ )
+ for name, arr in flat_state.items():
+ stats["param_count"] += int(arr.size)
+ stats["num_tensors"] += 1
+ stats["baseline_tensor_bytes"] += int(arr.nbytes)
+ if not mx.issubdtype(arr.dtype, mx.floating):
+ stats["num_nonfloat_tensors"] += 1
+ passthrough[name] = np.ascontiguousarray(np.array(arr))
+ stats["int8_payload_bytes"] += int(passthrough[name].nbytes)
+ continue
+
+ # Small float tensors are cheap enough to keep directly. We still downcast
+ # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size.
+ if int(arr.size) <= INT8_KEEP_FLOAT_MAX_NUMEL:
+ kept = keep_float_array(name, arr, passthrough_orig_dtypes)
+ passthrough[name] = kept
+ stats["int8_payload_bytes"] += int(kept.nbytes)
+ continue
+
+ stats["num_float_tensors"] += 1
+ q, s = quantize_float_array(arr)
+ if s.ndim > 0:
+ qmeta[name] = {"scheme": "per_row", "axis": 0}
+ quantized[name] = q
+ scales[name] = s
+ dtypes[name] = str(arr.dtype).split(".")[-1]
+ stats["int8_payload_bytes"] += int(q.nbytes + s.nbytes)
+ obj: dict[str, object] = {
+ "__quant_format__": "int8_clean_per_row_v1",
+ "quantized": quantized,
+ "scales": scales,
+ "dtypes": dtypes,
+ "passthrough": passthrough,
+ }
+ if qmeta:
+ obj["qmeta"] = qmeta
+ if passthrough_orig_dtypes:
+ obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes
+ return obj, stats
+
+
+def dequantize_state_dict_int8(quant_obj: dict[str, object]) -> dict[str, mx.array]:
+ out: dict[str, mx.array] = {}
+ qmeta = quant_obj.get("qmeta", {})
+ passthrough_orig_dtypes = quant_obj.get("passthrough_orig_dtypes", {})
+ for name, q in quant_obj["quantized"].items():
+ q_np = np.asarray(q, dtype=np.int8)
+ dtype_name = quant_obj["dtypes"][name]
+ scale = np.asarray(quant_obj["scales"][name], dtype=np.float32)
+ if qmeta.get(name, {}).get("scheme") == "per_row" or scale.ndim > 0:
+ # Broadcast the saved row scale back across trailing dimensions.
+ out_arr = q_np.astype(np.float32) * scale.reshape((q_np.shape[0],) + (1,) * (q_np.ndim - 1))
+ else:
+ out_arr = q_np.astype(np.float32) * float(scale)
+ out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[dtype_name])
+ for name, arr in quant_obj["passthrough"].items():
+ # Restore small tensors, undoing the temporary fp16 storage cast if needed.
+ out_arr = np.array(arr, copy=True)
+ orig_dtype = passthrough_orig_dtypes.get(name)
+ if isinstance(orig_dtype, str):
+ out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[orig_dtype])
+ else:
+ out[name] = mx.array(out_arr)
+ return out
+
+
+def build_sentencepiece_luts(
+ sp: spm.SentencePieceProcessor, vocab_size: int
+) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
+ sp_vocab_size = int(sp.vocab_size())
+ table_size = max(sp_vocab_size, vocab_size)
+ base_bytes_lut = np.zeros((table_size,), dtype=np.int16)
+ has_leading_space_lut = np.zeros((table_size,), dtype=np.bool_)
+ is_boundary_token_lut = np.ones((table_size,), dtype=np.bool_)
+ for token_id in range(sp_vocab_size):
+ if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id):
+ continue
+ is_boundary_token_lut[token_id] = False
+ if sp.is_byte(token_id):
+ base_bytes_lut[token_id] = 1
+ continue
+ piece = sp.id_to_piece(token_id)
+ if piece.startswith("▁"):
+ has_leading_space_lut[token_id] = True
+ piece = piece[1:]
+ base_bytes_lut[token_id] = len(piece.encode("utf-8"))
+ return base_bytes_lut, has_leading_space_lut, is_boundary_token_lut
+
+
+def validate_dataset_tokenizer_pair(data_path: str, tokenizer_path: str) -> tuple[str, int, int | None]:
+ # The shard directory and tokenizer are coupled: val_bpb is only meaningful if we
+ # decode bytes with the exact tokenizer that produced the shards. The manifest
+ # lets the training script fail fast on accidental dataset/tokenizer mismatches.
+ dataset_dir = Path(data_path).resolve()
+ actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin")))
+ if len(dataset_dir.parents) < 2:
+ return dataset_dir.name, actual_train_files, None
+ manifest_path = dataset_dir.parents[1] / "manifest.json"
+ if not manifest_path.is_file():
+ return dataset_dir.name, actual_train_files, None
+
+ manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
+ dataset_entry = next((x for x in manifest.get("datasets", []) if x.get("name") == dataset_dir.name), None)
+ if dataset_entry is None:
+ return dataset_dir.name, actual_train_files, None
+
+ tokenizer_name = dataset_entry.get("tokenizer_name")
+ tokenizer_entry = (
+ next((x for x in manifest.get("tokenizers", []) if x.get("name") == tokenizer_name), None)
+ if tokenizer_name
+ else None
+ )
+ expected_name = Path((tokenizer_entry or {}).get("model_path") or (tokenizer_entry or {}).get("path") or "").name
+ if expected_name and Path(tokenizer_path).name != expected_name:
+ raise ValueError(f"{dataset_dir.name} expects tokenizer {expected_name}, got {Path(tokenizer_path).name}")
+ expected_train_files = (dataset_entry.get("stats") or {}).get("files_train")
+ if expected_train_files is not None:
+ expected_train_files = int(expected_train_files)
+ if actual_train_files > expected_train_files:
+ raise ValueError(
+ f"{dataset_dir.name} has more train shards than expected: found {actual_train_files}, "
+ f"manifest says {expected_train_files}"
+ )
+ return dataset_dir.name, actual_train_files, expected_train_files
+
+
+def load_validation_tokens(pattern: str, seq_len: int) -> np.ndarray:
+ files = [Path(p) for p in sorted(glob.glob(pattern))]
+ if not files:
+ raise FileNotFoundError(f"No files found for pattern: {pattern}")
+ # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*.
+ tokens = np.ascontiguousarray(np.concatenate([load_data_shard(file) for file in files], axis=0))
+ usable = ((tokens.size - 1) // seq_len) * seq_len
+ if usable <= 0:
+ raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}")
+ return tokens[: usable + 1]
+
+
+def loss_and_grad_chunked(
+ args: Hyperparameters,
+ train_loader: TokenLoader,
+ compiled_loss_and_grad,
+) -> tuple[mx.array, dict]:
+ chunk_sizes = token_chunks(args.microbatch_tokens, args.train_seq_len, args.mlx_max_microbatch_tokens)
+ total_tokens = float(sum(chunk_sizes))
+ loss_value = mx.array(0.0, dtype=mx.float32)
+ grad_accum: dict[str, mx.array] | None = None
+ for chunk_tokens in chunk_sizes:
+ x, y = train_loader.next_batch(chunk_tokens, args.train_seq_len)
+ loss, grads = compiled_loss_and_grad(x, y)
+ scale = float(y.size) / total_tokens
+ loss_value = loss_value + loss.astype(mx.float32) * scale
+ grad_accum = accumulate_flat_grads(grad_accum, grads, scale)
+ return loss_value, tree_unflatten(list(grad_accum.items()))
+
+
+def eval_val(
+ args: Hyperparameters,
+ compiled_loss,
+ val_tokens: np.ndarray,
+ base_bytes_lut: np.ndarray,
+ has_leading_space_lut: np.ndarray,
+ is_boundary_token_lut: np.ndarray,
+) -> tuple[float, float]:
+ # Validation computes two metrics:
+ # - val_loss: token cross-entropy (natural log)
+ # - val_bpb: tokenizer-agnostic compression metric used by the challenge
+ val_batch_tokens = args.val_batch_size // args.grad_accum_steps
+ if val_batch_tokens < args.train_seq_len:
+ raise ValueError(
+ "VAL_BATCH_SIZE must provide at least one sequence; "
+ f"got VAL_BATCH_SIZE={args.val_batch_size}, GRAD_ACCUM_STEPS={args.grad_accum_steps}, "
+ f"TRAIN_SEQ_LEN={args.train_seq_len}"
+ )
+ val_batch_seqs = val_batch_tokens // args.train_seq_len
+ total_seqs = (val_tokens.size - 1) // args.train_seq_len
+ total_loss = mx.array(0.0, dtype=mx.float32)
+ total_tokens = 0.0
+ total_bytes = 0.0
+ for batch_seq_start in range(0, total_seqs, val_batch_seqs):
+ batch_seq_end = min(batch_seq_start + val_batch_seqs, total_seqs)
+ raw_start = batch_seq_start * args.train_seq_len
+ raw_end = batch_seq_end * args.train_seq_len + 1
+ chunk = val_tokens[raw_start:raw_end]
+ x_np = chunk[:-1].reshape(-1, args.train_seq_len)
+ y_np = chunk[1:].reshape(-1, args.train_seq_len)
+ x = mx.array(x_np, dtype=mx.int32)
+ y = mx.array(y_np, dtype=mx.int32)
+ chunk_token_count = float(y.size)
+ total_loss = total_loss + compiled_loss(x, y).astype(mx.float32) * chunk_token_count
+ prev_ids = x_np.reshape(-1)
+ tgt_ids = y_np.reshape(-1)
+ bytes_np = base_bytes_lut[tgt_ids].astype(np.int16, copy=True)
+ bytes_np += (
+ has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]
+ ).astype(np.int16, copy=False)
+ total_tokens += chunk_token_count
+ total_bytes += float(bytes_np.astype(np.float64).sum())
+ total_loss = total_loss / total_tokens
+ mx.eval(total_loss)
+ val_loss = float(total_loss.item())
+ bits_per_token = val_loss / math.log(2.0)
+ val_bpb = bits_per_token * (total_tokens / total_bytes)
+ return val_loss, val_bpb
+
+# -----------------------------
+# TRAINING
+# -----------------------------
+
+def clip_grad_tree(grads_tree: dict, max_norm: float) -> dict:
+ if max_norm <= 0:
+ return grads_tree
+ flat = dict(tree_flatten(grads_tree))
+ total_sq = 0.0
+ for grad in flat.values():
+ total_sq += float(np.sum(np.square(_np_float32(grad)), dtype=np.float64))
+ if total_sq <= 0.0:
+ return grads_tree
+ total_norm = math.sqrt(total_sq)
+ if total_norm <= max_norm:
+ return grads_tree
+ scale = max_norm / (total_norm + 1e-12)
+ return tree_unflatten([(k, g * scale) for k, g in flat.items()])
+
+
+def main() -> None:
+ # ==============================================================================
+ # TOKENIZER + VALIDATION METRIC SETUP
+ # ==============================================================================
+ args = Hyperparameters()
+ out_dir = Path(args.out_dir)
+ out_dir.mkdir(parents=True, exist_ok=True)
+ logfile = out_dir / f"{args.run_id}.txt"
+ print(logfile)
+
+ def log(msg: str, console: bool = True) -> None:
+ if console:
+ print(msg)
+ with logfile.open("a", encoding="utf-8") as f:
+ print(msg, file=f)
+
+ code = Path(__file__).read_text(encoding="utf-8")
+ log(code, console=False)
+ log("=" * 100, console=False)
+ log(f"Running Python {sys.version}", console=False)
+ log(f"Running MLX {mx.__version__}", console=False)
+ log("=" * 100, console=False)
+
+ if not args.tie_embeddings:
+ raise NotImplementedError("train_gpt_mlx.py only supports tied embeddings")
+ if not args.tokenizer_path.endswith(".model"):
+ raise ValueError(f"TOKENIZER_PATH must point to a SentencePiece .model file: {args.tokenizer_path}")
+ sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path)
+ if int(sp.vocab_size()) != args.vocab_size:
+ raise ValueError(
+ f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}"
+ )
+ dataset_name, actual_train_files, expected_train_files = validate_dataset_tokenizer_pair(
+ args.data_path,
+ args.tokenizer_path,
+ )
+ val_tokens = load_validation_tokens(args.val_files, args.train_seq_len)
+
+ base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts(
+ sp, args.vocab_size
+ )
+
+ # ==============================================================================
+ # TRAINING SETUP
+ # ==============================================================================
+ mx.random.seed(args.seed)
+
+ train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name)
+
+ # ==============================================================================
+ # MODEL + OPTIMIZER SETUP
+ # ==============================================================================
+ model = GPT(
+ vocab_size=args.vocab_size,
+ num_layers=args.num_layers,
+ dim=args.model_dim,
+ num_heads=args.num_heads,
+ num_kv_heads=args.num_kv_heads,
+ mlp_mult=args.mlp_mult,
+ logit_chunk_tokens=args.logit_chunk_tokens,
+ logit_softcap=args.logit_softcap,
+ rope_base=args.rope_base,
+ tied_embed_init_std=args.tied_embed_init_std,
+ qk_gain_init=args.qk_gain_init,
+ )
+ opt = SplitOptimizers(model, args)
+
+ # ==============================================================================
+ # COMPILED TRAIN / EVAL FUNCTIONS (MLX)
+ # ==============================================================================
+ # The crucial MLX detail is capture scope: this model contains non-trainable arrays too (for example
+ # inside RoPE modules), so compiling only against trainable parameters throws "uncaptured inputs".
+ # Compiling the model-bound functions and capturing the full model state fixes that while still
+ # returning gradients only for trainable parameters via nn.value_and_grad(...).
+ compiled_loss = mx.compile(lambda x, y: model.loss(x, y), inputs=model.state, outputs=model.state)
+ compiled_loss_and_grad = mx.compile(
+ nn.value_and_grad(model, lambda x, y: model.loss(x, y)),
+ inputs=model.state,
+ outputs=model.state,
+ )
+
+ # Print config once so logs are self-describing.
+ n_params = sum(int(np.prod(p.shape)) for _, p in tree_flatten(model.parameters()))
+ log(f"run_id:{args.run_id}")
+ log(f"mlx_version:{mx.__version__}")
+ log(f"train_loader:shards pattern={args.train_files}")
+ log(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.size - 1}")
+ if expected_train_files is None:
+ log(f"train_loader:dataset:{dataset_name} train_shards:{actual_train_files}")
+ elif actual_train_files < expected_train_files:
+ log(
+ f"WARNING: train_loader:subset dataset:{dataset_name} "
+ f"train_shards:{actual_train_files}/{expected_train_files} "
+ f"new epochs will arrive sooner than the full dataset"
+ )
+ else:
+ log(f"train_loader:dataset:{dataset_name} train_shards:{actual_train_files}/{expected_train_files}")
+ log(f"tokenizer_path:{args.tokenizer_path}")
+ log(
+ f"model_params:{n_params} vocab_size:{args.vocab_size} layers:{args.num_layers} "
+ f"dim:{args.model_dim} heads:{args.num_heads} kv_heads:{args.num_kv_heads} "
+ f"seq_len:{args.train_seq_len} tie_embeddings:{args.tie_embeddings}"
+ )
+ log(
+ f"iterations:{args.iterations} train_batch_tokens:{args.train_batch_tokens} grad_accum_steps:{args.grad_accum_steps} "
+ f"microbatch_tokens:{args.microbatch_tokens} microbatch_batch_size:{args.microbatch_tokens // args.train_seq_len} "
+ f"val_batch_size:{args.val_batch_size} "
+ f"warmup_steps:{args.warmup_steps} max_wallclock_seconds:{args.max_wallclock_seconds:.3f}"
+ )
+ log(f"mlx_max_microbatch_tokens:{args.mlx_max_microbatch_tokens}")
+ log(
+ f"optimizer:muon+adam muon_matrix_params:{len(opt.matrix_keys)} scalar_params:{len(opt.scalar_keys)} "
+ f"embed_lr:{args.tied_embed_lr} "
+ f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr} "
+ f"muon_momentum:{args.muon_momentum} muon_steps:{args.muon_backend_steps}"
+ )
+ log(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}")
+ log(f"compute_dtype:{COMPUTE_DTYPE} compile:True")
+ log(
+ f"dtypes tok_emb:{model.tok_emb.weight.dtype} "
+ f"linear_weight:{model.blocks[0].attn.c_q.weight.dtype} "
+ f"skip_weights:{model.skip_weights.dtype}"
+ )
+
+ # ==============================================================================
+ # TRAINING LOOP
+ # ==============================================================================
+ if args.warmup_steps > 0:
+ # Warmup should only prime MLX compile/allocation paths. Updating parameters here forces us
+ # to snapshot and restore model/optimizer state, which is expensive on unified-memory Macs.
+ # Instead we run the real train shapes, force the loss/grads to materialize, and then reset
+ # the loader so measured training still starts from the true init and token window.
+ for warmup_step in range(args.warmup_steps):
+ accum: dict[str, mx.array] | None = None
+ warmup_loss = mx.array(0.0, dtype=mx.float32)
+ grad_scale = 1.0 / args.grad_accum_steps
+ for _ in range(args.grad_accum_steps):
+ warmup_loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad)
+ accum = accumulate_flat_grads(accum, grads, grad_scale)
+ mx.eval(warmup_loss, accum)
+ mx.synchronize()
+ if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps:
+ log(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}")
+
+ # Prime the standalone eval graph once too. It is compiled separately from value_and_grad.
+ val_batch_tokens = args.val_batch_size // args.grad_accum_steps
+ if val_batch_tokens < args.train_seq_len:
+ raise ValueError(
+ "VAL_BATCH_SIZE must provide at least one sequence; "
+ f"got VAL_BATCH_SIZE={args.val_batch_size}, GRAD_ACCUM_STEPS={args.grad_accum_steps}, "
+ f"TRAIN_SEQ_LEN={args.train_seq_len}"
+ )
+ warm_val_seqs = min(val_batch_tokens // args.train_seq_len, (val_tokens.size - 1) // args.train_seq_len)
+ warm_chunk = val_tokens[: warm_val_seqs * args.train_seq_len + 1]
+ x_val = mx.array(warm_chunk[:-1].reshape(-1, args.train_seq_len), dtype=mx.int32)
+ y_val = mx.array(warm_chunk[1:].reshape(-1, args.train_seq_len), dtype=mx.int32)
+ warm_val_loss = compiled_loss(x_val, y_val)
+ mx.eval(warm_val_loss)
+ mx.synchronize()
+
+ train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name)
+
+ train_time_ms = 0.0
+ max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None
+ stop_after_step: int | None = None
+ t0 = time.perf_counter()
+ step = 0
+ while True:
+ last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step)
+ if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0):
+ # Validation always scans the same fixed full validation split.
+ val_loss, val_bpb = eval_val(
+ args,
+ compiled_loss,
+ val_tokens,
+ base_bytes_lut,
+ has_leading_space_lut,
+ is_boundary_token_lut,
+ )
+ train_time_ms += 1000.0 * (time.perf_counter() - t0)
+ if step % 25 == 0 or last_step:
+ log(
+ f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} "
+ f"train_time:{train_time_ms:.0f}ms step_avg:{train_time_ms / max(step, 1):.2f}ms"
+ )
+ t0 = time.perf_counter()
+ if last_step:
+ if stop_after_step is not None and step < args.iterations:
+ log(f"stopping_early: wallclock_cap train_time:{train_time_ms:.0f}ms step:{step}/{args.iterations}")
+ break
+
+ lr_mul = args.lr_mul(step, train_time_ms + 1000.0 * (time.perf_counter() - t0))
+ step_t0 = time.perf_counter()
+
+ accum: dict[str, mx.array] | None = None
+ train_loss = mx.array(0.0, dtype=mx.float32)
+ grad_scale = 1.0 / args.grad_accum_steps
+ for _ in range(args.grad_accum_steps):
+ loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad)
+ accum = accumulate_flat_grads(accum, grads, grad_scale)
+ train_loss = train_loss + loss.astype(mx.float32) * grad_scale
+
+ grads = tree_unflatten(list(accum.items()))
+ grads = clip_grad_tree(grads, args.grad_clip_norm)
+ train_loss_value = float(train_loss.item())
+ opt.step(model, grads, step=step, lr_mul=lr_mul)
+ mx.synchronize()
+
+ step_ms = 1000.0 * (time.perf_counter() - step_t0)
+ approx_train_time_ms = train_time_ms + 1000.0 * (time.perf_counter() - t0)
+ tok_s = args.train_batch_tokens / (step_ms / 1000.0)
+ step += 1
+ if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None):
+ log(
+ f"step:{step}/{args.iterations} train_loss:{train_loss_value:.4f} "
+ f"train_time:{approx_train_time_ms:.0f}ms step_avg:{approx_train_time_ms / step:.2f}ms tok_s:{tok_s:.0f}"
+ )
+ if max_wallclock_ms is not None and stop_after_step is None and approx_train_time_ms >= max_wallclock_ms:
+ stop_after_step = step
+
+ # ==============================================================================
+ # FINAL SERIALIZATION + QUANTIZED ROUNDTRIP EVAL
+ # ==============================================================================
+ # We always write a raw artifact and a quantized artifact, then validate the
+ # quantized roundtrip directly by loading the dequantized tensors back into the
+ # model and running one final validation pass.
+ out_path = out_dir / f"{args.run_id}_mlx_model.npz"
+ flat_state = {k: v for k, v in tree_flatten(model.state)}
+ mx.savez(str(out_path), **flat_state)
+ log(f"saved_model:{out_path} bytes:{out_path.stat().st_size}")
+
+ quant_obj, quant_stats = quantize_state_dict_int8(flat_state)
+ quant_raw = pickle.dumps(quant_obj, protocol=pickle.HIGHEST_PROTOCOL)
+ quant_blob = zlib.compress(quant_raw, level=9)
+ quant_serialized_bytes = len(quant_raw)
+ quant_path = out_dir / f"{args.run_id}_mlx_model.int8.ptz"
+ with quant_path.open("wb") as f:
+ f.write(quant_blob)
+ quant_file_bytes = quant_path.stat().st_size
+ ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1)
+ log(
+ f"serialized_model_int8_zlib:{quant_file_bytes} bytes "
+ f"(payload:{quant_stats['int8_payload_bytes']} raw_pickle:{quant_serialized_bytes} payload_ratio:{ratio:.2f}x)"
+ )
+
+ with quant_path.open("rb") as f:
+ quant_blob_disk = f.read()
+ quant_flat = dequantize_state_dict_int8(pickle.loads(zlib.decompress(quant_blob_disk)))
+ model.update(tree_unflatten(list(quant_flat.items())))
+ q_t0 = time.perf_counter()
+ q_val_loss, q_val_bpb = eval_val(
+ args,
+ compiled_loss,
+ val_tokens,
+ base_bytes_lut,
+ has_leading_space_lut,
+ is_boundary_token_lut,
+ )
+ q_eval_ms = 1000.0 * (time.perf_counter() - q_t0)
+ log(f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} eval_time:{q_eval_ms:.0f}ms")
+ log(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}")
+
+
+if __name__ == "__main__":
+ main()