summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Diskin <yhn1124@gmail.com>2026-03-19 01:27:44 +0300
committerMichael Diskin <yhn1124@gmail.com>2026-03-19 01:27:44 +0300
commit321e82cf0bd5277d0becedd76a3b477204311e00 (patch)
treeed58d70a971e74b808fc16258d6d7e31e7a3d20f
parent0c0ea98e6ad92bab5fd2aaab226b6a6f0e68f4d2 (diff)
Fix MLX validation loss accumulation
-rw-r--r--train_gpt_mlx.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/train_gpt_mlx.py b/train_gpt_mlx.py
index bf7c7d1..5eb7d73 100644
--- a/train_gpt_mlx.py
+++ b/train_gpt_mlx.py
@@ -772,7 +772,7 @@ def eval_val(
)
val_batch_seqs = val_batch_tokens // args.train_seq_len
total_seqs = (val_tokens.size - 1) // args.train_seq_len
- total_loss = mx.array(0.0, dtype=mx.float32)
+ total_loss_sum = 0.0
total_tokens = 0.0
total_bytes = 0.0
for batch_seq_start in range(0, total_seqs, val_batch_seqs):
@@ -785,7 +785,9 @@ def eval_val(
x = mx.array(x_np, dtype=mx.int32)
y = mx.array(y_np, dtype=mx.int32)
chunk_token_count = float(y.size)
- total_loss = total_loss + compiled_loss(x, y).astype(mx.float32) * chunk_token_count
+ batch_loss = compiled_loss(x, y).astype(mx.float32)
+ mx.eval(batch_loss)
+ total_loss_sum += float(batch_loss.item()) * chunk_token_count
prev_ids = x_np.reshape(-1)
tgt_ids = y_np.reshape(-1)
bytes_np = base_bytes_lut[tgt_ids].astype(np.int16, copy=True)
@@ -794,9 +796,7 @@ def eval_val(
).astype(np.int16, copy=False)
total_tokens += chunk_token_count
total_bytes += float(bytes_np.astype(np.float64).sum())
- total_loss = total_loss / total_tokens
- mx.eval(total_loss)
- val_loss = float(total_loss.item())
+ val_loss = total_loss_sum / total_tokens
bits_per_token = val_loss / math.log(2.0)
val_bpb = bits_per_token * (total_tokens / total_bytes)
return val_loss, val_bpb