summaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-09 12:28:55 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-09 12:28:55 -0600
commitef678d2e1ba70b1a9dadb78c73ed372f986aea13 (patch)
treeb90b5c53960b22a6a5498ca69fbfffad7e1832f8 /scripts
parent93d77b197d457b1fdfa7341ecd59fc460b20d6b1 (diff)
Fix NLL double-shift bug and head weight init
- NLL loss was shifting labels twice (olmo_labels already shifted, then code did logits[:,:-1] vs labels[:,1:]). Fixed in 9 locations: trainer, pipeline, olmo_graph, sanity_check, eval. - Head U/V weights init with std=0.01 (was Kaiming ~5.7 std) so UV^T≈0 at init, ensuring Z≈logit_bias=15 and A≈0.953. - Updated SVD rank test to subtract logit_bias before checking. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'scripts')
-rw-r--r--scripts/eval.py12
-rw-r--r--scripts/sanity_check.py17
2 files changed, 16 insertions, 13 deletions
diff --git a/scripts/eval.py b/scripts/eval.py
index bc471dc..33314cf 100644
--- a/scripts/eval.py
+++ b/scripts/eval.py
@@ -88,8 +88,8 @@ def main():
A_soft = predictor(raw_texts, tau=tau, mode="eval_soft")
logits_soft = olmo_wrapper(olmo_ids, A_soft)
nll_soft = F.cross_entropy(
- logits_soft[:, :-1].contiguous().view(-1, vocab_size),
- olmo_labels[:, 1:].contiguous().view(-1),
+ logits_soft.contiguous().view(-1, vocab_size),
+ olmo_labels.contiguous().view(-1),
)
nll_soft_sum += nll_soft.item()
@@ -97,8 +97,8 @@ def main():
A_hard = predictor(raw_texts, tau=tau, mode="eval_hard")
logits_hard = olmo_wrapper(olmo_ids, A_hard)
nll_hard = F.cross_entropy(
- logits_hard[:, :-1].contiguous().view(-1, vocab_size),
- olmo_labels[:, 1:].contiguous().view(-1),
+ logits_hard.contiguous().view(-1, vocab_size),
+ olmo_labels.contiguous().view(-1),
)
nll_hard_sum += nll_hard.item()
@@ -106,8 +106,8 @@ def main():
A_ones = create_all_ones_A(olmo_ids.shape[0]).to(device)
logits_base = olmo_wrapper(olmo_ids, A_ones)
nll_base = F.cross_entropy(
- logits_base[:, :-1].contiguous().view(-1, vocab_size),
- olmo_labels[:, 1:].contiguous().view(-1),
+ logits_base.contiguous().view(-1, vocab_size),
+ olmo_labels.contiguous().view(-1),
)
nll_baseline_sum += nll_base.item()
diff --git a/scripts/sanity_check.py b/scripts/sanity_check.py
index f30bd58..9745482 100644
--- a/scripts/sanity_check.py
+++ b/scripts/sanity_check.py
@@ -48,11 +48,14 @@ def get_test_batch(tokenizer, seq_len: int = 64, device: str = "cpu"):
def compute_dagformer_nll(wrapper: DAGFormerOLMo, input_ids: torch.Tensor,
labels: torch.Tensor, A: torch.Tensor) -> torch.Tensor:
- """Compute NLL using DAGFormer modified forward."""
+ """Compute NLL using DAGFormer modified forward.
+
+ labels is already shifted (chunk[1:seq_len+1]), no additional shift needed.
+ """
logits = wrapper.forward(input_ids, A)
nll = F.cross_entropy(
- logits[:, :-1].contiguous().view(-1, logits.size(-1)),
- labels[:, 1:].contiguous().view(-1),
+ logits.contiguous().view(-1, logits.size(-1)),
+ labels.contiguous().view(-1),
)
return nll
@@ -132,8 +135,8 @@ def check_4_gradient_flow(wrapper, tokenizer, device):
logits = wrapper.forward(input_ids, A)
nll = F.cross_entropy(
- logits[:, :-1].contiguous().view(-1, logits.size(-1)),
- labels[:, 1:].contiguous().view(-1),
+ logits.contiguous().view(-1, logits.size(-1)),
+ labels.contiguous().view(-1),
)
nll.backward()
@@ -176,8 +179,8 @@ def check_5_normalization_smoke(wrapper_factory, tokenizer, device):
logits = wrapper.forward(input_ids, A)
is_finite = torch.isfinite(logits).all().item()
nll = F.cross_entropy(
- logits[:, :-1].contiguous().view(-1, logits.size(-1)),
- labels[:, 1:].contiguous().view(-1),
+ logits.contiguous().view(-1, logits.size(-1)),
+ labels.contiguous().view(-1),
).item()
print(f" {method:12s}: NLL={nll:.4f}, finite={is_finite}")
if not is_finite: