From ef678d2e1ba70b1a9dadb78c73ed372f986aea13 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Mon, 9 Feb 2026 12:28:55 -0600 Subject: Fix NLL double-shift bug and head weight init MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - NLL loss was shifting labels twice (olmo_labels already shifted, then code did logits[:,:-1] vs labels[:,1:]). Fixed in 9 locations: trainer, pipeline, olmo_graph, sanity_check, eval. - Head U/V weights init with std=0.01 (was Kaiming ~5.7 std) so UV^T≈0 at init, ensuring Z≈logit_bias=15 and A≈0.953. - Updated SVD rank test to subtract logit_bias before checking. Co-Authored-By: Claude Opus 4.6 --- src/training/trainer.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) (limited to 'src/training/trainer.py') diff --git a/src/training/trainer.py b/src/training/trainer.py index de0eb96..7ebd21e 100644 --- a/src/training/trainer.py +++ b/src/training/trainer.py @@ -299,10 +299,10 @@ class Trainer: A = self.predictor(raw_texts, tau=tau, mode="train") logits = self.olmo_wrapper(olmo_ids, A) - # NLL loss + # NLL loss (olmo_labels already shifted, no additional shift needed) nll = F.cross_entropy( - logits[:, :-1].contiguous().view(-1, self.olmo.config.vocab_size), - olmo_labels[:, 1:].contiguous().view(-1), + logits.contiguous().view(-1, self.olmo.config.vocab_size), + olmo_labels.contiguous().view(-1), ) # Sparsity loss @@ -417,8 +417,8 @@ class Trainer: A_soft = self.predictor(raw_texts, tau=tau, mode="eval_soft") logits_soft = self.olmo_wrapper(olmo_ids, A_soft) nll_soft = F.cross_entropy( - logits_soft[:, :-1].contiguous().view(-1, vocab_size), - olmo_labels[:, 1:].contiguous().view(-1), + logits_soft.contiguous().view(-1, vocab_size), + olmo_labels.contiguous().view(-1), ) nll_soft_total += nll_soft.item() @@ -426,8 +426,8 @@ class Trainer: A_hard = self.predictor(raw_texts, tau=tau, mode="eval_hard") logits_hard = self.olmo_wrapper(olmo_ids, A_hard) nll_hard = F.cross_entropy( - logits_hard[:, :-1].contiguous().view(-1, vocab_size), - olmo_labels[:, 1:].contiguous().view(-1), + logits_hard.contiguous().view(-1, vocab_size), + olmo_labels.contiguous().view(-1), ) nll_hard_total += nll_hard.item() @@ -435,8 +435,8 @@ class Trainer: A_ones = create_all_ones_A(olmo_ids.shape[0]).to(self.device) logits_base = self.olmo_wrapper(olmo_ids, A_ones) nll_base = F.cross_entropy( - logits_base[:, :-1].contiguous().view(-1, vocab_size), - olmo_labels[:, 1:].contiguous().view(-1), + logits_base.contiguous().view(-1, vocab_size), + olmo_labels.contiguous().view(-1), ) nll_baseline_total += nll_base.item() -- cgit v1.2.3