summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWill DePue <williamd@openai.com>2026-03-18 16:30:22 -0700
committerGitHub <noreply@github.com>2026-03-18 16:30:22 -0700
commit825357724a36e54bb61dca99700b21b07aaa8c47 (patch)
treebddbdbf375bf1fd7438d936a43fd094e7645b803
parent09c3e8edaa478068bcae05982b426026f1d3a023 (diff)
parente17ed019f49ab5bc45b1f2e8488809e9f4f5e314 (diff)
Merge pull request #32 from yhn112/fix-mlx-eval-memory-growth
Fix MLX multi-batch validation memory growth
-rw-r--r--train_gpt_mlx.py20
1 files changed, 14 insertions, 6 deletions
diff --git a/train_gpt_mlx.py b/train_gpt_mlx.py
index ea7fdca..e20794d 100644
--- a/train_gpt_mlx.py
+++ b/train_gpt_mlx.py
@@ -759,6 +759,7 @@ def eval_val(
base_bytes_lut: np.ndarray,
has_leading_space_lut: np.ndarray,
is_boundary_token_lut: np.ndarray,
+ log_fn: Callable[[str], None] | None = None,
) -> tuple[float, float]:
# Validation computes two metrics:
# - val_loss: token cross-entropy (natural log)
@@ -772,10 +773,11 @@ def eval_val(
)
val_batch_seqs = val_batch_tokens // args.train_seq_len
total_seqs = (val_tokens.size - 1) // args.train_seq_len
- total_loss = mx.array(0.0, dtype=mx.float32)
+ total_batches = max((total_seqs + val_batch_seqs - 1) // val_batch_seqs, 1)
+ total_loss_sum = 0.0
total_tokens = 0.0
total_bytes = 0.0
- for batch_seq_start in range(0, total_seqs, val_batch_seqs):
+ for batch_idx, batch_seq_start in enumerate(range(0, total_seqs, val_batch_seqs), start=1):
batch_seq_end = min(batch_seq_start + val_batch_seqs, total_seqs)
raw_start = batch_seq_start * args.train_seq_len
raw_end = batch_seq_end * args.train_seq_len + 1
@@ -785,7 +787,9 @@ def eval_val(
x = mx.array(x_np, dtype=mx.int32)
y = mx.array(y_np, dtype=mx.int32)
chunk_token_count = float(y.size)
- total_loss = total_loss + compiled_loss(x, y).astype(mx.float32) * chunk_token_count
+ batch_loss = compiled_loss(x, y).astype(mx.float32)
+ mx.eval(batch_loss)
+ total_loss_sum += float(batch_loss.item()) * chunk_token_count
prev_ids = x_np.reshape(-1)
tgt_ids = y_np.reshape(-1)
bytes_np = base_bytes_lut[tgt_ids].astype(np.int16, copy=True)
@@ -794,9 +798,11 @@ def eval_val(
).astype(np.int16, copy=False)
total_tokens += chunk_token_count
total_bytes += float(bytes_np.astype(np.float64).sum())
- total_loss = total_loss / total_tokens
- mx.eval(total_loss)
- val_loss = float(total_loss.item())
+ if log_fn is not None and total_batches > 1 and (
+ batch_idx == 1 or batch_idx == total_batches or batch_idx % 25 == 0
+ ):
+ log_fn(f"val_progress:{batch_idx}/{total_batches}")
+ val_loss = total_loss_sum / total_tokens
bits_per_token = val_loss / math.log(2.0)
val_bpb = bits_per_token * (total_tokens / total_bytes)
return val_loss, val_bpb
@@ -1000,6 +1006,7 @@ def main() -> None:
base_bytes_lut,
has_leading_space_lut,
is_boundary_token_lut,
+ log_fn=log,
)
if step % 25 == 0 or last_step:
log(
@@ -1078,6 +1085,7 @@ def main() -> None:
base_bytes_lut,
has_leading_space_lut,
is_boundary_token_lut,
+ log_fn=log,
)
q_eval_ms = 1000.0 * (time.perf_counter() - q_t0)
log(f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} eval_time:{q_eval_ms:.0f}ms")