summaryrefslogtreecommitdiff
path: root/models/small_resnet.py
blob: 10b122eec029ec5dfaab7e9fc6ac49e4475ef86d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""
Small CIFAR-10 ResNet for the FA-evaluation paper. Standard BatchNorm-based
post-activation residual blocks (no LayerNorm). 4 residual blocks at width 64.

Supports `num_blocks=0` (shallow baseline: just embed → bn → head) and frozen
blocks via `requires_grad=False` on `.blocks` parameters.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
    """Standard ResNet BasicBlock with BatchNorm. Pre-activation NOT used; this is
    the post-activation form: relu(BN(W2 * relu(BN(W1 x)))) + x. d_hidden in == d_hidden out.
    No stride / downsampling — all blocks operate at the same spatial resolution
    after the initial stem. This keeps the architecture simple and matches the
    "4 residual blocks at fixed width" structure of our ResMLP and ViT-Mini comparisons.
    """
    def __init__(self, d_hidden):
        super().__init__()
        self.conv1 = nn.Conv2d(d_hidden, d_hidden, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(d_hidden)
        self.conv2 = nn.Conv2d(d_hidden, d_hidden, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(d_hidden)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = F.relu(x + out)
        return out


class SmallResNet(nn.Module):
    """Small CIFAR-10 ResNet:
        - 3x3 conv stem (3 → d_hidden) + BN + ReLU
        - num_blocks BasicBlocks at the same width and resolution
        - global average pool
        - linear classification head

    `num_blocks=0` gives the shallow baseline (just stem → pool → head).
    """
    def __init__(self, d_hidden=64, num_classes=10, num_blocks=4):
        super().__init__()
        self.stem_conv = nn.Conv2d(3, d_hidden, kernel_size=3, padding=1, bias=False)
        self.stem_bn = nn.BatchNorm2d(d_hidden)
        self.blocks = nn.ModuleList([BasicBlock(d_hidden) for _ in range(num_blocks)])
        self.out_head = nn.Linear(d_hidden, num_classes)
        self.num_blocks = num_blocks
        self.d_hidden = d_hidden

    def stem(self, x):
        # x: (B, 3, 32, 32)
        if x.dim() == 2:
            x = x.view(x.size(0), 3, 32, 32)
        h = F.relu(self.stem_bn(self.stem_conv(x)))
        return h

    def forward(self, x, return_hidden=False):
        h = self.stem(x)                      # (B, d, 32, 32)
        hiddens = [h] if return_hidden else None
        for block in self.blocks:
            h = block(h)
            if return_hidden:
                hiddens.append(h)
        h_pool = F.adaptive_avg_pool2d(h, 1).flatten(1)   # (B, d)
        logits = self.out_head(h_pool)
        if return_hidden:
            return logits, hiddens
        return logits

    # Convenience alias for snapshot script compatibility (treats stem as the embed)
    def embed(self, x):
        return self.stem(x)