From 1118b7457c261de36ead6103503c00c321c75f9b Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Sun, 14 Jun 2026 20:32:31 -0500 Subject: Depth-utility ladder: trainable-block sweep (BP/FA/DFA) on ResMLP CIFAR-10 Appendix experiment triangulating the depth-utility diagnostic (D3) by varying the number of trainable residual blocks k (last-k trainable, first L-k frozen at init; embed/LN/head always trained). - d=256 L=4 and d=512 L=2, 3 seeds, recipe identical to the main audit. - BP climbs monotonically (+22-23pp); DFA peaks at the frozen baseline (k=0) and declines once any deep block is trained; FA shows partial/no net depth utility. - Cross-checks reproduce existing anchors (BP 0.617, DFA 0.301, FA 0.402, frozen 0.349). - frozen_init_identity_check quantifies frozen stack as a near-norm-preserving random feature map (per-block ||f||/||h||~0.10, stack cos 0.981), explaining the above-chance k=0 rung. Co-Authored-By: Claude Opus 4.8 (1M context) --- experiments/depth_utility_ladder.py | 317 ++++++++++++++++++++++++++++++ experiments/frozen_init_identity_check.py | 82 ++++++++ experiments/plot_depth_ladder.py | 63 ++++++ 3 files changed, 462 insertions(+) create mode 100644 experiments/depth_utility_ladder.py create mode 100644 experiments/frozen_init_identity_check.py create mode 100644 experiments/plot_depth_ladder.py (limited to 'experiments') diff --git a/experiments/depth_utility_ladder.py b/experiments/depth_utility_ladder.py new file mode 100644 index 0000000..c9de9e9 --- /dev/null +++ b/experiments/depth_utility_ladder.py @@ -0,0 +1,317 @@ +""" +Depth-utility ladder (appendix experiment for the FA-evaluation E&D paper). + +Turns the binary frozen-vs-trained block comparison into a CURVE: vary the number +of trainable residual blocks k, training the LAST k blocks (output side) and +freezing the first L-k at random init. Embedding / out_ln / out_head are ALWAYS +trained. Credit still propagates through frozen blocks (forward + FA feedback +matrices unchanged); only their weights stay at init. + +Question. As more blocks are made trainable, does test accuracy rise? + - BP (positive control): should climb monotonically with k. + - FA (Lillicrap vanilla): modest climb where depth is usable, flat where not. + - DFA (direct FA): flat at / below the frozen baseline (deep credit + is non-functional -> the D3 failure at every k). + +Output-side-first is deliberate: the deepest block receives the most direct +credit (FA's last block sees the exact output gradient), so it is the BEST case +for the method. If even these blocks add nothing, depth is unused. + +Recipe is identical to the main CIFAR audit (cifar_resmlp.py): AdamW, lr 1e-3, +wd 0.01, cosine, batch 128, 100 epochs, per-block independent optimizers and +rms-normalized local surrogate losses. + +k=0 reproduces the frozen-blocks baseline; k=L reproduces the full audit. + +Usage: + CUDA_VISIBLE_DEVICES=2 python experiments/depth_utility_ladder.py \ + --d_hidden 256 --num_blocks 4 --dataset cifar10 \ + --methods bp fa dfa --k_values 0 1 2 3 4 --seeds 42 123 456 \ + --epochs 100 --output_dir results/depth_ladder +""" +import os +import sys +import json +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision +import torchvision.transforms as transforms + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from models.residual_mlp import ResidualMLP + + +# --------------------------------------------------------------------------- +# Data / eval +# --------------------------------------------------------------------------- +def get_data(dataset, batch_size=128): + if dataset == 'cifar100': + mean, std = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761) + DatasetClass, num_classes, input_dim = torchvision.datasets.CIFAR100, 100, 32 * 32 * 3 + else: + mean, std = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616) + DatasetClass, num_classes, input_dim = torchvision.datasets.CIFAR10, 10, 32 * 32 * 3 + tf_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean, std), + ]) + tf_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) + tr = DatasetClass('./data', True, download=True, transform=tf_train) + te = DatasetClass('./data', False, download=True, transform=tf_test) + return ( + DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True), + DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True), + input_dim, num_classes, + ) + + +def evaluate(model, loader, dev): + model.eval() + c = n = 0 + with torch.no_grad(): + for x, y in loader: + x = x.view(x.size(0), -1).to(dev); y = y.to(dev) + c += (model(x).argmax(-1) == y).sum().item() + n += x.size(0) + return c / n + + +def freeze_first(model, k): + """Freeze the first L-k blocks (indices 0 .. L-k-1); leave the last k trainable. + Returns the set of trainable block indices.""" + L = model.num_blocks + n_frozen = L - k + trainable = set(range(n_frozen, L)) + for l, block in enumerate(model.blocks): + req = l in trainable + for p in block.parameters(): + p.requires_grad_(req) + return trainable + + +# --------------------------------------------------------------------------- +# Trainers (freeze-aware ports of cifar_resmlp.py) +# --------------------------------------------------------------------------- +def train_bp(model, train_loader, test_loader, dev, args, trainable): + """End-to-end BP; optimizer filters to requires_grad params (frozen blocks excluded). + Gradients still flow THROUGH frozen blocks to reach trainable blocks / embed.""" + opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), + lr=args.lr, weight_decay=args.wd) + sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=args.epochs) + curve = [] + for ep in range(1, args.epochs + 1): + model.train() + for x, y in train_loader: + x = x.view(x.size(0), -1).to(dev); y = y.to(dev) + loss = F.cross_entropy(model(x), y) + opt.zero_grad(); loss.backward(); opt.step() + sch.step() + if ep % 10 == 0 or ep == 1 or ep == args.epochs: + acc = evaluate(model, test_loader, dev) + curve.append((ep, acc)) + print(f" [BP k] ep {ep}: test={acc:.4f}", flush=True) + return curve + + +def train_dfa(model, train_loader, test_loader, dev, args, trainable): + """DFA: each block reads output error directly via B_l (no sequential propagation). + Only TRAINABLE blocks are updated; embed / out_ln / out_head always trained.""" + d, C, L = model.d_hidden, args.num_classes, model.num_blocks + Bs = [torch.randn(d, C, device=dev) / np.sqrt(C) for _ in range(L)] + + block_opts = {l: optim.AdamW(model.blocks[l].parameters(), lr=args.lr, weight_decay=args.wd) + for l in sorted(trainable)} + embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd) + head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=args.lr, weight_decay=args.wd) + scheds = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) + for o in list(block_opts.values()) + [embed_opt, head_opt]] + + curve = [] + for ep in range(1, args.epochs + 1): + model.train() + for x, y in train_loader: + x = x.view(x.size(0), -1).to(dev); y = y.to(dev) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1 + + # head: exact CE, h_L detached + hL = hiddens[-1].detach() + head_opt.zero_grad() + F.cross_entropy(model.out_head(model.out_ln(hL)), y).backward() + head_opt.step() + + # trainable blocks: DFA local surrogate + for l in sorted(trainable): + a = (e_T @ Bs[l].T).detach() + a = a / ((a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6) + f_l = model.blocks[l](hiddens[l].detach()) + local = (f_l * a).sum(-1).mean() + block_opts[l].zero_grad(); local.backward(); block_opts[l].step() + + # embed: DFA credit at h_0 + a0 = (e_T @ Bs[0].T).detach() + a0 = a0 / ((a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6) + embed_loss = (model.embed(x) * a0).sum(-1).mean() + embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step() + + for s in scheds: + s.step() + if ep % 10 == 0 or ep == 1 or ep == args.epochs: + acc = evaluate(model, test_loader, dev) + curve.append((ep, acc)) + print(f" [DFA k] ep {ep}: test={acc:.4f}", flush=True) + return curve + + +def train_fa(model, train_loader, test_loader, dev, args, trainable): + """Vanilla FA: credit propagates sequentially backward via fixed d×d B_l. + Frozen blocks STILL propagate credit (a_credit = a_credit @ B_l) so trainable + blocks / embed downstream receive it; only their weight update is skipped.""" + d, C, L = model.d_hidden, args.num_classes, model.num_blocks + Bs = [torch.randn(d, d, device=dev) / np.sqrt(d) for _ in range(L)] + + block_opts = {l: optim.AdamW(model.blocks[l].parameters(), lr=args.lr, weight_decay=args.wd) + for l in sorted(trainable)} + embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd) + head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=args.lr, weight_decay=args.wd) + scheds = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) + for o in list(block_opts.values()) + [embed_opt, head_opt]] + + curve = [] + for ep in range(1, args.epochs + 1): + model.train() + for x, y in train_loader: + x = x.view(x.size(0), -1).to(dev); y = y.to(dev) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + + # head: exact CE; a_credit = exact gradient at h_L (FA's starting credit) + hL = hiddens[-1].detach().requires_grad_(True) + head_opt.zero_grad() + F.cross_entropy(model.out_head(model.out_ln(hL)), y).backward() + head_opt.step() + a_credit = hL.grad.detach() + + # blocks backward: update only trainable; ALWAYS propagate credit + for l in range(L - 1, -1, -1): + if l in trainable: + a = a_credit / ((a_credit ** 2).mean(-1, keepdim=True).sqrt() + 1e-6) + f_l = model.blocks[l](hiddens[l].detach()) + local = (f_l * a).sum(-1).mean() + block_opts[l].zero_grad(); local.backward(); block_opts[l].step() + a_credit = (a_credit @ Bs[l]).detach() + + # embed: FA credit at h_0 + a0 = a_credit / ((a_credit ** 2).mean(-1, keepdim=True).sqrt() + 1e-6) + embed_loss = (model.embed(x) * a0).sum(-1).mean() + embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step() + + for s in scheds: + s.step() + if ep % 10 == 0 or ep == 1 or ep == args.epochs: + acc = evaluate(model, test_loader, dev) + curve.append((ep, acc)) + print(f" [FA k] ep {ep}: test={acc:.4f}", flush=True) + return curve + + +TRAINERS = {'bp': train_bp, 'dfa': train_dfa, 'fa': train_fa} + + +# --------------------------------------------------------------------------- +# Driver +# --------------------------------------------------------------------------- +def main(): + p = argparse.ArgumentParser() + p.add_argument('--d_hidden', type=int, default=256) + p.add_argument('--num_blocks', type=int, default=4) + p.add_argument('--dataset', type=str, default='cifar10') + p.add_argument('--methods', type=str, nargs='+', default=['bp', 'fa', 'dfa']) + p.add_argument('--k_values', type=int, nargs='+', default=[0, 1, 2, 3, 4]) + p.add_argument('--seeds', type=int, nargs='+', default=[42, 123, 456]) + p.add_argument('--epochs', type=int, default=100) + p.add_argument('--lr', type=float, default=1e-3) + p.add_argument('--wd', type=float, default=0.01) + p.add_argument('--batch_size', type=int, default=128) + p.add_argument('--gpu', type=int, default=0) + p.add_argument('--output_dir', type=str, default='results/depth_ladder') + args = p.parse_args() + + dev = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + os.makedirs(args.output_dir, exist_ok=True) + L = args.num_blocks + tag = f"ladder_d{args.d_hidden}_L{L}_{args.dataset}" + out_path = os.path.join(args.output_dir, f"{tag}.json") + print(f"Device={dev} {tag} methods={args.methods} k={args.k_values} seeds={args.seeds} " + f"epochs={args.epochs}", flush=True) + + # incremental results: results[method][k][seed] = {final_acc, curve} + results = {} + if os.path.exists(out_path): + with open(out_path) as f: + results = json.load(f).get('results', {}) + print(f"Resuming; existing keys: " + f"{[(m, list(results[m].keys())) for m in results]}", flush=True) + + def save(): + with open(out_path, 'w') as f: + json.dump({'config': vars(args), 'results': results}, f, indent=2) + + for method in args.methods: + results.setdefault(method, {}) + for k in args.k_values: + if k > L: + continue + results[method].setdefault(str(k), {}) + for seed in args.seeds: + if str(seed) in results[method][str(k)]: + print(f" skip {method} k={k} seed={seed} (done)", flush=True) + continue + print(f"\n=== {method.upper()} k={k} (last {k} of {L} trainable) " + f"seed={seed} ===", flush=True) + torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) + train_loader, test_loader, input_dim, num_classes = get_data(args.dataset, args.batch_size) + args.num_classes = num_classes + + model = ResidualMLP(input_dim, args.d_hidden, num_classes, L).to(dev) + trainable = freeze_first(model, k) + n_train = sum(pp.numel() for pp in model.parameters() if pp.requires_grad) + print(f" trainable blocks: {sorted(trainable)} " + f"trainable params: {n_train:,}", flush=True) + + curve = TRAINERS[method](model, train_loader, test_loader, dev, args, trainable) + final_acc = evaluate(model, test_loader, dev) + results[method][str(k)][str(seed)] = {'final_acc': final_acc, 'curve': curve} + print(f" FINAL {method} k={k} seed={seed}: {final_acc:.4f}", flush=True) + save() + + # summary table + print(f"\n{'='*60}\nSUMMARY {tag} (mean ± ddof-1 std over seeds)\n{'='*60}", flush=True) + for method in args.methods: + row = [] + for k in args.k_values: + if k > L: + continue + accs = [v['final_acc'] for v in results[method][str(k)].values()] + if accs: + m = float(np.mean(accs)); s = float(np.std(accs, ddof=1)) if len(accs) > 1 else 0.0 + row.append(f"k={k}: {m:.4f}±{s:.4f}") + print(f" {method.upper():4s} " + " ".join(row), flush=True) + save() + print(f"\nSaved -> {out_path}", flush=True) + + +if __name__ == '__main__': + main() diff --git a/experiments/frozen_init_identity_check.py b/experiments/frozen_init_identity_check.py new file mode 100644 index 0000000..3f58d7d --- /dev/null +++ b/experiments/frozen_init_identity_check.py @@ -0,0 +1,82 @@ +""" +Frozen-init identity check (supporting measurement for the depth-utility ladder). + +Quantifies how close a randomly-initialized, frozen ResidualMLP block stack is to +the identity map. This grounds the footnote explaining why the k=0 rung of the +ladder (all blocks frozen at init) already sits well above chance: the trained +embedding + readout are composed with a fixed, near-norm-preserving random feature +map, i.e. effectively a trained (near-)linear classifier on pixels. + +Reports, at random init, on a CIFAR-10 test batch (mean over seeds): + - per-block residual ratio ||f_l(h_l)|| / ||h_l|| (median over batch) + - whole-stack deviation ||h_L - h_0|| / ||h_0|| (median over batch) + - whole-stack direction cos(h_L, h_0) (median over batch) + +Usage: + CUDA_VISIBLE_DEVICES=2 python experiments/frozen_init_identity_check.py +""" +import os, sys, json +import numpy as np +import torch +import torch.nn.functional as F +import torchvision +import torchvision.transforms as transforms + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from models.residual_mlp import ResidualMLP + + +def main(): + d_hidden, L, C, n = 256, 4, 10, 256 + seeds = [42, 123, 456] + tf = transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), + (0.2470, 0.2435, 0.2616))]) + ds = torchvision.datasets.CIFAR10('./data', train=False, download=True, transform=tf) + x = torch.stack([ds[i][0] for i in range(n)]).view(n, -1) + + per_block, rel_dev, cos_dev = [], [], [] + seed_rows = {} + for seed in seeds: + torch.manual_seed(seed); np.random.seed(seed) + m = ResidualMLP(32 * 32 * 3, d_hidden, C, L).eval() + with torch.no_grad(): + h0 = m.embed(x); h = h0; ratios = [] + for blk in m.blocks: + f = blk(h) + ratios.append(float((f.norm(dim=-1) / h.norm(dim=-1)).median())) + h = h + f + rel = float(((h - h0).norm(dim=-1) / h0.norm(dim=-1)).median()) + cos = float(F.cosine_similarity(h, h0, dim=-1).median()) + per_block.append(ratios); rel_dev.append(rel); cos_dev.append(cos) + seed_rows[str(seed)] = {'per_block_ratio': ratios, 'rel_dev': rel, 'cos': cos} + print(f"seed {seed}: per-block ||f||/||h|| = " + f"{['%.4f' % r for r in ratios]} " + f"||h_L-h_0||/||h_0|| = {rel:.3f} cos(h_L,h_0) = {cos:.4f}", flush=True) + + pb = np.array(per_block) + summary = { + 'config': {'d_hidden': d_hidden, 'L': L, 'num_classes': C, 'batch': n, + 'dataset': 'cifar10-test', 'seeds': seeds}, + 'per_seed': seed_rows, + 'per_block_ratio_mean': pb.mean(0).tolist(), + 'per_block_ratio_grand_mean': float(pb.mean()), + 'rel_dev_mean': float(np.mean(rel_dev)), + 'rel_dev_std': float(np.std(rel_dev, ddof=1)), + 'cos_mean': float(np.mean(cos_dev)), + 'cos_std': float(np.std(cos_dev, ddof=1)), + } + print(f"\nMEAN over {len(seeds)} seeds: " + f"per-block ratio ≈ {summary['per_block_ratio_grand_mean']:.3f}, " + f"||h_L-h_0||/||h_0|| = {summary['rel_dev_mean']:.3f} ± {summary['rel_dev_std']:.3f}, " + f"cos = {summary['cos_mean']:.4f} ± {summary['cos_std']:.4f}", flush=True) + + out = 'results/depth_ladder/frozen_init_identity.json' + os.makedirs(os.path.dirname(out), exist_ok=True) + with open(out, 'w') as f: + json.dump(summary, f, indent=2) + print(f"Saved -> {out}", flush=True) + + +if __name__ == '__main__': + main() diff --git a/experiments/plot_depth_ladder.py b/experiments/plot_depth_ladder.py new file mode 100644 index 0000000..a5709bf --- /dev/null +++ b/experiments/plot_depth_ladder.py @@ -0,0 +1,63 @@ +""" +Plot the depth-utility ladder: test accuracy vs number of trainable blocks k, +one curve per method (BP / FA / DFA), one panel per architecture. + +Usage: + python experiments/plot_depth_ladder.py +""" +import os, sys, json +import numpy as np +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +CONFIGS = [ + ('results/depth_ladder/ladder_d256_L4_cifar10.json', 'ResMLP d=256, L=4', 4), + ('results/depth_ladder/ladder_d512_L2_cifar10.json', 'ResMLP d=512, L=2', 2), +] +METHODS = [('bp', 'BP', 'tab:green', 'o'), + ('fa', 'FA', 'tab:orange', 's'), + ('dfa', 'DFA', 'tab:red', '^')] + + +def agg(path, L): + d = json.load(open(path))['results'] + out = {} + for m, _, _, _ in METHODS: + ks, mu, sd = [], [], [] + for k in range(L + 1): + a = [v['final_acc'] for v in d[m][str(k)].values()] + ks.append(k); mu.append(np.mean(a)) + sd.append(np.std(a, ddof=1) if len(a) > 1 else 0.0) + out[m] = (np.array(ks), np.array(mu), np.array(sd)) + return out + + +def main(): + fig, axes = plt.subplots(1, len(CONFIGS), figsize=(11, 4.2)) + if len(CONFIGS) == 1: + axes = [axes] + for ax, (path, title, L) in zip(axes, CONFIGS): + data = agg(path, L) + for m, label, color, mk in METHODS: + ks, mu, sd = data[m] + ax.errorbar(ks, mu, yerr=sd, marker=mk, color=color, label=label, + capsize=3, lw=2, ms=7) + # frozen baseline reference (k=0, averaged across methods is ~chance-of-readout) + ax.axhline(0.10, ls=':', color='gray', lw=1) + ax.text(0.02, 0.105, 'chance', color='gray', fontsize=8, transform=ax.get_yaxis_transform()) + ax.set_xlabel('trainable blocks $k$ (last $k$ of $L$)') + ax.set_ylabel('CIFAR-10 test accuracy') + ax.set_title(title) + ax.set_xticks(range(L + 1)) + ax.grid(alpha=0.3) + ax.legend(loc='center right') + fig.suptitle('Depth-utility ladder: does training deeper blocks raise accuracy?', y=1.02) + fig.tight_layout() + out = 'results/depth_ladder/depth_ladder.png' + fig.savefig(out, dpi=150, bbox_inches='tight') + print(f"Saved -> {out}") + + +if __name__ == '__main__': + main() -- cgit v1.2.3