From ef678d2e1ba70b1a9dadb78c73ed372f986aea13 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Mon, 9 Feb 2026 12:28:55 -0600 Subject: Fix NLL double-shift bug and head weight init MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - NLL loss was shifting labels twice (olmo_labels already shifted, then code did logits[:,:-1] vs labels[:,1:]). Fixed in 9 locations: trainer, pipeline, olmo_graph, sanity_check, eval. - Head U/V weights init with std=0.01 (was Kaiming ~5.7 std) so UV^T≈0 at init, ensuring Z≈logit_bias=15 and A≈0.953. - Updated SVD rank test to subtract logit_bias before checking. Co-Authored-By: Claude Opus 4.6 --- src/model/olmo_graph.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'src/model/olmo_graph.py') diff --git a/src/model/olmo_graph.py b/src/model/olmo_graph.py index af9f848..4056181 100644 --- a/src/model/olmo_graph.py +++ b/src/model/olmo_graph.py @@ -379,9 +379,10 @@ def compute_vanilla_nll( with torch.no_grad(): outputs = model(input_ids=input_ids) logits = outputs.logits + # labels is already shifted (chunk[1:seq_len+1]), no additional shift needed nll = F.cross_entropy( - logits[:, :-1].contiguous().view(-1, logits.size(-1)), - labels[:, 1:].contiguous().view(-1), + logits.contiguous().view(-1, logits.size(-1)), + labels.contiguous().view(-1), ) return nll -- cgit v1.2.3