From 6a08c9dcda3edd33202c48d5df104918318ab004 Mon Sep 17 00:00:00 2001 From: sandrone Date: Fri, 20 Mar 2026 00:35:33 +0800 Subject: Add MLX_EAGER_EVAL flag to further reduce memory pressure by force-evaluating the graph after each sub-batch step --- train_gpt_mlx.py | 8 ++++++++ 1 file changed, 8 insertions(+) (limited to 'train_gpt_mlx.py') diff --git a/train_gpt_mlx.py b/train_gpt_mlx.py index 7ab9bb6..7b9e935 100644 --- a/train_gpt_mlx.py +++ b/train_gpt_mlx.py @@ -59,6 +59,10 @@ class Hyperparameters: # Chunk each logical MLX microbatch into smaller sub-batches to reduce peak # memory pressure without changing the effective optimizer batch. mlx_max_microbatch_tokens: int = int(os.environ.get("MLX_MAX_MICROBATCH_TOKENS", 8_192)) + # Force MLX to materialize the graph after every sub-batch, preventing lazy + # graph buildup across accumulation steps. Keeps peak memory low on 16GB machines. + # Disable on 32GB+ unified memory for better throughput (MLX_EAGER_EVAL=0). + mlx_eager_eval: bool = bool(int(os.environ.get("MLX_EAGER_EVAL", "1"))) warmup_steps: int = int(os.environ.get("WARMUP_STEPS", 20)) warmdown_iters: int = int(os.environ.get("WARMDOWN_ITERS", 1200)) max_wallclock_seconds: float = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) @@ -749,6 +753,8 @@ def loss_and_grad_chunked( scale = float(y.size) / total_tokens loss_value = loss_value + loss.astype(mx.float32) * scale grad_accum = accumulate_flat_grads(grad_accum, grads, scale) + if args.mlx_eager_eval: + mx.eval(loss_value, grad_accum) # materialize each chunk to cap peak memory return loss_value, tree_unflatten(list(grad_accum.items())) @@ -1029,6 +1035,8 @@ def main() -> None: loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad) accum = accumulate_flat_grads(accum, grads, grad_scale) train_loss = train_loss + loss.astype(mx.float32) * grad_scale + if args.mlx_eager_eval: + mx.eval(train_loss, accum) # materialize each microbatch to cap peak memory grads = tree_unflatten(list(accum.items())) grads = clip_grad_tree(grads, args.grad_clip_norm) -- cgit v1.2.3