summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/__init__.py0
-rw-r--r--models/residual_mlp.py75
-rw-r--r--models/small_resnet.py74
-rw-r--r--models/vit_mini.py109
4 files changed, 258 insertions, 0 deletions
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/models/__init__.py
diff --git a/models/residual_mlp.py b/models/residual_mlp.py
new file mode 100644
index 0000000..6827057
--- /dev/null
+++ b/models/residual_mlp.py
@@ -0,0 +1,75 @@
+"""
+Deep Residual MLP for classification.
+Architecture: Input -> Linear embedding -> L residual blocks -> LayerNorm -> Linear output head.
+Each block: h_{l+1} = h_l + W2 * GELU(W1 * LN(h_l))
+"""
+import torch
+import torch.nn as nn
+
+
+class ResidualBlock(nn.Module):
+ """Single pre-LayerNorm residual MLP block."""
+
+ def __init__(self, d_hidden: int, w2_std: float = 0.01):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.w1 = nn.Linear(d_hidden, d_hidden)
+ self.w2 = nn.Linear(d_hidden, d_hidden)
+ # Small init for residual branch (or larger if used as a non-residual stack)
+ nn.init.normal_(self.w2.weight, std=w2_std)
+ nn.init.zeros_(self.w2.bias)
+
+ def forward(self, h: torch.Tensor) -> torch.Tensor:
+ """Returns the residual F_l(h), NOT h + F_l(h)."""
+ z = self.ln(h)
+ z = self.w1(z)
+ z = torch.nn.functional.gelu(z)
+ z = self.w2(z)
+ return z
+
+
+class ResidualMLP(nn.Module):
+ """Deep residual MLP: embed -> L blocks -> LN -> output head."""
+
+ def __init__(self, input_dim: int, d_hidden: int, num_classes: int, num_blocks: int,
+ residual_add: bool = True, w2_std: float = 0.01):
+ super().__init__()
+ self.embed = nn.Linear(input_dim, d_hidden)
+ self.blocks = nn.ModuleList([ResidualBlock(d_hidden, w2_std=w2_std) for _ in range(num_blocks)])
+ self.out_ln = nn.LayerNorm(d_hidden)
+ self.out_head = nn.Linear(d_hidden, num_classes)
+ self.num_blocks = num_blocks
+ self.d_hidden = d_hidden
+ self.residual_add = residual_add
+
+ def forward(self, x: torch.Tensor, return_hidden: bool = False):
+ """
+ Args:
+ x: (batch, input_dim)
+ return_hidden: if True, also return list of hidden states [h_0, ..., h_L]
+ Returns:
+ logits: (batch, num_classes)
+ hiddens: list of (batch, d_hidden) if return_hidden
+ """
+ h = self.embed(x)
+ hiddens = [h] if return_hidden else None
+
+ for block in self.blocks:
+ f = block(h)
+ h = h + f if self.residual_add else f
+ if return_hidden:
+ hiddens.append(h)
+
+ logits = self.out_head(self.out_ln(h))
+
+ if return_hidden:
+ return logits, hiddens
+ return logits
+
+ def forward_from_layer(self, h: torch.Tensor, start_layer: int):
+ """Run forward from a given layer index to output. Used for perturbation tests."""
+ for i in range(start_layer, self.num_blocks):
+ f = self.blocks[i](h)
+ h = h + f if self.residual_add else f
+ logits = self.out_head(self.out_ln(h))
+ return logits
diff --git a/models/small_resnet.py b/models/small_resnet.py
new file mode 100644
index 0000000..10b122e
--- /dev/null
+++ b/models/small_resnet.py
@@ -0,0 +1,74 @@
+"""
+Small CIFAR-10 ResNet for the FA-evaluation paper. Standard BatchNorm-based
+post-activation residual blocks (no LayerNorm). 4 residual blocks at width 64.
+
+Supports `num_blocks=0` (shallow baseline: just embed → bn → head) and frozen
+blocks via `requires_grad=False` on `.blocks` parameters.
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class BasicBlock(nn.Module):
+ """Standard ResNet BasicBlock with BatchNorm. Pre-activation NOT used; this is
+ the post-activation form: relu(BN(W2 * relu(BN(W1 x)))) + x. d_hidden in == d_hidden out.
+ No stride / downsampling — all blocks operate at the same spatial resolution
+ after the initial stem. This keeps the architecture simple and matches the
+ "4 residual blocks at fixed width" structure of our ResMLP and ViT-Mini comparisons.
+ """
+ def __init__(self, d_hidden):
+ super().__init__()
+ self.conv1 = nn.Conv2d(d_hidden, d_hidden, kernel_size=3, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(d_hidden)
+ self.conv2 = nn.Conv2d(d_hidden, d_hidden, kernel_size=3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(d_hidden)
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.bn2(self.conv2(out))
+ out = F.relu(x + out)
+ return out
+
+
+class SmallResNet(nn.Module):
+ """Small CIFAR-10 ResNet:
+ - 3x3 conv stem (3 → d_hidden) + BN + ReLU
+ - num_blocks BasicBlocks at the same width and resolution
+ - global average pool
+ - linear classification head
+
+ `num_blocks=0` gives the shallow baseline (just stem → pool → head).
+ """
+ def __init__(self, d_hidden=64, num_classes=10, num_blocks=4):
+ super().__init__()
+ self.stem_conv = nn.Conv2d(3, d_hidden, kernel_size=3, padding=1, bias=False)
+ self.stem_bn = nn.BatchNorm2d(d_hidden)
+ self.blocks = nn.ModuleList([BasicBlock(d_hidden) for _ in range(num_blocks)])
+ self.out_head = nn.Linear(d_hidden, num_classes)
+ self.num_blocks = num_blocks
+ self.d_hidden = d_hidden
+
+ def stem(self, x):
+ # x: (B, 3, 32, 32)
+ if x.dim() == 2:
+ x = x.view(x.size(0), 3, 32, 32)
+ h = F.relu(self.stem_bn(self.stem_conv(x)))
+ return h
+
+ def forward(self, x, return_hidden=False):
+ h = self.stem(x) # (B, d, 32, 32)
+ hiddens = [h] if return_hidden else None
+ for block in self.blocks:
+ h = block(h)
+ if return_hidden:
+ hiddens.append(h)
+ h_pool = F.adaptive_avg_pool2d(h, 1).flatten(1) # (B, d)
+ logits = self.out_head(h_pool)
+ if return_hidden:
+ return logits, hiddens
+ return logits
+
+ # Convenience alias for snapshot script compatibility (treats stem as the embed)
+ def embed(self, x):
+ return self.stem(x)
diff --git a/models/vit_mini.py b/models/vit_mini.py
new file mode 100644
index 0000000..af6ba60
--- /dev/null
+++ b/models/vit_mini.py
@@ -0,0 +1,109 @@
+"""
+Minimal Vision Transformer for CIFAR-10. Pre-LN with terminal LayerNorm before
+the classification head — the architecture P4 should target.
+
+Designed to be compatible with the snapshot evolution / DFA training framework.
+Each TransformerBlock is a "layer" for FA-style local credit purposes.
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class TransformerBlock(nn.Module):
+ """Pre-LN transformer block: x = x + attn(LN(x)); x = x + mlp(LN(x))."""
+ def __init__(self, d_model: int, n_heads: int, mlp_ratio: float = 4.0, dropout: float = 0.0):
+ super().__init__()
+ self.ln1 = nn.LayerNorm(d_model)
+ self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
+ self.ln2 = nn.LayerNorm(d_model)
+ mlp_hidden = int(d_model * mlp_ratio)
+ self.mlp = nn.Sequential(
+ nn.Linear(d_model, mlp_hidden),
+ nn.GELU(),
+ nn.Dropout(dropout),
+ nn.Linear(mlp_hidden, d_model),
+ nn.Dropout(dropout),
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # Self-attention sublayer
+ x_norm = self.ln1(x)
+ attn_out, _ = self.attn(x_norm, x_norm, x_norm, need_weights=False)
+ x = x + attn_out
+ # MLP sublayer
+ x = x + self.mlp(self.ln2(x))
+ return x
+
+
+class ViTMini(nn.Module):
+ """Minimal Vision Transformer for CIFAR-10.
+ Patch size 4x4 → 64 patches per image. Plus a learned cls token.
+ Pre-LN with terminal LayerNorm before the head.
+ """
+ def __init__(
+ self,
+ image_size: int = 32,
+ patch_size: int = 4,
+ in_channels: int = 3,
+ num_classes: int = 10,
+ d_model: int = 128,
+ n_heads: int = 4,
+ num_blocks: int = 4,
+ mlp_ratio: float = 4.0,
+ dropout: float = 0.0,
+ ):
+ super().__init__()
+ assert image_size % patch_size == 0
+ n_patches = (image_size // patch_size) ** 2
+ self.n_tokens = n_patches + 1 # +1 for cls token
+
+ # Patch embedding via Conv2d (equivalent to flatten + linear)
+ self.patch_embed = nn.Conv2d(in_channels, d_model, kernel_size=patch_size, stride=patch_size)
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.n_tokens, d_model))
+ nn.init.trunc_normal_(self.cls_token, std=0.02)
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
+
+ self.blocks = nn.ModuleList([
+ TransformerBlock(d_model, n_heads, mlp_ratio, dropout) for _ in range(num_blocks)
+ ])
+ self.out_ln = nn.LayerNorm(d_model) # terminal LN — the P4 trigger
+ self.out_head = nn.Linear(d_model, num_classes)
+
+ self.num_blocks = num_blocks
+ self.d_model = d_model
+ self.d_hidden = d_model # alias for compatibility with snapshot script
+
+ def embed(self, x: torch.Tensor) -> torch.Tensor:
+ """Take a flat-CIFAR input (B, 3072) or image (B, 3, 32, 32) → token sequence (B, 65, d_model)."""
+ if x.dim() == 2: # flat input
+ x = x.view(x.size(0), 3, 32, 32)
+ # x: (B, 3, 32, 32)
+ x = self.patch_embed(x) # (B, d, 8, 8)
+ x = x.flatten(2).transpose(1, 2) # (B, 64, d)
+ cls = self.cls_token.expand(x.size(0), -1, -1)
+ x = torch.cat([cls, x], dim=1) # (B, 65, d)
+ x = x + self.pos_embed
+ return x
+
+ def forward(self, x: torch.Tensor, return_hidden: bool = False):
+ h = self.embed(x) # (B, 65, d)
+ hiddens = [h] if return_hidden else None
+ for block in self.blocks:
+ h = block(h)
+ if return_hidden:
+ hiddens.append(h)
+ # Take cls token, normalize, classify
+ h_cls = self.out_ln(h[:, 0]) # (B, d)
+ logits = self.out_head(h_cls)
+ if return_hidden:
+ return logits, hiddens
+ return logits
+
+ def forward_from_layer(self, h: torch.Tensor, start_layer: int):
+ """Run forward from a given block index. h has shape (B, n_tokens, d)."""
+ for i in range(start_layer, self.num_blocks):
+ h = self.blocks[i](h)
+ h_cls = self.out_ln(h[:, 0])
+ return self.out_head(h_cls)