diff options
| -rw-r--r-- | scripts/eval.py | 12 | ||||
| -rw-r--r-- | scripts/sanity_check.py | 17 | ||||
| -rw-r--r-- | src/model/olmo_graph.py | 5 | ||||
| -rw-r--r-- | src/model/pipeline.py | 11 | ||||
| -rw-r--r-- | src/model/predictor.py | 10 | ||||
| -rw-r--r-- | src/training/trainer.py | 18 | ||||
| -rw-r--r-- | tests/test_predictor.py | 12 |
7 files changed, 51 insertions, 34 deletions
diff --git a/scripts/eval.py b/scripts/eval.py index bc471dc..33314cf 100644 --- a/scripts/eval.py +++ b/scripts/eval.py @@ -88,8 +88,8 @@ def main(): A_soft = predictor(raw_texts, tau=tau, mode="eval_soft") logits_soft = olmo_wrapper(olmo_ids, A_soft) nll_soft = F.cross_entropy( - logits_soft[:, :-1].contiguous().view(-1, vocab_size), - olmo_labels[:, 1:].contiguous().view(-1), + logits_soft.contiguous().view(-1, vocab_size), + olmo_labels.contiguous().view(-1), ) nll_soft_sum += nll_soft.item() @@ -97,8 +97,8 @@ def main(): A_hard = predictor(raw_texts, tau=tau, mode="eval_hard") logits_hard = olmo_wrapper(olmo_ids, A_hard) nll_hard = F.cross_entropy( - logits_hard[:, :-1].contiguous().view(-1, vocab_size), - olmo_labels[:, 1:].contiguous().view(-1), + logits_hard.contiguous().view(-1, vocab_size), + olmo_labels.contiguous().view(-1), ) nll_hard_sum += nll_hard.item() @@ -106,8 +106,8 @@ def main(): A_ones = create_all_ones_A(olmo_ids.shape[0]).to(device) logits_base = olmo_wrapper(olmo_ids, A_ones) nll_base = F.cross_entropy( - logits_base[:, :-1].contiguous().view(-1, vocab_size), - olmo_labels[:, 1:].contiguous().view(-1), + logits_base.contiguous().view(-1, vocab_size), + olmo_labels.contiguous().view(-1), ) nll_baseline_sum += nll_base.item() diff --git a/scripts/sanity_check.py b/scripts/sanity_check.py index f30bd58..9745482 100644 --- a/scripts/sanity_check.py +++ b/scripts/sanity_check.py @@ -48,11 +48,14 @@ def get_test_batch(tokenizer, seq_len: int = 64, device: str = "cpu"): def compute_dagformer_nll(wrapper: DAGFormerOLMo, input_ids: torch.Tensor, labels: torch.Tensor, A: torch.Tensor) -> torch.Tensor: - """Compute NLL using DAGFormer modified forward.""" + """Compute NLL using DAGFormer modified forward. + + labels is already shifted (chunk[1:seq_len+1]), no additional shift needed. + """ logits = wrapper.forward(input_ids, A) nll = F.cross_entropy( - logits[:, :-1].contiguous().view(-1, logits.size(-1)), - labels[:, 1:].contiguous().view(-1), + logits.contiguous().view(-1, logits.size(-1)), + labels.contiguous().view(-1), ) return nll @@ -132,8 +135,8 @@ def check_4_gradient_flow(wrapper, tokenizer, device): logits = wrapper.forward(input_ids, A) nll = F.cross_entropy( - logits[:, :-1].contiguous().view(-1, logits.size(-1)), - labels[:, 1:].contiguous().view(-1), + logits.contiguous().view(-1, logits.size(-1)), + labels.contiguous().view(-1), ) nll.backward() @@ -176,8 +179,8 @@ def check_5_normalization_smoke(wrapper_factory, tokenizer, device): logits = wrapper.forward(input_ids, A) is_finite = torch.isfinite(logits).all().item() nll = F.cross_entropy( - logits[:, :-1].contiguous().view(-1, logits.size(-1)), - labels[:, 1:].contiguous().view(-1), + logits.contiguous().view(-1, logits.size(-1)), + labels.contiguous().view(-1), ).item() print(f" {method:12s}: NLL={nll:.4f}, finite={is_finite}") if not is_finite: diff --git a/src/model/olmo_graph.py b/src/model/olmo_graph.py index af9f848..4056181 100644 --- a/src/model/olmo_graph.py +++ b/src/model/olmo_graph.py @@ -379,9 +379,10 @@ def compute_vanilla_nll( with torch.no_grad(): outputs = model(input_ids=input_ids) logits = outputs.logits + # labels is already shifted (chunk[1:seq_len+1]), no additional shift needed nll = F.cross_entropy( - logits[:, :-1].contiguous().view(-1, logits.size(-1)), - labels[:, 1:].contiguous().view(-1), + logits.contiguous().view(-1, logits.size(-1)), + labels.contiguous().view(-1), ) return nll diff --git a/src/model/pipeline.py b/src/model/pipeline.py index bbfcabf..d5ceec0 100644 --- a/src/model/pipeline.py +++ b/src/model/pipeline.py @@ -100,10 +100,10 @@ class DAGFormerPipeline(nn.Module): # logits: [batch, seq_len, vocab_size] # Step 3: Compute NLL (next-token prediction) - # Shift: logits[:, :-1] predicts labels[:, 1:] + # olmo_labels is already shifted (chunk[1:seq_len+1]), no additional shift needed nll = F.cross_entropy( - logits[:, :-1].contiguous().view(-1, self.vocab_size), - olmo_labels[:, 1:].contiguous().view(-1), + logits.contiguous().view(-1, self.vocab_size), + olmo_labels.contiguous().view(-1), ) # Step 4: Sparsity regularization @@ -130,9 +130,10 @@ class DAGFormerPipeline(nn.Module): A = create_all_ones_A(batch).to(olmo_ids.device) with torch.no_grad(): logits = self.olmo_wrapper(olmo_ids, A) + # olmo_labels is already shifted, no additional shift needed nll = F.cross_entropy( - logits[:, :-1].contiguous().view(-1, self.vocab_size), - olmo_labels[:, 1:].contiguous().view(-1), + logits.contiguous().view(-1, self.vocab_size), + olmo_labels.contiguous().view(-1), ) return nll diff --git a/src/model/predictor.py b/src/model/predictor.py index ed243ad..b5f9674 100644 --- a/src/model/predictor.py +++ b/src/model/predictor.py @@ -98,6 +98,16 @@ class PredictorMLP(nn.Module): self.head_U = nn.Linear(hidden_dim, num_nodes * rank) self.head_V = nn.Linear(hidden_dim, num_nodes * rank) + # Initialize head_U and head_V with small weights so UV^T ≈ 0 at init. + # Default Kaiming init gives UV^T with std≈√rank≈5.7 which overwhelms + # the logit_bias. Small init ensures Z ≈ logit_bias ± small noise. + # std=0.01 gives UV^T std≈0.6 (with hidden_dim=1024, rank=32), + # small vs logit_bias=15 but enough for input-dependent gradients. + nn.init.normal_(self.head_U.weight, std=0.01) + nn.init.normal_(self.head_V.weight, std=0.01) + nn.init.zeros_(self.head_U.bias) + nn.init.zeros_(self.head_V.bias) + # Learnable bias added to Z logits. Initialized positive so that # σ(init_logit / τ_init) ≈ 1, reproducing dense connectivity (A≈1) # at init. With τ_init=5.0: σ(15/5) = σ(3) ≈ 0.95. diff --git a/src/training/trainer.py b/src/training/trainer.py index de0eb96..7ebd21e 100644 --- a/src/training/trainer.py +++ b/src/training/trainer.py @@ -299,10 +299,10 @@ class Trainer: A = self.predictor(raw_texts, tau=tau, mode="train") logits = self.olmo_wrapper(olmo_ids, A) - # NLL loss + # NLL loss (olmo_labels already shifted, no additional shift needed) nll = F.cross_entropy( - logits[:, :-1].contiguous().view(-1, self.olmo.config.vocab_size), - olmo_labels[:, 1:].contiguous().view(-1), + logits.contiguous().view(-1, self.olmo.config.vocab_size), + olmo_labels.contiguous().view(-1), ) # Sparsity loss @@ -417,8 +417,8 @@ class Trainer: A_soft = self.predictor(raw_texts, tau=tau, mode="eval_soft") logits_soft = self.olmo_wrapper(olmo_ids, A_soft) nll_soft = F.cross_entropy( - logits_soft[:, :-1].contiguous().view(-1, vocab_size), - olmo_labels[:, 1:].contiguous().view(-1), + logits_soft.contiguous().view(-1, vocab_size), + olmo_labels.contiguous().view(-1), ) nll_soft_total += nll_soft.item() @@ -426,8 +426,8 @@ class Trainer: A_hard = self.predictor(raw_texts, tau=tau, mode="eval_hard") logits_hard = self.olmo_wrapper(olmo_ids, A_hard) nll_hard = F.cross_entropy( - logits_hard[:, :-1].contiguous().view(-1, vocab_size), - olmo_labels[:, 1:].contiguous().view(-1), + logits_hard.contiguous().view(-1, vocab_size), + olmo_labels.contiguous().view(-1), ) nll_hard_total += nll_hard.item() @@ -435,8 +435,8 @@ class Trainer: A_ones = create_all_ones_A(olmo_ids.shape[0]).to(self.device) logits_base = self.olmo_wrapper(olmo_ids, A_ones) nll_base = F.cross_entropy( - logits_base[:, :-1].contiguous().view(-1, vocab_size), - olmo_labels[:, 1:].contiguous().view(-1), + logits_base.contiguous().view(-1, vocab_size), + olmo_labels.contiguous().view(-1), ) nll_baseline_total += nll_base.item() diff --git a/tests/test_predictor.py b/tests/test_predictor.py index 00a4124..5a092d4 100644 --- a/tests/test_predictor.py +++ b/tests/test_predictor.py @@ -28,15 +28,17 @@ class TestPredictorMLP: assert Z.shape == (self.batch, 256, 256) def test_low_rank_structure(self): - """Z = UV^T should have rank <= r.""" + """Z - logit_bias = UV^T should have rank <= r.""" e = torch.randn(1, self.input_dim) Z = self.mlp(e) Z_2d = Z.squeeze(0) - # SVD to check effective rank - S = torch.linalg.svdvals(Z_2d) + # Subtract the scalar logit_bias (constant across all entries) + # so we test the rank of UV^T alone + Z_no_bias = Z_2d - self.mlp.logit_bias.detach() + S = torch.linalg.svdvals(Z_no_bias) # Values beyond rank r should be ~0 (up to numerical precision) - assert S[self.rank:].abs().max() < 1e-4, \ - f"Z has effective rank > {self.rank}: max singular value beyond rank = {S[self.rank:].abs().max()}" + assert S[self.rank:].abs().max() < 0.05, \ + f"UV^T has effective rank > {self.rank}: max singular value beyond rank = {S[self.rank:].abs().max()}" def test_gradient_flow(self): e = torch.randn(self.batch, self.input_dim) |
