diff options
Diffstat (limited to 'scripts/sanity_check.py')
| -rw-r--r-- | scripts/sanity_check.py | 17 |
1 files changed, 10 insertions, 7 deletions
diff --git a/scripts/sanity_check.py b/scripts/sanity_check.py index f30bd58..9745482 100644 --- a/scripts/sanity_check.py +++ b/scripts/sanity_check.py @@ -48,11 +48,14 @@ def get_test_batch(tokenizer, seq_len: int = 64, device: str = "cpu"): def compute_dagformer_nll(wrapper: DAGFormerOLMo, input_ids: torch.Tensor, labels: torch.Tensor, A: torch.Tensor) -> torch.Tensor: - """Compute NLL using DAGFormer modified forward.""" + """Compute NLL using DAGFormer modified forward. + + labels is already shifted (chunk[1:seq_len+1]), no additional shift needed. + """ logits = wrapper.forward(input_ids, A) nll = F.cross_entropy( - logits[:, :-1].contiguous().view(-1, logits.size(-1)), - labels[:, 1:].contiguous().view(-1), + logits.contiguous().view(-1, logits.size(-1)), + labels.contiguous().view(-1), ) return nll @@ -132,8 +135,8 @@ def check_4_gradient_flow(wrapper, tokenizer, device): logits = wrapper.forward(input_ids, A) nll = F.cross_entropy( - logits[:, :-1].contiguous().view(-1, logits.size(-1)), - labels[:, 1:].contiguous().view(-1), + logits.contiguous().view(-1, logits.size(-1)), + labels.contiguous().view(-1), ) nll.backward() @@ -176,8 +179,8 @@ def check_5_normalization_smoke(wrapper_factory, tokenizer, device): logits = wrapper.forward(input_ids, A) is_finite = torch.isfinite(logits).all().item() nll = F.cross_entropy( - logits[:, :-1].contiguous().view(-1, logits.size(-1)), - labels[:, 1:].contiguous().view(-1), + logits.contiguous().view(-1, logits.size(-1)), + labels.contiguous().view(-1), ).item() print(f" {method:12s}: NLL={nll:.4f}, finite={is_finite}") if not is_finite: |
