From a15093adad328a650d421e53c078cbd2c45beb0e Mon Sep 17 00:00:00 2001 From: Will DePue Date: Wed, 18 Mar 2026 09:32:01 -0700 Subject: Launch snapshot --- train_gpt_mlx.py | 1088 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1088 insertions(+) create mode 100644 train_gpt_mlx.py (limited to 'train_gpt_mlx.py') diff --git a/train_gpt_mlx.py b/train_gpt_mlx.py new file mode 100644 index 0000000..bf7c7d1 --- /dev/null +++ b/train_gpt_mlx.py @@ -0,0 +1,1088 @@ +#!/usr/bin/env python3 +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. +""" +from __future__ import annotations + +import glob +import json +import math +import os +import pickle +import sys +import time +import uuid +import zlib +from collections.abc import Callable +from pathlib import Path + +import numpy as np +import sentencepiece as spm + +import mlx.core as mx +import mlx.nn as nn +import mlx.optimizers as optim +from mlx.utils import tree_flatten, tree_unflatten + +# ============================================================================== +# SHARD FORMAT + COMPUTE DTYPE +# ============================================================================== + +COMPUTE_DTYPE = mx.bfloat16 + +# ============================================================================== +# HYPERPARAMETERS +# ============================================================================== +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap +class Hyperparameters: + # Data / tokenizer. + data_path: str = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + tokenizer_path: str = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id: str = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed: int = int(os.environ.get("SEED", 1337)) + + # Training loop. These defaults now mirror train_gpt.py on a single process. + iterations: int = int(os.environ.get("ITERATIONS", 20_000)) + val_loss_every: int = int(os.environ.get("VAL_LOSS_EVERY", 0)) + # Validation always uses the full fineweb_val split. + val_batch_size: int = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + train_log_every: int = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + train_batch_tokens: int = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + grad_accum_steps: int = int(os.environ.get("GRAD_ACCUM_STEPS", 8)) + train_seq_len: int = int(os.environ.get("TRAIN_SEQ_LEN", os.environ.get("TRAIN_MAX_SEQ_LEN", 1024))) + # Chunk each logical MLX microbatch into smaller sub-batches to reduce peak + # memory pressure without changing the effective optimizer batch. + mlx_max_microbatch_tokens: int = int(os.environ.get("MLX_MAX_MICROBATCH_TOKENS", 8_192)) + warmup_steps: int = int(os.environ.get("WARMUP_STEPS", 20)) + warmdown_iters: int = int(os.environ.get("WARMDOWN_ITERS", 1200)) + max_wallclock_seconds: float = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + + # Model (defaults match the current baseline setup). + vocab_size: int = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers: int = int(os.environ.get("NUM_LAYERS", 9)) + model_dim: int = int(os.environ.get("MODEL_DIM", 512)) + num_heads: int = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads: int = int(os.environ.get("NUM_KV_HEADS", 4)) + mlp_mult: int = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings: bool = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + tied_embed_init_std: float = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + logit_chunk_tokens: int = int(os.environ.get("LOGIT_CHUNK_TOKENS", 0)) + logit_softcap: float = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + rope_base: float = float(os.environ.get("ROPE_BASE", 10000.0)) + qk_gain_init: float = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Optimizer. We keep the same per-group defaults as train_gpt.py. + beta1: float = float(os.environ.get("BETA1", 0.9)) + beta2: float = float(os.environ.get("BETA2", 0.95)) + adam_eps: float = float(os.environ.get("ADAM_EPS", 1e-8)) + tied_embed_lr: float = float(os.environ.get("TIED_EMBED_LR", 0.05)) + matrix_lr: float = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr: float = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum: float = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps: int = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start: float = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps: int = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + grad_clip_norm: float = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + out_dir: str = os.environ.get("OUT_DIR", "logs") + + @property + def train_files(self) -> str: + return f"{self.data_path}/fineweb_train_*.bin" + + @property + def val_files(self) -> str: + return f"{self.data_path}/fineweb_val_*.bin" + + @property + def microbatch_tokens(self) -> int: + return self.train_batch_tokens // self.grad_accum_steps + + def lr_mul(self, step: int, elapsed_ms: float) -> float: + if self.warmdown_iters <= 0: + return 1.0 + if self.max_wallclock_seconds <= 0: + warmdown_start = max(self.iterations - self.warmdown_iters, 0) + return max((self.iterations - step) / max(self.warmdown_iters, 1), 0.0) if warmdown_start <= step < self.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = self.warmdown_iters * step_ms + remaining_ms = max(1000.0 * self.max_wallclock_seconds - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) + + +def token_chunks(total_tokens: int, seq_len: int, max_chunk_tokens: int) -> list[int]: + usable_total = (total_tokens // seq_len) * seq_len + if usable_total <= 0: + raise ValueError(f"token budget too small for seq_len={seq_len}") + usable_chunk = max((max_chunk_tokens // seq_len) * seq_len, seq_len) + chunks: list[int] = [] + remaining = usable_total + while remaining > 0: + chunk = min(remaining, usable_chunk) + chunks.append(chunk) + remaining -= chunk + return chunks + + +def accumulate_flat_grads( + accum: dict[str, mx.array] | None, + grads_tree: dict, + scale: float, +) -> dict[str, mx.array]: + flat = dict(tree_flatten(grads_tree)) + if accum is None: + return {k: g * scale for k, g in flat.items()} + for k, g in flat.items(): + accum[k] = accum[k] + g * scale + return accum + + +# ============================================================================== +# MATH HELPERS +# ============================================================================== + +def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array: + return (x * mx.rsqrt(mx.mean(x * x, axis=-1, keepdims=True) + eps)).astype(x.dtype) + + +def zeropower_newtonschulz5(g: mx.array, steps: int, eps: float = 1e-7) -> mx.array: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + # Background on Muon: https://kellerjordan.github.io/posts/muon/ + a, b, c = 3.4445, -4.7750, 2.0315 + x = g.astype(mx.float32) + x = x / (mx.sqrt(mx.sum(x * x)) + eps) + transposed = x.shape[0] > x.shape[1] + if transposed: + x = x.T + for _ in range(steps): + a_mat = x @ x.T + b_mat = b * a_mat + c * (a_mat @ a_mat) + x = a * x + b_mat @ x + if transposed: + x = x.T + return x.astype(g.dtype) + + +def load_data_shard(path: Path) -> np.ndarray: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + if self.file_idx == 0: + self.epoch += 1 + if self.log_fn is not None: + self.log_fn( + f"WARNING: starting epoch:{self.epoch} " + f"dataset:{self.dataset_name} train_shards:{len(self.files)}" + ) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> np.ndarray: + chunks: list[np.ndarray] = [] + left = n + while left > 0: + if self.pos >= self.tokens.size: + self.next_file() + k = min(left, int(self.tokens.size - self.pos)) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + left -= k + return chunks[0] if len(chunks) == 1 else np.concatenate(chunks, axis=0) + + +class TokenLoader: + def __init__( + self, + pattern: str, + log_fn: Callable[[str], None] | None = None, + dataset_name: str = "", + ): + self.stream = TokenStream(pattern, log_fn=log_fn, dataset_name=dataset_name) + + def next_batch(self, batch_tokens: int, seq_len: int) -> tuple[mx.array, mx.array]: + usable = (batch_tokens // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"token budget too small for seq_len={seq_len}") + chunk = self.stream.take(usable + 1) + x = chunk[:-1].reshape(-1, seq_len) + y = chunk[1:].reshape(-1, seq_len) + return mx.array(x, dtype=mx.int32), mx.array(y, dtype=mx.int32) + + +# ============================================================================== +# MODEL BLOCKS +# ============================================================================== + +class CastedLinear(nn.Module): + def __init__(self, in_dim: int, out_dim: int): + super().__init__() + self.weight = nn.Linear(in_dim, out_dim, bias=False).weight.astype(mx.float32) + + def __call__(self, x: mx.array) -> mx.array: + return x @ self.weight.astype(x.dtype).T + + +class RMSNormNoWeight(nn.Module): + # MLX module wrapper around the functional RMSNorm helper so it composes nicely in blocks. + def __call__(self, x: mx.array) -> mx.array: + return rms_norm(x) + + +class CausalSelfAttention(nn.Module): + # - separate q/k/v projections + # - RMSNorm on q and k before attention + # - RoPE on q and k + # - causal masked SDPA + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim) + self.c_k = CastedLinear(dim, kv_dim) + self.c_v = CastedLinear(dim, kv_dim) + self.proj = CastedLinear(dim, dim) + self.q_gain = mx.ones((num_heads,), dtype=mx.float32) * qk_gain_init + self.rope = nn.RoPE(self.head_dim, traditional=False, base=rope_base) + self.scale = self.head_dim ** -0.5 + + def __call__(self, x: mx.array) -> mx.array: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) + + q = self.rope(rms_norm(q).astype(COMPUTE_DTYPE)) + k = self.rope(rms_norm(k).astype(COMPUTE_DTYPE)) + q = q * self.q_gain.astype(q.dtype)[None, :, None, None] + y = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask="causal") + y = y.transpose(0, 2, 1, 3).reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # Baseline MLP uses relu^2 instead of GELU/SiLU. It is cheap and works well in this setup. + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = dim * mlp_mult + self.fc = CastedLinear(dim, hidden) + self.proj = CastedLinear(hidden, dim) + + def __call__(self, x: mx.array) -> mx.array: + x = nn.relu(self.fc(x)) + return self.proj(x * x) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNormNoWeight() + self.mlp_norm = RMSNormNoWeight() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = mx.ones((dim,), dtype=mx.float32) + self.mlp_scale = mx.ones((dim,), dtype=mx.float32) + self.resid_mix = mx.array(np.stack((np.ones((dim,), dtype=np.float32), np.zeros((dim,), dtype=np.float32)))) + + def __call__(self, x: mx.array, x0: mx.array) -> mx.array: + mix = self.resid_mix.astype(x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.astype(x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.astype(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + # - token embedding + RMSNorm + # - encoder half accumulates skip tensors + # - decoder half consumes reversed skips with learned skip_weights + # - tied embeddings for the LM head (the baseline default setup) + def __init__(self, vocab_size: int, num_layers: int, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, + logit_chunk_tokens: int, logit_softcap: float, rope_base: float, tied_embed_init_std: float, + qk_gain_init: float): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.logit_chunk_tokens = logit_chunk_tokens + self.logit_softcap = logit_softcap + + self.tok_emb = nn.Embedding(vocab_size, dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = mx.ones((self.num_skip_weights, dim), dtype=mx.float32) + self.blocks = [ + Block(dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) + for i in range(num_layers) + ] + self.final_norm = RMSNormNoWeight() + + for b in self.blocks: + b.attn.proj.weight = mx.zeros_like(b.attn.proj.weight) + b.mlp.proj.weight = mx.zeros_like(b.mlp.proj.weight) + self.tok_emb.weight = ( + mx.random.normal(self.tok_emb.weight.shape, dtype=mx.float32) * tied_embed_init_std + ).astype(COMPUTE_DTYPE) + + def softcap(self, logits: mx.array) -> mx.array: + c = self.logit_softcap + return c * mx.tanh(logits / c) + + def __call__(self, input_ids: mx.array) -> mx.array: + x = rms_norm(self.tok_emb(input_ids).astype(COMPUTE_DTYPE)) + x0 = x + skips: list[mx.array] = [] + + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + # Odd layer counts have one more decoder block than encoder block. The baseline only + # applies a skip connection when one exists, then runs the remaining decoder block(s) + # without an added skip. + if skips: + x = x + self.skip_weights[i].astype(x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + return self.final_norm(x) + + def loss(self, input_ids: mx.array, target_ids: mx.array) -> mx.array: + # Cross-entropy over flattened tokens. We keep optional logit chunking because it is a useful + # memory knob on Macs, but the common path is chunk_tokens=0 (single matmul + CE). + x = self(input_ids).reshape(-1, self.tok_emb.weight.shape[1]) + y = target_ids.reshape(-1) + if self.logit_chunk_tokens <= 0 or x.shape[0] <= self.logit_chunk_tokens: + logits_proj = x @ self.tok_emb.weight.astype(x.dtype).T + logits = self.softcap(logits_proj) + return nn.losses.cross_entropy(logits.astype(mx.float32), y, reduction="mean") + + loss_sum = mx.array(0.0, dtype=mx.float32) + n = int(x.shape[0]) + for s in range(0, n, self.logit_chunk_tokens): + e = min(s + self.logit_chunk_tokens, n) + logits_proj = x[s:e] @ self.tok_emb.weight.astype(x.dtype).T + logits = self.softcap(logits_proj) + loss_sum = loss_sum + nn.losses.cross_entropy(logits.astype(mx.float32), y[s:e], reduction="sum") + return loss_sum / float(n) + +# ============================================================================== +# OPTIMIZERS (MUON + ADAM SPLIT) +# ============================================================================== +class Muon: + # Muon applies SGD-momentum to matrix gradients, then orthogonalizes the result before the + # parameter update. + def __init__(self, keys: list[str], params: dict[str, mx.array], args: Hyperparameters): + self.keys = keys + self.args = args + self.buffers = {k: mx.zeros_like(params[k]) for k in keys} + + def step(self, params: dict[str, mx.array], grads: dict[str, mx.array], step: int, lr_mul: float) -> dict[str, mx.array]: + if self.args.muon_momentum_warmup_steps: + t = min(step / self.args.muon_momentum_warmup_steps, 1.0) + momentum = (1.0 - t) * self.args.muon_momentum_warmup_start + t * self.args.muon_momentum + else: + momentum = self.args.muon_momentum + lr = self.args.matrix_lr * lr_mul + out: dict[str, mx.array] = {} + for k in self.keys: + p = params[k] + g = grads[k] + buf = momentum * self.buffers[k] + g + self.buffers[k] = buf + g_eff = g + momentum * buf + g_ortho = zeropower_newtonschulz5(g_eff, self.args.muon_backend_steps) + scale = math.sqrt(max(1.0, float(p.shape[0]) / float(p.shape[1]))) + out[k] = p - lr * (g_ortho * scale).astype(p.dtype) + return out + + +class SplitOptimizers: + # - embeddings: Adam with the tied-embedding LR + # - block matrices (2D): Muon + # - block scalars + skip weights: Adam + # This preserves the high-level optimization behavior even though MLX internals differ. + def __init__(self, model: GPT, args: Hyperparameters): + self.args = args + params = dict(tree_flatten(model.parameters())) + self.embed_key = "tok_emb.weight" + self.matrix_keys = [ + k + for k, p in params.items() + if k.startswith("blocks.") and p.ndim == 2 and not any(pattern in k for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + self.scalar_keys = [ + k + for k, p in params.items() + if k == "skip_weights" or (k.startswith("blocks.") and (p.ndim < 2 or any(pattern in k for pattern in CONTROL_TENSOR_NAME_PATTERNS))) + ] + + self.muon = Muon(self.matrix_keys, params, args) + self.adam_embed = optim.Adam( + learning_rate=args.tied_embed_lr, + betas=[args.beta1, args.beta2], + eps=args.adam_eps, + bias_correction=True, + ) + self.adam_scalar = optim.Adam( + learning_rate=args.scalar_lr, + betas=[args.beta1, args.beta2], + eps=args.adam_eps, + bias_correction=True, + ) + + def step(self, model: GPT, grads_tree: dict, step: int, lr_mul: float) -> None: + params = dict(tree_flatten(model.parameters())) + grads = dict(tree_flatten(grads_tree)) + updated = dict(params) + + updated.update(self.muon.step(params, grads, step=step, lr_mul=lr_mul)) + + self.adam_embed.learning_rate = self.args.tied_embed_lr * lr_mul + updated.update( + self.adam_embed.apply_gradients( + {self.embed_key: grads[self.embed_key]}, + {self.embed_key: params[self.embed_key]}, + ) + ) + + self.adam_scalar.learning_rate = self.args.scalar_lr * lr_mul + scalar_grads = {k: grads[k] for k in self.scalar_keys} + scalar_params = {k: params[k] for k in self.scalar_keys} + updated.update(self.adam_scalar.apply_gradients(scalar_grads, scalar_params)) + + model.update(tree_unflatten(list(updated.items()))) + +# ============================================================================== +# QUANTIZATION (INT8 + ZLIB) +# ============================================================================== +# - per-row int8 for 2D float tensors +# - per-tensor int8 for other float tensors +# - fp16 passthrough for small float tensors +# - exact passthrough for non-floats + +MX_DTYPE_FROM_NAME = { + "float32": mx.float32, + "float16": mx.float16, + "bfloat16": mx.bfloat16, +} + +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = np.float16 +INT8_PER_ROW_SCALE_DTYPE = np.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + + +def _np_float32(arr: mx.array) -> np.ndarray: + return np.array(arr.astype(mx.float32), dtype=np.float32, copy=False) + + +def keep_float_array(name: str, arr: mx.array, passthrough_orig_dtypes: dict[str, str]) -> np.ndarray: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return np.ascontiguousarray(_np_float32(arr)) + if arr.dtype in {mx.float32, mx.bfloat16}: + passthrough_orig_dtypes[name] = str(arr.dtype).split(".")[-1] + return np.ascontiguousarray(np.array(arr.astype(mx.float16), dtype=INT8_KEEP_FLOAT_STORE_DTYPE, copy=False)) + return np.ascontiguousarray(np.array(arr, copy=True)) + + +def quantize_float_array(arr: mx.array) -> tuple[np.ndarray, np.ndarray]: + f32 = _np_float32(arr) + if f32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = np.quantile(np.abs(f32), INT8_CLIP_Q, axis=1) if f32.size else np.empty((f32.shape[0],), dtype=np.float32) + clipped = np.clip(f32, -clip_abs[:, None], clip_abs[:, None]) + scale = np.maximum(clip_abs / 127.0, 1.0 / 127.0).astype(np.float32, copy=False) + q = np.clip(np.round(clipped / scale[:, None]), -127, 127).astype(np.int8, copy=False) + return np.ascontiguousarray(q), np.ascontiguousarray(scale.astype(INT8_PER_ROW_SCALE_DTYPE, copy=False)) + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(np.quantile(np.abs(f32).reshape(-1), INT8_CLIP_Q)) if f32.size else 0.0 + scale = np.array(clip_abs / 127.0 if clip_abs > 0.0 else 1.0, dtype=np.float32) + q = np.clip(np.round(np.clip(f32, -clip_abs, clip_abs) / scale), -127, 127).astype(np.int8, copy=False) + return np.ascontiguousarray(q), scale + + +def quantize_state_dict_int8(flat_state: dict[str, mx.array]) -> tuple[dict[str, object], dict[str, int]]: + quantized: dict[str, np.ndarray] = {} + scales: dict[str, np.ndarray] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, np.ndarray] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, arr in flat_state.items(): + stats["param_count"] += int(arr.size) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += int(arr.nbytes) + if not mx.issubdtype(arr.dtype, mx.floating): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = np.ascontiguousarray(np.array(arr)) + stats["int8_payload_bytes"] += int(passthrough[name].nbytes) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if int(arr.size) <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_array(name, arr, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += int(kept.nbytes) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_array(arr) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(arr.dtype).split(".")[-1] + stats["int8_payload_bytes"] += int(q.nbytes + s.nbytes) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def dequantize_state_dict_int8(quant_obj: dict[str, object]) -> dict[str, mx.array]: + out: dict[str, mx.array] = {} + qmeta = quant_obj.get("qmeta", {}) + passthrough_orig_dtypes = quant_obj.get("passthrough_orig_dtypes", {}) + for name, q in quant_obj["quantized"].items(): + q_np = np.asarray(q, dtype=np.int8) + dtype_name = quant_obj["dtypes"][name] + scale = np.asarray(quant_obj["scales"][name], dtype=np.float32) + if qmeta.get(name, {}).get("scheme") == "per_row" or scale.ndim > 0: + # Broadcast the saved row scale back across trailing dimensions. + out_arr = q_np.astype(np.float32) * scale.reshape((q_np.shape[0],) + (1,) * (q_np.ndim - 1)) + else: + out_arr = q_np.astype(np.float32) * float(scale) + out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[dtype_name]) + for name, arr in quant_obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_arr = np.array(arr, copy=True) + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[orig_dtype]) + else: + out[name] = mx.array(out_arr) + return out + + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_lut = np.zeros((table_size,), dtype=np.int16) + has_leading_space_lut = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_lut = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_lut[token_id] = False + if sp.is_byte(token_id): + base_bytes_lut[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_lut[token_id] = True + piece = piece[1:] + base_bytes_lut[token_id] = len(piece.encode("utf-8")) + return base_bytes_lut, has_leading_space_lut, is_boundary_token_lut + + +def validate_dataset_tokenizer_pair(data_path: str, tokenizer_path: str) -> tuple[str, int, int | None]: + # The shard directory and tokenizer are coupled: val_bpb is only meaningful if we + # decode bytes with the exact tokenizer that produced the shards. The manifest + # lets the training script fail fast on accidental dataset/tokenizer mismatches. + dataset_dir = Path(data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + if len(dataset_dir.parents) < 2: + return dataset_dir.name, actual_train_files, None + manifest_path = dataset_dir.parents[1] / "manifest.json" + if not manifest_path.is_file(): + return dataset_dir.name, actual_train_files, None + + manifest = json.loads(manifest_path.read_text(encoding="utf-8")) + dataset_entry = next((x for x in manifest.get("datasets", []) if x.get("name") == dataset_dir.name), None) + if dataset_entry is None: + return dataset_dir.name, actual_train_files, None + + tokenizer_name = dataset_entry.get("tokenizer_name") + tokenizer_entry = ( + next((x for x in manifest.get("tokenizers", []) if x.get("name") == tokenizer_name), None) + if tokenizer_name + else None + ) + expected_name = Path((tokenizer_entry or {}).get("model_path") or (tokenizer_entry or {}).get("path") or "").name + if expected_name and Path(tokenizer_path).name != expected_name: + raise ValueError(f"{dataset_dir.name} expects tokenizer {expected_name}, got {Path(tokenizer_path).name}") + expected_train_files = (dataset_entry.get("stats") or {}).get("files_train") + if expected_train_files is not None: + expected_train_files = int(expected_train_files) + if actual_train_files > expected_train_files: + raise ValueError( + f"{dataset_dir.name} has more train shards than expected: found {actual_train_files}, " + f"manifest says {expected_train_files}" + ) + return dataset_dir.name, actual_train_files, expected_train_files + + +def load_validation_tokens(pattern: str, seq_len: int) -> np.ndarray: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = np.ascontiguousarray(np.concatenate([load_data_shard(file) for file in files], axis=0)) + usable = ((tokens.size - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def loss_and_grad_chunked( + args: Hyperparameters, + train_loader: TokenLoader, + compiled_loss_and_grad, +) -> tuple[mx.array, dict]: + chunk_sizes = token_chunks(args.microbatch_tokens, args.train_seq_len, args.mlx_max_microbatch_tokens) + total_tokens = float(sum(chunk_sizes)) + loss_value = mx.array(0.0, dtype=mx.float32) + grad_accum: dict[str, mx.array] | None = None + for chunk_tokens in chunk_sizes: + x, y = train_loader.next_batch(chunk_tokens, args.train_seq_len) + loss, grads = compiled_loss_and_grad(x, y) + scale = float(y.size) / total_tokens + loss_value = loss_value + loss.astype(mx.float32) * scale + grad_accum = accumulate_flat_grads(grad_accum, grads, scale) + return loss_value, tree_unflatten(list(grad_accum.items())) + + +def eval_val( + args: Hyperparameters, + compiled_loss, + val_tokens: np.ndarray, + base_bytes_lut: np.ndarray, + has_leading_space_lut: np.ndarray, + is_boundary_token_lut: np.ndarray, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + val_batch_tokens = args.val_batch_size // args.grad_accum_steps + if val_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, GRAD_ACCUM_STEPS={args.grad_accum_steps}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + val_batch_seqs = val_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.size - 1) // args.train_seq_len + total_loss = mx.array(0.0, dtype=mx.float32) + total_tokens = 0.0 + total_bytes = 0.0 + for batch_seq_start in range(0, total_seqs, val_batch_seqs): + batch_seq_end = min(batch_seq_start + val_batch_seqs, total_seqs) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + chunk = val_tokens[raw_start:raw_end] + x_np = chunk[:-1].reshape(-1, args.train_seq_len) + y_np = chunk[1:].reshape(-1, args.train_seq_len) + x = mx.array(x_np, dtype=mx.int32) + y = mx.array(y_np, dtype=mx.int32) + chunk_token_count = float(y.size) + total_loss = total_loss + compiled_loss(x, y).astype(mx.float32) * chunk_token_count + prev_ids = x_np.reshape(-1) + tgt_ids = y_np.reshape(-1) + bytes_np = base_bytes_lut[tgt_ids].astype(np.int16, copy=True) + bytes_np += ( + has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids] + ).astype(np.int16, copy=False) + total_tokens += chunk_token_count + total_bytes += float(bytes_np.astype(np.float64).sum()) + total_loss = total_loss / total_tokens + mx.eval(total_loss) + val_loss = float(total_loss.item()) + bits_per_token = val_loss / math.log(2.0) + val_bpb = bits_per_token * (total_tokens / total_bytes) + return val_loss, val_bpb + +# ----------------------------- +# TRAINING +# ----------------------------- + +def clip_grad_tree(grads_tree: dict, max_norm: float) -> dict: + if max_norm <= 0: + return grads_tree + flat = dict(tree_flatten(grads_tree)) + total_sq = 0.0 + for grad in flat.values(): + total_sq += float(np.sum(np.square(_np_float32(grad)), dtype=np.float64)) + if total_sq <= 0.0: + return grads_tree + total_norm = math.sqrt(total_sq) + if total_norm <= max_norm: + return grads_tree + scale = max_norm / (total_norm + 1e-12) + return tree_unflatten([(k, g * scale) for k, g in flat.items()]) + + +def main() -> None: + # ============================================================================== + # TOKENIZER + VALIDATION METRIC SETUP + # ============================================================================== + args = Hyperparameters() + out_dir = Path(args.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + logfile = out_dir / f"{args.run_id}.txt" + print(logfile) + + def log(msg: str, console: bool = True) -> None: + if console: + print(msg) + with logfile.open("a", encoding="utf-8") as f: + print(msg, file=f) + + code = Path(__file__).read_text(encoding="utf-8") + log(code, console=False) + log("=" * 100, console=False) + log(f"Running Python {sys.version}", console=False) + log(f"Running MLX {mx.__version__}", console=False) + log("=" * 100, console=False) + + if not args.tie_embeddings: + raise NotImplementedError("train_gpt_mlx.py only supports tied embeddings") + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"TOKENIZER_PATH must point to a SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_name, actual_train_files, expected_train_files = validate_dataset_tokenizer_pair( + args.data_path, + args.tokenizer_path, + ) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size + ) + + # ============================================================================== + # TRAINING SETUP + # ============================================================================== + mx.random.seed(args.seed) + + train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name) + + # ============================================================================== + # MODEL + OPTIMIZER SETUP + # ============================================================================== + model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + logit_chunk_tokens=args.logit_chunk_tokens, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + tied_embed_init_std=args.tied_embed_init_std, + qk_gain_init=args.qk_gain_init, + ) + opt = SplitOptimizers(model, args) + + # ============================================================================== + # COMPILED TRAIN / EVAL FUNCTIONS (MLX) + # ============================================================================== + # The crucial MLX detail is capture scope: this model contains non-trainable arrays too (for example + # inside RoPE modules), so compiling only against trainable parameters throws "uncaptured inputs". + # Compiling the model-bound functions and capturing the full model state fixes that while still + # returning gradients only for trainable parameters via nn.value_and_grad(...). + compiled_loss = mx.compile(lambda x, y: model.loss(x, y), inputs=model.state, outputs=model.state) + compiled_loss_and_grad = mx.compile( + nn.value_and_grad(model, lambda x, y: model.loss(x, y)), + inputs=model.state, + outputs=model.state, + ) + + # Print config once so logs are self-describing. + n_params = sum(int(np.prod(p.shape)) for _, p in tree_flatten(model.parameters())) + log(f"run_id:{args.run_id}") + log(f"mlx_version:{mx.__version__}") + log(f"train_loader:shards pattern={args.train_files}") + log(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.size - 1}") + if expected_train_files is None: + log(f"train_loader:dataset:{dataset_name} train_shards:{actual_train_files}") + elif actual_train_files < expected_train_files: + log( + f"WARNING: train_loader:subset dataset:{dataset_name} " + f"train_shards:{actual_train_files}/{expected_train_files} " + f"new epochs will arrive sooner than the full dataset" + ) + else: + log(f"train_loader:dataset:{dataset_name} train_shards:{actual_train_files}/{expected_train_files}") + log(f"tokenizer_path:{args.tokenizer_path}") + log( + f"model_params:{n_params} vocab_size:{args.vocab_size} layers:{args.num_layers} " + f"dim:{args.model_dim} heads:{args.num_heads} kv_heads:{args.num_kv_heads} " + f"seq_len:{args.train_seq_len} tie_embeddings:{args.tie_embeddings}" + ) + log( + f"iterations:{args.iterations} train_batch_tokens:{args.train_batch_tokens} grad_accum_steps:{args.grad_accum_steps} " + f"microbatch_tokens:{args.microbatch_tokens} microbatch_batch_size:{args.microbatch_tokens // args.train_seq_len} " + f"val_batch_size:{args.val_batch_size} " + f"warmup_steps:{args.warmup_steps} max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log(f"mlx_max_microbatch_tokens:{args.mlx_max_microbatch_tokens}") + log( + f"optimizer:muon+adam muon_matrix_params:{len(opt.matrix_keys)} scalar_params:{len(opt.scalar_keys)} " + f"embed_lr:{args.tied_embed_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr} " + f"muon_momentum:{args.muon_momentum} muon_steps:{args.muon_backend_steps}" + ) + log(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log(f"compute_dtype:{COMPUTE_DTYPE} compile:True") + log( + f"dtypes tok_emb:{model.tok_emb.weight.dtype} " + f"linear_weight:{model.blocks[0].attn.c_q.weight.dtype} " + f"skip_weights:{model.skip_weights.dtype}" + ) + + # ============================================================================== + # TRAINING LOOP + # ============================================================================== + if args.warmup_steps > 0: + # Warmup should only prime MLX compile/allocation paths. Updating parameters here forces us + # to snapshot and restore model/optimizer state, which is expensive on unified-memory Macs. + # Instead we run the real train shapes, force the loss/grads to materialize, and then reset + # the loader so measured training still starts from the true init and token window. + for warmup_step in range(args.warmup_steps): + accum: dict[str, mx.array] | None = None + warmup_loss = mx.array(0.0, dtype=mx.float32) + grad_scale = 1.0 / args.grad_accum_steps + for _ in range(args.grad_accum_steps): + warmup_loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad) + accum = accumulate_flat_grads(accum, grads, grad_scale) + mx.eval(warmup_loss, accum) + mx.synchronize() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + + # Prime the standalone eval graph once too. It is compiled separately from value_and_grad. + val_batch_tokens = args.val_batch_size // args.grad_accum_steps + if val_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, GRAD_ACCUM_STEPS={args.grad_accum_steps}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + warm_val_seqs = min(val_batch_tokens // args.train_seq_len, (val_tokens.size - 1) // args.train_seq_len) + warm_chunk = val_tokens[: warm_val_seqs * args.train_seq_len + 1] + x_val = mx.array(warm_chunk[:-1].reshape(-1, args.train_seq_len), dtype=mx.int32) + y_val = mx.array(warm_chunk[1:].reshape(-1, args.train_seq_len), dtype=mx.int32) + warm_val_loss = compiled_loss(x_val, y_val) + mx.eval(warm_val_loss) + mx.synchronize() + + train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name) + + train_time_ms = 0.0 + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + stop_after_step: int | None = None + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # Validation always scans the same fixed full validation split. + val_loss, val_bpb = eval_val( + args, + compiled_loss, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + train_time_ms += 1000.0 * (time.perf_counter() - t0) + if step % 25 == 0 or last_step: + log( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{train_time_ms:.0f}ms step_avg:{train_time_ms / max(step, 1):.2f}ms" + ) + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log(f"stopping_early: wallclock_cap train_time:{train_time_ms:.0f}ms step:{step}/{args.iterations}") + break + + lr_mul = args.lr_mul(step, train_time_ms + 1000.0 * (time.perf_counter() - t0)) + step_t0 = time.perf_counter() + + accum: dict[str, mx.array] | None = None + train_loss = mx.array(0.0, dtype=mx.float32) + grad_scale = 1.0 / args.grad_accum_steps + for _ in range(args.grad_accum_steps): + loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad) + accum = accumulate_flat_grads(accum, grads, grad_scale) + train_loss = train_loss + loss.astype(mx.float32) * grad_scale + + grads = tree_unflatten(list(accum.items())) + grads = clip_grad_tree(grads, args.grad_clip_norm) + train_loss_value = float(train_loss.item()) + opt.step(model, grads, step=step, lr_mul=lr_mul) + mx.synchronize() + + step_ms = 1000.0 * (time.perf_counter() - step_t0) + approx_train_time_ms = train_time_ms + 1000.0 * (time.perf_counter() - t0) + tok_s = args.train_batch_tokens / (step_ms / 1000.0) + step += 1 + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None): + log( + f"step:{step}/{args.iterations} train_loss:{train_loss_value:.4f} " + f"train_time:{approx_train_time_ms:.0f}ms step_avg:{approx_train_time_ms / step:.2f}ms tok_s:{tok_s:.0f}" + ) + if max_wallclock_ms is not None and stop_after_step is None and approx_train_time_ms >= max_wallclock_ms: + stop_after_step = step + + # ============================================================================== + # FINAL SERIALIZATION + QUANTIZED ROUNDTRIP EVAL + # ============================================================================== + # We always write a raw artifact and a quantized artifact, then validate the + # quantized roundtrip directly by loading the dequantized tensors back into the + # model and running one final validation pass. + out_path = out_dir / f"{args.run_id}_mlx_model.npz" + flat_state = {k: v for k, v in tree_flatten(model.state)} + mx.savez(str(out_path), **flat_state) + log(f"saved_model:{out_path} bytes:{out_path.stat().st_size}") + + quant_obj, quant_stats = quantize_state_dict_int8(flat_state) + quant_raw = pickle.dumps(quant_obj, protocol=pickle.HIGHEST_PROTOCOL) + quant_blob = zlib.compress(quant_raw, level=9) + quant_serialized_bytes = len(quant_raw) + quant_path = out_dir / f"{args.run_id}_mlx_model.int8.ptz" + with quant_path.open("wb") as f: + f.write(quant_blob) + quant_file_bytes = quant_path.stat().st_size + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log( + f"serialized_model_int8_zlib:{quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_pickle:{quant_serialized_bytes} payload_ratio:{ratio:.2f}x)" + ) + + with quant_path.open("rb") as f: + quant_blob_disk = f.read() + quant_flat = dequantize_state_dict_int8(pickle.loads(zlib.decompress(quant_blob_disk))) + model.update(tree_unflatten(list(quant_flat.items()))) + q_t0 = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + compiled_loss, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + q_eval_ms = 1000.0 * (time.perf_counter() - q_t0) + log(f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} eval_time:{q_eval_ms:.0f}ms") + log(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + +if __name__ == "__main__": + main() -- cgit v1.2.3