diff options
Diffstat (limited to 'experiments/resmlp_frozen_blocks_baseline.py')
| -rw-r--r-- | experiments/resmlp_frozen_blocks_baseline.py | 33 |
1 files changed, 20 insertions, 13 deletions
diff --git a/experiments/resmlp_frozen_blocks_baseline.py b/experiments/resmlp_frozen_blocks_baseline.py index 6040bd4..bd56890 100644 --- a/experiments/resmlp_frozen_blocks_baseline.py +++ b/experiments/resmlp_frozen_blocks_baseline.py @@ -31,23 +31,30 @@ import numpy as np from models.residual_mlp import ResidualMLP -def get_loaders(batch_size=128): +def get_loaders(batch_size=128, dataset='cifar10'): + if dataset == 'cifar100': + mean, std = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761) + DatasetClass = torchvision.datasets.CIFAR100 + else: + mean, std = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616) + DatasetClass = torchvision.datasets.CIFAR10 tv_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + transforms.Normalize(mean, std), ]) tv = transforms.Compose([ transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + transforms.Normalize(mean, std), ]) - tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) - te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) + tr = DatasetClass('./data', True, download=True, transform=tv_train) + te = DatasetClass('./data', False, download=True, transform=tv) + num_classes = 100 if dataset == 'cifar100' else 10 return ( DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2), - ) + ), num_classes def evaluate(model, loader, dev): @@ -84,14 +91,14 @@ def train_bp(model, train_loader, test_loader, dev, epochs, lr, wd, label): return model -def train_dfa(model, train_loader, test_loader, dev, epochs, lr, wd, label): +def train_dfa(model, train_loader, test_loader, dev, epochs, lr, wd, label, num_classes=10): """DFA-style: head with true CE, embed (and unfrozen blocks if any) with random feedback. For frozen-blocks: blocks are skipped. For trainable blocks not used here. For num_blocks=0 (shallow): only embed/head are updated. """ d_hidden = model.d_hidden L = model.num_blocks - C = 10 + C = num_classes Bs = [torch.randn(d_hidden, C, device=dev) / np.sqrt(C) for _ in range(max(L, 1))] embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) @@ -138,15 +145,15 @@ def main(): parser.add_argument('--wd', type=float, default=0.01) parser.add_argument('--d_hidden', type=int, default=256) parser.add_argument('--num_blocks', type=int, default=4) + parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100']) args = parser.parse_args() dev = torch.device('cuda:0') - print(f"Device: {dev}, seed={args.seed}, epochs={args.epochs}", flush=True) - train_loader, test_loader = get_loaders(batch_size=128) + print(f"Device: {dev}, seed={args.seed}, epochs={args.epochs}, dataset={args.dataset}", flush=True) + (train_loader, test_loader), C = get_loaders(batch_size=128, dataset=args.dataset) results = {} input_dim = 32 * 32 * 3 - C = 10 # Condition 1: BP shallow (num_blocks=0) print(f"\n=== BP shallow (ResMLP num_blocks=0), seed={args.seed} ===", flush=True) @@ -173,7 +180,7 @@ def main(): torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) m = ResidualMLP(input_dim, args.d_hidden, C, 0).to(dev) print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True) - train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-shallow') + train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-shallow', num_classes=C) results['dfa_shallow'] = evaluate(m, test_loader, dev) print(f"FINAL DFA-shallow: {results['dfa_shallow']:.4f}", flush=True) @@ -183,7 +190,7 @@ def main(): m = ResidualMLP(input_dim, args.d_hidden, C, L).to(dev) freeze_blocks(m) print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True) - train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-frozen') + train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-frozen', num_classes=C) results['dfa_frozen'] = evaluate(m, test_loader, dev) print(f"FINAL DFA-frozen-blocks: {results['dfa_frozen']:.4f}", flush=True) |
