summaryrefslogtreecommitdiff
path: root/experiments/cifar_resmlp.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-26 09:31:30 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-26 09:31:30 -0500
commita501c1c84b6ac4ff7dbf2e4b92cebd3122eb7abe (patch)
tree25a83479302e211359bd4f49df44b2bf69d0aaee /experiments/cifar_resmlp.py
parent9751e97dd190b8667c337215dcb70e0cab8f92ff (diff)
BP+EP audit for d=512 L=2 qualifying seeds + CIFAR-100 support
BP results for qualifying seeds (1, 2, 5) on d=512 L=2: BP s1: 0.606, s2: 0.608, s5: 0.607 (all above frozen 0.349) FA s1: 0.347, s2: 0.346, s5: 0.341 (all below frozen, cos +0.47-0.49) DFA s1: 0.298, s2: 0.297, s5: 0.296 (all below frozen, cos +0.18-0.21) EP did not save (likely architecture compatibility issue at d=512 L=2). Also: added CIFAR-100 dataset support to both cifar_resmlp.py and resmlp_frozen_blocks_baseline.py for the harder-task scan. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments/cifar_resmlp.py')
-rw-r--r--experiments/cifar_resmlp.py15
1 files changed, 15 insertions, 0 deletions
diff --git a/experiments/cifar_resmlp.py b/experiments/cifar_resmlp.py
index 05a355d..435b484 100644
--- a/experiments/cifar_resmlp.py
+++ b/experiments/cifar_resmlp.py
@@ -47,6 +47,21 @@ def get_data(dataset='cifar10', batch_size=128):
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
input_dim = 32 * 32 * 3
num_classes = 10
+ elif dataset == 'cifar100':
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
+ ])
+ trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
+ testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
+ input_dim = 32 * 32 * 3
+ num_classes = 100
elif dataset == 'fashionmnist':
transform_train = transforms.Compose([
transforms.RandomCrop(28, padding=2),