From 321e82cf0bd5277d0becedd76a3b477204311e00 Mon Sep 17 00:00:00 2001 From: Michael Diskin Date: Thu, 19 Mar 2026 01:27:44 +0300 Subject: Fix MLX validation loss accumulation --- train_gpt_mlx.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/train_gpt_mlx.py b/train_gpt_mlx.py index bf7c7d1..5eb7d73 100644 --- a/train_gpt_mlx.py +++ b/train_gpt_mlx.py @@ -772,7 +772,7 @@ def eval_val( ) val_batch_seqs = val_batch_tokens // args.train_seq_len total_seqs = (val_tokens.size - 1) // args.train_seq_len - total_loss = mx.array(0.0, dtype=mx.float32) + total_loss_sum = 0.0 total_tokens = 0.0 total_bytes = 0.0 for batch_seq_start in range(0, total_seqs, val_batch_seqs): @@ -785,7 +785,9 @@ def eval_val( x = mx.array(x_np, dtype=mx.int32) y = mx.array(y_np, dtype=mx.int32) chunk_token_count = float(y.size) - total_loss = total_loss + compiled_loss(x, y).astype(mx.float32) * chunk_token_count + batch_loss = compiled_loss(x, y).astype(mx.float32) + mx.eval(batch_loss) + total_loss_sum += float(batch_loss.item()) * chunk_token_count prev_ids = x_np.reshape(-1) tgt_ids = y_np.reshape(-1) bytes_np = base_bytes_lut[tgt_ids].astype(np.int16, copy=True) @@ -794,9 +796,7 @@ def eval_val( ).astype(np.int16, copy=False) total_tokens += chunk_token_count total_bytes += float(bytes_np.astype(np.float64).sum()) - total_loss = total_loss / total_tokens - mx.eval(total_loss) - val_loss = float(total_loss.item()) + val_loss = total_loss_sum / total_tokens bits_per_token = val_loss / math.log(2.0) val_bpb = bits_per_token * (total_tokens / total_bytes) return val_loss, val_bpb -- cgit v1.2.3 From e17ed019f49ab5bc45b1f2e8488809e9f4f5e314 Mon Sep 17 00:00:00 2001 From: Michael Diskin Date: Thu, 19 Mar 2026 01:56:24 +0300 Subject: Log MLX validation progress --- train_gpt_mlx.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/train_gpt_mlx.py b/train_gpt_mlx.py index 5eb7d73..40175a7 100644 --- a/train_gpt_mlx.py +++ b/train_gpt_mlx.py @@ -759,6 +759,7 @@ def eval_val( base_bytes_lut: np.ndarray, has_leading_space_lut: np.ndarray, is_boundary_token_lut: np.ndarray, + log_fn: Callable[[str], None] | None = None, ) -> tuple[float, float]: # Validation computes two metrics: # - val_loss: token cross-entropy (natural log) @@ -772,10 +773,11 @@ def eval_val( ) val_batch_seqs = val_batch_tokens // args.train_seq_len total_seqs = (val_tokens.size - 1) // args.train_seq_len + total_batches = max((total_seqs + val_batch_seqs - 1) // val_batch_seqs, 1) total_loss_sum = 0.0 total_tokens = 0.0 total_bytes = 0.0 - for batch_seq_start in range(0, total_seqs, val_batch_seqs): + for batch_idx, batch_seq_start in enumerate(range(0, total_seqs, val_batch_seqs), start=1): batch_seq_end = min(batch_seq_start + val_batch_seqs, total_seqs) raw_start = batch_seq_start * args.train_seq_len raw_end = batch_seq_end * args.train_seq_len + 1 @@ -796,6 +798,10 @@ def eval_val( ).astype(np.int16, copy=False) total_tokens += chunk_token_count total_bytes += float(bytes_np.astype(np.float64).sum()) + if log_fn is not None and total_batches > 1 and ( + batch_idx == 1 or batch_idx == total_batches or batch_idx % 25 == 0 + ): + log_fn(f"val_progress:{batch_idx}/{total_batches}") val_loss = total_loss_sum / total_tokens bits_per_token = val_loss / math.log(2.0) val_bpb = bits_per_token * (total_tokens / total_bytes) @@ -999,6 +1005,7 @@ def main() -> None: base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + log_fn=log, ) train_time_ms += 1000.0 * (time.perf_counter() - t0) if step % 25 == 0 or last_step: @@ -1078,6 +1085,7 @@ def main() -> None: base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + log_fn=log, ) q_eval_ms = 1000.0 * (time.perf_counter() - q_t0) log(f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} eval_time:{q_eval_ms:.0f}ms") -- cgit v1.2.3