diff options
Diffstat (limited to 'tests/test_predictor.py')
| -rw-r--r-- | tests/test_predictor.py | 21 |
1 files changed, 14 insertions, 7 deletions
diff --git a/tests/test_predictor.py b/tests/test_predictor.py index 5a092d4..104c204 100644 --- a/tests/test_predictor.py +++ b/tests/test_predictor.py @@ -159,15 +159,22 @@ class TestCascadingGate: assert (A_gated >= 0).all() and (A_gated <= 1).all() def test_hard_mode_kills_disconnected(self): - """Nodes with no incoming edges should have all outgoing edges zeroed.""" + """Non-layer-0 nodes with no incoming edges should have outgoing edges zeroed. + + Layer 0 is exempt (receives embedding, not truly disconnected). + """ A = torch.zeros(1, 256, 256) - # Only set edges from node 0 to node 16 (layer 0 → layer 1) - A[0, 0, 16] = 1.0 + # Node 32 (layer 2, head 0) has no incoming edges but has outgoing to node 48 + A[0, 32, 48] = 1.0 A_gated = cascading_gate(A, k=5.0, hard=True) - # Node 0 has no incoming edges → its outgoing should be zeroed - assert A_gated[0, 0, 16] == 0.0, "Node 0 has no incoming but wasn't gated to 0" - # Node 16 has incoming from node 0 (but node 0 was gated to 0) - # In one-pass mode, inc uses ORIGINAL A, so node 16 has inc > 0 + # Node 32 has no incoming → outgoing should be zeroed + assert A_gated[0, 32, 48] == 0.0, "Node 32 has no incoming but wasn't gated to 0" + + # Layer 0 should be exempt: node 0 has no incoming but keeps outgoing + A2 = torch.zeros(1, 256, 256) + A2[0, 0, 16] = 1.0 + A2_gated = cascading_gate(A2, k=5.0, hard=True) + assert A2_gated[0, 0, 16] == 1.0, "Layer 0 should be exempt from cascading gate" def test_hard_mode_preserves_connected(self): """Nodes with incoming edges keep their outgoing edges.""" |
