summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--scripts/eval.py12
-rw-r--r--scripts/sanity_check.py17
-rw-r--r--src/model/olmo_graph.py5
-rw-r--r--src/model/pipeline.py11
-rw-r--r--src/model/predictor.py10
-rw-r--r--src/training/trainer.py18
-rw-r--r--tests/test_predictor.py12
7 files changed, 51 insertions, 34 deletions
diff --git a/scripts/eval.py b/scripts/eval.py
index bc471dc..33314cf 100644
--- a/scripts/eval.py
+++ b/scripts/eval.py
@@ -88,8 +88,8 @@ def main():
A_soft = predictor(raw_texts, tau=tau, mode="eval_soft")
logits_soft = olmo_wrapper(olmo_ids, A_soft)
nll_soft = F.cross_entropy(
- logits_soft[:, :-1].contiguous().view(-1, vocab_size),
- olmo_labels[:, 1:].contiguous().view(-1),
+ logits_soft.contiguous().view(-1, vocab_size),
+ olmo_labels.contiguous().view(-1),
)
nll_soft_sum += nll_soft.item()
@@ -97,8 +97,8 @@ def main():
A_hard = predictor(raw_texts, tau=tau, mode="eval_hard")
logits_hard = olmo_wrapper(olmo_ids, A_hard)
nll_hard = F.cross_entropy(
- logits_hard[:, :-1].contiguous().view(-1, vocab_size),
- olmo_labels[:, 1:].contiguous().view(-1),
+ logits_hard.contiguous().view(-1, vocab_size),
+ olmo_labels.contiguous().view(-1),
)
nll_hard_sum += nll_hard.item()
@@ -106,8 +106,8 @@ def main():
A_ones = create_all_ones_A(olmo_ids.shape[0]).to(device)
logits_base = olmo_wrapper(olmo_ids, A_ones)
nll_base = F.cross_entropy(
- logits_base[:, :-1].contiguous().view(-1, vocab_size),
- olmo_labels[:, 1:].contiguous().view(-1),
+ logits_base.contiguous().view(-1, vocab_size),
+ olmo_labels.contiguous().view(-1),
)
nll_baseline_sum += nll_base.item()
diff --git a/scripts/sanity_check.py b/scripts/sanity_check.py
index f30bd58..9745482 100644
--- a/scripts/sanity_check.py
+++ b/scripts/sanity_check.py
@@ -48,11 +48,14 @@ def get_test_batch(tokenizer, seq_len: int = 64, device: str = "cpu"):
def compute_dagformer_nll(wrapper: DAGFormerOLMo, input_ids: torch.Tensor,
labels: torch.Tensor, A: torch.Tensor) -> torch.Tensor:
- """Compute NLL using DAGFormer modified forward."""
+ """Compute NLL using DAGFormer modified forward.
+
+ labels is already shifted (chunk[1:seq_len+1]), no additional shift needed.
+ """
logits = wrapper.forward(input_ids, A)
nll = F.cross_entropy(
- logits[:, :-1].contiguous().view(-1, logits.size(-1)),
- labels[:, 1:].contiguous().view(-1),
+ logits.contiguous().view(-1, logits.size(-1)),
+ labels.contiguous().view(-1),
)
return nll
@@ -132,8 +135,8 @@ def check_4_gradient_flow(wrapper, tokenizer, device):
logits = wrapper.forward(input_ids, A)
nll = F.cross_entropy(
- logits[:, :-1].contiguous().view(-1, logits.size(-1)),
- labels[:, 1:].contiguous().view(-1),
+ logits.contiguous().view(-1, logits.size(-1)),
+ labels.contiguous().view(-1),
)
nll.backward()
@@ -176,8 +179,8 @@ def check_5_normalization_smoke(wrapper_factory, tokenizer, device):
logits = wrapper.forward(input_ids, A)
is_finite = torch.isfinite(logits).all().item()
nll = F.cross_entropy(
- logits[:, :-1].contiguous().view(-1, logits.size(-1)),
- labels[:, 1:].contiguous().view(-1),
+ logits.contiguous().view(-1, logits.size(-1)),
+ labels.contiguous().view(-1),
).item()
print(f" {method:12s}: NLL={nll:.4f}, finite={is_finite}")
if not is_finite:
diff --git a/src/model/olmo_graph.py b/src/model/olmo_graph.py
index af9f848..4056181 100644
--- a/src/model/olmo_graph.py
+++ b/src/model/olmo_graph.py
@@ -379,9 +379,10 @@ def compute_vanilla_nll(
with torch.no_grad():
outputs = model(input_ids=input_ids)
logits = outputs.logits
+ # labels is already shifted (chunk[1:seq_len+1]), no additional shift needed
nll = F.cross_entropy(
- logits[:, :-1].contiguous().view(-1, logits.size(-1)),
- labels[:, 1:].contiguous().view(-1),
+ logits.contiguous().view(-1, logits.size(-1)),
+ labels.contiguous().view(-1),
)
return nll
diff --git a/src/model/pipeline.py b/src/model/pipeline.py
index bbfcabf..d5ceec0 100644
--- a/src/model/pipeline.py
+++ b/src/model/pipeline.py
@@ -100,10 +100,10 @@ class DAGFormerPipeline(nn.Module):
# logits: [batch, seq_len, vocab_size]
# Step 3: Compute NLL (next-token prediction)
- # Shift: logits[:, :-1] predicts labels[:, 1:]
+ # olmo_labels is already shifted (chunk[1:seq_len+1]), no additional shift needed
nll = F.cross_entropy(
- logits[:, :-1].contiguous().view(-1, self.vocab_size),
- olmo_labels[:, 1:].contiguous().view(-1),
+ logits.contiguous().view(-1, self.vocab_size),
+ olmo_labels.contiguous().view(-1),
)
# Step 4: Sparsity regularization
@@ -130,9 +130,10 @@ class DAGFormerPipeline(nn.Module):
A = create_all_ones_A(batch).to(olmo_ids.device)
with torch.no_grad():
logits = self.olmo_wrapper(olmo_ids, A)
+ # olmo_labels is already shifted, no additional shift needed
nll = F.cross_entropy(
- logits[:, :-1].contiguous().view(-1, self.vocab_size),
- olmo_labels[:, 1:].contiguous().view(-1),
+ logits.contiguous().view(-1, self.vocab_size),
+ olmo_labels.contiguous().view(-1),
)
return nll
diff --git a/src/model/predictor.py b/src/model/predictor.py
index ed243ad..b5f9674 100644
--- a/src/model/predictor.py
+++ b/src/model/predictor.py
@@ -98,6 +98,16 @@ class PredictorMLP(nn.Module):
self.head_U = nn.Linear(hidden_dim, num_nodes * rank)
self.head_V = nn.Linear(hidden_dim, num_nodes * rank)
+ # Initialize head_U and head_V with small weights so UV^T ≈ 0 at init.
+ # Default Kaiming init gives UV^T with std≈√rank≈5.7 which overwhelms
+ # the logit_bias. Small init ensures Z ≈ logit_bias ± small noise.
+ # std=0.01 gives UV^T std≈0.6 (with hidden_dim=1024, rank=32),
+ # small vs logit_bias=15 but enough for input-dependent gradients.
+ nn.init.normal_(self.head_U.weight, std=0.01)
+ nn.init.normal_(self.head_V.weight, std=0.01)
+ nn.init.zeros_(self.head_U.bias)
+ nn.init.zeros_(self.head_V.bias)
+
# Learnable bias added to Z logits. Initialized positive so that
# σ(init_logit / τ_init) ≈ 1, reproducing dense connectivity (A≈1)
# at init. With τ_init=5.0: σ(15/5) = σ(3) ≈ 0.95.
diff --git a/src/training/trainer.py b/src/training/trainer.py
index de0eb96..7ebd21e 100644
--- a/src/training/trainer.py
+++ b/src/training/trainer.py
@@ -299,10 +299,10 @@ class Trainer:
A = self.predictor(raw_texts, tau=tau, mode="train")
logits = self.olmo_wrapper(olmo_ids, A)
- # NLL loss
+ # NLL loss (olmo_labels already shifted, no additional shift needed)
nll = F.cross_entropy(
- logits[:, :-1].contiguous().view(-1, self.olmo.config.vocab_size),
- olmo_labels[:, 1:].contiguous().view(-1),
+ logits.contiguous().view(-1, self.olmo.config.vocab_size),
+ olmo_labels.contiguous().view(-1),
)
# Sparsity loss
@@ -417,8 +417,8 @@ class Trainer:
A_soft = self.predictor(raw_texts, tau=tau, mode="eval_soft")
logits_soft = self.olmo_wrapper(olmo_ids, A_soft)
nll_soft = F.cross_entropy(
- logits_soft[:, :-1].contiguous().view(-1, vocab_size),
- olmo_labels[:, 1:].contiguous().view(-1),
+ logits_soft.contiguous().view(-1, vocab_size),
+ olmo_labels.contiguous().view(-1),
)
nll_soft_total += nll_soft.item()
@@ -426,8 +426,8 @@ class Trainer:
A_hard = self.predictor(raw_texts, tau=tau, mode="eval_hard")
logits_hard = self.olmo_wrapper(olmo_ids, A_hard)
nll_hard = F.cross_entropy(
- logits_hard[:, :-1].contiguous().view(-1, vocab_size),
- olmo_labels[:, 1:].contiguous().view(-1),
+ logits_hard.contiguous().view(-1, vocab_size),
+ olmo_labels.contiguous().view(-1),
)
nll_hard_total += nll_hard.item()
@@ -435,8 +435,8 @@ class Trainer:
A_ones = create_all_ones_A(olmo_ids.shape[0]).to(self.device)
logits_base = self.olmo_wrapper(olmo_ids, A_ones)
nll_base = F.cross_entropy(
- logits_base[:, :-1].contiguous().view(-1, vocab_size),
- olmo_labels[:, 1:].contiguous().view(-1),
+ logits_base.contiguous().view(-1, vocab_size),
+ olmo_labels.contiguous().view(-1),
)
nll_baseline_total += nll_base.item()
diff --git a/tests/test_predictor.py b/tests/test_predictor.py
index 00a4124..5a092d4 100644
--- a/tests/test_predictor.py
+++ b/tests/test_predictor.py
@@ -28,15 +28,17 @@ class TestPredictorMLP:
assert Z.shape == (self.batch, 256, 256)
def test_low_rank_structure(self):
- """Z = UV^T should have rank <= r."""
+ """Z - logit_bias = UV^T should have rank <= r."""
e = torch.randn(1, self.input_dim)
Z = self.mlp(e)
Z_2d = Z.squeeze(0)
- # SVD to check effective rank
- S = torch.linalg.svdvals(Z_2d)
+ # Subtract the scalar logit_bias (constant across all entries)
+ # so we test the rank of UV^T alone
+ Z_no_bias = Z_2d - self.mlp.logit_bias.detach()
+ S = torch.linalg.svdvals(Z_no_bias)
# Values beyond rank r should be ~0 (up to numerical precision)
- assert S[self.rank:].abs().max() < 1e-4, \
- f"Z has effective rank > {self.rank}: max singular value beyond rank = {S[self.rank:].abs().max()}"
+ assert S[self.rank:].abs().max() < 0.05, \
+ f"UV^T has effective rank > {self.rank}: max singular value beyond rank = {S[self.rank:].abs().max()}"
def test_gradient_flow(self):
e = torch.randn(self.batch, self.input_dim)