1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
|
"""
Deep Residual MLP for classification.
Architecture: Input -> Linear embedding -> L residual blocks -> LayerNorm -> Linear output head.
Each block: h_{l+1} = h_l + W2 * GELU(W1 * LN(h_l))
"""
import torch
import torch.nn as nn
class ResidualBlock(nn.Module):
"""Single pre-LayerNorm residual MLP block."""
def __init__(self, d_hidden: int, w2_std: float = 0.01):
super().__init__()
self.ln = nn.LayerNorm(d_hidden)
self.w1 = nn.Linear(d_hidden, d_hidden)
self.w2 = nn.Linear(d_hidden, d_hidden)
# Small init for residual branch (or larger if used as a non-residual stack)
nn.init.normal_(self.w2.weight, std=w2_std)
nn.init.zeros_(self.w2.bias)
def forward(self, h: torch.Tensor) -> torch.Tensor:
"""Returns the residual F_l(h), NOT h + F_l(h)."""
z = self.ln(h)
z = self.w1(z)
z = torch.nn.functional.gelu(z)
z = self.w2(z)
return z
class ResidualMLP(nn.Module):
"""Deep residual MLP: embed -> L blocks -> LN -> output head."""
def __init__(self, input_dim: int, d_hidden: int, num_classes: int, num_blocks: int,
residual_add: bool = True, w2_std: float = 0.01):
super().__init__()
self.embed = nn.Linear(input_dim, d_hidden)
self.blocks = nn.ModuleList([ResidualBlock(d_hidden, w2_std=w2_std) for _ in range(num_blocks)])
self.out_ln = nn.LayerNorm(d_hidden)
self.out_head = nn.Linear(d_hidden, num_classes)
self.num_blocks = num_blocks
self.d_hidden = d_hidden
self.residual_add = residual_add
def forward(self, x: torch.Tensor, return_hidden: bool = False):
"""
Args:
x: (batch, input_dim)
return_hidden: if True, also return list of hidden states [h_0, ..., h_L]
Returns:
logits: (batch, num_classes)
hiddens: list of (batch, d_hidden) if return_hidden
"""
h = self.embed(x)
hiddens = [h] if return_hidden else None
for block in self.blocks:
f = block(h)
h = h + f if self.residual_add else f
if return_hidden:
hiddens.append(h)
logits = self.out_head(self.out_ln(h))
if return_hidden:
return logits, hiddens
return logits
def forward_from_layer(self, h: torch.Tensor, start_layer: int):
"""Run forward from a given layer index to output. Used for perturbation tests."""
for i in range(start_layer, self.num_blocks):
f = self.blocks[i](h)
h = h + f if self.residual_add else f
logits = self.out_head(self.out_ln(h))
return logits
|