From 3a6fec7941f3cc4187ac208496b6842f3563e18c Mon Sep 17 00:00:00 2001 From: Matthew Li <156706407+mattqlf@users.noreply.github.com> Date: Thu, 19 Mar 2026 17:00:42 -0400 Subject: Fix: score final partial window in sliding window eval (#124) The window_starts filter dropped windows shorter than stride, silently skipping up to (stride-1) tokens at the end of the validation set. Now includes all windows with >= 1 scoreable token, and clamps the score start for short final windows. --- records/track_10min_16mb/2026-03-19_SlidingWindowEval/train_gpt.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'records') diff --git a/records/track_10min_16mb/2026-03-19_SlidingWindowEval/train_gpt.py b/records/track_10min_16mb/2026-03-19_SlidingWindowEval/train_gpt.py index 8555de6..6a8fd84 100644 --- a/records/track_10min_16mb/2026-03-19_SlidingWindowEval/train_gpt.py +++ b/records/track_10min_16mb/2026-03-19_SlidingWindowEval/train_gpt.py @@ -856,9 +856,9 @@ def eval_val_sliding( seq_len = args.train_seq_len total_tokens = val_tokens.numel() - 1 - # Build windows; skip any too short to score a full stride + # Build windows; include final partial window if it has at least 1 token window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= stride] + if min(ws + seq_len, total_tokens) - ws >= 1] total_windows = len(window_starts) # Distribute across ranks @@ -899,7 +899,7 @@ def eval_val_sliding( for i, ws in enumerate(batch_ws): wlen = wlens[i] - s = 0 if ws == 0 else wlen - stride + s = 0 if ws == 0 else max(wlen - stride, 0) scored_nll = nll[i, s:wlen].to(torch.float64) loss_sum += scored_nll.sum() token_count += float(wlen - s) -- cgit v1.2.3