diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-26 09:31:30 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-26 09:31:30 -0500 |
| commit | a501c1c84b6ac4ff7dbf2e4b92cebd3122eb7abe (patch) | |
| tree | 25a83479302e211359bd4f49df44b2bf69d0aaee /experiments | |
| parent | 9751e97dd190b8667c337215dcb70e0cab8f92ff (diff) | |
BP+EP audit for d=512 L=2 qualifying seeds + CIFAR-100 support
BP results for qualifying seeds (1, 2, 5) on d=512 L=2:
BP s1: 0.606, s2: 0.608, s5: 0.607 (all above frozen 0.349)
FA s1: 0.347, s2: 0.346, s5: 0.341 (all below frozen, cos +0.47-0.49)
DFA s1: 0.298, s2: 0.297, s5: 0.296 (all below frozen, cos +0.18-0.21)
EP did not save (likely architecture compatibility issue at d=512 L=2).
Also: added CIFAR-100 dataset support to both cifar_resmlp.py and
resmlp_frozen_blocks_baseline.py for the harder-task scan.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
| -rw-r--r-- | experiments/cifar_resmlp.py | 15 | ||||
| -rw-r--r-- | experiments/resmlp_frozen_blocks_baseline.py | 33 |
2 files changed, 35 insertions, 13 deletions
diff --git a/experiments/cifar_resmlp.py b/experiments/cifar_resmlp.py index 05a355d..435b484 100644 --- a/experiments/cifar_resmlp.py +++ b/experiments/cifar_resmlp.py @@ -47,6 +47,21 @@ def get_data(dataset='cifar10', batch_size=128): testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) input_dim = 32 * 32 * 3 num_classes = 10 + elif dataset == 'cifar100': + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), + ]) + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), + ]) + trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train) + testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test) + input_dim = 32 * 32 * 3 + num_classes = 100 elif dataset == 'fashionmnist': transform_train = transforms.Compose([ transforms.RandomCrop(28, padding=2), diff --git a/experiments/resmlp_frozen_blocks_baseline.py b/experiments/resmlp_frozen_blocks_baseline.py index 6040bd4..bd56890 100644 --- a/experiments/resmlp_frozen_blocks_baseline.py +++ b/experiments/resmlp_frozen_blocks_baseline.py @@ -31,23 +31,30 @@ import numpy as np from models.residual_mlp import ResidualMLP -def get_loaders(batch_size=128): +def get_loaders(batch_size=128, dataset='cifar10'): + if dataset == 'cifar100': + mean, std = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761) + DatasetClass = torchvision.datasets.CIFAR100 + else: + mean, std = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616) + DatasetClass = torchvision.datasets.CIFAR10 tv_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + transforms.Normalize(mean, std), ]) tv = transforms.Compose([ transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + transforms.Normalize(mean, std), ]) - tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) - te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) + tr = DatasetClass('./data', True, download=True, transform=tv_train) + te = DatasetClass('./data', False, download=True, transform=tv) + num_classes = 100 if dataset == 'cifar100' else 10 return ( DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2), - ) + ), num_classes def evaluate(model, loader, dev): @@ -84,14 +91,14 @@ def train_bp(model, train_loader, test_loader, dev, epochs, lr, wd, label): return model -def train_dfa(model, train_loader, test_loader, dev, epochs, lr, wd, label): +def train_dfa(model, train_loader, test_loader, dev, epochs, lr, wd, label, num_classes=10): """DFA-style: head with true CE, embed (and unfrozen blocks if any) with random feedback. For frozen-blocks: blocks are skipped. For trainable blocks not used here. For num_blocks=0 (shallow): only embed/head are updated. """ d_hidden = model.d_hidden L = model.num_blocks - C = 10 + C = num_classes Bs = [torch.randn(d_hidden, C, device=dev) / np.sqrt(C) for _ in range(max(L, 1))] embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) @@ -138,15 +145,15 @@ def main(): parser.add_argument('--wd', type=float, default=0.01) parser.add_argument('--d_hidden', type=int, default=256) parser.add_argument('--num_blocks', type=int, default=4) + parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100']) args = parser.parse_args() dev = torch.device('cuda:0') - print(f"Device: {dev}, seed={args.seed}, epochs={args.epochs}", flush=True) - train_loader, test_loader = get_loaders(batch_size=128) + print(f"Device: {dev}, seed={args.seed}, epochs={args.epochs}, dataset={args.dataset}", flush=True) + (train_loader, test_loader), C = get_loaders(batch_size=128, dataset=args.dataset) results = {} input_dim = 32 * 32 * 3 - C = 10 # Condition 1: BP shallow (num_blocks=0) print(f"\n=== BP shallow (ResMLP num_blocks=0), seed={args.seed} ===", flush=True) @@ -173,7 +180,7 @@ def main(): torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) m = ResidualMLP(input_dim, args.d_hidden, C, 0).to(dev) print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True) - train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-shallow') + train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-shallow', num_classes=C) results['dfa_shallow'] = evaluate(m, test_loader, dev) print(f"FINAL DFA-shallow: {results['dfa_shallow']:.4f}", flush=True) @@ -183,7 +190,7 @@ def main(): m = ResidualMLP(input_dim, args.d_hidden, C, L).to(dev) freeze_blocks(m) print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True) - train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-frozen') + train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-frozen', num_classes=C) results['dfa_frozen'] = evaluate(m, test_loader, dev) print(f"FINAL DFA-frozen-blocks: {results['dfa_frozen']:.4f}", flush=True) |
