summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-08 05:39:39 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-08 05:39:39 -0500
commit8dd65b2ec3df32749adabbf62c55101d5b00ae7b (patch)
tree3a329bfdf9867ae13889dfcecd65ef216734947b /models
parent68cfa13af2f026b7ff388aae4420eba0f0db804a (diff)
Round 32+33 H2 ablation: add no_residual_add flag; falsify residual-as-cause hypothesis
- models/residual_mlp.py: add residual_add and w2_std flags (default unchanged) - experiments/snapshot_evolution_residual_explosion.py: add --no_residual_add and --w2_std CLI flags - paper/main.tex §3 ¶3: add 1-sentence reference to no-residual control showing Mode 1 still fires - paper/main.tex Appendix I: full smoke-test table + interpretation - v2.2 main content stays at 8 pages (within 9-page E&D budget); 13 pages total Smoke test (3 ep, w2_std=0.5, seed 42): - DFA no-residual: ||h_L|| 4.69 -> 22050, ||g|| 1.6e-7 (Mode 1 (a) fires; (b) at floor) - BP no-residual: acc only 0.16 at ep 3 (architecture is partially degenerate) - Conclusion: residual skip is NOT necessary for Mode 1; the proximate trigger is more general - Codex round 33 verdict: WALK BACK H2; demote 100ep run to confirmatory Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'models')
-rw-r--r--models/residual_mlp.py16
1 files changed, 9 insertions, 7 deletions
diff --git a/models/residual_mlp.py b/models/residual_mlp.py
index c16778c..6827057 100644
--- a/models/residual_mlp.py
+++ b/models/residual_mlp.py
@@ -10,13 +10,13 @@ import torch.nn as nn
class ResidualBlock(nn.Module):
"""Single pre-LayerNorm residual MLP block."""
- def __init__(self, d_hidden: int):
+ def __init__(self, d_hidden: int, w2_std: float = 0.01):
super().__init__()
self.ln = nn.LayerNorm(d_hidden)
self.w1 = nn.Linear(d_hidden, d_hidden)
self.w2 = nn.Linear(d_hidden, d_hidden)
- # Small init for residual branch
- nn.init.normal_(self.w2.weight, std=0.01)
+ # Small init for residual branch (or larger if used as a non-residual stack)
+ nn.init.normal_(self.w2.weight, std=w2_std)
nn.init.zeros_(self.w2.bias)
def forward(self, h: torch.Tensor) -> torch.Tensor:
@@ -31,14 +31,16 @@ class ResidualBlock(nn.Module):
class ResidualMLP(nn.Module):
"""Deep residual MLP: embed -> L blocks -> LN -> output head."""
- def __init__(self, input_dim: int, d_hidden: int, num_classes: int, num_blocks: int):
+ def __init__(self, input_dim: int, d_hidden: int, num_classes: int, num_blocks: int,
+ residual_add: bool = True, w2_std: float = 0.01):
super().__init__()
self.embed = nn.Linear(input_dim, d_hidden)
- self.blocks = nn.ModuleList([ResidualBlock(d_hidden) for _ in range(num_blocks)])
+ self.blocks = nn.ModuleList([ResidualBlock(d_hidden, w2_std=w2_std) for _ in range(num_blocks)])
self.out_ln = nn.LayerNorm(d_hidden)
self.out_head = nn.Linear(d_hidden, num_classes)
self.num_blocks = num_blocks
self.d_hidden = d_hidden
+ self.residual_add = residual_add
def forward(self, x: torch.Tensor, return_hidden: bool = False):
"""
@@ -54,7 +56,7 @@ class ResidualMLP(nn.Module):
for block in self.blocks:
f = block(h)
- h = h + f
+ h = h + f if self.residual_add else f
if return_hidden:
hiddens.append(h)
@@ -68,6 +70,6 @@ class ResidualMLP(nn.Module):
"""Run forward from a given layer index to output. Used for perturbation tests."""
for i in range(start_layer, self.num_blocks):
f = self.blocks[i](h)
- h = h + f
+ h = h + f if self.residual_add else f
logits = self.out_head(self.out_ln(h))
return logits