summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWill DePue <williamd@openai.com>2026-03-19 10:07:55 -0700
committerGitHub <noreply@github.com>2026-03-19 10:07:55 -0700
commit2081ba1cb7c779b3aedbf728bf4448f772083ce2 (patch)
tree6c131613093eb9b0733d117fc5400924cfad2ca9
parent954a158102ec64c292ad82b2442e387e505a9388 (diff)
parent6a08c9dcda3edd33202c48d5df104918318ab004 (diff)
Merge pull request #100 from sandsevenone/mlx_eager_eval
Use eager mx.eval() to fix running train script on 16GB Mac devices
-rw-r--r--train_gpt_mlx.py8
1 files changed, 8 insertions, 0 deletions
diff --git a/train_gpt_mlx.py b/train_gpt_mlx.py
index 7ab9bb6..7b9e935 100644
--- a/train_gpt_mlx.py
+++ b/train_gpt_mlx.py
@@ -59,6 +59,10 @@ class Hyperparameters:
# Chunk each logical MLX microbatch into smaller sub-batches to reduce peak
# memory pressure without changing the effective optimizer batch.
mlx_max_microbatch_tokens: int = int(os.environ.get("MLX_MAX_MICROBATCH_TOKENS", 8_192))
+ # Force MLX to materialize the graph after every sub-batch, preventing lazy
+ # graph buildup across accumulation steps. Keeps peak memory low on 16GB machines.
+ # Disable on 32GB+ unified memory for better throughput (MLX_EAGER_EVAL=0).
+ mlx_eager_eval: bool = bool(int(os.environ.get("MLX_EAGER_EVAL", "1")))
warmup_steps: int = int(os.environ.get("WARMUP_STEPS", 20))
warmdown_iters: int = int(os.environ.get("WARMDOWN_ITERS", 1200))
max_wallclock_seconds: float = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0))
@@ -749,6 +753,8 @@ def loss_and_grad_chunked(
scale = float(y.size) / total_tokens
loss_value = loss_value + loss.astype(mx.float32) * scale
grad_accum = accumulate_flat_grads(grad_accum, grads, scale)
+ if args.mlx_eager_eval:
+ mx.eval(loss_value, grad_accum) # materialize each chunk to cap peak memory
return loss_value, tree_unflatten(list(grad_accum.items()))
@@ -1029,6 +1035,8 @@ def main() -> None:
loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad)
accum = accumulate_flat_grads(accum, grads, grad_scale)
train_loss = train_loss + loss.astype(mx.float32) * grad_scale
+ if args.mlx_eager_eval:
+ mx.eval(train_loss, accum) # materialize each microbatch to cap peak memory
grads = tree_unflatten(list(accum.items()))
grads = clip_grad_tree(grads, args.grad_clip_norm)