summaryrefslogtreecommitdiff
path: root/models/small_resnet.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/small_resnet.py')
-rw-r--r--models/small_resnet.py74
1 files changed, 74 insertions, 0 deletions
diff --git a/models/small_resnet.py b/models/small_resnet.py
new file mode 100644
index 0000000..10b122e
--- /dev/null
+++ b/models/small_resnet.py
@@ -0,0 +1,74 @@
+"""
+Small CIFAR-10 ResNet for the FA-evaluation paper. Standard BatchNorm-based
+post-activation residual blocks (no LayerNorm). 4 residual blocks at width 64.
+
+Supports `num_blocks=0` (shallow baseline: just embed → bn → head) and frozen
+blocks via `requires_grad=False` on `.blocks` parameters.
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class BasicBlock(nn.Module):
+ """Standard ResNet BasicBlock with BatchNorm. Pre-activation NOT used; this is
+ the post-activation form: relu(BN(W2 * relu(BN(W1 x)))) + x. d_hidden in == d_hidden out.
+ No stride / downsampling — all blocks operate at the same spatial resolution
+ after the initial stem. This keeps the architecture simple and matches the
+ "4 residual blocks at fixed width" structure of our ResMLP and ViT-Mini comparisons.
+ """
+ def __init__(self, d_hidden):
+ super().__init__()
+ self.conv1 = nn.Conv2d(d_hidden, d_hidden, kernel_size=3, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(d_hidden)
+ self.conv2 = nn.Conv2d(d_hidden, d_hidden, kernel_size=3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(d_hidden)
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.bn2(self.conv2(out))
+ out = F.relu(x + out)
+ return out
+
+
+class SmallResNet(nn.Module):
+ """Small CIFAR-10 ResNet:
+ - 3x3 conv stem (3 → d_hidden) + BN + ReLU
+ - num_blocks BasicBlocks at the same width and resolution
+ - global average pool
+ - linear classification head
+
+ `num_blocks=0` gives the shallow baseline (just stem → pool → head).
+ """
+ def __init__(self, d_hidden=64, num_classes=10, num_blocks=4):
+ super().__init__()
+ self.stem_conv = nn.Conv2d(3, d_hidden, kernel_size=3, padding=1, bias=False)
+ self.stem_bn = nn.BatchNorm2d(d_hidden)
+ self.blocks = nn.ModuleList([BasicBlock(d_hidden) for _ in range(num_blocks)])
+ self.out_head = nn.Linear(d_hidden, num_classes)
+ self.num_blocks = num_blocks
+ self.d_hidden = d_hidden
+
+ def stem(self, x):
+ # x: (B, 3, 32, 32)
+ if x.dim() == 2:
+ x = x.view(x.size(0), 3, 32, 32)
+ h = F.relu(self.stem_bn(self.stem_conv(x)))
+ return h
+
+ def forward(self, x, return_hidden=False):
+ h = self.stem(x) # (B, d, 32, 32)
+ hiddens = [h] if return_hidden else None
+ for block in self.blocks:
+ h = block(h)
+ if return_hidden:
+ hiddens.append(h)
+ h_pool = F.adaptive_avg_pool2d(h, 1).flatten(1) # (B, d)
+ logits = self.out_head(h_pool)
+ if return_hidden:
+ return logits, hiddens
+ return logits
+
+ # Convenience alias for snapshot script compatibility (treats stem as the embed)
+ def embed(self, x):
+ return self.stem(x)