summaryrefslogtreecommitdiff
path: root/scripts/sanity_check.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/sanity_check.py')
-rw-r--r--scripts/sanity_check.py17
1 files changed, 10 insertions, 7 deletions
diff --git a/scripts/sanity_check.py b/scripts/sanity_check.py
index f30bd58..9745482 100644
--- a/scripts/sanity_check.py
+++ b/scripts/sanity_check.py
@@ -48,11 +48,14 @@ def get_test_batch(tokenizer, seq_len: int = 64, device: str = "cpu"):
def compute_dagformer_nll(wrapper: DAGFormerOLMo, input_ids: torch.Tensor,
labels: torch.Tensor, A: torch.Tensor) -> torch.Tensor:
- """Compute NLL using DAGFormer modified forward."""
+ """Compute NLL using DAGFormer modified forward.
+
+ labels is already shifted (chunk[1:seq_len+1]), no additional shift needed.
+ """
logits = wrapper.forward(input_ids, A)
nll = F.cross_entropy(
- logits[:, :-1].contiguous().view(-1, logits.size(-1)),
- labels[:, 1:].contiguous().view(-1),
+ logits.contiguous().view(-1, logits.size(-1)),
+ labels.contiguous().view(-1),
)
return nll
@@ -132,8 +135,8 @@ def check_4_gradient_flow(wrapper, tokenizer, device):
logits = wrapper.forward(input_ids, A)
nll = F.cross_entropy(
- logits[:, :-1].contiguous().view(-1, logits.size(-1)),
- labels[:, 1:].contiguous().view(-1),
+ logits.contiguous().view(-1, logits.size(-1)),
+ labels.contiguous().view(-1),
)
nll.backward()
@@ -176,8 +179,8 @@ def check_5_normalization_smoke(wrapper_factory, tokenizer, device):
logits = wrapper.forward(input_ids, A)
is_finite = torch.isfinite(logits).all().item()
nll = F.cross_entropy(
- logits[:, :-1].contiguous().view(-1, logits.size(-1)),
- labels[:, 1:].contiguous().view(-1),
+ logits.contiguous().view(-1, logits.size(-1)),
+ labels.contiguous().view(-1),
).item()
print(f" {method:12s}: NLL={nll:.4f}, finite={is_finite}")
if not is_finite: