diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-03-23 18:21:26 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-03-23 18:21:26 -0500 |
| commit | 6ed4fa50ddfa4c7957aaa909aaf72f0d7d317712 (patch) | |
| tree | d7c63adcd19c4f5d46c8a937e5047fece55dea62 /models/residual_mlp.py | |
Initial implementation: all models, methods, toy and CIFAR experiments
Debug phase. Toy LQ experiments (3 seeds) complete with terminal gradient matching.
Credit bridge matches state bridge on linear system (~0.94 cosine).
CIFAR experiments in progress.
Diffstat (limited to 'models/residual_mlp.py')
| -rw-r--r-- | models/residual_mlp.py | 73 |
1 files changed, 73 insertions, 0 deletions
diff --git a/models/residual_mlp.py b/models/residual_mlp.py new file mode 100644 index 0000000..c16778c --- /dev/null +++ b/models/residual_mlp.py @@ -0,0 +1,73 @@ +""" +Deep Residual MLP for classification. +Architecture: Input -> Linear embedding -> L residual blocks -> LayerNorm -> Linear output head. +Each block: h_{l+1} = h_l + W2 * GELU(W1 * LN(h_l)) +""" +import torch +import torch.nn as nn + + +class ResidualBlock(nn.Module): + """Single pre-LayerNorm residual MLP block.""" + + def __init__(self, d_hidden: int): + super().__init__() + self.ln = nn.LayerNorm(d_hidden) + self.w1 = nn.Linear(d_hidden, d_hidden) + self.w2 = nn.Linear(d_hidden, d_hidden) + # Small init for residual branch + nn.init.normal_(self.w2.weight, std=0.01) + nn.init.zeros_(self.w2.bias) + + def forward(self, h: torch.Tensor) -> torch.Tensor: + """Returns the residual F_l(h), NOT h + F_l(h).""" + z = self.ln(h) + z = self.w1(z) + z = torch.nn.functional.gelu(z) + z = self.w2(z) + return z + + +class ResidualMLP(nn.Module): + """Deep residual MLP: embed -> L blocks -> LN -> output head.""" + + def __init__(self, input_dim: int, d_hidden: int, num_classes: int, num_blocks: int): + super().__init__() + self.embed = nn.Linear(input_dim, d_hidden) + self.blocks = nn.ModuleList([ResidualBlock(d_hidden) for _ in range(num_blocks)]) + self.out_ln = nn.LayerNorm(d_hidden) + self.out_head = nn.Linear(d_hidden, num_classes) + self.num_blocks = num_blocks + self.d_hidden = d_hidden + + def forward(self, x: torch.Tensor, return_hidden: bool = False): + """ + Args: + x: (batch, input_dim) + return_hidden: if True, also return list of hidden states [h_0, ..., h_L] + Returns: + logits: (batch, num_classes) + hiddens: list of (batch, d_hidden) if return_hidden + """ + h = self.embed(x) + hiddens = [h] if return_hidden else None + + for block in self.blocks: + f = block(h) + h = h + f + if return_hidden: + hiddens.append(h) + + logits = self.out_head(self.out_ln(h)) + + if return_hidden: + return logits, hiddens + return logits + + def forward_from_layer(self, h: torch.Tensor, start_layer: int): + """Run forward from a given layer index to output. Used for perturbation tests.""" + for i in range(start_layer, self.num_blocks): + f = self.blocks[i](h) + h = h + f + logits = self.out_head(self.out_ln(h)) + return logits |
