From b480d0cdc21f944e4adccf6e81cc939b0450c5e9 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Mon, 4 May 2026 19:50:45 -0500 Subject: Initial submission code: FA evaluation protocol + reproduction scripts Reference implementation of the three-diagnostic FA evaluation protocol (scale stability, reference validity, depth utility) from the NeurIPS 2026 E&D track paper. Includes models, metrics, and full reproduction pipeline. Co-Authored-By: Claude Opus 4.6 (1M context) --- models/small_resnet.py | 74 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 models/small_resnet.py (limited to 'models/small_resnet.py') diff --git a/models/small_resnet.py b/models/small_resnet.py new file mode 100644 index 0000000..10b122e --- /dev/null +++ b/models/small_resnet.py @@ -0,0 +1,74 @@ +""" +Small CIFAR-10 ResNet for the FA-evaluation paper. Standard BatchNorm-based +post-activation residual blocks (no LayerNorm). 4 residual blocks at width 64. + +Supports `num_blocks=0` (shallow baseline: just embed → bn → head) and frozen +blocks via `requires_grad=False` on `.blocks` parameters. +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BasicBlock(nn.Module): + """Standard ResNet BasicBlock with BatchNorm. Pre-activation NOT used; this is + the post-activation form: relu(BN(W2 * relu(BN(W1 x)))) + x. d_hidden in == d_hidden out. + No stride / downsampling — all blocks operate at the same spatial resolution + after the initial stem. This keeps the architecture simple and matches the + "4 residual blocks at fixed width" structure of our ResMLP and ViT-Mini comparisons. + """ + def __init__(self, d_hidden): + super().__init__() + self.conv1 = nn.Conv2d(d_hidden, d_hidden, kernel_size=3, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(d_hidden) + self.conv2 = nn.Conv2d(d_hidden, d_hidden, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(d_hidden) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out = F.relu(x + out) + return out + + +class SmallResNet(nn.Module): + """Small CIFAR-10 ResNet: + - 3x3 conv stem (3 → d_hidden) + BN + ReLU + - num_blocks BasicBlocks at the same width and resolution + - global average pool + - linear classification head + + `num_blocks=0` gives the shallow baseline (just stem → pool → head). + """ + def __init__(self, d_hidden=64, num_classes=10, num_blocks=4): + super().__init__() + self.stem_conv = nn.Conv2d(3, d_hidden, kernel_size=3, padding=1, bias=False) + self.stem_bn = nn.BatchNorm2d(d_hidden) + self.blocks = nn.ModuleList([BasicBlock(d_hidden) for _ in range(num_blocks)]) + self.out_head = nn.Linear(d_hidden, num_classes) + self.num_blocks = num_blocks + self.d_hidden = d_hidden + + def stem(self, x): + # x: (B, 3, 32, 32) + if x.dim() == 2: + x = x.view(x.size(0), 3, 32, 32) + h = F.relu(self.stem_bn(self.stem_conv(x))) + return h + + def forward(self, x, return_hidden=False): + h = self.stem(x) # (B, d, 32, 32) + hiddens = [h] if return_hidden else None + for block in self.blocks: + h = block(h) + if return_hidden: + hiddens.append(h) + h_pool = F.adaptive_avg_pool2d(h, 1).flatten(1) # (B, d) + logits = self.out_head(h_pool) + if return_hidden: + return logits, hiddens + return logits + + # Convenience alias for snapshot script compatibility (treats stem as the embed) + def embed(self, x): + return self.stem(x) -- cgit v1.2.3