summaryrefslogtreecommitdiff
path: root/src/model
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-09 14:40:31 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-09 14:40:31 -0600
commit80579d6cc254d337a23e71404ae7ecab1849d1e5 (patch)
treebc6790229c20af516da662d7a4b7c8c7f1c4cb8c /src/model
parentef678d2e1ba70b1a9dadb78c73ed372f986aea13 (diff)
Fix cascading gate: exempt layer 0 from disconnection checkHEADmain
Layer 0 has no incoming edges structurally (no prior layers), but receives the embedding as input. The cascading gate was killing its outgoing edges (hard: g=0, soft: g=0.5), causing nll_hard to be ~2x worse than baseline. Fix: set g=1 for layer 0 nodes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'src/model')
-rw-r--r--src/model/predictor.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/src/model/predictor.py b/src/model/predictor.py
index b5f9674..9ce5711 100644
--- a/src/model/predictor.py
+++ b/src/model/predictor.py
@@ -168,6 +168,7 @@ def cascading_gate(
A: torch.Tensor,
k: float = 5.0,
hard: bool = False,
+ heads_per_layer: int = 16,
) -> torch.Tensor:
"""Apply cascading activation gate: kill outgoing edges from disconnected nodes.
@@ -176,6 +177,9 @@ def cascading_gate(
2. Compute gates: g_j = σ(k * inc_j) (soft) or (inc_j > 0) (hard)
3. Apply: A[j, :] *= g_j
+ Layer 0 nodes are exempted: they have inc=0 structurally (no prior layers)
+ but receive the embedding as input, so they are NOT disconnected.
+
Uses ORIGINAL A values for incoming sums (before any gates applied).
See CLAUDE.md §2.3 cascading gate section.
@@ -183,6 +187,7 @@ def cascading_gate(
A: [batch, 256, 256] — gate matrix
k: steepness of sigmoid gate (default: 5.0)
hard: if True, use binary gates (for eval_hard mode)
+ heads_per_layer: number of heads per layer (default: 16)
Returns:
A_gated: [batch, 256, 256] — A with cascading gate applied
@@ -195,6 +200,11 @@ def cascading_gate(
else:
g = torch.sigmoid(k * inc) # [batch, 256]
+ # Exempt layer 0: always g=1 (they receive embedding, not disconnected)
+ # Use non-in-place op to preserve autograd graph
+ exempt = torch.arange(g.shape[1], device=g.device) < heads_per_layer
+ g = torch.where(exempt.unsqueeze(0), torch.ones_like(g), g)
+
# Gate outgoing edges: A[j, :] *= g[j]
# g: [B, 256] → [B, 256, 1] to broadcast with A: [B, 256, 256]
return A * g.unsqueeze(2)