diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-08 05:39:39 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-08 05:39:39 -0500 |
| commit | 8dd65b2ec3df32749adabbf62c55101d5b00ae7b (patch) | |
| tree | 3a329bfdf9867ae13889dfcecd65ef216734947b /models | |
| parent | 68cfa13af2f026b7ff388aae4420eba0f0db804a (diff) | |
Round 32+33 H2 ablation: add no_residual_add flag; falsify residual-as-cause hypothesis
- models/residual_mlp.py: add residual_add and w2_std flags (default unchanged)
- experiments/snapshot_evolution_residual_explosion.py: add --no_residual_add and --w2_std CLI flags
- paper/main.tex §3 ¶3: add 1-sentence reference to no-residual control showing Mode 1 still fires
- paper/main.tex Appendix I: full smoke-test table + interpretation
- v2.2 main content stays at 8 pages (within 9-page E&D budget); 13 pages total
Smoke test (3 ep, w2_std=0.5, seed 42):
- DFA no-residual: ||h_L|| 4.69 -> 22050, ||g|| 1.6e-7 (Mode 1 (a) fires; (b) at floor)
- BP no-residual: acc only 0.16 at ep 3 (architecture is partially degenerate)
- Conclusion: residual skip is NOT necessary for Mode 1; the proximate trigger is more general
- Codex round 33 verdict: WALK BACK H2; demote 100ep run to confirmatory
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'models')
| -rw-r--r-- | models/residual_mlp.py | 16 |
1 files changed, 9 insertions, 7 deletions
diff --git a/models/residual_mlp.py b/models/residual_mlp.py index c16778c..6827057 100644 --- a/models/residual_mlp.py +++ b/models/residual_mlp.py @@ -10,13 +10,13 @@ import torch.nn as nn class ResidualBlock(nn.Module): """Single pre-LayerNorm residual MLP block.""" - def __init__(self, d_hidden: int): + def __init__(self, d_hidden: int, w2_std: float = 0.01): super().__init__() self.ln = nn.LayerNorm(d_hidden) self.w1 = nn.Linear(d_hidden, d_hidden) self.w2 = nn.Linear(d_hidden, d_hidden) - # Small init for residual branch - nn.init.normal_(self.w2.weight, std=0.01) + # Small init for residual branch (or larger if used as a non-residual stack) + nn.init.normal_(self.w2.weight, std=w2_std) nn.init.zeros_(self.w2.bias) def forward(self, h: torch.Tensor) -> torch.Tensor: @@ -31,14 +31,16 @@ class ResidualBlock(nn.Module): class ResidualMLP(nn.Module): """Deep residual MLP: embed -> L blocks -> LN -> output head.""" - def __init__(self, input_dim: int, d_hidden: int, num_classes: int, num_blocks: int): + def __init__(self, input_dim: int, d_hidden: int, num_classes: int, num_blocks: int, + residual_add: bool = True, w2_std: float = 0.01): super().__init__() self.embed = nn.Linear(input_dim, d_hidden) - self.blocks = nn.ModuleList([ResidualBlock(d_hidden) for _ in range(num_blocks)]) + self.blocks = nn.ModuleList([ResidualBlock(d_hidden, w2_std=w2_std) for _ in range(num_blocks)]) self.out_ln = nn.LayerNorm(d_hidden) self.out_head = nn.Linear(d_hidden, num_classes) self.num_blocks = num_blocks self.d_hidden = d_hidden + self.residual_add = residual_add def forward(self, x: torch.Tensor, return_hidden: bool = False): """ @@ -54,7 +56,7 @@ class ResidualMLP(nn.Module): for block in self.blocks: f = block(h) - h = h + f + h = h + f if self.residual_add else f if return_hidden: hiddens.append(h) @@ -68,6 +70,6 @@ class ResidualMLP(nn.Module): """Run forward from a given layer index to output. Used for perturbation tests.""" for i in range(start_layer, self.num_blocks): f = self.blocks[i](h) - h = h + f + h = h + f if self.residual_add else f logits = self.out_head(self.out_ln(h)) return logits |
