diff options
Diffstat (limited to 'scripts/eval.py')
| -rw-r--r-- | scripts/eval.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/scripts/eval.py b/scripts/eval.py index bc471dc..33314cf 100644 --- a/scripts/eval.py +++ b/scripts/eval.py @@ -88,8 +88,8 @@ def main(): A_soft = predictor(raw_texts, tau=tau, mode="eval_soft") logits_soft = olmo_wrapper(olmo_ids, A_soft) nll_soft = F.cross_entropy( - logits_soft[:, :-1].contiguous().view(-1, vocab_size), - olmo_labels[:, 1:].contiguous().view(-1), + logits_soft.contiguous().view(-1, vocab_size), + olmo_labels.contiguous().view(-1), ) nll_soft_sum += nll_soft.item() @@ -97,8 +97,8 @@ def main(): A_hard = predictor(raw_texts, tau=tau, mode="eval_hard") logits_hard = olmo_wrapper(olmo_ids, A_hard) nll_hard = F.cross_entropy( - logits_hard[:, :-1].contiguous().view(-1, vocab_size), - olmo_labels[:, 1:].contiguous().view(-1), + logits_hard.contiguous().view(-1, vocab_size), + olmo_labels.contiguous().view(-1), ) nll_hard_sum += nll_hard.item() @@ -106,8 +106,8 @@ def main(): A_ones = create_all_ones_A(olmo_ids.shape[0]).to(device) logits_base = olmo_wrapper(olmo_ids, A_ones) nll_base = F.cross_entropy( - logits_base[:, :-1].contiguous().view(-1, vocab_size), - olmo_labels[:, 1:].contiguous().view(-1), + logits_base.contiguous().view(-1, vocab_size), + olmo_labels.contiguous().view(-1), ) nll_baseline_sum += nll_base.item() |
