summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWill DePue <williamd@openai.com>2026-03-18 16:26:02 -0700
committerGitHub <noreply@github.com>2026-03-18 16:26:02 -0700
commit5472f29be414fe6b50189058c6ccc9aa3d73566d (patch)
tree004912a4eaa1a91d657eb09d1ce1790a098a16c5
parent0c0ea98e6ad92bab5fd2aaab226b6a6f0e68f4d2 (diff)
parent3d1c8e6d741118ae7166f0d014f6ac784a2aa3e1 (diff)
Merge pull request #18 from berniwal/main
MLX Timing Mismatch with Main Script
-rw-r--r--train_gpt_mlx.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/train_gpt_mlx.py b/train_gpt_mlx.py
index bf7c7d1..ea7fdca 100644
--- a/train_gpt_mlx.py
+++ b/train_gpt_mlx.py
@@ -991,6 +991,7 @@ def main() -> None:
while True:
last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step)
if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0):
+ train_time_ms += 1000.0 * (time.perf_counter() - t0)
# Validation always scans the same fixed full validation split.
val_loss, val_bpb = eval_val(
args,
@@ -1000,7 +1001,6 @@ def main() -> None:
has_leading_space_lut,
is_boundary_token_lut,
)
- train_time_ms += 1000.0 * (time.perf_counter() - t0)
if step % 25 == 0 or last_step:
log(
f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} "