summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/model/predictor.py10
-rw-r--r--tests/test_predictor.py21
2 files changed, 24 insertions, 7 deletions
diff --git a/src/model/predictor.py b/src/model/predictor.py
index b5f9674..9ce5711 100644
--- a/src/model/predictor.py
+++ b/src/model/predictor.py
@@ -168,6 +168,7 @@ def cascading_gate(
A: torch.Tensor,
k: float = 5.0,
hard: bool = False,
+ heads_per_layer: int = 16,
) -> torch.Tensor:
"""Apply cascading activation gate: kill outgoing edges from disconnected nodes.
@@ -176,6 +177,9 @@ def cascading_gate(
2. Compute gates: g_j = σ(k * inc_j) (soft) or (inc_j > 0) (hard)
3. Apply: A[j, :] *= g_j
+ Layer 0 nodes are exempted: they have inc=0 structurally (no prior layers)
+ but receive the embedding as input, so they are NOT disconnected.
+
Uses ORIGINAL A values for incoming sums (before any gates applied).
See CLAUDE.md §2.3 cascading gate section.
@@ -183,6 +187,7 @@ def cascading_gate(
A: [batch, 256, 256] — gate matrix
k: steepness of sigmoid gate (default: 5.0)
hard: if True, use binary gates (for eval_hard mode)
+ heads_per_layer: number of heads per layer (default: 16)
Returns:
A_gated: [batch, 256, 256] — A with cascading gate applied
@@ -195,6 +200,11 @@ def cascading_gate(
else:
g = torch.sigmoid(k * inc) # [batch, 256]
+ # Exempt layer 0: always g=1 (they receive embedding, not disconnected)
+ # Use non-in-place op to preserve autograd graph
+ exempt = torch.arange(g.shape[1], device=g.device) < heads_per_layer
+ g = torch.where(exempt.unsqueeze(0), torch.ones_like(g), g)
+
# Gate outgoing edges: A[j, :] *= g[j]
# g: [B, 256] → [B, 256, 1] to broadcast with A: [B, 256, 256]
return A * g.unsqueeze(2)
diff --git a/tests/test_predictor.py b/tests/test_predictor.py
index 5a092d4..104c204 100644
--- a/tests/test_predictor.py
+++ b/tests/test_predictor.py
@@ -159,15 +159,22 @@ class TestCascadingGate:
assert (A_gated >= 0).all() and (A_gated <= 1).all()
def test_hard_mode_kills_disconnected(self):
- """Nodes with no incoming edges should have all outgoing edges zeroed."""
+ """Non-layer-0 nodes with no incoming edges should have outgoing edges zeroed.
+
+ Layer 0 is exempt (receives embedding, not truly disconnected).
+ """
A = torch.zeros(1, 256, 256)
- # Only set edges from node 0 to node 16 (layer 0 → layer 1)
- A[0, 0, 16] = 1.0
+ # Node 32 (layer 2, head 0) has no incoming edges but has outgoing to node 48
+ A[0, 32, 48] = 1.0
A_gated = cascading_gate(A, k=5.0, hard=True)
- # Node 0 has no incoming edges → its outgoing should be zeroed
- assert A_gated[0, 0, 16] == 0.0, "Node 0 has no incoming but wasn't gated to 0"
- # Node 16 has incoming from node 0 (but node 0 was gated to 0)
- # In one-pass mode, inc uses ORIGINAL A, so node 16 has inc > 0
+ # Node 32 has no incoming → outgoing should be zeroed
+ assert A_gated[0, 32, 48] == 0.0, "Node 32 has no incoming but wasn't gated to 0"
+
+ # Layer 0 should be exempt: node 0 has no incoming but keeps outgoing
+ A2 = torch.zeros(1, 256, 256)
+ A2[0, 0, 16] = 1.0
+ A2_gated = cascading_gate(A2, k=5.0, hard=True)
+ assert A2_gated[0, 0, 16] == 1.0, "Layer 0 should be exempt from cascading gate"
def test_hard_mode_preserves_connected(self):
"""Nodes with incoming edges keep their outgoing edges."""