From 8dd65b2ec3df32749adabbf62c55101d5b00ae7b Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Wed, 8 Apr 2026 05:39:39 -0500 Subject: Round 32+33 H2 ablation: add no_residual_add flag; falsify residual-as-cause hypothesis MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - models/residual_mlp.py: add residual_add and w2_std flags (default unchanged) - experiments/snapshot_evolution_residual_explosion.py: add --no_residual_add and --w2_std CLI flags - paper/main.tex §3 ¶3: add 1-sentence reference to no-residual control showing Mode 1 still fires - paper/main.tex Appendix I: full smoke-test table + interpretation - v2.2 main content stays at 8 pages (within 9-page E&D budget); 13 pages total Smoke test (3 ep, w2_std=0.5, seed 42): - DFA no-residual: ||h_L|| 4.69 -> 22050, ||g|| 1.6e-7 (Mode 1 (a) fires; (b) at floor) - BP no-residual: acc only 0.16 at ep 3 (architecture is partially degenerate) - Conclusion: residual skip is NOT necessary for Mode 1; the proximate trigger is more general - Codex round 33 verdict: WALK BACK H2; demote 100ep run to confirmatory Co-Authored-By: Claude Opus 4.6 (1M context) --- models/residual_mlp.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) (limited to 'models') diff --git a/models/residual_mlp.py b/models/residual_mlp.py index c16778c..6827057 100644 --- a/models/residual_mlp.py +++ b/models/residual_mlp.py @@ -10,13 +10,13 @@ import torch.nn as nn class ResidualBlock(nn.Module): """Single pre-LayerNorm residual MLP block.""" - def __init__(self, d_hidden: int): + def __init__(self, d_hidden: int, w2_std: float = 0.01): super().__init__() self.ln = nn.LayerNorm(d_hidden) self.w1 = nn.Linear(d_hidden, d_hidden) self.w2 = nn.Linear(d_hidden, d_hidden) - # Small init for residual branch - nn.init.normal_(self.w2.weight, std=0.01) + # Small init for residual branch (or larger if used as a non-residual stack) + nn.init.normal_(self.w2.weight, std=w2_std) nn.init.zeros_(self.w2.bias) def forward(self, h: torch.Tensor) -> torch.Tensor: @@ -31,14 +31,16 @@ class ResidualBlock(nn.Module): class ResidualMLP(nn.Module): """Deep residual MLP: embed -> L blocks -> LN -> output head.""" - def __init__(self, input_dim: int, d_hidden: int, num_classes: int, num_blocks: int): + def __init__(self, input_dim: int, d_hidden: int, num_classes: int, num_blocks: int, + residual_add: bool = True, w2_std: float = 0.01): super().__init__() self.embed = nn.Linear(input_dim, d_hidden) - self.blocks = nn.ModuleList([ResidualBlock(d_hidden) for _ in range(num_blocks)]) + self.blocks = nn.ModuleList([ResidualBlock(d_hidden, w2_std=w2_std) for _ in range(num_blocks)]) self.out_ln = nn.LayerNorm(d_hidden) self.out_head = nn.Linear(d_hidden, num_classes) self.num_blocks = num_blocks self.d_hidden = d_hidden + self.residual_add = residual_add def forward(self, x: torch.Tensor, return_hidden: bool = False): """ @@ -54,7 +56,7 @@ class ResidualMLP(nn.Module): for block in self.blocks: f = block(h) - h = h + f + h = h + f if self.residual_add else f if return_hidden: hiddens.append(h) @@ -68,6 +70,6 @@ class ResidualMLP(nn.Module): """Run forward from a given layer index to output. Used for perturbation tests.""" for i in range(start_layer, self.num_blocks): f = self.blocks[i](h) - h = h + f + h = h + f if self.residual_add else f logits = self.out_head(self.out_ln(h)) return logits -- cgit v1.2.3