From ef678d2e1ba70b1a9dadb78c73ed372f986aea13 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Mon, 9 Feb 2026 12:28:55 -0600 Subject: Fix NLL double-shift bug and head weight init MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - NLL loss was shifting labels twice (olmo_labels already shifted, then code did logits[:,:-1] vs labels[:,1:]). Fixed in 9 locations: trainer, pipeline, olmo_graph, sanity_check, eval. - Head U/V weights init with std=0.01 (was Kaiming ~5.7 std) so UV^T≈0 at init, ensuring Z≈logit_bias=15 and A≈0.953. - Updated SVD rank test to subtract logit_bias before checking. Co-Authored-By: Claude Opus 4.6 --- src/model/predictor.py | 10 ++++++++++ 1 file changed, 10 insertions(+) (limited to 'src/model/predictor.py') diff --git a/src/model/predictor.py b/src/model/predictor.py index ed243ad..b5f9674 100644 --- a/src/model/predictor.py +++ b/src/model/predictor.py @@ -98,6 +98,16 @@ class PredictorMLP(nn.Module): self.head_U = nn.Linear(hidden_dim, num_nodes * rank) self.head_V = nn.Linear(hidden_dim, num_nodes * rank) + # Initialize head_U and head_V with small weights so UV^T ≈ 0 at init. + # Default Kaiming init gives UV^T with std≈√rank≈5.7 which overwhelms + # the logit_bias. Small init ensures Z ≈ logit_bias ± small noise. + # std=0.01 gives UV^T std≈0.6 (with hidden_dim=1024, rank=32), + # small vs logit_bias=15 but enough for input-dependent gradients. + nn.init.normal_(self.head_U.weight, std=0.01) + nn.init.normal_(self.head_V.weight, std=0.01) + nn.init.zeros_(self.head_U.bias) + nn.init.zeros_(self.head_V.bias) + # Learnable bias added to Z logits. Initialized positive so that # σ(init_logit / τ_init) ≈ 1, reproducing dense connectivity (A≈1) # at init. With τ_init=5.0: σ(15/5) = σ(3) ≈ 0.95. -- cgit v1.2.3