diff options
Diffstat (limited to 'tests/test_olmo_graph.py')
| -rw-r--r-- | tests/test_olmo_graph.py | 153 |
1 files changed, 153 insertions, 0 deletions
diff --git a/tests/test_olmo_graph.py b/tests/test_olmo_graph.py new file mode 100644 index 0000000..efeb57b --- /dev/null +++ b/tests/test_olmo_graph.py @@ -0,0 +1,153 @@ +"""Unit tests for olmo_graph.py. + +Tests that don't require model download run with synthetic tensors. +Integration tests (baseline reproduction) require the model and are +skipped if model is not available. +""" + +import pytest +import torch +import torch.nn as nn + +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + +from src.model.olmo_graph import ( + create_block_upper_triangular_mask, + InputNormalizer, +) + + +class TestBlockUpperTriangularMask: + """Test the DAG constraint mask.""" + + def test_shape(self): + mask = create_block_upper_triangular_mask(256, 16) + assert mask.shape == (256, 256) + + def test_dtype(self): + mask = create_block_upper_triangular_mask(256, 16) + assert mask.dtype == torch.float32 + + def test_no_self_connections(self): + """Diagonal should be 0 — a node cannot connect to itself.""" + mask = create_block_upper_triangular_mask(256, 16) + assert mask.diag().sum() == 0 + + def test_no_same_layer_connections(self): + """Nodes in the same layer should NOT be connected.""" + mask = create_block_upper_triangular_mask(256, 16) + for layer in range(16): + start = layer * 16 + end = start + 16 + block = mask[start:end, start:end] + assert block.sum() == 0, f"Layer {layer} has same-layer connections" + + def test_no_backward_connections(self): + """No connections from higher layer to lower layer.""" + mask = create_block_upper_triangular_mask(256, 16) + for src_layer in range(16): + for tgt_layer in range(src_layer): # tgt < src = backward + src_start = src_layer * 16 + tgt_start = tgt_layer * 16 + block = mask[src_start:src_start+16, tgt_start:tgt_start+16] + assert block.sum() == 0, f"Backward connection from layer {src_layer} to {tgt_layer}" + + def test_forward_connections_exist(self): + """Forward connections (higher layer targets) should be 1.""" + mask = create_block_upper_triangular_mask(256, 16) + for src_layer in range(15): + for tgt_layer in range(src_layer + 1, 16): + src_start = src_layer * 16 + tgt_start = tgt_layer * 16 + block = mask[src_start:src_start+16, tgt_start:tgt_start+16] + assert block.sum() == 16 * 16, \ + f"Missing connections from layer {src_layer} to layer {tgt_layer}" + + def test_total_valid_entries(self): + """Should have exactly 30,720 valid entries.""" + mask = create_block_upper_triangular_mask(256, 16) + assert mask.sum().item() == 30720 + + def test_adjacent_connections_count(self): + """Adjacent layer connections: 15 × 16 × 16 = 3840.""" + mask = create_block_upper_triangular_mask(256, 16) + count = 0 + for src_layer in range(15): + tgt_layer = src_layer + 1 + src_start = src_layer * 16 + tgt_start = tgt_layer * 16 + count += mask[src_start:src_start+16, tgt_start:tgt_start+16].sum().item() + assert count == 3840 + + def test_skip_connections_count(self): + """Skip connections: 105 × 16 × 16 = 26880.""" + mask = create_block_upper_triangular_mask(256, 16) + count = 0 + for src_layer in range(14): + for tgt_layer in range(src_layer + 2, 16): + src_start = src_layer * 16 + tgt_start = tgt_layer * 16 + count += mask[src_start:src_start+16, tgt_start:tgt_start+16].sum().item() + assert count == 26880 + + def test_not_torch_triu(self): + """Verify this is NOT element-upper-triangular. + + torch.triu would set mask[0,15]=1 (both in layer 0), which is wrong. + """ + mask = create_block_upper_triangular_mask(256, 16) + # Node 0 (layer 0, head 0) to node 15 (layer 0, head 15) + assert mask[0, 15] == 0, "Same-layer connection detected — did you use torch.triu()?" + # Node 0 (layer 0, head 0) to node 16 (layer 1, head 0) + assert mask[0, 16] == 1, "Adjacent-layer connection should be 1" + + +class TestInputNormalizer: + """Test input normalization methods.""" + + def test_none(self): + norm = InputNormalizer("none") + x = torch.randn(2, 16, 32, 2048) + out = norm(x) + assert torch.allclose(out, x) + + def test_gate_mean(self): + norm = InputNormalizer("gate_mean") + gated_sum = torch.randn(2, 16, 32, 2048) + A_slice = torch.rand(2, 48, 16) # 3 prior layers + out = norm(gated_sum, A_slice=A_slice) + assert out.shape == gated_sum.shape + assert torch.isfinite(out).all() + + def test_rms_post(self): + norm = InputNormalizer("rms_post", model_dim=2048) + x = torch.randn(2, 16, 32, 2048) + out = norm(x) + assert out.shape == x.shape + assert torch.isfinite(out).all() + + def test_ln_post(self): + norm = InputNormalizer("ln_post", model_dim=2048) + x = torch.randn(2, 16, 32, 2048) + out = norm(x) + assert out.shape == x.shape + assert torch.isfinite(out).all() + + def test_rms_pre(self): + norm = InputNormalizer("rms_pre", model_dim=64, num_nodes=32) # small for test + prior = torch.randn(2, 32, 8, 64) + A_slice = torch.rand(2, 32, 4) + gated_sum = torch.einsum('bih,bisd->bhsd', A_slice, prior) + out = norm(gated_sum, A_slice=A_slice, prior_head_outs=prior) + assert out.shape == gated_sum.shape + assert torch.isfinite(out).all() + + def test_unknown_method_raises(self): + with pytest.raises(ValueError, match="Unknown input_norm"): + InputNormalizer("unknown_method") + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) |
