1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
|
#!/usr/bin/env python3
"""
The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder.
Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines.
"""
from __future__ import annotations
import glob
import json
import math
import os
import pickle
import sys
import time
import uuid
import zlib
from collections.abc import Callable
from pathlib import Path
import numpy as np
import sentencepiece as spm
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx.utils import tree_flatten, tree_unflatten
# ==============================================================================
# SHARD FORMAT + COMPUTE DTYPE
# ==============================================================================
COMPUTE_DTYPE = mx.bfloat16
# ==============================================================================
# HYPERPARAMETERS
# ==============================================================================
# Default Simple Baseline run:
# - 9 transformer blocks at width 512
# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion
# - vocab size 1024, sequence length 1024, tied embeddings
# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap
class Hyperparameters:
# Data / tokenizer.
data_path: str = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024")
tokenizer_path: str = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model")
run_id: str = os.environ.get("RUN_ID", str(uuid.uuid4()))
seed: int = int(os.environ.get("SEED", 1337))
# Training loop. These defaults now mirror train_gpt.py on a single process.
iterations: int = int(os.environ.get("ITERATIONS", 20_000))
val_loss_every: int = int(os.environ.get("VAL_LOSS_EVERY", 0))
# Validation always uses the full fineweb_val split.
val_batch_size: int = int(os.environ.get("VAL_BATCH_SIZE", 524_288))
train_log_every: int = int(os.environ.get("TRAIN_LOG_EVERY", 200))
train_batch_tokens: int = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288))
grad_accum_steps: int = int(os.environ.get("GRAD_ACCUM_STEPS", 8))
train_seq_len: int = int(os.environ.get("TRAIN_SEQ_LEN", os.environ.get("TRAIN_MAX_SEQ_LEN", 1024)))
# Chunk each logical MLX microbatch into smaller sub-batches to reduce peak
# memory pressure without changing the effective optimizer batch.
mlx_max_microbatch_tokens: int = int(os.environ.get("MLX_MAX_MICROBATCH_TOKENS", 8_192))
# Force MLX to materialize the graph after every sub-batch, preventing lazy
# graph buildup across accumulation steps. Keeps peak memory low on 16GB machines.
# Disable on 32GB+ unified memory for better throughput (MLX_EAGER_EVAL=0).
mlx_eager_eval: bool = bool(int(os.environ.get("MLX_EAGER_EVAL", "1")))
warmup_steps: int = int(os.environ.get("WARMUP_STEPS", 20))
warmdown_iters: int = int(os.environ.get("WARMDOWN_ITERS", 1200))
max_wallclock_seconds: float = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0))
# Model (defaults match the current baseline setup).
vocab_size: int = int(os.environ.get("VOCAB_SIZE", 1024))
num_layers: int = int(os.environ.get("NUM_LAYERS", 9))
model_dim: int = int(os.environ.get("MODEL_DIM", 512))
num_heads: int = int(os.environ.get("NUM_HEADS", 8))
num_kv_heads: int = int(os.environ.get("NUM_KV_HEADS", 4))
mlp_mult: int = int(os.environ.get("MLP_MULT", 2))
tie_embeddings: bool = bool(int(os.environ.get("TIE_EMBEDDINGS", "1")))
tied_embed_init_std: float = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005))
logit_chunk_tokens: int = int(os.environ.get("LOGIT_CHUNK_TOKENS", 0))
logit_softcap: float = float(os.environ.get("LOGIT_SOFTCAP", 30.0))
rope_base: float = float(os.environ.get("ROPE_BASE", 10000.0))
qk_gain_init: float = float(os.environ.get("QK_GAIN_INIT", 1.5))
# Optimizer. We keep the same per-group defaults as train_gpt.py.
beta1: float = float(os.environ.get("BETA1", 0.9))
beta2: float = float(os.environ.get("BETA2", 0.95))
adam_eps: float = float(os.environ.get("ADAM_EPS", 1e-8))
tied_embed_lr: float = float(os.environ.get("TIED_EMBED_LR", 0.05))
matrix_lr: float = float(os.environ.get("MATRIX_LR", 0.04))
scalar_lr: float = float(os.environ.get("SCALAR_LR", 0.04))
muon_momentum: float = float(os.environ.get("MUON_MOMENTUM", 0.95))
muon_backend_steps: int = int(os.environ.get("MUON_BACKEND_STEPS", 5))
muon_momentum_warmup_start: float = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85))
muon_momentum_warmup_steps: int = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500))
grad_clip_norm: float = float(os.environ.get("GRAD_CLIP_NORM", 0.0))
out_dir: str = os.environ.get("OUT_DIR", "logs")
@property
def train_files(self) -> str:
return f"{self.data_path}/fineweb_train_*.bin"
@property
def val_files(self) -> str:
return f"{self.data_path}/fineweb_val_*.bin"
@property
def microbatch_tokens(self) -> int:
return self.train_batch_tokens // self.grad_accum_steps
def lr_mul(self, step: int, elapsed_ms: float) -> float:
if self.warmdown_iters <= 0:
return 1.0
if self.max_wallclock_seconds <= 0:
warmdown_start = max(self.iterations - self.warmdown_iters, 0)
return max((self.iterations - step) / max(self.warmdown_iters, 1), 0.0) if warmdown_start <= step < self.iterations else 1.0
step_ms = elapsed_ms / max(step, 1)
warmdown_ms = self.warmdown_iters * step_ms
remaining_ms = max(1000.0 * self.max_wallclock_seconds - elapsed_ms, 0.0)
return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0
CONTROL_TENSOR_NAME_PATTERNS = tuple(
pattern
for pattern in os.environ.get(
"CONTROL_TENSOR_NAME_PATTERNS",
"attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights",
).split(",")
if pattern
)
INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple(
pattern
for pattern in os.environ.get(
"INT8_KEEP_FLOAT_FP32_NAME_PATTERNS",
",".join(CONTROL_TENSOR_NAME_PATTERNS),
).split(",")
if pattern
)
def token_chunks(total_tokens: int, seq_len: int, max_chunk_tokens: int) -> list[int]:
usable_total = (total_tokens // seq_len) * seq_len
if usable_total <= 0:
raise ValueError(f"token budget too small for seq_len={seq_len}")
usable_chunk = max((max_chunk_tokens // seq_len) * seq_len, seq_len)
chunks: list[int] = []
remaining = usable_total
while remaining > 0:
chunk = min(remaining, usable_chunk)
chunks.append(chunk)
remaining -= chunk
return chunks
def accumulate_flat_grads(
accum: dict[str, mx.array] | None,
grads_tree: dict,
scale: float,
) -> dict[str, mx.array]:
flat = dict(tree_flatten(grads_tree))
if accum is None:
return {k: g * scale for k, g in flat.items()}
for k, g in flat.items():
accum[k] = accum[k] + g * scale
return accum
# ==============================================================================
# MATH HELPERS
# ==============================================================================
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array:
return (x * mx.rsqrt(mx.mean(x * x, axis=-1, keepdims=True) + eps)).astype(x.dtype)
def zeropower_newtonschulz5(g: mx.array, steps: int, eps: float = 1e-7) -> mx.array:
# Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration.
# Muon uses this to normalize matrix-shaped gradients before applying them.
# Background on Muon: https://kellerjordan.github.io/posts/muon/
a, b, c = 3.4445, -4.7750, 2.0315
x = g.astype(mx.float32)
x = x / (mx.sqrt(mx.sum(x * x)) + eps)
transposed = x.shape[0] > x.shape[1]
if transposed:
x = x.T
for _ in range(steps):
a_mat = x @ x.T
b_mat = b * a_mat + c * (a_mat @ a_mat)
x = a * x + b_mat @ x
if transposed:
x = x.T
return x.astype(g.dtype)
def load_data_shard(path: Path) -> np.ndarray:
header_bytes = 256 * np.dtype("<i4").itemsize
token_bytes = np.dtype("<u2").itemsize
header = np.fromfile(path, dtype="<i4", count=256)
if header.size != 256 or int(header[0]) != 20240520 or int(header[1]) != 1:
raise ValueError(f"Unexpected shard header for {path}")
num_tokens = int(header[2])
if path.stat().st_size != header_bytes + num_tokens * token_bytes:
raise ValueError(f"Shard size mismatch for {path}")
tokens = np.fromfile(path, dtype="<u2", count=num_tokens, offset=header_bytes)
if tokens.size != num_tokens:
raise ValueError(f"Short read for {path}")
return tokens.astype(np.int32, copy=False)
# ==============================================================================
# TOKEN STREAMING / BATCHING
# ==============================================================================
class TokenStream:
def __init__(
self,
pattern: str,
log_fn: Callable[[str], None] | None = None,
dataset_name: str = "",
):
self.files = [Path(p) for p in sorted(glob.glob(pattern))]
if not self.files:
raise FileNotFoundError(f"No files found for pattern: {pattern}")
self.epoch = 1
self.file_idx = 0
self.log_fn = log_fn
self.dataset_name = dataset_name
self.tokens = load_data_shard(self.files[0])
self.pos = 0
def next_file(self) -> None:
self.file_idx = (self.file_idx + 1) % len(self.files)
if self.file_idx == 0:
self.epoch += 1
if self.log_fn is not None:
self.log_fn(
f"WARNING: starting epoch:{self.epoch} "
f"dataset:{self.dataset_name} train_shards:{len(self.files)}"
)
self.tokens = load_data_shard(self.files[self.file_idx])
self.pos = 0
def take(self, n: int) -> np.ndarray:
chunks: list[np.ndarray] = []
left = n
while left > 0:
if self.pos >= self.tokens.size:
self.next_file()
k = min(left, int(self.tokens.size - self.pos))
chunks.append(self.tokens[self.pos : self.pos + k])
self.pos += k
left -= k
return chunks[0] if len(chunks) == 1 else np.concatenate(chunks, axis=0)
class TokenLoader:
def __init__(
self,
pattern: str,
log_fn: Callable[[str], None] | None = None,
dataset_name: str = "",
):
self.stream = TokenStream(pattern, log_fn=log_fn, dataset_name=dataset_name)
def next_batch(self, batch_tokens: int, seq_len: int) -> tuple[mx.array, mx.array]:
usable = (batch_tokens // seq_len) * seq_len
if usable <= 0:
raise ValueError(f"token budget too small for seq_len={seq_len}")
chunk = self.stream.take(usable + 1)
x = chunk[:-1].reshape(-1, seq_len)
y = chunk[1:].reshape(-1, seq_len)
return mx.array(x, dtype=mx.int32), mx.array(y, dtype=mx.int32)
# ==============================================================================
# MODEL BLOCKS
# ==============================================================================
class CastedLinear(nn.Module):
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
self.weight = nn.Linear(in_dim, out_dim, bias=False).weight.astype(mx.float32)
def __call__(self, x: mx.array) -> mx.array:
return x @ self.weight.astype(x.dtype).T
class RMSNormNoWeight(nn.Module):
# MLX module wrapper around the functional RMSNorm helper so it composes nicely in blocks.
def __call__(self, x: mx.array) -> mx.array:
return rms_norm(x)
class CausalSelfAttention(nn.Module):
# - separate q/k/v projections
# - RMSNorm on q and k before attention
# - RoPE on q and k
# - causal masked SDPA
def __init__(
self,
dim: int,
num_heads: int,
num_kv_heads: int,
rope_base: float,
qk_gain_init: float,
):
super().__init__()
if dim % num_heads != 0:
raise ValueError("model_dim must be divisible by num_heads")
if num_heads % num_kv_heads != 0:
raise ValueError("num_heads must be divisible by num_kv_heads")
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = dim // num_heads
if self.head_dim % 2 != 0:
raise ValueError("head_dim must be even for RoPE")
kv_dim = self.num_kv_heads * self.head_dim
self.c_q = CastedLinear(dim, dim)
self.c_k = CastedLinear(dim, kv_dim)
self.c_v = CastedLinear(dim, kv_dim)
self.proj = CastedLinear(dim, dim)
self.q_gain = mx.ones((num_heads,), dtype=mx.float32) * qk_gain_init
self.rope = nn.RoPE(self.head_dim, traditional=False, base=rope_base)
self.scale = self.head_dim ** -0.5
def __call__(self, x: mx.array) -> mx.array:
bsz, seqlen, dim = x.shape
q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
q = self.rope(rms_norm(q).astype(COMPUTE_DTYPE))
k = self.rope(rms_norm(k).astype(COMPUTE_DTYPE))
q = q * self.q_gain.astype(q.dtype)[None, :, None, None]
y = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask="causal")
y = y.transpose(0, 2, 1, 3).reshape(bsz, seqlen, dim)
return self.proj(y)
class MLP(nn.Module):
# Baseline MLP uses relu^2 instead of GELU/SiLU. It is cheap and works well in this setup.
def __init__(self, dim: int, mlp_mult: int):
super().__init__()
hidden = dim * mlp_mult
self.fc = CastedLinear(dim, hidden)
self.proj = CastedLinear(hidden, dim)
def __call__(self, x: mx.array) -> mx.array:
x = nn.relu(self.fc(x))
return self.proj(x * x)
class Block(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
num_kv_heads: int,
mlp_mult: int,
rope_base: float,
qk_gain_init: float,
):
super().__init__()
self.attn_norm = RMSNormNoWeight()
self.mlp_norm = RMSNormNoWeight()
self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init)
self.mlp = MLP(dim, mlp_mult)
self.attn_scale = mx.ones((dim,), dtype=mx.float32)
self.mlp_scale = mx.ones((dim,), dtype=mx.float32)
self.resid_mix = mx.array(np.stack((np.ones((dim,), dtype=np.float32), np.zeros((dim,), dtype=np.float32))))
def __call__(self, x: mx.array, x0: mx.array) -> mx.array:
mix = self.resid_mix.astype(x.dtype)
x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0
attn_out = self.attn(self.attn_norm(x))
x = x + self.attn_scale.astype(x.dtype)[None, None, :] * attn_out
x = x + self.mlp_scale.astype(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x))
return x
class GPT(nn.Module):
# - token embedding + RMSNorm
# - encoder half accumulates skip tensors
# - decoder half consumes reversed skips with learned skip_weights
# - tied embeddings for the LM head (the baseline default setup)
def __init__(self, vocab_size: int, num_layers: int, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int,
logit_chunk_tokens: int, logit_softcap: float, rope_base: float, tied_embed_init_std: float,
qk_gain_init: float):
super().__init__()
if logit_softcap <= 0.0:
raise ValueError(f"logit_softcap must be positive, got {logit_softcap}")
self.logit_chunk_tokens = logit_chunk_tokens
self.logit_softcap = logit_softcap
self.tok_emb = nn.Embedding(vocab_size, dim)
self.num_encoder_layers = num_layers // 2
self.num_decoder_layers = num_layers - self.num_encoder_layers
self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers)
self.skip_weights = mx.ones((self.num_skip_weights, dim), dtype=mx.float32)
self.blocks = [
Block(dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init)
for i in range(num_layers)
]
self.final_norm = RMSNormNoWeight()
for b in self.blocks:
b.attn.proj.weight = mx.zeros_like(b.attn.proj.weight)
b.mlp.proj.weight = mx.zeros_like(b.mlp.proj.weight)
self.tok_emb.weight = (
mx.random.normal(self.tok_emb.weight.shape, dtype=mx.float32) * tied_embed_init_std
).astype(COMPUTE_DTYPE)
def softcap(self, logits: mx.array) -> mx.array:
c = self.logit_softcap
return c * mx.tanh(logits / c)
def __call__(self, input_ids: mx.array) -> mx.array:
x = rms_norm(self.tok_emb(input_ids).astype(COMPUTE_DTYPE))
x0 = x
skips: list[mx.array] = []
for i in range(self.num_encoder_layers):
x = self.blocks[i](x, x0)
skips.append(x)
for i in range(self.num_decoder_layers):
# Odd layer counts have one more decoder block than encoder block. The baseline only
# applies a skip connection when one exists, then runs the remaining decoder block(s)
# without an added skip.
if skips:
x = x + self.skip_weights[i].astype(x.dtype)[None, None, :] * skips.pop()
x = self.blocks[self.num_encoder_layers + i](x, x0)
return self.final_norm(x)
def loss(self, input_ids: mx.array, target_ids: mx.array) -> mx.array:
# Cross-entropy over flattened tokens. We keep optional logit chunking because it is a useful
# memory knob on Macs, but the common path is chunk_tokens=0 (single matmul + CE).
x = self(input_ids).reshape(-1, self.tok_emb.weight.shape[1])
y = target_ids.reshape(-1)
if self.logit_chunk_tokens <= 0 or x.shape[0] <= self.logit_chunk_tokens:
logits_proj = x @ self.tok_emb.weight.astype(x.dtype).T
logits = self.softcap(logits_proj)
return nn.losses.cross_entropy(logits.astype(mx.float32), y, reduction="mean")
loss_sum = mx.array(0.0, dtype=mx.float32)
n = int(x.shape[0])
for s in range(0, n, self.logit_chunk_tokens):
e = min(s + self.logit_chunk_tokens, n)
logits_proj = x[s:e] @ self.tok_emb.weight.astype(x.dtype).T
logits = self.softcap(logits_proj)
loss_sum = loss_sum + nn.losses.cross_entropy(logits.astype(mx.float32), y[s:e], reduction="sum")
return loss_sum / float(n)
# ==============================================================================
# OPTIMIZERS (MUON + ADAM SPLIT)
# ==============================================================================
class Muon:
# Muon applies SGD-momentum to matrix gradients, then orthogonalizes the result before the
# parameter update.
def __init__(self, keys: list[str], params: dict[str, mx.array], args: Hyperparameters):
self.keys = keys
self.args = args
self.buffers = {k: mx.zeros_like(params[k]) for k in keys}
def step(self, params: dict[str, mx.array], grads: dict[str, mx.array], step: int, lr_mul: float) -> dict[str, mx.array]:
if self.args.muon_momentum_warmup_steps:
t = min(step / self.args.muon_momentum_warmup_steps, 1.0)
momentum = (1.0 - t) * self.args.muon_momentum_warmup_start + t * self.args.muon_momentum
else:
momentum = self.args.muon_momentum
lr = self.args.matrix_lr * lr_mul
out: dict[str, mx.array] = {}
for k in self.keys:
p = params[k]
g = grads[k]
buf = momentum * self.buffers[k] + g
self.buffers[k] = buf
g_eff = g + momentum * buf
g_ortho = zeropower_newtonschulz5(g_eff, self.args.muon_backend_steps)
scale = math.sqrt(max(1.0, float(p.shape[0]) / float(p.shape[1])))
out[k] = p - lr * (g_ortho * scale).astype(p.dtype)
return out
class SplitOptimizers:
# - embeddings: Adam with the tied-embedding LR
# - block matrices (2D): Muon
# - block scalars + skip weights: Adam
# This preserves the high-level optimization behavior even though MLX internals differ.
def __init__(self, model: GPT, args: Hyperparameters):
self.args = args
params = dict(tree_flatten(model.parameters()))
self.embed_key = "tok_emb.weight"
self.matrix_keys = [
k
for k, p in params.items()
if k.startswith("blocks.") and p.ndim == 2 and not any(pattern in k for pattern in CONTROL_TENSOR_NAME_PATTERNS)
]
self.scalar_keys = [
k
for k, p in params.items()
if k == "skip_weights" or (k.startswith("blocks.") and (p.ndim < 2 or any(pattern in k for pattern in CONTROL_TENSOR_NAME_PATTERNS)))
]
self.muon = Muon(self.matrix_keys, params, args)
self.adam_embed = optim.Adam(
learning_rate=args.tied_embed_lr,
betas=[args.beta1, args.beta2],
eps=args.adam_eps,
bias_correction=True,
)
self.adam_scalar = optim.Adam(
learning_rate=args.scalar_lr,
betas=[args.beta1, args.beta2],
eps=args.adam_eps,
bias_correction=True,
)
def step(self, model: GPT, grads_tree: dict, step: int, lr_mul: float) -> None:
params = dict(tree_flatten(model.parameters()))
grads = dict(tree_flatten(grads_tree))
updated = dict(params)
updated.update(self.muon.step(params, grads, step=step, lr_mul=lr_mul))
self.adam_embed.learning_rate = self.args.tied_embed_lr * lr_mul
updated.update(
self.adam_embed.apply_gradients(
{self.embed_key: grads[self.embed_key]},
{self.embed_key: params[self.embed_key]},
)
)
self.adam_scalar.learning_rate = self.args.scalar_lr * lr_mul
scalar_grads = {k: grads[k] for k in self.scalar_keys}
scalar_params = {k: params[k] for k in self.scalar_keys}
updated.update(self.adam_scalar.apply_gradients(scalar_grads, scalar_params))
model.update(tree_unflatten(list(updated.items())))
# ==============================================================================
# QUANTIZATION (INT8 + ZLIB)
# ==============================================================================
# - per-row int8 for 2D float tensors
# - per-tensor int8 for other float tensors
# - fp16 passthrough for small float tensors
# - exact passthrough for non-floats
MX_DTYPE_FROM_NAME = {
"float32": mx.float32,
"float16": mx.float16,
"bfloat16": mx.bfloat16,
}
INT8_KEEP_FLOAT_MAX_NUMEL = 65_536
INT8_KEEP_FLOAT_STORE_DTYPE = np.float16
INT8_PER_ROW_SCALE_DTYPE = np.float16
INT8_CLIP_PERCENTILE = 99.99984
INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0
def _np_float32(arr: mx.array) -> np.ndarray:
return np.array(arr.astype(mx.float32), dtype=np.float32, copy=False)
def keep_float_array(name: str, arr: mx.array, passthrough_orig_dtypes: dict[str, str]) -> np.ndarray:
if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS):
return np.ascontiguousarray(_np_float32(arr))
if arr.dtype in {mx.float32, mx.bfloat16}:
passthrough_orig_dtypes[name] = str(arr.dtype).split(".")[-1]
return np.ascontiguousarray(np.array(arr.astype(mx.float16), dtype=INT8_KEEP_FLOAT_STORE_DTYPE, copy=False))
return np.ascontiguousarray(np.array(arr, copy=True))
def quantize_float_array(arr: mx.array) -> tuple[np.ndarray, np.ndarray]:
f32 = _np_float32(arr)
if f32.ndim == 2:
# Matrices get one scale per row, which usually tracks output-channel
# ranges much better than a single tensor-wide scale.
clip_abs = np.quantile(np.abs(f32), INT8_CLIP_Q, axis=1) if f32.size else np.empty((f32.shape[0],), dtype=np.float32)
clipped = np.clip(f32, -clip_abs[:, None], clip_abs[:, None])
scale = np.maximum(clip_abs / 127.0, 1.0 / 127.0).astype(np.float32, copy=False)
q = np.clip(np.round(clipped / scale[:, None]), -127, 127).astype(np.int8, copy=False)
return np.ascontiguousarray(q), np.ascontiguousarray(scale.astype(INT8_PER_ROW_SCALE_DTYPE, copy=False))
# Vectors / scalars use a simpler per-tensor scale.
clip_abs = float(np.quantile(np.abs(f32).reshape(-1), INT8_CLIP_Q)) if f32.size else 0.0
scale = np.array(clip_abs / 127.0 if clip_abs > 0.0 else 1.0, dtype=np.float32)
q = np.clip(np.round(np.clip(f32, -clip_abs, clip_abs) / scale), -127, 127).astype(np.int8, copy=False)
return np.ascontiguousarray(q), scale
def quantize_state_dict_int8(flat_state: dict[str, mx.array]) -> tuple[dict[str, object], dict[str, int]]:
quantized: dict[str, np.ndarray] = {}
scales: dict[str, np.ndarray] = {}
dtypes: dict[str, str] = {}
passthrough: dict[str, np.ndarray] = {}
passthrough_orig_dtypes: dict[str, str] = {}
qmeta: dict[str, dict[str, object]] = {}
stats = dict.fromkeys(
("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"),
0,
)
for name, arr in flat_state.items():
stats["param_count"] += int(arr.size)
stats["num_tensors"] += 1
stats["baseline_tensor_bytes"] += int(arr.nbytes)
if not mx.issubdtype(arr.dtype, mx.floating):
stats["num_nonfloat_tensors"] += 1
passthrough[name] = np.ascontiguousarray(np.array(arr))
stats["int8_payload_bytes"] += int(passthrough[name].nbytes)
continue
# Small float tensors are cheap enough to keep directly. We still downcast
# fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size.
if int(arr.size) <= INT8_KEEP_FLOAT_MAX_NUMEL:
kept = keep_float_array(name, arr, passthrough_orig_dtypes)
passthrough[name] = kept
stats["int8_payload_bytes"] += int(kept.nbytes)
continue
stats["num_float_tensors"] += 1
q, s = quantize_float_array(arr)
if s.ndim > 0:
qmeta[name] = {"scheme": "per_row", "axis": 0}
quantized[name] = q
scales[name] = s
dtypes[name] = str(arr.dtype).split(".")[-1]
stats["int8_payload_bytes"] += int(q.nbytes + s.nbytes)
obj: dict[str, object] = {
"__quant_format__": "int8_clean_per_row_v1",
"quantized": quantized,
"scales": scales,
"dtypes": dtypes,
"passthrough": passthrough,
}
if qmeta:
obj["qmeta"] = qmeta
if passthrough_orig_dtypes:
obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes
return obj, stats
def dequantize_state_dict_int8(quant_obj: dict[str, object]) -> dict[str, mx.array]:
out: dict[str, mx.array] = {}
qmeta = quant_obj.get("qmeta", {})
passthrough_orig_dtypes = quant_obj.get("passthrough_orig_dtypes", {})
for name, q in quant_obj["quantized"].items():
q_np = np.asarray(q, dtype=np.int8)
dtype_name = quant_obj["dtypes"][name]
scale = np.asarray(quant_obj["scales"][name], dtype=np.float32)
if qmeta.get(name, {}).get("scheme") == "per_row" or scale.ndim > 0:
# Broadcast the saved row scale back across trailing dimensions.
out_arr = q_np.astype(np.float32) * scale.reshape((q_np.shape[0],) + (1,) * (q_np.ndim - 1))
else:
out_arr = q_np.astype(np.float32) * float(scale)
out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[dtype_name])
for name, arr in quant_obj["passthrough"].items():
# Restore small tensors, undoing the temporary fp16 storage cast if needed.
out_arr = np.array(arr, copy=True)
orig_dtype = passthrough_orig_dtypes.get(name)
if isinstance(orig_dtype, str):
out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[orig_dtype])
else:
out[name] = mx.array(out_arr)
return out
def build_sentencepiece_luts(
sp: spm.SentencePieceProcessor, vocab_size: int
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
sp_vocab_size = int(sp.vocab_size())
table_size = max(sp_vocab_size, vocab_size)
base_bytes_lut = np.zeros((table_size,), dtype=np.int16)
has_leading_space_lut = np.zeros((table_size,), dtype=np.bool_)
is_boundary_token_lut = np.ones((table_size,), dtype=np.bool_)
for token_id in range(sp_vocab_size):
if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id):
continue
is_boundary_token_lut[token_id] = False
if sp.is_byte(token_id):
base_bytes_lut[token_id] = 1
continue
piece = sp.id_to_piece(token_id)
if piece.startswith("▁"):
has_leading_space_lut[token_id] = True
piece = piece[1:]
base_bytes_lut[token_id] = len(piece.encode("utf-8"))
return base_bytes_lut, has_leading_space_lut, is_boundary_token_lut
def validate_dataset_tokenizer_pair(data_path: str, tokenizer_path: str) -> tuple[str, int, int | None]:
# The shard directory and tokenizer are coupled: val_bpb is only meaningful if we
# decode bytes with the exact tokenizer that produced the shards. The manifest
# lets the training script fail fast on accidental dataset/tokenizer mismatches.
dataset_dir = Path(data_path).resolve()
actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin")))
if len(dataset_dir.parents) < 2:
return dataset_dir.name, actual_train_files, None
manifest_path = dataset_dir.parents[1] / "manifest.json"
if not manifest_path.is_file():
return dataset_dir.name, actual_train_files, None
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
dataset_entry = next((x for x in manifest.get("datasets", []) if x.get("name") == dataset_dir.name), None)
if dataset_entry is None:
return dataset_dir.name, actual_train_files, None
tokenizer_name = dataset_entry.get("tokenizer_name")
tokenizer_entry = (
next((x for x in manifest.get("tokenizers", []) if x.get("name") == tokenizer_name), None)
if tokenizer_name
else None
)
expected_name = Path((tokenizer_entry or {}).get("model_path") or (tokenizer_entry or {}).get("path") or "").name
if expected_name and Path(tokenizer_path).name != expected_name:
raise ValueError(f"{dataset_dir.name} expects tokenizer {expected_name}, got {Path(tokenizer_path).name}")
expected_train_files = (dataset_entry.get("stats") or {}).get("files_train")
if expected_train_files is not None:
expected_train_files = int(expected_train_files)
if actual_train_files > expected_train_files:
raise ValueError(
f"{dataset_dir.name} has more train shards than expected: found {actual_train_files}, "
f"manifest says {expected_train_files}"
)
return dataset_dir.name, actual_train_files, expected_train_files
def load_validation_tokens(pattern: str, seq_len: int) -> np.ndarray:
files = [Path(p) for p in sorted(glob.glob(pattern))]
if not files:
raise FileNotFoundError(f"No files found for pattern: {pattern}")
# The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*.
tokens = np.ascontiguousarray(np.concatenate([load_data_shard(file) for file in files], axis=0))
usable = ((tokens.size - 1) // seq_len) * seq_len
if usable <= 0:
raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}")
return tokens[: usable + 1]
def loss_and_grad_chunked(
args: Hyperparameters,
train_loader: TokenLoader,
compiled_loss_and_grad,
) -> tuple[mx.array, dict]:
chunk_sizes = token_chunks(args.microbatch_tokens, args.train_seq_len, args.mlx_max_microbatch_tokens)
total_tokens = float(sum(chunk_sizes))
loss_value = mx.array(0.0, dtype=mx.float32)
grad_accum: dict[str, mx.array] | None = None
for chunk_tokens in chunk_sizes:
x, y = train_loader.next_batch(chunk_tokens, args.train_seq_len)
loss, grads = compiled_loss_and_grad(x, y)
scale = float(y.size) / total_tokens
loss_value = loss_value + loss.astype(mx.float32) * scale
grad_accum = accumulate_flat_grads(grad_accum, grads, scale)
if args.mlx_eager_eval:
mx.eval(loss_value, grad_accum) # materialize each chunk to cap peak memory
return loss_value, tree_unflatten(list(grad_accum.items()))
def eval_val(
args: Hyperparameters,
compiled_loss,
val_tokens: np.ndarray,
base_bytes_lut: np.ndarray,
has_leading_space_lut: np.ndarray,
is_boundary_token_lut: np.ndarray,
log_fn: Callable[[str], None] | None = None,
) -> tuple[float, float]:
# Validation computes two metrics:
# - val_loss: token cross-entropy (natural log)
# - val_bpb: tokenizer-agnostic compression metric used by the challenge
val_batch_tokens = args.val_batch_size // args.grad_accum_steps
if val_batch_tokens < args.train_seq_len:
raise ValueError(
"VAL_BATCH_SIZE must provide at least one sequence; "
f"got VAL_BATCH_SIZE={args.val_batch_size}, GRAD_ACCUM_STEPS={args.grad_accum_steps}, "
f"TRAIN_SEQ_LEN={args.train_seq_len}"
)
val_batch_seqs = val_batch_tokens // args.train_seq_len
total_seqs = (val_tokens.size - 1) // args.train_seq_len
total_batches = max((total_seqs + val_batch_seqs - 1) // val_batch_seqs, 1)
total_loss_sum = 0.0
total_tokens = 0.0
total_bytes = 0.0
for batch_idx, batch_seq_start in enumerate(range(0, total_seqs, val_batch_seqs), start=1):
batch_seq_end = min(batch_seq_start + val_batch_seqs, total_seqs)
raw_start = batch_seq_start * args.train_seq_len
raw_end = batch_seq_end * args.train_seq_len + 1
chunk = val_tokens[raw_start:raw_end]
x_np = chunk[:-1].reshape(-1, args.train_seq_len)
y_np = chunk[1:].reshape(-1, args.train_seq_len)
x = mx.array(x_np, dtype=mx.int32)
y = mx.array(y_np, dtype=mx.int32)
chunk_token_count = float(y.size)
batch_loss = compiled_loss(x, y).astype(mx.float32)
mx.eval(batch_loss)
total_loss_sum += float(batch_loss.item()) * chunk_token_count
prev_ids = x_np.reshape(-1)
tgt_ids = y_np.reshape(-1)
bytes_np = base_bytes_lut[tgt_ids].astype(np.int16, copy=True)
bytes_np += (
has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]
).astype(np.int16, copy=False)
total_tokens += chunk_token_count
total_bytes += float(bytes_np.astype(np.float64).sum())
if log_fn is not None and total_batches > 1 and (
batch_idx == 1 or batch_idx == total_batches or batch_idx % 25 == 0
):
log_fn(f"val_progress:{batch_idx}/{total_batches}")
val_loss = total_loss_sum / total_tokens
bits_per_token = val_loss / math.log(2.0)
val_bpb = bits_per_token * (total_tokens / total_bytes)
return val_loss, val_bpb
# -----------------------------
# TRAINING
# -----------------------------
def clip_grad_tree(grads_tree: dict, max_norm: float) -> dict:
if max_norm <= 0:
return grads_tree
flat = dict(tree_flatten(grads_tree))
total_sq = 0.0
for grad in flat.values():
total_sq += float(np.sum(np.square(_np_float32(grad)), dtype=np.float64))
if total_sq <= 0.0:
return grads_tree
total_norm = math.sqrt(total_sq)
if total_norm <= max_norm:
return grads_tree
scale = max_norm / (total_norm + 1e-12)
return tree_unflatten([(k, g * scale) for k, g in flat.items()])
def main() -> None:
# ==============================================================================
# TOKENIZER + VALIDATION METRIC SETUP
# ==============================================================================
args = Hyperparameters()
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
logfile = out_dir / f"{args.run_id}.txt"
print(logfile)
def log(msg: str, console: bool = True) -> None:
if console:
print(msg)
with logfile.open("a", encoding="utf-8") as f:
print(msg, file=f)
code = Path(__file__).read_text(encoding="utf-8")
log(code, console=False)
log("=" * 100, console=False)
log(f"Running Python {sys.version}", console=False)
log(f"Running MLX {mx.__version__}", console=False)
log("=" * 100, console=False)
if not args.tie_embeddings:
raise NotImplementedError("train_gpt_mlx.py only supports tied embeddings")
if not args.tokenizer_path.endswith(".model"):
raise ValueError(f"TOKENIZER_PATH must point to a SentencePiece .model file: {args.tokenizer_path}")
sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path)
if int(sp.vocab_size()) != args.vocab_size:
raise ValueError(
f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}"
)
dataset_name, actual_train_files, expected_train_files = validate_dataset_tokenizer_pair(
args.data_path,
args.tokenizer_path,
)
val_tokens = load_validation_tokens(args.val_files, args.train_seq_len)
base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts(
sp, args.vocab_size
)
# ==============================================================================
# TRAINING SETUP
# ==============================================================================
mx.random.seed(args.seed)
train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name)
# ==============================================================================
# MODEL + OPTIMIZER SETUP
# ==============================================================================
model = GPT(
vocab_size=args.vocab_size,
num_layers=args.num_layers,
dim=args.model_dim,
num_heads=args.num_heads,
num_kv_heads=args.num_kv_heads,
mlp_mult=args.mlp_mult,
logit_chunk_tokens=args.logit_chunk_tokens,
logit_softcap=args.logit_softcap,
rope_base=args.rope_base,
tied_embed_init_std=args.tied_embed_init_std,
qk_gain_init=args.qk_gain_init,
)
opt = SplitOptimizers(model, args)
# ==============================================================================
# COMPILED TRAIN / EVAL FUNCTIONS (MLX)
# ==============================================================================
# The crucial MLX detail is capture scope: this model contains non-trainable arrays too (for example
# inside RoPE modules), so compiling only against trainable parameters throws "uncaptured inputs".
# Compiling the model-bound functions and capturing the full model state fixes that while still
# returning gradients only for trainable parameters via nn.value_and_grad(...).
compiled_loss = mx.compile(lambda x, y: model.loss(x, y), inputs=model.state, outputs=model.state)
compiled_loss_and_grad = mx.compile(
nn.value_and_grad(model, lambda x, y: model.loss(x, y)),
inputs=model.state,
outputs=model.state,
)
# Print config once so logs are self-describing.
n_params = sum(int(np.prod(p.shape)) for _, p in tree_flatten(model.parameters()))
log(f"run_id:{args.run_id}")
log(f"mlx_version:{mx.__version__}")
log(f"train_loader:shards pattern={args.train_files}")
log(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.size - 1}")
if expected_train_files is None:
log(f"train_loader:dataset:{dataset_name} train_shards:{actual_train_files}")
elif actual_train_files < expected_train_files:
log(
f"WARNING: train_loader:subset dataset:{dataset_name} "
f"train_shards:{actual_train_files}/{expected_train_files} "
f"new epochs will arrive sooner than the full dataset"
)
else:
log(f"train_loader:dataset:{dataset_name} train_shards:{actual_train_files}/{expected_train_files}")
log(f"tokenizer_path:{args.tokenizer_path}")
log(
f"model_params:{n_params} vocab_size:{args.vocab_size} layers:{args.num_layers} "
f"dim:{args.model_dim} heads:{args.num_heads} kv_heads:{args.num_kv_heads} "
f"seq_len:{args.train_seq_len} tie_embeddings:{args.tie_embeddings}"
)
log(
f"iterations:{args.iterations} train_batch_tokens:{args.train_batch_tokens} grad_accum_steps:{args.grad_accum_steps} "
f"microbatch_tokens:{args.microbatch_tokens} microbatch_batch_size:{args.microbatch_tokens // args.train_seq_len} "
f"val_batch_size:{args.val_batch_size} "
f"warmup_steps:{args.warmup_steps} max_wallclock_seconds:{args.max_wallclock_seconds:.3f}"
)
log(f"mlx_max_microbatch_tokens:{args.mlx_max_microbatch_tokens}")
log(
f"optimizer:muon+adam muon_matrix_params:{len(opt.matrix_keys)} scalar_params:{len(opt.scalar_keys)} "
f"embed_lr:{args.tied_embed_lr} "
f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr} "
f"muon_momentum:{args.muon_momentum} muon_steps:{args.muon_backend_steps}"
)
log(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}")
log(f"compute_dtype:{COMPUTE_DTYPE} compile:True")
log(
f"dtypes tok_emb:{model.tok_emb.weight.dtype} "
f"linear_weight:{model.blocks[0].attn.c_q.weight.dtype} "
f"skip_weights:{model.skip_weights.dtype}"
)
# ==============================================================================
# TRAINING LOOP
# ==============================================================================
if args.warmup_steps > 0:
# Warmup should only prime MLX compile/allocation paths. Updating parameters here forces us
# to snapshot and restore model/optimizer state, which is expensive on unified-memory Macs.
# Instead we run the real train shapes, force the loss/grads to materialize, and then reset
# the loader so measured training still starts from the true init and token window.
for warmup_step in range(args.warmup_steps):
accum: dict[str, mx.array] | None = None
warmup_loss = mx.array(0.0, dtype=mx.float32)
grad_scale = 1.0 / args.grad_accum_steps
for _ in range(args.grad_accum_steps):
warmup_loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad)
accum = accumulate_flat_grads(accum, grads, grad_scale)
mx.eval(warmup_loss, accum)
mx.synchronize()
if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps:
log(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}")
# Prime the standalone eval graph once too. It is compiled separately from value_and_grad.
val_batch_tokens = args.val_batch_size // args.grad_accum_steps
if val_batch_tokens < args.train_seq_len:
raise ValueError(
"VAL_BATCH_SIZE must provide at least one sequence; "
f"got VAL_BATCH_SIZE={args.val_batch_size}, GRAD_ACCUM_STEPS={args.grad_accum_steps}, "
f"TRAIN_SEQ_LEN={args.train_seq_len}"
)
warm_val_seqs = min(val_batch_tokens // args.train_seq_len, (val_tokens.size - 1) // args.train_seq_len)
warm_chunk = val_tokens[: warm_val_seqs * args.train_seq_len + 1]
x_val = mx.array(warm_chunk[:-1].reshape(-1, args.train_seq_len), dtype=mx.int32)
y_val = mx.array(warm_chunk[1:].reshape(-1, args.train_seq_len), dtype=mx.int32)
warm_val_loss = compiled_loss(x_val, y_val)
mx.eval(warm_val_loss)
mx.synchronize()
train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name)
train_time_ms = 0.0
max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None
stop_after_step: int | None = None
t0 = time.perf_counter()
step = 0
while True:
last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step)
if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0):
train_time_ms += 1000.0 * (time.perf_counter() - t0)
# Validation always scans the same fixed full validation split.
val_loss, val_bpb = eval_val(
args,
compiled_loss,
val_tokens,
base_bytes_lut,
has_leading_space_lut,
is_boundary_token_lut,
log_fn=log,
)
if step % 25 == 0 or last_step:
log(
f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} "
f"train_time:{train_time_ms:.0f}ms step_avg:{train_time_ms / max(step, 1):.2f}ms"
)
t0 = time.perf_counter()
if last_step:
if stop_after_step is not None and step < args.iterations:
log(f"stopping_early: wallclock_cap train_time:{train_time_ms:.0f}ms step:{step}/{args.iterations}")
break
lr_mul = args.lr_mul(step, train_time_ms + 1000.0 * (time.perf_counter() - t0))
step_t0 = time.perf_counter()
accum: dict[str, mx.array] | None = None
train_loss = mx.array(0.0, dtype=mx.float32)
grad_scale = 1.0 / args.grad_accum_steps
for _ in range(args.grad_accum_steps):
loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad)
accum = accumulate_flat_grads(accum, grads, grad_scale)
train_loss = train_loss + loss.astype(mx.float32) * grad_scale
if args.mlx_eager_eval:
mx.eval(train_loss, accum) # materialize each microbatch to cap peak memory
grads = tree_unflatten(list(accum.items()))
grads = clip_grad_tree(grads, args.grad_clip_norm)
train_loss_value = float(train_loss.item())
opt.step(model, grads, step=step, lr_mul=lr_mul)
mx.synchronize()
step_ms = 1000.0 * (time.perf_counter() - step_t0)
approx_train_time_ms = train_time_ms + 1000.0 * (time.perf_counter() - t0)
tok_s = args.train_batch_tokens / (step_ms / 1000.0)
step += 1
if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None):
log(
f"step:{step}/{args.iterations} train_loss:{train_loss_value:.4f} "
f"train_time:{approx_train_time_ms:.0f}ms step_avg:{approx_train_time_ms / step:.2f}ms tok_s:{tok_s:.0f}"
)
if max_wallclock_ms is not None and stop_after_step is None and approx_train_time_ms >= max_wallclock_ms:
stop_after_step = step
# ==============================================================================
# FINAL SERIALIZATION + QUANTIZED ROUNDTRIP EVAL
# ==============================================================================
# We always write a raw artifact and a quantized artifact, then validate the
# quantized roundtrip directly by loading the dequantized tensors back into the
# model and running one final validation pass.
out_path = out_dir / f"{args.run_id}_mlx_model.npz"
flat_state = {k: v for k, v in tree_flatten(model.state)}
mx.savez(str(out_path), **flat_state)
log(f"saved_model:{out_path} bytes:{out_path.stat().st_size}")
quant_obj, quant_stats = quantize_state_dict_int8(flat_state)
quant_raw = pickle.dumps(quant_obj, protocol=pickle.HIGHEST_PROTOCOL)
quant_blob = zlib.compress(quant_raw, level=9)
quant_serialized_bytes = len(quant_raw)
quant_path = out_dir / f"{args.run_id}_mlx_model.int8.ptz"
with quant_path.open("wb") as f:
f.write(quant_blob)
quant_file_bytes = quant_path.stat().st_size
ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1)
log(
f"serialized_model_int8_zlib:{quant_file_bytes} bytes "
f"(payload:{quant_stats['int8_payload_bytes']} raw_pickle:{quant_serialized_bytes} payload_ratio:{ratio:.2f}x)"
)
with quant_path.open("rb") as f:
quant_blob_disk = f.read()
quant_flat = dequantize_state_dict_int8(pickle.loads(zlib.decompress(quant_blob_disk)))
model.update(tree_unflatten(list(quant_flat.items())))
q_t0 = time.perf_counter()
q_val_loss, q_val_bpb = eval_val(
args,
compiled_loss,
val_tokens,
base_bytes_lut,
has_leading_space_lut,
is_boundary_token_lut,
log_fn=log,
)
q_eval_ms = 1000.0 * (time.perf_counter() - q_t0)
log(f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} eval_time:{q_eval_ms:.0f}ms")
log(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}")
if __name__ == "__main__":
main()
|