summaryrefslogtreecommitdiff
path: root/models/residual_mlp.py
blob: 6827057ac0cfbb2ee9406f8b1844af12095618d1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""
Deep Residual MLP for classification.
Architecture: Input -> Linear embedding -> L residual blocks -> LayerNorm -> Linear output head.
Each block: h_{l+1} = h_l + W2 * GELU(W1 * LN(h_l))
"""
import torch
import torch.nn as nn


class ResidualBlock(nn.Module):
    """Single pre-LayerNorm residual MLP block."""

    def __init__(self, d_hidden: int, w2_std: float = 0.01):
        super().__init__()
        self.ln = nn.LayerNorm(d_hidden)
        self.w1 = nn.Linear(d_hidden, d_hidden)
        self.w2 = nn.Linear(d_hidden, d_hidden)
        # Small init for residual branch (or larger if used as a non-residual stack)
        nn.init.normal_(self.w2.weight, std=w2_std)
        nn.init.zeros_(self.w2.bias)

    def forward(self, h: torch.Tensor) -> torch.Tensor:
        """Returns the residual F_l(h), NOT h + F_l(h)."""
        z = self.ln(h)
        z = self.w1(z)
        z = torch.nn.functional.gelu(z)
        z = self.w2(z)
        return z


class ResidualMLP(nn.Module):
    """Deep residual MLP: embed -> L blocks -> LN -> output head."""

    def __init__(self, input_dim: int, d_hidden: int, num_classes: int, num_blocks: int,
                 residual_add: bool = True, w2_std: float = 0.01):
        super().__init__()
        self.embed = nn.Linear(input_dim, d_hidden)
        self.blocks = nn.ModuleList([ResidualBlock(d_hidden, w2_std=w2_std) for _ in range(num_blocks)])
        self.out_ln = nn.LayerNorm(d_hidden)
        self.out_head = nn.Linear(d_hidden, num_classes)
        self.num_blocks = num_blocks
        self.d_hidden = d_hidden
        self.residual_add = residual_add

    def forward(self, x: torch.Tensor, return_hidden: bool = False):
        """
        Args:
            x: (batch, input_dim)
            return_hidden: if True, also return list of hidden states [h_0, ..., h_L]
        Returns:
            logits: (batch, num_classes)
            hiddens: list of (batch, d_hidden) if return_hidden
        """
        h = self.embed(x)
        hiddens = [h] if return_hidden else None

        for block in self.blocks:
            f = block(h)
            h = h + f if self.residual_add else f
            if return_hidden:
                hiddens.append(h)

        logits = self.out_head(self.out_ln(h))

        if return_hidden:
            return logits, hiddens
        return logits

    def forward_from_layer(self, h: torch.Tensor, start_layer: int):
        """Run forward from a given layer index to output. Used for perturbation tests."""
        for i in range(start_layer, self.num_blocks):
            f = self.blocks[i](h)
            h = h + f if self.residual_add else f
        logits = self.out_head(self.out_ln(h))
        return logits