""" Frozen-blocks and shallow baselines for the 4-block d=256 ResidualMLP on CIFAR-10. This is the codex-round-8 control for ResMLP, parallel to the ViT-Mini frozen-blocks experiment that walked back the "DFA trains a 4-block ViT" claim. Conditions (4 per seed): - BP shallow (num_blocks=0, just embed -> out_ln -> out_head) - BP frozen-blocks (num_blocks=4, blocks frozen at random init, only embed/LN/head trainable) - DFA shallow (num_blocks=0) - DFA frozen-blocks (num_blocks=4, blocks frozen) If frozen ≈ trainable for DFA: DFA-on-ResMLP also has the same "blocks are passengers" problem as ViT-Mini, and the strongest remaining DFA performance result in the paper falls. If frozen << trainable: DFA on ResMLP IS doing meaningful block training, and the contrast with ViT becomes the most interesting result. Usage: CUDA_VISIBLE_DEVICES=2 python experiments/resmlp_frozen_blocks_baseline.py --seed 42 """ import sys, os, argparse sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader import torchvision import torchvision.transforms as transforms import numpy as np from models.residual_mlp import ResidualMLP def get_loaders(batch_size=128, dataset='cifar10'): if dataset == 'cifar100': mean, std = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761) DatasetClass = torchvision.datasets.CIFAR100 else: mean, std = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616) DatasetClass = torchvision.datasets.CIFAR10 tv_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean, std), ]) tv = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean, std), ]) tr = DatasetClass('./data', True, download=True, transform=tv_train) te = DatasetClass('./data', False, download=True, transform=tv) num_classes = 100 if dataset == 'cifar100' else 10 return ( DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2), ), num_classes def evaluate(model, loader, dev): model.eval() n = c = 0 with torch.no_grad(): for x, y in loader: x = x.view(x.size(0), -1).to(dev); y = y.to(dev) preds = model(x).argmax(-1) c += (preds == y).sum().item() n += x.size(0) return c / n def freeze_blocks(model): for p in model.blocks.parameters(): p.requires_grad_(False) def train_bp(model, train_loader, test_loader, dev, epochs, lr, wd, label): """Standard BP. Filters optimizer to requires_grad params.""" opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=wd) sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) for ep in range(1, epochs + 1): model.train() for x, y in train_loader: x = x.view(x.size(0), -1).to(dev); y = y.to(dev) loss = F.cross_entropy(model(x), y) opt.zero_grad(); loss.backward(); opt.step() sch.step() if ep % 10 == 0 or ep == 1 or ep == epochs: acc = evaluate(model, test_loader, dev) print(f" [{label}] ep {ep}: test_acc={acc:.4f}", flush=True) return model def train_dfa(model, train_loader, test_loader, dev, epochs, lr, wd, label, num_classes=10): """DFA-style: head with true CE, embed (and unfrozen blocks if any) with random feedback. For frozen-blocks: blocks are skipped. For trainable blocks not used here. For num_blocks=0 (shallow): only embed/head are updated. """ d_hidden = model.d_hidden L = model.num_blocks C = num_classes Bs = [torch.randn(d_hidden, C, device=dev) / np.sqrt(C) for _ in range(max(L, 1))] embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) head_opt = optim.AdamW( list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=lr, weight_decay=wd ) sch1 = optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs) sch2 = optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs) for ep in range(1, epochs + 1): model.train() for x, y in train_loader: x = x.view(x.size(0), -1).to(dev); y = y.to(dev) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(-1); e_T[torch.arange(x.size(0)), y] -= 1 hL_det = hiddens[-1].detach() # Head update via true CE logits_out = model.out_head(model.out_ln(hL_det)) head_opt.zero_grad() F.cross_entropy(logits_out, y).backward() head_opt.step() # Embed update via DFA feedback a0 = (e_T @ Bs[0].T).detach() rms = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 h0 = model.embed(x) embed_loss = (h0 * (a0 / rms)).sum(-1).mean() embed_opt.zero_grad() embed_loss.backward() embed_opt.step() sch1.step(); sch2.step() if ep % 10 == 0 or ep == 1 or ep == epochs: acc = evaluate(model, test_loader, dev) print(f" [{label}] ep {ep}: test_acc={acc:.4f}", flush=True) return model def main(): parser = argparse.ArgumentParser() parser.add_argument('--seed', type=int, default=42) parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--wd', type=float, default=0.01) parser.add_argument('--d_hidden', type=int, default=256) parser.add_argument('--num_blocks', type=int, default=4) parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100']) args = parser.parse_args() dev = torch.device('cuda:0') print(f"Device: {dev}, seed={args.seed}, epochs={args.epochs}, dataset={args.dataset}", flush=True) (train_loader, test_loader), C = get_loaders(batch_size=128, dataset=args.dataset) results = {} input_dim = 32 * 32 * 3 # Condition 1: BP shallow (num_blocks=0) print(f"\n=== BP shallow (ResMLP num_blocks=0), seed={args.seed} ===", flush=True) torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) m = ResidualMLP(input_dim, args.d_hidden, C, 0).to(dev) print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True) train_bp(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'BP-shallow') results['bp_shallow'] = evaluate(m, test_loader, dev) print(f"FINAL BP-shallow: {results['bp_shallow']:.4f}", flush=True) # Condition 2: BP frozen-blocks (blocks frozen at random init) L = args.num_blocks print(f"\n=== BP frozen-blocks (ResMLP num_blocks={L}, blocks frozen), seed={args.seed} ===", flush=True) torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) m = ResidualMLP(input_dim, args.d_hidden, C, L).to(dev) freeze_blocks(m) print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True) train_bp(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'BP-frozen') results['bp_frozen'] = evaluate(m, test_loader, dev) print(f"FINAL BP-frozen-blocks: {results['bp_frozen']:.4f}", flush=True) # Condition 3: DFA shallow (num_blocks=0) print(f"\n=== DFA shallow (ResMLP num_blocks=0), seed={args.seed} ===", flush=True) torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) m = ResidualMLP(input_dim, args.d_hidden, C, 0).to(dev) print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True) train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-shallow', num_classes=C) results['dfa_shallow'] = evaluate(m, test_loader, dev) print(f"FINAL DFA-shallow: {results['dfa_shallow']:.4f}", flush=True) # Condition 4: DFA frozen-blocks (blocks frozen at random init) print(f"\n=== DFA frozen-blocks (ResMLP num_blocks={L}, blocks frozen), seed={args.seed} ===", flush=True) torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) m = ResidualMLP(input_dim, args.d_hidden, C, L).to(dev) freeze_blocks(m) print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True) train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-frozen', num_classes=C) results['dfa_frozen'] = evaluate(m, test_loader, dev) print(f"FINAL DFA-frozen-blocks: {results['dfa_frozen']:.4f}", flush=True) print(f"\n=== ResMLP frozen/shallow baseline summary, seed={args.seed} ===") print(f" BP-shallow: {results['bp_shallow']:.4f}") print(f" BP-frozen: {results['bp_frozen']:.4f}") print(f" DFA-shallow: {results['dfa_shallow']:.4f}") print(f" DFA-frozen: {results['dfa_frozen']:.4f}") print(f"") print(f"Compare to trainable 4-block ResMLP (3-seed): BP=0.6147 100ep / 0.585 30ep, DFA=0.306 100ep / 0.301 30ep") print(f"") print(f"Interpretation:") print(f" If DFA-frozen ≈ DFA-trainable: blocks are passengers, walk-back parallels ViT") print(f" If DFA-frozen << DFA-trainable: ResMLP DFA actually trains the blocks (interesting contrast with ViT)") if __name__ == '__main__': main()