1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
|
"""
Small CIFAR-10 ResNet for the FA-evaluation paper. Standard BatchNorm-based
post-activation residual blocks (no LayerNorm). 4 residual blocks at width 64.
Supports `num_blocks=0` (shallow baseline: just embed → bn → head) and frozen
blocks via `requires_grad=False` on `.blocks` parameters.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicBlock(nn.Module):
"""Standard ResNet BasicBlock with BatchNorm. Pre-activation NOT used; this is
the post-activation form: relu(BN(W2 * relu(BN(W1 x)))) + x. d_hidden in == d_hidden out.
No stride / downsampling — all blocks operate at the same spatial resolution
after the initial stem. This keeps the architecture simple and matches the
"4 residual blocks at fixed width" structure of our ResMLP and ViT-Mini comparisons.
"""
def __init__(self, d_hidden):
super().__init__()
self.conv1 = nn.Conv2d(d_hidden, d_hidden, kernel_size=3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(d_hidden)
self.conv2 = nn.Conv2d(d_hidden, d_hidden, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(d_hidden)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out = F.relu(x + out)
return out
class SmallResNet(nn.Module):
"""Small CIFAR-10 ResNet:
- 3x3 conv stem (3 → d_hidden) + BN + ReLU
- num_blocks BasicBlocks at the same width and resolution
- global average pool
- linear classification head
`num_blocks=0` gives the shallow baseline (just stem → pool → head).
"""
def __init__(self, d_hidden=64, num_classes=10, num_blocks=4):
super().__init__()
self.stem_conv = nn.Conv2d(3, d_hidden, kernel_size=3, padding=1, bias=False)
self.stem_bn = nn.BatchNorm2d(d_hidden)
self.blocks = nn.ModuleList([BasicBlock(d_hidden) for _ in range(num_blocks)])
self.out_head = nn.Linear(d_hidden, num_classes)
self.num_blocks = num_blocks
self.d_hidden = d_hidden
def stem(self, x):
# x: (B, 3, 32, 32)
if x.dim() == 2:
x = x.view(x.size(0), 3, 32, 32)
h = F.relu(self.stem_bn(self.stem_conv(x)))
return h
def forward(self, x, return_hidden=False):
h = self.stem(x) # (B, d, 32, 32)
hiddens = [h] if return_hidden else None
for block in self.blocks:
h = block(h)
if return_hidden:
hiddens.append(h)
h_pool = F.adaptive_avg_pool2d(h, 1).flatten(1) # (B, d)
logits = self.out_head(h_pool)
if return_hidden:
return logits, hiddens
return logits
# Convenience alias for snapshot script compatibility (treats stem as the embed)
def embed(self, x):
return self.stem(x)
|