summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-26 09:31:30 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-26 09:31:30 -0500
commita501c1c84b6ac4ff7dbf2e4b92cebd3122eb7abe (patch)
tree25a83479302e211359bd4f49df44b2bf69d0aaee /experiments
parent9751e97dd190b8667c337215dcb70e0cab8f92ff (diff)
BP+EP audit for d=512 L=2 qualifying seeds + CIFAR-100 support
BP results for qualifying seeds (1, 2, 5) on d=512 L=2: BP s1: 0.606, s2: 0.608, s5: 0.607 (all above frozen 0.349) FA s1: 0.347, s2: 0.346, s5: 0.341 (all below frozen, cos +0.47-0.49) DFA s1: 0.298, s2: 0.297, s5: 0.296 (all below frozen, cos +0.18-0.21) EP did not save (likely architecture compatibility issue at d=512 L=2). Also: added CIFAR-100 dataset support to both cifar_resmlp.py and resmlp_frozen_blocks_baseline.py for the harder-task scan. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
-rw-r--r--experiments/cifar_resmlp.py15
-rw-r--r--experiments/resmlp_frozen_blocks_baseline.py33
2 files changed, 35 insertions, 13 deletions
diff --git a/experiments/cifar_resmlp.py b/experiments/cifar_resmlp.py
index 05a355d..435b484 100644
--- a/experiments/cifar_resmlp.py
+++ b/experiments/cifar_resmlp.py
@@ -47,6 +47,21 @@ def get_data(dataset='cifar10', batch_size=128):
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
input_dim = 32 * 32 * 3
num_classes = 10
+ elif dataset == 'cifar100':
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
+ ])
+ trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
+ testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
+ input_dim = 32 * 32 * 3
+ num_classes = 100
elif dataset == 'fashionmnist':
transform_train = transforms.Compose([
transforms.RandomCrop(28, padding=2),
diff --git a/experiments/resmlp_frozen_blocks_baseline.py b/experiments/resmlp_frozen_blocks_baseline.py
index 6040bd4..bd56890 100644
--- a/experiments/resmlp_frozen_blocks_baseline.py
+++ b/experiments/resmlp_frozen_blocks_baseline.py
@@ -31,23 +31,30 @@ import numpy as np
from models.residual_mlp import ResidualMLP
-def get_loaders(batch_size=128):
+def get_loaders(batch_size=128, dataset='cifar10'):
+ if dataset == 'cifar100':
+ mean, std = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)
+ DatasetClass = torchvision.datasets.CIFAR100
+ else:
+ mean, std = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)
+ DatasetClass = torchvision.datasets.CIFAR10
tv_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
- transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ transforms.Normalize(mean, std),
])
tv = transforms.Compose([
transforms.ToTensor(),
- transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ transforms.Normalize(mean, std),
])
- tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train)
- te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv)
+ tr = DatasetClass('./data', True, download=True, transform=tv_train)
+ te = DatasetClass('./data', False, download=True, transform=tv)
+ num_classes = 100 if dataset == 'cifar100' else 10
return (
DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2),
DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2),
- )
+ ), num_classes
def evaluate(model, loader, dev):
@@ -84,14 +91,14 @@ def train_bp(model, train_loader, test_loader, dev, epochs, lr, wd, label):
return model
-def train_dfa(model, train_loader, test_loader, dev, epochs, lr, wd, label):
+def train_dfa(model, train_loader, test_loader, dev, epochs, lr, wd, label, num_classes=10):
"""DFA-style: head with true CE, embed (and unfrozen blocks if any) with random feedback.
For frozen-blocks: blocks are skipped. For trainable blocks not used here.
For num_blocks=0 (shallow): only embed/head are updated.
"""
d_hidden = model.d_hidden
L = model.num_blocks
- C = 10
+ C = num_classes
Bs = [torch.randn(d_hidden, C, device=dev) / np.sqrt(C) for _ in range(max(L, 1))]
embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd)
@@ -138,15 +145,15 @@ def main():
parser.add_argument('--wd', type=float, default=0.01)
parser.add_argument('--d_hidden', type=int, default=256)
parser.add_argument('--num_blocks', type=int, default=4)
+ parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100'])
args = parser.parse_args()
dev = torch.device('cuda:0')
- print(f"Device: {dev}, seed={args.seed}, epochs={args.epochs}", flush=True)
- train_loader, test_loader = get_loaders(batch_size=128)
+ print(f"Device: {dev}, seed={args.seed}, epochs={args.epochs}, dataset={args.dataset}", flush=True)
+ (train_loader, test_loader), C = get_loaders(batch_size=128, dataset=args.dataset)
results = {}
input_dim = 32 * 32 * 3
- C = 10
# Condition 1: BP shallow (num_blocks=0)
print(f"\n=== BP shallow (ResMLP num_blocks=0), seed={args.seed} ===", flush=True)
@@ -173,7 +180,7 @@ def main():
torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
m = ResidualMLP(input_dim, args.d_hidden, C, 0).to(dev)
print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True)
- train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-shallow')
+ train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-shallow', num_classes=C)
results['dfa_shallow'] = evaluate(m, test_loader, dev)
print(f"FINAL DFA-shallow: {results['dfa_shallow']:.4f}", flush=True)
@@ -183,7 +190,7 @@ def main():
m = ResidualMLP(input_dim, args.d_hidden, C, L).to(dev)
freeze_blocks(m)
print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True)
- train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-frozen')
+ train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-frozen', num_classes=C)
results['dfa_frozen'] = evaluate(m, test_loader, dev)
print(f"FINAL DFA-frozen-blocks: {results['dfa_frozen']:.4f}", flush=True)