summaryrefslogtreecommitdiff
path: root/models/residual_mlp.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-05-04 19:50:45 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-05-04 19:50:45 -0500
commitb480d0cdc21f944e4adccf6e81cc939b0450c5e9 (patch)
treef0e6afb5b3d448d1d6c35d9622d22d63073ca9a7 /models/residual_mlp.py
Initial submission code: FA evaluation protocol + reproduction scripts
Reference implementation of the three-diagnostic FA evaluation protocol (scale stability, reference validity, depth utility) from the NeurIPS 2026 E&D track paper. Includes models, metrics, and full reproduction pipeline. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'models/residual_mlp.py')
-rw-r--r--models/residual_mlp.py75
1 files changed, 75 insertions, 0 deletions
diff --git a/models/residual_mlp.py b/models/residual_mlp.py
new file mode 100644
index 0000000..6827057
--- /dev/null
+++ b/models/residual_mlp.py
@@ -0,0 +1,75 @@
+"""
+Deep Residual MLP for classification.
+Architecture: Input -> Linear embedding -> L residual blocks -> LayerNorm -> Linear output head.
+Each block: h_{l+1} = h_l + W2 * GELU(W1 * LN(h_l))
+"""
+import torch
+import torch.nn as nn
+
+
+class ResidualBlock(nn.Module):
+ """Single pre-LayerNorm residual MLP block."""
+
+ def __init__(self, d_hidden: int, w2_std: float = 0.01):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.w1 = nn.Linear(d_hidden, d_hidden)
+ self.w2 = nn.Linear(d_hidden, d_hidden)
+ # Small init for residual branch (or larger if used as a non-residual stack)
+ nn.init.normal_(self.w2.weight, std=w2_std)
+ nn.init.zeros_(self.w2.bias)
+
+ def forward(self, h: torch.Tensor) -> torch.Tensor:
+ """Returns the residual F_l(h), NOT h + F_l(h)."""
+ z = self.ln(h)
+ z = self.w1(z)
+ z = torch.nn.functional.gelu(z)
+ z = self.w2(z)
+ return z
+
+
+class ResidualMLP(nn.Module):
+ """Deep residual MLP: embed -> L blocks -> LN -> output head."""
+
+ def __init__(self, input_dim: int, d_hidden: int, num_classes: int, num_blocks: int,
+ residual_add: bool = True, w2_std: float = 0.01):
+ super().__init__()
+ self.embed = nn.Linear(input_dim, d_hidden)
+ self.blocks = nn.ModuleList([ResidualBlock(d_hidden, w2_std=w2_std) for _ in range(num_blocks)])
+ self.out_ln = nn.LayerNorm(d_hidden)
+ self.out_head = nn.Linear(d_hidden, num_classes)
+ self.num_blocks = num_blocks
+ self.d_hidden = d_hidden
+ self.residual_add = residual_add
+
+ def forward(self, x: torch.Tensor, return_hidden: bool = False):
+ """
+ Args:
+ x: (batch, input_dim)
+ return_hidden: if True, also return list of hidden states [h_0, ..., h_L]
+ Returns:
+ logits: (batch, num_classes)
+ hiddens: list of (batch, d_hidden) if return_hidden
+ """
+ h = self.embed(x)
+ hiddens = [h] if return_hidden else None
+
+ for block in self.blocks:
+ f = block(h)
+ h = h + f if self.residual_add else f
+ if return_hidden:
+ hiddens.append(h)
+
+ logits = self.out_head(self.out_ln(h))
+
+ if return_hidden:
+ return logits, hiddens
+ return logits
+
+ def forward_from_layer(self, h: torch.Tensor, start_layer: int):
+ """Run forward from a given layer index to output. Used for perturbation tests."""
+ for i in range(start_layer, self.num_blocks):
+ f = self.blocks[i](h)
+ h = h + f if self.residual_add else f
+ logits = self.out_head(self.out_ln(h))
+ return logits