summaryrefslogtreecommitdiff
path: root/train_gpt_mlx.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_gpt_mlx.py')
-rw-r--r--train_gpt_mlx.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/train_gpt_mlx.py b/train_gpt_mlx.py
index 40175a7..e20794d 100644
--- a/train_gpt_mlx.py
+++ b/train_gpt_mlx.py
@@ -997,6 +997,7 @@ def main() -> None:
while True:
last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step)
if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0):
+ train_time_ms += 1000.0 * (time.perf_counter() - t0)
# Validation always scans the same fixed full validation split.
val_loss, val_bpb = eval_val(
args,
@@ -1007,7 +1008,6 @@ def main() -> None:
is_boundary_token_lut,
log_fn=log,
)
- train_time_ms += 1000.0 * (time.perf_counter() - t0)
if step % 25 == 0 or last_step:
log(
f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} "