summaryrefslogtreecommitdiff
path: root/src/model/pipeline.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/model/pipeline.py')
-rw-r--r--src/model/pipeline.py11
1 files changed, 6 insertions, 5 deletions
diff --git a/src/model/pipeline.py b/src/model/pipeline.py
index bbfcabf..d5ceec0 100644
--- a/src/model/pipeline.py
+++ b/src/model/pipeline.py
@@ -100,10 +100,10 @@ class DAGFormerPipeline(nn.Module):
# logits: [batch, seq_len, vocab_size]
# Step 3: Compute NLL (next-token prediction)
- # Shift: logits[:, :-1] predicts labels[:, 1:]
+ # olmo_labels is already shifted (chunk[1:seq_len+1]), no additional shift needed
nll = F.cross_entropy(
- logits[:, :-1].contiguous().view(-1, self.vocab_size),
- olmo_labels[:, 1:].contiguous().view(-1),
+ logits.contiguous().view(-1, self.vocab_size),
+ olmo_labels.contiguous().view(-1),
)
# Step 4: Sparsity regularization
@@ -130,9 +130,10 @@ class DAGFormerPipeline(nn.Module):
A = create_all_ones_A(batch).to(olmo_ids.device)
with torch.no_grad():
logits = self.olmo_wrapper(olmo_ids, A)
+ # olmo_labels is already shifted, no additional shift needed
nll = F.cross_entropy(
- logits[:, :-1].contiguous().view(-1, self.vocab_size),
- olmo_labels[:, 1:].contiguous().view(-1),
+ logits.contiguous().view(-1, self.vocab_size),
+ olmo_labels.contiguous().view(-1),
)
return nll