diff options
Diffstat (limited to 'src/model')
| -rw-r--r-- | src/model/olmo_graph.py | 5 | ||||
| -rw-r--r-- | src/model/pipeline.py | 11 | ||||
| -rw-r--r-- | src/model/predictor.py | 10 |
3 files changed, 19 insertions, 7 deletions
diff --git a/src/model/olmo_graph.py b/src/model/olmo_graph.py index af9f848..4056181 100644 --- a/src/model/olmo_graph.py +++ b/src/model/olmo_graph.py @@ -379,9 +379,10 @@ def compute_vanilla_nll( with torch.no_grad(): outputs = model(input_ids=input_ids) logits = outputs.logits + # labels is already shifted (chunk[1:seq_len+1]), no additional shift needed nll = F.cross_entropy( - logits[:, :-1].contiguous().view(-1, logits.size(-1)), - labels[:, 1:].contiguous().view(-1), + logits.contiguous().view(-1, logits.size(-1)), + labels.contiguous().view(-1), ) return nll diff --git a/src/model/pipeline.py b/src/model/pipeline.py index bbfcabf..d5ceec0 100644 --- a/src/model/pipeline.py +++ b/src/model/pipeline.py @@ -100,10 +100,10 @@ class DAGFormerPipeline(nn.Module): # logits: [batch, seq_len, vocab_size] # Step 3: Compute NLL (next-token prediction) - # Shift: logits[:, :-1] predicts labels[:, 1:] + # olmo_labels is already shifted (chunk[1:seq_len+1]), no additional shift needed nll = F.cross_entropy( - logits[:, :-1].contiguous().view(-1, self.vocab_size), - olmo_labels[:, 1:].contiguous().view(-1), + logits.contiguous().view(-1, self.vocab_size), + olmo_labels.contiguous().view(-1), ) # Step 4: Sparsity regularization @@ -130,9 +130,10 @@ class DAGFormerPipeline(nn.Module): A = create_all_ones_A(batch).to(olmo_ids.device) with torch.no_grad(): logits = self.olmo_wrapper(olmo_ids, A) + # olmo_labels is already shifted, no additional shift needed nll = F.cross_entropy( - logits[:, :-1].contiguous().view(-1, self.vocab_size), - olmo_labels[:, 1:].contiguous().view(-1), + logits.contiguous().view(-1, self.vocab_size), + olmo_labels.contiguous().view(-1), ) return nll diff --git a/src/model/predictor.py b/src/model/predictor.py index ed243ad..b5f9674 100644 --- a/src/model/predictor.py +++ b/src/model/predictor.py @@ -98,6 +98,16 @@ class PredictorMLP(nn.Module): self.head_U = nn.Linear(hidden_dim, num_nodes * rank) self.head_V = nn.Linear(hidden_dim, num_nodes * rank) + # Initialize head_U and head_V with small weights so UV^T ≈ 0 at init. + # Default Kaiming init gives UV^T with std≈√rank≈5.7 which overwhelms + # the logit_bias. Small init ensures Z ≈ logit_bias ± small noise. + # std=0.01 gives UV^T std≈0.6 (with hidden_dim=1024, rank=32), + # small vs logit_bias=15 but enough for input-dependent gradients. + nn.init.normal_(self.head_U.weight, std=0.01) + nn.init.normal_(self.head_V.weight, std=0.01) + nn.init.zeros_(self.head_U.bias) + nn.init.zeros_(self.head_V.bias) + # Learnable bias added to Z logits. Initialized positive so that # σ(init_logit / τ_init) ≈ 1, reproducing dense connectivity (A≈1) # at init. With τ_init=5.0: σ(15/5) = σ(3) ≈ 0.95. |
