summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-09 14:40:31 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-09 14:40:31 -0600
commit80579d6cc254d337a23e71404ae7ecab1849d1e5 (patch)
treebc6790229c20af516da662d7a4b7c8c7f1c4cb8c /tests
parentef678d2e1ba70b1a9dadb78c73ed372f986aea13 (diff)
Fix cascading gate: exempt layer 0 from disconnection checkHEADmain
Layer 0 has no incoming edges structurally (no prior layers), but receives the embedding as input. The cascading gate was killing its outgoing edges (hard: g=0, soft: g=0.5), causing nll_hard to be ~2x worse than baseline. Fix: set g=1 for layer 0 nodes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'tests')
-rw-r--r--tests/test_predictor.py21
1 files changed, 14 insertions, 7 deletions
diff --git a/tests/test_predictor.py b/tests/test_predictor.py
index 5a092d4..104c204 100644
--- a/tests/test_predictor.py
+++ b/tests/test_predictor.py
@@ -159,15 +159,22 @@ class TestCascadingGate:
assert (A_gated >= 0).all() and (A_gated <= 1).all()
def test_hard_mode_kills_disconnected(self):
- """Nodes with no incoming edges should have all outgoing edges zeroed."""
+ """Non-layer-0 nodes with no incoming edges should have outgoing edges zeroed.
+
+ Layer 0 is exempt (receives embedding, not truly disconnected).
+ """
A = torch.zeros(1, 256, 256)
- # Only set edges from node 0 to node 16 (layer 0 → layer 1)
- A[0, 0, 16] = 1.0
+ # Node 32 (layer 2, head 0) has no incoming edges but has outgoing to node 48
+ A[0, 32, 48] = 1.0
A_gated = cascading_gate(A, k=5.0, hard=True)
- # Node 0 has no incoming edges → its outgoing should be zeroed
- assert A_gated[0, 0, 16] == 0.0, "Node 0 has no incoming but wasn't gated to 0"
- # Node 16 has incoming from node 0 (but node 0 was gated to 0)
- # In one-pass mode, inc uses ORIGINAL A, so node 16 has inc > 0
+ # Node 32 has no incoming → outgoing should be zeroed
+ assert A_gated[0, 32, 48] == 0.0, "Node 32 has no incoming but wasn't gated to 0"
+
+ # Layer 0 should be exempt: node 0 has no incoming but keeps outgoing
+ A2 = torch.zeros(1, 256, 256)
+ A2[0, 0, 16] = 1.0
+ A2_gated = cascading_gate(A2, k=5.0, hard=True)
+ assert A2_gated[0, 0, 16] == 1.0, "Layer 0 should be exempt from cascading gate"
def test_hard_mode_preserves_connected(self):
"""Nodes with incoming edges keep their outgoing edges."""