""" Small CIFAR-10 ResNet for the FA-evaluation paper. Standard BatchNorm-based post-activation residual blocks (no LayerNorm). 4 residual blocks at width 64. Supports `num_blocks=0` (shallow baseline: just embed → bn → head) and frozen blocks via `requires_grad=False` on `.blocks` parameters. """ import torch import torch.nn as nn import torch.nn.functional as F class BasicBlock(nn.Module): """Standard ResNet BasicBlock with BatchNorm. Pre-activation NOT used; this is the post-activation form: relu(BN(W2 * relu(BN(W1 x)))) + x. d_hidden in == d_hidden out. No stride / downsampling — all blocks operate at the same spatial resolution after the initial stem. This keeps the architecture simple and matches the "4 residual blocks at fixed width" structure of our ResMLP and ViT-Mini comparisons. """ def __init__(self, d_hidden): super().__init__() self.conv1 = nn.Conv2d(d_hidden, d_hidden, kernel_size=3, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(d_hidden) self.conv2 = nn.Conv2d(d_hidden, d_hidden, kernel_size=3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(d_hidden) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out = F.relu(x + out) return out class SmallResNet(nn.Module): """Small CIFAR-10 ResNet: - 3x3 conv stem (3 → d_hidden) + BN + ReLU - num_blocks BasicBlocks at the same width and resolution - global average pool - linear classification head `num_blocks=0` gives the shallow baseline (just stem → pool → head). """ def __init__(self, d_hidden=64, num_classes=10, num_blocks=4): super().__init__() self.stem_conv = nn.Conv2d(3, d_hidden, kernel_size=3, padding=1, bias=False) self.stem_bn = nn.BatchNorm2d(d_hidden) self.blocks = nn.ModuleList([BasicBlock(d_hidden) for _ in range(num_blocks)]) self.out_head = nn.Linear(d_hidden, num_classes) self.num_blocks = num_blocks self.d_hidden = d_hidden def stem(self, x): # x: (B, 3, 32, 32) if x.dim() == 2: x = x.view(x.size(0), 3, 32, 32) h = F.relu(self.stem_bn(self.stem_conv(x))) return h def forward(self, x, return_hidden=False): h = self.stem(x) # (B, d, 32, 32) hiddens = [h] if return_hidden else None for block in self.blocks: h = block(h) if return_hidden: hiddens.append(h) h_pool = F.adaptive_avg_pool2d(h, 1).flatten(1) # (B, d) logits = self.out_head(h_pool) if return_hidden: return logits, hiddens return logits # Convenience alias for snapshot script compatibility (treats stem as the embed) def embed(self, x): return self.stem(x)