""" Deep Residual MLP for classification. Architecture: Input -> Linear embedding -> L residual blocks -> LayerNorm -> Linear output head. Each block: h_{l+1} = h_l + W2 * GELU(W1 * LN(h_l)) """ import torch import torch.nn as nn class ResidualBlock(nn.Module): """Single pre-LayerNorm residual MLP block.""" def __init__(self, d_hidden: int, w2_std: float = 0.01): super().__init__() self.ln = nn.LayerNorm(d_hidden) self.w1 = nn.Linear(d_hidden, d_hidden) self.w2 = nn.Linear(d_hidden, d_hidden) # Small init for residual branch (or larger if used as a non-residual stack) nn.init.normal_(self.w2.weight, std=w2_std) nn.init.zeros_(self.w2.bias) def forward(self, h: torch.Tensor) -> torch.Tensor: """Returns the residual F_l(h), NOT h + F_l(h).""" z = self.ln(h) z = self.w1(z) z = torch.nn.functional.gelu(z) z = self.w2(z) return z class ResidualMLP(nn.Module): """Deep residual MLP: embed -> L blocks -> LN -> output head.""" def __init__(self, input_dim: int, d_hidden: int, num_classes: int, num_blocks: int, residual_add: bool = True, w2_std: float = 0.01): super().__init__() self.embed = nn.Linear(input_dim, d_hidden) self.blocks = nn.ModuleList([ResidualBlock(d_hidden, w2_std=w2_std) for _ in range(num_blocks)]) self.out_ln = nn.LayerNorm(d_hidden) self.out_head = nn.Linear(d_hidden, num_classes) self.num_blocks = num_blocks self.d_hidden = d_hidden self.residual_add = residual_add def forward(self, x: torch.Tensor, return_hidden: bool = False): """ Args: x: (batch, input_dim) return_hidden: if True, also return list of hidden states [h_0, ..., h_L] Returns: logits: (batch, num_classes) hiddens: list of (batch, d_hidden) if return_hidden """ h = self.embed(x) hiddens = [h] if return_hidden else None for block in self.blocks: f = block(h) h = h + f if self.residual_add else f if return_hidden: hiddens.append(h) logits = self.out_head(self.out_ln(h)) if return_hidden: return logits, hiddens return logits def forward_from_layer(self, h: torch.Tensor, start_layer: int): """Run forward from a given layer index to output. Used for perturbation tests.""" for i in range(start_layer, self.num_blocks): f = self.blocks[i](h) h = h + f if self.residual_add else f logits = self.out_head(self.out_ln(h)) return logits