From b480d0cdc21f944e4adccf6e81cc939b0450c5e9 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Mon, 4 May 2026 19:50:45 -0500 Subject: Initial submission code: FA evaluation protocol + reproduction scripts Reference implementation of the three-diagnostic FA evaluation protocol (scale stability, reference validity, depth utility) from the NeurIPS 2026 E&D track paper. Includes models, metrics, and full reproduction pipeline. Co-Authored-By: Claude Opus 4.6 (1M context) --- models/residual_mlp.py | 75 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 models/residual_mlp.py (limited to 'models/residual_mlp.py') diff --git a/models/residual_mlp.py b/models/residual_mlp.py new file mode 100644 index 0000000..6827057 --- /dev/null +++ b/models/residual_mlp.py @@ -0,0 +1,75 @@ +""" +Deep Residual MLP for classification. +Architecture: Input -> Linear embedding -> L residual blocks -> LayerNorm -> Linear output head. +Each block: h_{l+1} = h_l + W2 * GELU(W1 * LN(h_l)) +""" +import torch +import torch.nn as nn + + +class ResidualBlock(nn.Module): + """Single pre-LayerNorm residual MLP block.""" + + def __init__(self, d_hidden: int, w2_std: float = 0.01): + super().__init__() + self.ln = nn.LayerNorm(d_hidden) + self.w1 = nn.Linear(d_hidden, d_hidden) + self.w2 = nn.Linear(d_hidden, d_hidden) + # Small init for residual branch (or larger if used as a non-residual stack) + nn.init.normal_(self.w2.weight, std=w2_std) + nn.init.zeros_(self.w2.bias) + + def forward(self, h: torch.Tensor) -> torch.Tensor: + """Returns the residual F_l(h), NOT h + F_l(h).""" + z = self.ln(h) + z = self.w1(z) + z = torch.nn.functional.gelu(z) + z = self.w2(z) + return z + + +class ResidualMLP(nn.Module): + """Deep residual MLP: embed -> L blocks -> LN -> output head.""" + + def __init__(self, input_dim: int, d_hidden: int, num_classes: int, num_blocks: int, + residual_add: bool = True, w2_std: float = 0.01): + super().__init__() + self.embed = nn.Linear(input_dim, d_hidden) + self.blocks = nn.ModuleList([ResidualBlock(d_hidden, w2_std=w2_std) for _ in range(num_blocks)]) + self.out_ln = nn.LayerNorm(d_hidden) + self.out_head = nn.Linear(d_hidden, num_classes) + self.num_blocks = num_blocks + self.d_hidden = d_hidden + self.residual_add = residual_add + + def forward(self, x: torch.Tensor, return_hidden: bool = False): + """ + Args: + x: (batch, input_dim) + return_hidden: if True, also return list of hidden states [h_0, ..., h_L] + Returns: + logits: (batch, num_classes) + hiddens: list of (batch, d_hidden) if return_hidden + """ + h = self.embed(x) + hiddens = [h] if return_hidden else None + + for block in self.blocks: + f = block(h) + h = h + f if self.residual_add else f + if return_hidden: + hiddens.append(h) + + logits = self.out_head(self.out_ln(h)) + + if return_hidden: + return logits, hiddens + return logits + + def forward_from_layer(self, h: torch.Tensor, start_layer: int): + """Run forward from a given layer index to output. Used for perturbation tests.""" + for i in range(start_layer, self.num_blocks): + f = self.blocks[i](h) + h = h + f if self.residual_add else f + logits = self.out_head(self.out_ln(h)) + return logits -- cgit v1.2.3