diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-05-04 19:50:45 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-05-04 19:50:45 -0500 |
| commit | b480d0cdc21f944e4adccf6e81cc939b0450c5e9 (patch) | |
| tree | f0e6afb5b3d448d1d6c35d9622d22d63073ca9a7 /models/residual_mlp.py | |
Initial submission code: FA evaluation protocol + reproduction scripts
Reference implementation of the three-diagnostic FA evaluation protocol
(scale stability, reference validity, depth utility) from the NeurIPS 2026
E&D track paper. Includes models, metrics, and full reproduction pipeline.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'models/residual_mlp.py')
| -rw-r--r-- | models/residual_mlp.py | 75 |
1 files changed, 75 insertions, 0 deletions
diff --git a/models/residual_mlp.py b/models/residual_mlp.py new file mode 100644 index 0000000..6827057 --- /dev/null +++ b/models/residual_mlp.py @@ -0,0 +1,75 @@ +""" +Deep Residual MLP for classification. +Architecture: Input -> Linear embedding -> L residual blocks -> LayerNorm -> Linear output head. +Each block: h_{l+1} = h_l + W2 * GELU(W1 * LN(h_l)) +""" +import torch +import torch.nn as nn + + +class ResidualBlock(nn.Module): + """Single pre-LayerNorm residual MLP block.""" + + def __init__(self, d_hidden: int, w2_std: float = 0.01): + super().__init__() + self.ln = nn.LayerNorm(d_hidden) + self.w1 = nn.Linear(d_hidden, d_hidden) + self.w2 = nn.Linear(d_hidden, d_hidden) + # Small init for residual branch (or larger if used as a non-residual stack) + nn.init.normal_(self.w2.weight, std=w2_std) + nn.init.zeros_(self.w2.bias) + + def forward(self, h: torch.Tensor) -> torch.Tensor: + """Returns the residual F_l(h), NOT h + F_l(h).""" + z = self.ln(h) + z = self.w1(z) + z = torch.nn.functional.gelu(z) + z = self.w2(z) + return z + + +class ResidualMLP(nn.Module): + """Deep residual MLP: embed -> L blocks -> LN -> output head.""" + + def __init__(self, input_dim: int, d_hidden: int, num_classes: int, num_blocks: int, + residual_add: bool = True, w2_std: float = 0.01): + super().__init__() + self.embed = nn.Linear(input_dim, d_hidden) + self.blocks = nn.ModuleList([ResidualBlock(d_hidden, w2_std=w2_std) for _ in range(num_blocks)]) + self.out_ln = nn.LayerNorm(d_hidden) + self.out_head = nn.Linear(d_hidden, num_classes) + self.num_blocks = num_blocks + self.d_hidden = d_hidden + self.residual_add = residual_add + + def forward(self, x: torch.Tensor, return_hidden: bool = False): + """ + Args: + x: (batch, input_dim) + return_hidden: if True, also return list of hidden states [h_0, ..., h_L] + Returns: + logits: (batch, num_classes) + hiddens: list of (batch, d_hidden) if return_hidden + """ + h = self.embed(x) + hiddens = [h] if return_hidden else None + + for block in self.blocks: + f = block(h) + h = h + f if self.residual_add else f + if return_hidden: + hiddens.append(h) + + logits = self.out_head(self.out_ln(h)) + + if return_hidden: + return logits, hiddens + return logits + + def forward_from_layer(self, h: torch.Tensor, start_layer: int): + """Run forward from a given layer index to output. Used for perturbation tests.""" + for i in range(start_layer, self.num_blocks): + f = self.blocks[i](h) + h = h + f if self.residual_add else f + logits = self.out_head(self.out_ln(h)) + return logits |
