diff options
| author | Will DePue <williamd@openai.com> | 2026-03-19 10:07:55 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2026-03-19 10:07:55 -0700 |
| commit | 2081ba1cb7c779b3aedbf728bf4448f772083ce2 (patch) | |
| tree | 6c131613093eb9b0733d117fc5400924cfad2ca9 | |
| parent | 954a158102ec64c292ad82b2442e387e505a9388 (diff) | |
| parent | 6a08c9dcda3edd33202c48d5df104918318ab004 (diff) | |
Merge pull request #100 from sandsevenone/mlx_eager_eval
Use eager mx.eval() to fix running train script on 16GB Mac devices
| -rw-r--r-- | train_gpt_mlx.py | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/train_gpt_mlx.py b/train_gpt_mlx.py index 7ab9bb6..7b9e935 100644 --- a/train_gpt_mlx.py +++ b/train_gpt_mlx.py @@ -59,6 +59,10 @@ class Hyperparameters: # Chunk each logical MLX microbatch into smaller sub-batches to reduce peak # memory pressure without changing the effective optimizer batch. mlx_max_microbatch_tokens: int = int(os.environ.get("MLX_MAX_MICROBATCH_TOKENS", 8_192)) + # Force MLX to materialize the graph after every sub-batch, preventing lazy + # graph buildup across accumulation steps. Keeps peak memory low on 16GB machines. + # Disable on 32GB+ unified memory for better throughput (MLX_EAGER_EVAL=0). + mlx_eager_eval: bool = bool(int(os.environ.get("MLX_EAGER_EVAL", "1"))) warmup_steps: int = int(os.environ.get("WARMUP_STEPS", 20)) warmdown_iters: int = int(os.environ.get("WARMDOWN_ITERS", 1200)) max_wallclock_seconds: float = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) @@ -749,6 +753,8 @@ def loss_and_grad_chunked( scale = float(y.size) / total_tokens loss_value = loss_value + loss.astype(mx.float32) * scale grad_accum = accumulate_flat_grads(grad_accum, grads, scale) + if args.mlx_eager_eval: + mx.eval(loss_value, grad_accum) # materialize each chunk to cap peak memory return loss_value, tree_unflatten(list(grad_accum.items())) @@ -1029,6 +1035,8 @@ def main() -> None: loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad) accum = accumulate_flat_grads(accum, grads, grad_scale) train_loss = train_loss + loss.astype(mx.float32) * grad_scale + if args.mlx_eager_eval: + mx.eval(train_loss, accum) # materialize each microbatch to cap peak memory grads = tree_unflatten(list(accum.items())) grads = clip_grad_tree(grads, args.grad_clip_norm) |
