diff options
Diffstat (limited to 'src/model/pipeline.py')
| -rw-r--r-- | src/model/pipeline.py | 11 |
1 files changed, 6 insertions, 5 deletions
diff --git a/src/model/pipeline.py b/src/model/pipeline.py index bbfcabf..d5ceec0 100644 --- a/src/model/pipeline.py +++ b/src/model/pipeline.py @@ -100,10 +100,10 @@ class DAGFormerPipeline(nn.Module): # logits: [batch, seq_len, vocab_size] # Step 3: Compute NLL (next-token prediction) - # Shift: logits[:, :-1] predicts labels[:, 1:] + # olmo_labels is already shifted (chunk[1:seq_len+1]), no additional shift needed nll = F.cross_entropy( - logits[:, :-1].contiguous().view(-1, self.vocab_size), - olmo_labels[:, 1:].contiguous().view(-1), + logits.contiguous().view(-1, self.vocab_size), + olmo_labels.contiguous().view(-1), ) # Step 4: Sparsity regularization @@ -130,9 +130,10 @@ class DAGFormerPipeline(nn.Module): A = create_all_ones_A(batch).to(olmo_ids.device) with torch.no_grad(): logits = self.olmo_wrapper(olmo_ids, A) + # olmo_labels is already shifted, no additional shift needed nll = F.cross_entropy( - logits[:, :-1].contiguous().view(-1, self.vocab_size), - olmo_labels[:, 1:].contiguous().view(-1), + logits.contiguous().view(-1, self.vocab_size), + olmo_labels.contiguous().view(-1), ) return nll |
