summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/model/olmo_graph.py5
-rw-r--r--src/model/pipeline.py11
-rw-r--r--src/model/predictor.py10
-rw-r--r--src/training/trainer.py18
4 files changed, 28 insertions, 16 deletions
diff --git a/src/model/olmo_graph.py b/src/model/olmo_graph.py
index af9f848..4056181 100644
--- a/src/model/olmo_graph.py
+++ b/src/model/olmo_graph.py
@@ -379,9 +379,10 @@ def compute_vanilla_nll(
with torch.no_grad():
outputs = model(input_ids=input_ids)
logits = outputs.logits
+ # labels is already shifted (chunk[1:seq_len+1]), no additional shift needed
nll = F.cross_entropy(
- logits[:, :-1].contiguous().view(-1, logits.size(-1)),
- labels[:, 1:].contiguous().view(-1),
+ logits.contiguous().view(-1, logits.size(-1)),
+ labels.contiguous().view(-1),
)
return nll
diff --git a/src/model/pipeline.py b/src/model/pipeline.py
index bbfcabf..d5ceec0 100644
--- a/src/model/pipeline.py
+++ b/src/model/pipeline.py
@@ -100,10 +100,10 @@ class DAGFormerPipeline(nn.Module):
# logits: [batch, seq_len, vocab_size]
# Step 3: Compute NLL (next-token prediction)
- # Shift: logits[:, :-1] predicts labels[:, 1:]
+ # olmo_labels is already shifted (chunk[1:seq_len+1]), no additional shift needed
nll = F.cross_entropy(
- logits[:, :-1].contiguous().view(-1, self.vocab_size),
- olmo_labels[:, 1:].contiguous().view(-1),
+ logits.contiguous().view(-1, self.vocab_size),
+ olmo_labels.contiguous().view(-1),
)
# Step 4: Sparsity regularization
@@ -130,9 +130,10 @@ class DAGFormerPipeline(nn.Module):
A = create_all_ones_A(batch).to(olmo_ids.device)
with torch.no_grad():
logits = self.olmo_wrapper(olmo_ids, A)
+ # olmo_labels is already shifted, no additional shift needed
nll = F.cross_entropy(
- logits[:, :-1].contiguous().view(-1, self.vocab_size),
- olmo_labels[:, 1:].contiguous().view(-1),
+ logits.contiguous().view(-1, self.vocab_size),
+ olmo_labels.contiguous().view(-1),
)
return nll
diff --git a/src/model/predictor.py b/src/model/predictor.py
index ed243ad..b5f9674 100644
--- a/src/model/predictor.py
+++ b/src/model/predictor.py
@@ -98,6 +98,16 @@ class PredictorMLP(nn.Module):
self.head_U = nn.Linear(hidden_dim, num_nodes * rank)
self.head_V = nn.Linear(hidden_dim, num_nodes * rank)
+ # Initialize head_U and head_V with small weights so UV^T ≈ 0 at init.
+ # Default Kaiming init gives UV^T with std≈√rank≈5.7 which overwhelms
+ # the logit_bias. Small init ensures Z ≈ logit_bias ± small noise.
+ # std=0.01 gives UV^T std≈0.6 (with hidden_dim=1024, rank=32),
+ # small vs logit_bias=15 but enough for input-dependent gradients.
+ nn.init.normal_(self.head_U.weight, std=0.01)
+ nn.init.normal_(self.head_V.weight, std=0.01)
+ nn.init.zeros_(self.head_U.bias)
+ nn.init.zeros_(self.head_V.bias)
+
# Learnable bias added to Z logits. Initialized positive so that
# σ(init_logit / τ_init) ≈ 1, reproducing dense connectivity (A≈1)
# at init. With τ_init=5.0: σ(15/5) = σ(3) ≈ 0.95.
diff --git a/src/training/trainer.py b/src/training/trainer.py
index de0eb96..7ebd21e 100644
--- a/src/training/trainer.py
+++ b/src/training/trainer.py
@@ -299,10 +299,10 @@ class Trainer:
A = self.predictor(raw_texts, tau=tau, mode="train")
logits = self.olmo_wrapper(olmo_ids, A)
- # NLL loss
+ # NLL loss (olmo_labels already shifted, no additional shift needed)
nll = F.cross_entropy(
- logits[:, :-1].contiguous().view(-1, self.olmo.config.vocab_size),
- olmo_labels[:, 1:].contiguous().view(-1),
+ logits.contiguous().view(-1, self.olmo.config.vocab_size),
+ olmo_labels.contiguous().view(-1),
)
# Sparsity loss
@@ -417,8 +417,8 @@ class Trainer:
A_soft = self.predictor(raw_texts, tau=tau, mode="eval_soft")
logits_soft = self.olmo_wrapper(olmo_ids, A_soft)
nll_soft = F.cross_entropy(
- logits_soft[:, :-1].contiguous().view(-1, vocab_size),
- olmo_labels[:, 1:].contiguous().view(-1),
+ logits_soft.contiguous().view(-1, vocab_size),
+ olmo_labels.contiguous().view(-1),
)
nll_soft_total += nll_soft.item()
@@ -426,8 +426,8 @@ class Trainer:
A_hard = self.predictor(raw_texts, tau=tau, mode="eval_hard")
logits_hard = self.olmo_wrapper(olmo_ids, A_hard)
nll_hard = F.cross_entropy(
- logits_hard[:, :-1].contiguous().view(-1, vocab_size),
- olmo_labels[:, 1:].contiguous().view(-1),
+ logits_hard.contiguous().view(-1, vocab_size),
+ olmo_labels.contiguous().view(-1),
)
nll_hard_total += nll_hard.item()
@@ -435,8 +435,8 @@ class Trainer:
A_ones = create_all_ones_A(olmo_ids.shape[0]).to(self.device)
logits_base = self.olmo_wrapper(olmo_ids, A_ones)
nll_base = F.cross_entropy(
- logits_base[:, :-1].contiguous().view(-1, vocab_size),
- olmo_labels[:, 1:].contiguous().view(-1),
+ logits_base.contiguous().view(-1, vocab_size),
+ olmo_labels.contiguous().view(-1),
)
nll_baseline_total += nll_base.item()