summaryrefslogtreecommitdiff
path: root/records
diff options
context:
space:
mode:
authorMatthew Li <156706407+mattqlf@users.noreply.github.com>2026-03-19 17:00:42 -0400
committerGitHub <noreply@github.com>2026-03-19 14:00:42 -0700
commit3a6fec7941f3cc4187ac208496b6842f3563e18c (patch)
tree928fdd23b019ea48234189fd45bb729adba8a991 /records
parentd84a3e819100504d96879e1e36d022efa5cbb81b (diff)
Fix: score final partial window in sliding window eval (#124)
The window_starts filter dropped windows shorter than stride, silently skipping up to (stride-1) tokens at the end of the validation set. Now includes all windows with >= 1 scoreable token, and clamps the score start for short final windows.
Diffstat (limited to 'records')
-rw-r--r--records/track_10min_16mb/2026-03-19_SlidingWindowEval/train_gpt.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/records/track_10min_16mb/2026-03-19_SlidingWindowEval/train_gpt.py b/records/track_10min_16mb/2026-03-19_SlidingWindowEval/train_gpt.py
index 8555de6..6a8fd84 100644
--- a/records/track_10min_16mb/2026-03-19_SlidingWindowEval/train_gpt.py
+++ b/records/track_10min_16mb/2026-03-19_SlidingWindowEval/train_gpt.py
@@ -856,9 +856,9 @@ def eval_val_sliding(
seq_len = args.train_seq_len
total_tokens = val_tokens.numel() - 1
- # Build windows; skip any too short to score a full stride
+ # Build windows; include final partial window if it has at least 1 token
window_starts = [ws for ws in range(0, total_tokens, stride)
- if min(ws + seq_len, total_tokens) - ws >= stride]
+ if min(ws + seq_len, total_tokens) - ws >= 1]
total_windows = len(window_starts)
# Distribute across ranks
@@ -899,7 +899,7 @@ def eval_val_sliding(
for i, ws in enumerate(batch_ws):
wlen = wlens[i]
- s = 0 if ws == 0 else wlen - stride
+ s = 0 if ws == 0 else max(wlen - stride, 0)
scored_nll = nll[i, s:wlen].to(torch.float64)
loss_sum += scored_nll.sum()
token_count += float(wlen - s)