diff options
Diffstat (limited to 'models/residual_mlp.py')
| -rw-r--r-- | models/residual_mlp.py | 16 |
1 files changed, 9 insertions, 7 deletions
diff --git a/models/residual_mlp.py b/models/residual_mlp.py index c16778c..6827057 100644 --- a/models/residual_mlp.py +++ b/models/residual_mlp.py @@ -10,13 +10,13 @@ import torch.nn as nn class ResidualBlock(nn.Module): """Single pre-LayerNorm residual MLP block.""" - def __init__(self, d_hidden: int): + def __init__(self, d_hidden: int, w2_std: float = 0.01): super().__init__() self.ln = nn.LayerNorm(d_hidden) self.w1 = nn.Linear(d_hidden, d_hidden) self.w2 = nn.Linear(d_hidden, d_hidden) - # Small init for residual branch - nn.init.normal_(self.w2.weight, std=0.01) + # Small init for residual branch (or larger if used as a non-residual stack) + nn.init.normal_(self.w2.weight, std=w2_std) nn.init.zeros_(self.w2.bias) def forward(self, h: torch.Tensor) -> torch.Tensor: @@ -31,14 +31,16 @@ class ResidualBlock(nn.Module): class ResidualMLP(nn.Module): """Deep residual MLP: embed -> L blocks -> LN -> output head.""" - def __init__(self, input_dim: int, d_hidden: int, num_classes: int, num_blocks: int): + def __init__(self, input_dim: int, d_hidden: int, num_classes: int, num_blocks: int, + residual_add: bool = True, w2_std: float = 0.01): super().__init__() self.embed = nn.Linear(input_dim, d_hidden) - self.blocks = nn.ModuleList([ResidualBlock(d_hidden) for _ in range(num_blocks)]) + self.blocks = nn.ModuleList([ResidualBlock(d_hidden, w2_std=w2_std) for _ in range(num_blocks)]) self.out_ln = nn.LayerNorm(d_hidden) self.out_head = nn.Linear(d_hidden, num_classes) self.num_blocks = num_blocks self.d_hidden = d_hidden + self.residual_add = residual_add def forward(self, x: torch.Tensor, return_hidden: bool = False): """ @@ -54,7 +56,7 @@ class ResidualMLP(nn.Module): for block in self.blocks: f = block(h) - h = h + f + h = h + f if self.residual_add else f if return_hidden: hiddens.append(h) @@ -68,6 +70,6 @@ class ResidualMLP(nn.Module): """Run forward from a given layer index to output. Used for perturbation tests.""" for i in range(start_layer, self.num_blocks): f = self.blocks[i](h) - h = h + f + h = h + f if self.residual_add else f logits = self.out_head(self.out_ln(h)) return logits |
