summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--train_gpt_mlx.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/train_gpt_mlx.py b/train_gpt_mlx.py
index bf7c7d1..ea7fdca 100644
--- a/train_gpt_mlx.py
+++ b/train_gpt_mlx.py
@@ -991,6 +991,7 @@ def main() -> None:
while True:
last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step)
if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0):
+ train_time_ms += 1000.0 * (time.perf_counter() - t0)
# Validation always scans the same fixed full validation split.
val_loss, val_bpb = eval_val(
args,
@@ -1000,7 +1001,6 @@ def main() -> None:
has_leading_space_lut,
is_boundary_token_lut,
)
- train_time_ms += 1000.0 * (time.perf_counter() - t0)
if step % 25 == 0 or last_step:
log(
f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} "