diff options
Diffstat (limited to 'models')
| -rw-r--r-- | models/__init__.py | 0 | ||||
| -rw-r--r-- | models/residual_mlp.py | 75 | ||||
| -rw-r--r-- | models/small_resnet.py | 74 | ||||
| -rw-r--r-- | models/vit_mini.py | 109 |
4 files changed, 258 insertions, 0 deletions
diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/models/__init__.py diff --git a/models/residual_mlp.py b/models/residual_mlp.py new file mode 100644 index 0000000..6827057 --- /dev/null +++ b/models/residual_mlp.py @@ -0,0 +1,75 @@ +""" +Deep Residual MLP for classification. +Architecture: Input -> Linear embedding -> L residual blocks -> LayerNorm -> Linear output head. +Each block: h_{l+1} = h_l + W2 * GELU(W1 * LN(h_l)) +""" +import torch +import torch.nn as nn + + +class ResidualBlock(nn.Module): + """Single pre-LayerNorm residual MLP block.""" + + def __init__(self, d_hidden: int, w2_std: float = 0.01): + super().__init__() + self.ln = nn.LayerNorm(d_hidden) + self.w1 = nn.Linear(d_hidden, d_hidden) + self.w2 = nn.Linear(d_hidden, d_hidden) + # Small init for residual branch (or larger if used as a non-residual stack) + nn.init.normal_(self.w2.weight, std=w2_std) + nn.init.zeros_(self.w2.bias) + + def forward(self, h: torch.Tensor) -> torch.Tensor: + """Returns the residual F_l(h), NOT h + F_l(h).""" + z = self.ln(h) + z = self.w1(z) + z = torch.nn.functional.gelu(z) + z = self.w2(z) + return z + + +class ResidualMLP(nn.Module): + """Deep residual MLP: embed -> L blocks -> LN -> output head.""" + + def __init__(self, input_dim: int, d_hidden: int, num_classes: int, num_blocks: int, + residual_add: bool = True, w2_std: float = 0.01): + super().__init__() + self.embed = nn.Linear(input_dim, d_hidden) + self.blocks = nn.ModuleList([ResidualBlock(d_hidden, w2_std=w2_std) for _ in range(num_blocks)]) + self.out_ln = nn.LayerNorm(d_hidden) + self.out_head = nn.Linear(d_hidden, num_classes) + self.num_blocks = num_blocks + self.d_hidden = d_hidden + self.residual_add = residual_add + + def forward(self, x: torch.Tensor, return_hidden: bool = False): + """ + Args: + x: (batch, input_dim) + return_hidden: if True, also return list of hidden states [h_0, ..., h_L] + Returns: + logits: (batch, num_classes) + hiddens: list of (batch, d_hidden) if return_hidden + """ + h = self.embed(x) + hiddens = [h] if return_hidden else None + + for block in self.blocks: + f = block(h) + h = h + f if self.residual_add else f + if return_hidden: + hiddens.append(h) + + logits = self.out_head(self.out_ln(h)) + + if return_hidden: + return logits, hiddens + return logits + + def forward_from_layer(self, h: torch.Tensor, start_layer: int): + """Run forward from a given layer index to output. Used for perturbation tests.""" + for i in range(start_layer, self.num_blocks): + f = self.blocks[i](h) + h = h + f if self.residual_add else f + logits = self.out_head(self.out_ln(h)) + return logits diff --git a/models/small_resnet.py b/models/small_resnet.py new file mode 100644 index 0000000..10b122e --- /dev/null +++ b/models/small_resnet.py @@ -0,0 +1,74 @@ +""" +Small CIFAR-10 ResNet for the FA-evaluation paper. Standard BatchNorm-based +post-activation residual blocks (no LayerNorm). 4 residual blocks at width 64. + +Supports `num_blocks=0` (shallow baseline: just embed → bn → head) and frozen +blocks via `requires_grad=False` on `.blocks` parameters. +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BasicBlock(nn.Module): + """Standard ResNet BasicBlock with BatchNorm. Pre-activation NOT used; this is + the post-activation form: relu(BN(W2 * relu(BN(W1 x)))) + x. d_hidden in == d_hidden out. + No stride / downsampling — all blocks operate at the same spatial resolution + after the initial stem. This keeps the architecture simple and matches the + "4 residual blocks at fixed width" structure of our ResMLP and ViT-Mini comparisons. + """ + def __init__(self, d_hidden): + super().__init__() + self.conv1 = nn.Conv2d(d_hidden, d_hidden, kernel_size=3, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(d_hidden) + self.conv2 = nn.Conv2d(d_hidden, d_hidden, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(d_hidden) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out = F.relu(x + out) + return out + + +class SmallResNet(nn.Module): + """Small CIFAR-10 ResNet: + - 3x3 conv stem (3 → d_hidden) + BN + ReLU + - num_blocks BasicBlocks at the same width and resolution + - global average pool + - linear classification head + + `num_blocks=0` gives the shallow baseline (just stem → pool → head). + """ + def __init__(self, d_hidden=64, num_classes=10, num_blocks=4): + super().__init__() + self.stem_conv = nn.Conv2d(3, d_hidden, kernel_size=3, padding=1, bias=False) + self.stem_bn = nn.BatchNorm2d(d_hidden) + self.blocks = nn.ModuleList([BasicBlock(d_hidden) for _ in range(num_blocks)]) + self.out_head = nn.Linear(d_hidden, num_classes) + self.num_blocks = num_blocks + self.d_hidden = d_hidden + + def stem(self, x): + # x: (B, 3, 32, 32) + if x.dim() == 2: + x = x.view(x.size(0), 3, 32, 32) + h = F.relu(self.stem_bn(self.stem_conv(x))) + return h + + def forward(self, x, return_hidden=False): + h = self.stem(x) # (B, d, 32, 32) + hiddens = [h] if return_hidden else None + for block in self.blocks: + h = block(h) + if return_hidden: + hiddens.append(h) + h_pool = F.adaptive_avg_pool2d(h, 1).flatten(1) # (B, d) + logits = self.out_head(h_pool) + if return_hidden: + return logits, hiddens + return logits + + # Convenience alias for snapshot script compatibility (treats stem as the embed) + def embed(self, x): + return self.stem(x) diff --git a/models/vit_mini.py b/models/vit_mini.py new file mode 100644 index 0000000..af6ba60 --- /dev/null +++ b/models/vit_mini.py @@ -0,0 +1,109 @@ +""" +Minimal Vision Transformer for CIFAR-10. Pre-LN with terminal LayerNorm before +the classification head — the architecture P4 should target. + +Designed to be compatible with the snapshot evolution / DFA training framework. +Each TransformerBlock is a "layer" for FA-style local credit purposes. +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class TransformerBlock(nn.Module): + """Pre-LN transformer block: x = x + attn(LN(x)); x = x + mlp(LN(x)).""" + def __init__(self, d_model: int, n_heads: int, mlp_ratio: float = 4.0, dropout: float = 0.0): + super().__init__() + self.ln1 = nn.LayerNorm(d_model) + self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True) + self.ln2 = nn.LayerNorm(d_model) + mlp_hidden = int(d_model * mlp_ratio) + self.mlp = nn.Sequential( + nn.Linear(d_model, mlp_hidden), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(mlp_hidden, d_model), + nn.Dropout(dropout), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Self-attention sublayer + x_norm = self.ln1(x) + attn_out, _ = self.attn(x_norm, x_norm, x_norm, need_weights=False) + x = x + attn_out + # MLP sublayer + x = x + self.mlp(self.ln2(x)) + return x + + +class ViTMini(nn.Module): + """Minimal Vision Transformer for CIFAR-10. + Patch size 4x4 → 64 patches per image. Plus a learned cls token. + Pre-LN with terminal LayerNorm before the head. + """ + def __init__( + self, + image_size: int = 32, + patch_size: int = 4, + in_channels: int = 3, + num_classes: int = 10, + d_model: int = 128, + n_heads: int = 4, + num_blocks: int = 4, + mlp_ratio: float = 4.0, + dropout: float = 0.0, + ): + super().__init__() + assert image_size % patch_size == 0 + n_patches = (image_size // patch_size) ** 2 + self.n_tokens = n_patches + 1 # +1 for cls token + + # Patch embedding via Conv2d (equivalent to flatten + linear) + self.patch_embed = nn.Conv2d(in_channels, d_model, kernel_size=patch_size, stride=patch_size) + self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) + self.pos_embed = nn.Parameter(torch.zeros(1, self.n_tokens, d_model)) + nn.init.trunc_normal_(self.cls_token, std=0.02) + nn.init.trunc_normal_(self.pos_embed, std=0.02) + + self.blocks = nn.ModuleList([ + TransformerBlock(d_model, n_heads, mlp_ratio, dropout) for _ in range(num_blocks) + ]) + self.out_ln = nn.LayerNorm(d_model) # terminal LN — the P4 trigger + self.out_head = nn.Linear(d_model, num_classes) + + self.num_blocks = num_blocks + self.d_model = d_model + self.d_hidden = d_model # alias for compatibility with snapshot script + + def embed(self, x: torch.Tensor) -> torch.Tensor: + """Take a flat-CIFAR input (B, 3072) or image (B, 3, 32, 32) → token sequence (B, 65, d_model).""" + if x.dim() == 2: # flat input + x = x.view(x.size(0), 3, 32, 32) + # x: (B, 3, 32, 32) + x = self.patch_embed(x) # (B, d, 8, 8) + x = x.flatten(2).transpose(1, 2) # (B, 64, d) + cls = self.cls_token.expand(x.size(0), -1, -1) + x = torch.cat([cls, x], dim=1) # (B, 65, d) + x = x + self.pos_embed + return x + + def forward(self, x: torch.Tensor, return_hidden: bool = False): + h = self.embed(x) # (B, 65, d) + hiddens = [h] if return_hidden else None + for block in self.blocks: + h = block(h) + if return_hidden: + hiddens.append(h) + # Take cls token, normalize, classify + h_cls = self.out_ln(h[:, 0]) # (B, d) + logits = self.out_head(h_cls) + if return_hidden: + return logits, hiddens + return logits + + def forward_from_layer(self, h: torch.Tensor, start_layer: int): + """Run forward from a given block index. h has shape (B, n_tokens, d).""" + for i in range(start_layer, self.num_blocks): + h = self.blocks[i](h) + h_cls = self.out_ln(h[:, 0]) + return self.out_head(h_cls) |
