diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-09 14:40:31 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-09 14:40:31 -0600 |
| commit | 80579d6cc254d337a23e71404ae7ecab1849d1e5 (patch) | |
| tree | bc6790229c20af516da662d7a4b7c8c7f1c4cb8c | |
| parent | ef678d2e1ba70b1a9dadb78c73ed372f986aea13 (diff) | |
Layer 0 has no incoming edges structurally (no prior layers), but
receives the embedding as input. The cascading gate was killing its
outgoing edges (hard: g=0, soft: g=0.5), causing nll_hard to be
~2x worse than baseline. Fix: set g=1 for layer 0 nodes.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
| -rw-r--r-- | src/model/predictor.py | 10 | ||||
| -rw-r--r-- | tests/test_predictor.py | 21 |
2 files changed, 24 insertions, 7 deletions
diff --git a/src/model/predictor.py b/src/model/predictor.py index b5f9674..9ce5711 100644 --- a/src/model/predictor.py +++ b/src/model/predictor.py @@ -168,6 +168,7 @@ def cascading_gate( A: torch.Tensor, k: float = 5.0, hard: bool = False, + heads_per_layer: int = 16, ) -> torch.Tensor: """Apply cascading activation gate: kill outgoing edges from disconnected nodes. @@ -176,6 +177,9 @@ def cascading_gate( 2. Compute gates: g_j = σ(k * inc_j) (soft) or (inc_j > 0) (hard) 3. Apply: A[j, :] *= g_j + Layer 0 nodes are exempted: they have inc=0 structurally (no prior layers) + but receive the embedding as input, so they are NOT disconnected. + Uses ORIGINAL A values for incoming sums (before any gates applied). See CLAUDE.md §2.3 cascading gate section. @@ -183,6 +187,7 @@ def cascading_gate( A: [batch, 256, 256] — gate matrix k: steepness of sigmoid gate (default: 5.0) hard: if True, use binary gates (for eval_hard mode) + heads_per_layer: number of heads per layer (default: 16) Returns: A_gated: [batch, 256, 256] — A with cascading gate applied @@ -195,6 +200,11 @@ def cascading_gate( else: g = torch.sigmoid(k * inc) # [batch, 256] + # Exempt layer 0: always g=1 (they receive embedding, not disconnected) + # Use non-in-place op to preserve autograd graph + exempt = torch.arange(g.shape[1], device=g.device) < heads_per_layer + g = torch.where(exempt.unsqueeze(0), torch.ones_like(g), g) + # Gate outgoing edges: A[j, :] *= g[j] # g: [B, 256] → [B, 256, 1] to broadcast with A: [B, 256, 256] return A * g.unsqueeze(2) diff --git a/tests/test_predictor.py b/tests/test_predictor.py index 5a092d4..104c204 100644 --- a/tests/test_predictor.py +++ b/tests/test_predictor.py @@ -159,15 +159,22 @@ class TestCascadingGate: assert (A_gated >= 0).all() and (A_gated <= 1).all() def test_hard_mode_kills_disconnected(self): - """Nodes with no incoming edges should have all outgoing edges zeroed.""" + """Non-layer-0 nodes with no incoming edges should have outgoing edges zeroed. + + Layer 0 is exempt (receives embedding, not truly disconnected). + """ A = torch.zeros(1, 256, 256) - # Only set edges from node 0 to node 16 (layer 0 → layer 1) - A[0, 0, 16] = 1.0 + # Node 32 (layer 2, head 0) has no incoming edges but has outgoing to node 48 + A[0, 32, 48] = 1.0 A_gated = cascading_gate(A, k=5.0, hard=True) - # Node 0 has no incoming edges → its outgoing should be zeroed - assert A_gated[0, 0, 16] == 0.0, "Node 0 has no incoming but wasn't gated to 0" - # Node 16 has incoming from node 0 (but node 0 was gated to 0) - # In one-pass mode, inc uses ORIGINAL A, so node 16 has inc > 0 + # Node 32 has no incoming → outgoing should be zeroed + assert A_gated[0, 32, 48] == 0.0, "Node 32 has no incoming but wasn't gated to 0" + + # Layer 0 should be exempt: node 0 has no incoming but keeps outgoing + A2 = torch.zeros(1, 256, 256) + A2[0, 0, 16] = 1.0 + A2_gated = cascading_gate(A2, k=5.0, hard=True) + assert A2_gated[0, 0, 16] == 1.0, "Layer 0 should be exempt from cascading gate" def test_hard_mode_preserves_connected(self): """Nodes with incoming edges keep their outgoing edges.""" |
