summaryrefslogtreecommitdiff
path: root/src/training/trainer.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/training/trainer.py')
-rw-r--r--src/training/trainer.py18
1 files changed, 9 insertions, 9 deletions
diff --git a/src/training/trainer.py b/src/training/trainer.py
index de0eb96..7ebd21e 100644
--- a/src/training/trainer.py
+++ b/src/training/trainer.py
@@ -299,10 +299,10 @@ class Trainer:
A = self.predictor(raw_texts, tau=tau, mode="train")
logits = self.olmo_wrapper(olmo_ids, A)
- # NLL loss
+ # NLL loss (olmo_labels already shifted, no additional shift needed)
nll = F.cross_entropy(
- logits[:, :-1].contiguous().view(-1, self.olmo.config.vocab_size),
- olmo_labels[:, 1:].contiguous().view(-1),
+ logits.contiguous().view(-1, self.olmo.config.vocab_size),
+ olmo_labels.contiguous().view(-1),
)
# Sparsity loss
@@ -417,8 +417,8 @@ class Trainer:
A_soft = self.predictor(raw_texts, tau=tau, mode="eval_soft")
logits_soft = self.olmo_wrapper(olmo_ids, A_soft)
nll_soft = F.cross_entropy(
- logits_soft[:, :-1].contiguous().view(-1, vocab_size),
- olmo_labels[:, 1:].contiguous().view(-1),
+ logits_soft.contiguous().view(-1, vocab_size),
+ olmo_labels.contiguous().view(-1),
)
nll_soft_total += nll_soft.item()
@@ -426,8 +426,8 @@ class Trainer:
A_hard = self.predictor(raw_texts, tau=tau, mode="eval_hard")
logits_hard = self.olmo_wrapper(olmo_ids, A_hard)
nll_hard = F.cross_entropy(
- logits_hard[:, :-1].contiguous().view(-1, vocab_size),
- olmo_labels[:, 1:].contiguous().view(-1),
+ logits_hard.contiguous().view(-1, vocab_size),
+ olmo_labels.contiguous().view(-1),
)
nll_hard_total += nll_hard.item()
@@ -435,8 +435,8 @@ class Trainer:
A_ones = create_all_ones_A(olmo_ids.shape[0]).to(self.device)
logits_base = self.olmo_wrapper(olmo_ids, A_ones)
nll_base = F.cross_entropy(
- logits_base[:, :-1].contiguous().view(-1, vocab_size),
- olmo_labels[:, 1:].contiguous().view(-1),
+ logits_base.contiguous().view(-1, vocab_size),
+ olmo_labels.contiguous().view(-1),
)
nll_baseline_total += nll_base.item()