""" Depth-utility ladder (appendix experiment for the FA-evaluation E&D paper). Turns the binary frozen-vs-trained block comparison into a CURVE: vary the number of trainable residual blocks k, training the LAST k blocks (output side) and freezing the first L-k at random init. Embedding / out_ln / out_head are ALWAYS trained. Credit still propagates through frozen blocks (forward + FA feedback matrices unchanged); only their weights stay at init. Question. As more blocks are made trainable, does test accuracy rise? - BP (positive control): should climb monotonically with k. - FA (Lillicrap vanilla): modest climb where depth is usable, flat where not. - DFA (direct FA): flat at / below the frozen baseline (deep credit is non-functional -> the D3 failure at every k). Output-side-first is deliberate: the deepest block receives the most direct credit (FA's last block sees the exact output gradient), so it is the BEST case for the method. If even these blocks add nothing, depth is unused. Recipe is identical to the main CIFAR audit (cifar_resmlp.py): AdamW, lr 1e-3, wd 0.01, cosine, batch 128, 100 epochs, per-block independent optimizers and rms-normalized local surrogate losses. k=0 reproduces the frozen-blocks baseline; k=L reproduces the full audit. Usage: CUDA_VISIBLE_DEVICES=2 python experiments/depth_utility_ladder.py \ --d_hidden 256 --num_blocks 4 --dataset cifar10 \ --methods bp fa dfa --k_values 0 1 2 3 4 --seeds 42 123 456 \ --epochs 100 --output_dir results/depth_ladder """ import os import sys import json import argparse import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader import torchvision import torchvision.transforms as transforms sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.residual_mlp import ResidualMLP # --------------------------------------------------------------------------- # Data / eval # --------------------------------------------------------------------------- def get_data(dataset, batch_size=128): if dataset == 'cifar100': mean, std = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761) DatasetClass, num_classes, input_dim = torchvision.datasets.CIFAR100, 100, 32 * 32 * 3 else: mean, std = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616) DatasetClass, num_classes, input_dim = torchvision.datasets.CIFAR10, 10, 32 * 32 * 3 tf_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean, std), ]) tf_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) tr = DatasetClass('./data', True, download=True, transform=tf_train) te = DatasetClass('./data', False, download=True, transform=tf_test) return ( DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True), DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True), input_dim, num_classes, ) def evaluate(model, loader, dev): model.eval() c = n = 0 with torch.no_grad(): for x, y in loader: x = x.view(x.size(0), -1).to(dev); y = y.to(dev) c += (model(x).argmax(-1) == y).sum().item() n += x.size(0) return c / n def freeze_first(model, k): """Freeze the first L-k blocks (indices 0 .. L-k-1); leave the last k trainable. Returns the set of trainable block indices.""" L = model.num_blocks n_frozen = L - k trainable = set(range(n_frozen, L)) for l, block in enumerate(model.blocks): req = l in trainable for p in block.parameters(): p.requires_grad_(req) return trainable # --------------------------------------------------------------------------- # Trainers (freeze-aware ports of cifar_resmlp.py) # --------------------------------------------------------------------------- def train_bp(model, train_loader, test_loader, dev, args, trainable): """End-to-end BP; optimizer filters to requires_grad params (frozen blocks excluded). Gradients still flow THROUGH frozen blocks to reach trainable blocks / embed.""" opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wd) sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=args.epochs) curve = [] for ep in range(1, args.epochs + 1): model.train() for x, y in train_loader: x = x.view(x.size(0), -1).to(dev); y = y.to(dev) loss = F.cross_entropy(model(x), y) opt.zero_grad(); loss.backward(); opt.step() sch.step() if ep % 10 == 0 or ep == 1 or ep == args.epochs: acc = evaluate(model, test_loader, dev) curve.append((ep, acc)) print(f" [BP k] ep {ep}: test={acc:.4f}", flush=True) return curve def train_dfa(model, train_loader, test_loader, dev, args, trainable): """DFA: each block reads output error directly via B_l (no sequential propagation). Only TRAINABLE blocks are updated; embed / out_ln / out_head always trained.""" d, C, L = model.d_hidden, args.num_classes, model.num_blocks Bs = [torch.randn(d, C, device=dev) / np.sqrt(C) for _ in range(L)] block_opts = {l: optim.AdamW(model.blocks[l].parameters(), lr=args.lr, weight_decay=args.wd) for l in sorted(trainable)} embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd) head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=args.lr, weight_decay=args.wd) scheds = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in list(block_opts.values()) + [embed_opt, head_opt]] curve = [] for ep in range(1, args.epochs + 1): model.train() for x, y in train_loader: x = x.view(x.size(0), -1).to(dev); y = y.to(dev) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1 # head: exact CE, h_L detached hL = hiddens[-1].detach() head_opt.zero_grad() F.cross_entropy(model.out_head(model.out_ln(hL)), y).backward() head_opt.step() # trainable blocks: DFA local surrogate for l in sorted(trainable): a = (e_T @ Bs[l].T).detach() a = a / ((a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6) f_l = model.blocks[l](hiddens[l].detach()) local = (f_l * a).sum(-1).mean() block_opts[l].zero_grad(); local.backward(); block_opts[l].step() # embed: DFA credit at h_0 a0 = (e_T @ Bs[0].T).detach() a0 = a0 / ((a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6) embed_loss = (model.embed(x) * a0).sum(-1).mean() embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step() for s in scheds: s.step() if ep % 10 == 0 or ep == 1 or ep == args.epochs: acc = evaluate(model, test_loader, dev) curve.append((ep, acc)) print(f" [DFA k] ep {ep}: test={acc:.4f}", flush=True) return curve def train_fa(model, train_loader, test_loader, dev, args, trainable): """Vanilla FA: credit propagates sequentially backward via fixed d×d B_l. Frozen blocks STILL propagate credit (a_credit = a_credit @ B_l) so trainable blocks / embed downstream receive it; only their weight update is skipped.""" d, C, L = model.d_hidden, args.num_classes, model.num_blocks Bs = [torch.randn(d, d, device=dev) / np.sqrt(d) for _ in range(L)] block_opts = {l: optim.AdamW(model.blocks[l].parameters(), lr=args.lr, weight_decay=args.wd) for l in sorted(trainable)} embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd) head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=args.lr, weight_decay=args.wd) scheds = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in list(block_opts.values()) + [embed_opt, head_opt]] curve = [] for ep in range(1, args.epochs + 1): model.train() for x, y in train_loader: x = x.view(x.size(0), -1).to(dev); y = y.to(dev) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) # head: exact CE; a_credit = exact gradient at h_L (FA's starting credit) hL = hiddens[-1].detach().requires_grad_(True) head_opt.zero_grad() F.cross_entropy(model.out_head(model.out_ln(hL)), y).backward() head_opt.step() a_credit = hL.grad.detach() # blocks backward: update only trainable; ALWAYS propagate credit for l in range(L - 1, -1, -1): if l in trainable: a = a_credit / ((a_credit ** 2).mean(-1, keepdim=True).sqrt() + 1e-6) f_l = model.blocks[l](hiddens[l].detach()) local = (f_l * a).sum(-1).mean() block_opts[l].zero_grad(); local.backward(); block_opts[l].step() a_credit = (a_credit @ Bs[l]).detach() # embed: FA credit at h_0 a0 = a_credit / ((a_credit ** 2).mean(-1, keepdim=True).sqrt() + 1e-6) embed_loss = (model.embed(x) * a0).sum(-1).mean() embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step() for s in scheds: s.step() if ep % 10 == 0 or ep == 1 or ep == args.epochs: acc = evaluate(model, test_loader, dev) curve.append((ep, acc)) print(f" [FA k] ep {ep}: test={acc:.4f}", flush=True) return curve TRAINERS = {'bp': train_bp, 'dfa': train_dfa, 'fa': train_fa} # --------------------------------------------------------------------------- # Driver # --------------------------------------------------------------------------- def main(): p = argparse.ArgumentParser() p.add_argument('--d_hidden', type=int, default=256) p.add_argument('--num_blocks', type=int, default=4) p.add_argument('--dataset', type=str, default='cifar10') p.add_argument('--methods', type=str, nargs='+', default=['bp', 'fa', 'dfa']) p.add_argument('--k_values', type=int, nargs='+', default=[0, 1, 2, 3, 4]) p.add_argument('--seeds', type=int, nargs='+', default=[42, 123, 456]) p.add_argument('--epochs', type=int, default=100) p.add_argument('--lr', type=float, default=1e-3) p.add_argument('--wd', type=float, default=0.01) p.add_argument('--batch_size', type=int, default=128) p.add_argument('--gpu', type=int, default=0) p.add_argument('--output_dir', type=str, default='results/depth_ladder') args = p.parse_args() dev = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') os.makedirs(args.output_dir, exist_ok=True) L = args.num_blocks tag = f"ladder_d{args.d_hidden}_L{L}_{args.dataset}" out_path = os.path.join(args.output_dir, f"{tag}.json") print(f"Device={dev} {tag} methods={args.methods} k={args.k_values} seeds={args.seeds} " f"epochs={args.epochs}", flush=True) # incremental results: results[method][k][seed] = {final_acc, curve} results = {} if os.path.exists(out_path): with open(out_path) as f: results = json.load(f).get('results', {}) print(f"Resuming; existing keys: " f"{[(m, list(results[m].keys())) for m in results]}", flush=True) def save(): with open(out_path, 'w') as f: json.dump({'config': vars(args), 'results': results}, f, indent=2) for method in args.methods: results.setdefault(method, {}) for k in args.k_values: if k > L: continue results[method].setdefault(str(k), {}) for seed in args.seeds: if str(seed) in results[method][str(k)]: print(f" skip {method} k={k} seed={seed} (done)", flush=True) continue print(f"\n=== {method.upper()} k={k} (last {k} of {L} trainable) " f"seed={seed} ===", flush=True) torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) train_loader, test_loader, input_dim, num_classes = get_data(args.dataset, args.batch_size) args.num_classes = num_classes model = ResidualMLP(input_dim, args.d_hidden, num_classes, L).to(dev) trainable = freeze_first(model, k) n_train = sum(pp.numel() for pp in model.parameters() if pp.requires_grad) print(f" trainable blocks: {sorted(trainable)} " f"trainable params: {n_train:,}", flush=True) curve = TRAINERS[method](model, train_loader, test_loader, dev, args, trainable) final_acc = evaluate(model, test_loader, dev) results[method][str(k)][str(seed)] = {'final_acc': final_acc, 'curve': curve} print(f" FINAL {method} k={k} seed={seed}: {final_acc:.4f}", flush=True) save() # summary table print(f"\n{'='*60}\nSUMMARY {tag} (mean ± ddof-1 std over seeds)\n{'='*60}", flush=True) for method in args.methods: row = [] for k in args.k_values: if k > L: continue accs = [v['final_acc'] for v in results[method][str(k)].values()] if accs: m = float(np.mean(accs)); s = float(np.std(accs, ddof=1)) if len(accs) > 1 else 0.0 row.append(f"k={k}: {m:.4f}±{s:.4f}") print(f" {method.upper():4s} " + " ".join(row), flush=True) save() print(f"\nSaved -> {out_path}", flush=True) if __name__ == '__main__': main()