summaryrefslogtreecommitdiff
path: root/src/model/olmo_graph.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/model/olmo_graph.py')
-rw-r--r--src/model/olmo_graph.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/src/model/olmo_graph.py b/src/model/olmo_graph.py
index af9f848..4056181 100644
--- a/src/model/olmo_graph.py
+++ b/src/model/olmo_graph.py
@@ -379,9 +379,10 @@ def compute_vanilla_nll(
with torch.no_grad():
outputs = model(input_ids=input_ids)
logits = outputs.logits
+ # labels is already shifted (chunk[1:seq_len+1]), no additional shift needed
nll = F.cross_entropy(
- logits[:, :-1].contiguous().view(-1, logits.size(-1)),
- labels[:, 1:].contiguous().view(-1),
+ logits.contiguous().view(-1, logits.size(-1)),
+ labels.contiguous().view(-1),
)
return nll