summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/test_predictor.py21
1 files changed, 14 insertions, 7 deletions
diff --git a/tests/test_predictor.py b/tests/test_predictor.py
index 5a092d4..104c204 100644
--- a/tests/test_predictor.py
+++ b/tests/test_predictor.py
@@ -159,15 +159,22 @@ class TestCascadingGate:
assert (A_gated >= 0).all() and (A_gated <= 1).all()
def test_hard_mode_kills_disconnected(self):
- """Nodes with no incoming edges should have all outgoing edges zeroed."""
+ """Non-layer-0 nodes with no incoming edges should have outgoing edges zeroed.
+
+ Layer 0 is exempt (receives embedding, not truly disconnected).
+ """
A = torch.zeros(1, 256, 256)
- # Only set edges from node 0 to node 16 (layer 0 → layer 1)
- A[0, 0, 16] = 1.0
+ # Node 32 (layer 2, head 0) has no incoming edges but has outgoing to node 48
+ A[0, 32, 48] = 1.0
A_gated = cascading_gate(A, k=5.0, hard=True)
- # Node 0 has no incoming edges → its outgoing should be zeroed
- assert A_gated[0, 0, 16] == 0.0, "Node 0 has no incoming but wasn't gated to 0"
- # Node 16 has incoming from node 0 (but node 0 was gated to 0)
- # In one-pass mode, inc uses ORIGINAL A, so node 16 has inc > 0
+ # Node 32 has no incoming → outgoing should be zeroed
+ assert A_gated[0, 32, 48] == 0.0, "Node 32 has no incoming but wasn't gated to 0"
+
+ # Layer 0 should be exempt: node 0 has no incoming but keeps outgoing
+ A2 = torch.zeros(1, 256, 256)
+ A2[0, 0, 16] = 1.0
+ A2_gated = cascading_gate(A2, k=5.0, hard=True)
+ assert A2_gated[0, 0, 16] == 1.0, "Layer 0 should be exempt from cascading gate"
def test_hard_mode_preserves_connected(self):
"""Nodes with incoming edges keep their outgoing edges."""