summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-06-14 20:32:31 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-06-14 20:32:31 -0500
commit1118b7457c261de36ead6103503c00c321c75f9b (patch)
tree7ea76b32f070cb58458caaa2897a5d8133561f48 /experiments
parentaa73718eb6427d7da3b9cb416275802d90c4b2ed (diff)
Depth-utility ladder: trainable-block sweep (BP/FA/DFA) on ResMLP CIFAR-10HEADmaster
Appendix experiment triangulating the depth-utility diagnostic (D3) by varying the number of trainable residual blocks k (last-k trainable, first L-k frozen at init; embed/LN/head always trained). - d=256 L=4 and d=512 L=2, 3 seeds, recipe identical to the main audit. - BP climbs monotonically (+22-23pp); DFA peaks at the frozen baseline (k=0) and declines once any deep block is trained; FA shows partial/no net depth utility. - Cross-checks reproduce existing anchors (BP 0.617, DFA 0.301, FA 0.402, frozen 0.349). - frozen_init_identity_check quantifies frozen stack as a near-norm-preserving random feature map (per-block ||f||/||h||~0.10, stack cos 0.981), explaining the above-chance k=0 rung. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
-rw-r--r--experiments/depth_utility_ladder.py317
-rw-r--r--experiments/frozen_init_identity_check.py82
-rw-r--r--experiments/plot_depth_ladder.py63
3 files changed, 462 insertions, 0 deletions
diff --git a/experiments/depth_utility_ladder.py b/experiments/depth_utility_ladder.py
new file mode 100644
index 0000000..c9de9e9
--- /dev/null
+++ b/experiments/depth_utility_ladder.py
@@ -0,0 +1,317 @@
+"""
+Depth-utility ladder (appendix experiment for the FA-evaluation E&D paper).
+
+Turns the binary frozen-vs-trained block comparison into a CURVE: vary the number
+of trainable residual blocks k, training the LAST k blocks (output side) and
+freezing the first L-k at random init. Embedding / out_ln / out_head are ALWAYS
+trained. Credit still propagates through frozen blocks (forward + FA feedback
+matrices unchanged); only their weights stay at init.
+
+Question. As more blocks are made trainable, does test accuracy rise?
+ - BP (positive control): should climb monotonically with k.
+ - FA (Lillicrap vanilla): modest climb where depth is usable, flat where not.
+ - DFA (direct FA): flat at / below the frozen baseline (deep credit
+ is non-functional -> the D3 failure at every k).
+
+Output-side-first is deliberate: the deepest block receives the most direct
+credit (FA's last block sees the exact output gradient), so it is the BEST case
+for the method. If even these blocks add nothing, depth is unused.
+
+Recipe is identical to the main CIFAR audit (cifar_resmlp.py): AdamW, lr 1e-3,
+wd 0.01, cosine, batch 128, 100 epochs, per-block independent optimizers and
+rms-normalized local surrogate losses.
+
+k=0 reproduces the frozen-blocks baseline; k=L reproduces the full audit.
+
+Usage:
+ CUDA_VISIBLE_DEVICES=2 python experiments/depth_utility_ladder.py \
+ --d_hidden 256 --num_blocks 4 --dataset cifar10 \
+ --methods bp fa dfa --k_values 0 1 2 3 4 --seeds 42 123 456 \
+ --epochs 100 --output_dir results/depth_ladder
+"""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from models.residual_mlp import ResidualMLP
+
+
+# ---------------------------------------------------------------------------
+# Data / eval
+# ---------------------------------------------------------------------------
+def get_data(dataset, batch_size=128):
+ if dataset == 'cifar100':
+ mean, std = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)
+ DatasetClass, num_classes, input_dim = torchvision.datasets.CIFAR100, 100, 32 * 32 * 3
+ else:
+ mean, std = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)
+ DatasetClass, num_classes, input_dim = torchvision.datasets.CIFAR10, 10, 32 * 32 * 3
+ tf_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize(mean, std),
+ ])
+ tf_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
+ tr = DatasetClass('./data', True, download=True, transform=tf_train)
+ te = DatasetClass('./data', False, download=True, transform=tf_test)
+ return (
+ DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True),
+ DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True),
+ input_dim, num_classes,
+ )
+
+
+def evaluate(model, loader, dev):
+ model.eval()
+ c = n = 0
+ with torch.no_grad():
+ for x, y in loader:
+ x = x.view(x.size(0), -1).to(dev); y = y.to(dev)
+ c += (model(x).argmax(-1) == y).sum().item()
+ n += x.size(0)
+ return c / n
+
+
+def freeze_first(model, k):
+ """Freeze the first L-k blocks (indices 0 .. L-k-1); leave the last k trainable.
+ Returns the set of trainable block indices."""
+ L = model.num_blocks
+ n_frozen = L - k
+ trainable = set(range(n_frozen, L))
+ for l, block in enumerate(model.blocks):
+ req = l in trainable
+ for p in block.parameters():
+ p.requires_grad_(req)
+ return trainable
+
+
+# ---------------------------------------------------------------------------
+# Trainers (freeze-aware ports of cifar_resmlp.py)
+# ---------------------------------------------------------------------------
+def train_bp(model, train_loader, test_loader, dev, args, trainable):
+ """End-to-end BP; optimizer filters to requires_grad params (frozen blocks excluded).
+ Gradients still flow THROUGH frozen blocks to reach trainable blocks / embed."""
+ opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
+ lr=args.lr, weight_decay=args.wd)
+ sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=args.epochs)
+ curve = []
+ for ep in range(1, args.epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(dev); y = y.to(dev)
+ loss = F.cross_entropy(model(x), y)
+ opt.zero_grad(); loss.backward(); opt.step()
+ sch.step()
+ if ep % 10 == 0 or ep == 1 or ep == args.epochs:
+ acc = evaluate(model, test_loader, dev)
+ curve.append((ep, acc))
+ print(f" [BP k] ep {ep}: test={acc:.4f}", flush=True)
+ return curve
+
+
+def train_dfa(model, train_loader, test_loader, dev, args, trainable):
+ """DFA: each block reads output error directly via B_l (no sequential propagation).
+ Only TRAINABLE blocks are updated; embed / out_ln / out_head always trained."""
+ d, C, L = model.d_hidden, args.num_classes, model.num_blocks
+ Bs = [torch.randn(d, C, device=dev) / np.sqrt(C) for _ in range(L)]
+
+ block_opts = {l: optim.AdamW(model.blocks[l].parameters(), lr=args.lr, weight_decay=args.wd)
+ for l in sorted(trainable)}
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd)
+ head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=args.lr, weight_decay=args.wd)
+ scheds = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs)
+ for o in list(block_opts.values()) + [embed_opt, head_opt]]
+
+ curve = []
+ for ep in range(1, args.epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(dev); y = y.to(dev)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1
+
+ # head: exact CE, h_L detached
+ hL = hiddens[-1].detach()
+ head_opt.zero_grad()
+ F.cross_entropy(model.out_head(model.out_ln(hL)), y).backward()
+ head_opt.step()
+
+ # trainable blocks: DFA local surrogate
+ for l in sorted(trainable):
+ a = (e_T @ Bs[l].T).detach()
+ a = a / ((a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6)
+ f_l = model.blocks[l](hiddens[l].detach())
+ local = (f_l * a).sum(-1).mean()
+ block_opts[l].zero_grad(); local.backward(); block_opts[l].step()
+
+ # embed: DFA credit at h_0
+ a0 = (e_T @ Bs[0].T).detach()
+ a0 = a0 / ((a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6)
+ embed_loss = (model.embed(x) * a0).sum(-1).mean()
+ embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step()
+
+ for s in scheds:
+ s.step()
+ if ep % 10 == 0 or ep == 1 or ep == args.epochs:
+ acc = evaluate(model, test_loader, dev)
+ curve.append((ep, acc))
+ print(f" [DFA k] ep {ep}: test={acc:.4f}", flush=True)
+ return curve
+
+
+def train_fa(model, train_loader, test_loader, dev, args, trainable):
+ """Vanilla FA: credit propagates sequentially backward via fixed d×d B_l.
+ Frozen blocks STILL propagate credit (a_credit = a_credit @ B_l) so trainable
+ blocks / embed downstream receive it; only their weight update is skipped."""
+ d, C, L = model.d_hidden, args.num_classes, model.num_blocks
+ Bs = [torch.randn(d, d, device=dev) / np.sqrt(d) for _ in range(L)]
+
+ block_opts = {l: optim.AdamW(model.blocks[l].parameters(), lr=args.lr, weight_decay=args.wd)
+ for l in sorted(trainable)}
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd)
+ head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=args.lr, weight_decay=args.wd)
+ scheds = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs)
+ for o in list(block_opts.values()) + [embed_opt, head_opt]]
+
+ curve = []
+ for ep in range(1, args.epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(dev); y = y.to(dev)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+
+ # head: exact CE; a_credit = exact gradient at h_L (FA's starting credit)
+ hL = hiddens[-1].detach().requires_grad_(True)
+ head_opt.zero_grad()
+ F.cross_entropy(model.out_head(model.out_ln(hL)), y).backward()
+ head_opt.step()
+ a_credit = hL.grad.detach()
+
+ # blocks backward: update only trainable; ALWAYS propagate credit
+ for l in range(L - 1, -1, -1):
+ if l in trainable:
+ a = a_credit / ((a_credit ** 2).mean(-1, keepdim=True).sqrt() + 1e-6)
+ f_l = model.blocks[l](hiddens[l].detach())
+ local = (f_l * a).sum(-1).mean()
+ block_opts[l].zero_grad(); local.backward(); block_opts[l].step()
+ a_credit = (a_credit @ Bs[l]).detach()
+
+ # embed: FA credit at h_0
+ a0 = a_credit / ((a_credit ** 2).mean(-1, keepdim=True).sqrt() + 1e-6)
+ embed_loss = (model.embed(x) * a0).sum(-1).mean()
+ embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step()
+
+ for s in scheds:
+ s.step()
+ if ep % 10 == 0 or ep == 1 or ep == args.epochs:
+ acc = evaluate(model, test_loader, dev)
+ curve.append((ep, acc))
+ print(f" [FA k] ep {ep}: test={acc:.4f}", flush=True)
+ return curve
+
+
+TRAINERS = {'bp': train_bp, 'dfa': train_dfa, 'fa': train_fa}
+
+
+# ---------------------------------------------------------------------------
+# Driver
+# ---------------------------------------------------------------------------
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--d_hidden', type=int, default=256)
+ p.add_argument('--num_blocks', type=int, default=4)
+ p.add_argument('--dataset', type=str, default='cifar10')
+ p.add_argument('--methods', type=str, nargs='+', default=['bp', 'fa', 'dfa'])
+ p.add_argument('--k_values', type=int, nargs='+', default=[0, 1, 2, 3, 4])
+ p.add_argument('--seeds', type=int, nargs='+', default=[42, 123, 456])
+ p.add_argument('--epochs', type=int, default=100)
+ p.add_argument('--lr', type=float, default=1e-3)
+ p.add_argument('--wd', type=float, default=0.01)
+ p.add_argument('--batch_size', type=int, default=128)
+ p.add_argument('--gpu', type=int, default=0)
+ p.add_argument('--output_dir', type=str, default='results/depth_ladder')
+ args = p.parse_args()
+
+ dev = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ os.makedirs(args.output_dir, exist_ok=True)
+ L = args.num_blocks
+ tag = f"ladder_d{args.d_hidden}_L{L}_{args.dataset}"
+ out_path = os.path.join(args.output_dir, f"{tag}.json")
+ print(f"Device={dev} {tag} methods={args.methods} k={args.k_values} seeds={args.seeds} "
+ f"epochs={args.epochs}", flush=True)
+
+ # incremental results: results[method][k][seed] = {final_acc, curve}
+ results = {}
+ if os.path.exists(out_path):
+ with open(out_path) as f:
+ results = json.load(f).get('results', {})
+ print(f"Resuming; existing keys: "
+ f"{[(m, list(results[m].keys())) for m in results]}", flush=True)
+
+ def save():
+ with open(out_path, 'w') as f:
+ json.dump({'config': vars(args), 'results': results}, f, indent=2)
+
+ for method in args.methods:
+ results.setdefault(method, {})
+ for k in args.k_values:
+ if k > L:
+ continue
+ results[method].setdefault(str(k), {})
+ for seed in args.seeds:
+ if str(seed) in results[method][str(k)]:
+ print(f" skip {method} k={k} seed={seed} (done)", flush=True)
+ continue
+ print(f"\n=== {method.upper()} k={k} (last {k} of {L} trainable) "
+ f"seed={seed} ===", flush=True)
+ torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
+ train_loader, test_loader, input_dim, num_classes = get_data(args.dataset, args.batch_size)
+ args.num_classes = num_classes
+
+ model = ResidualMLP(input_dim, args.d_hidden, num_classes, L).to(dev)
+ trainable = freeze_first(model, k)
+ n_train = sum(pp.numel() for pp in model.parameters() if pp.requires_grad)
+ print(f" trainable blocks: {sorted(trainable)} "
+ f"trainable params: {n_train:,}", flush=True)
+
+ curve = TRAINERS[method](model, train_loader, test_loader, dev, args, trainable)
+ final_acc = evaluate(model, test_loader, dev)
+ results[method][str(k)][str(seed)] = {'final_acc': final_acc, 'curve': curve}
+ print(f" FINAL {method} k={k} seed={seed}: {final_acc:.4f}", flush=True)
+ save()
+
+ # summary table
+ print(f"\n{'='*60}\nSUMMARY {tag} (mean ± ddof-1 std over seeds)\n{'='*60}", flush=True)
+ for method in args.methods:
+ row = []
+ for k in args.k_values:
+ if k > L:
+ continue
+ accs = [v['final_acc'] for v in results[method][str(k)].values()]
+ if accs:
+ m = float(np.mean(accs)); s = float(np.std(accs, ddof=1)) if len(accs) > 1 else 0.0
+ row.append(f"k={k}: {m:.4f}±{s:.4f}")
+ print(f" {method.upper():4s} " + " ".join(row), flush=True)
+ save()
+ print(f"\nSaved -> {out_path}", flush=True)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/experiments/frozen_init_identity_check.py b/experiments/frozen_init_identity_check.py
new file mode 100644
index 0000000..3f58d7d
--- /dev/null
+++ b/experiments/frozen_init_identity_check.py
@@ -0,0 +1,82 @@
+"""
+Frozen-init identity check (supporting measurement for the depth-utility ladder).
+
+Quantifies how close a randomly-initialized, frozen ResidualMLP block stack is to
+the identity map. This grounds the footnote explaining why the k=0 rung of the
+ladder (all blocks frozen at init) already sits well above chance: the trained
+embedding + readout are composed with a fixed, near-norm-preserving random feature
+map, i.e. effectively a trained (near-)linear classifier on pixels.
+
+Reports, at random init, on a CIFAR-10 test batch (mean over seeds):
+ - per-block residual ratio ||f_l(h_l)|| / ||h_l|| (median over batch)
+ - whole-stack deviation ||h_L - h_0|| / ||h_0|| (median over batch)
+ - whole-stack direction cos(h_L, h_0) (median over batch)
+
+Usage:
+ CUDA_VISIBLE_DEVICES=2 python experiments/frozen_init_identity_check.py
+"""
+import os, sys, json
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torchvision
+import torchvision.transforms as transforms
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from models.residual_mlp import ResidualMLP
+
+
+def main():
+ d_hidden, L, C, n = 256, 4, 10, 256
+ seeds = [42, 123, 456]
+ tf = transforms.Compose([transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465),
+ (0.2470, 0.2435, 0.2616))])
+ ds = torchvision.datasets.CIFAR10('./data', train=False, download=True, transform=tf)
+ x = torch.stack([ds[i][0] for i in range(n)]).view(n, -1)
+
+ per_block, rel_dev, cos_dev = [], [], []
+ seed_rows = {}
+ for seed in seeds:
+ torch.manual_seed(seed); np.random.seed(seed)
+ m = ResidualMLP(32 * 32 * 3, d_hidden, C, L).eval()
+ with torch.no_grad():
+ h0 = m.embed(x); h = h0; ratios = []
+ for blk in m.blocks:
+ f = blk(h)
+ ratios.append(float((f.norm(dim=-1) / h.norm(dim=-1)).median()))
+ h = h + f
+ rel = float(((h - h0).norm(dim=-1) / h0.norm(dim=-1)).median())
+ cos = float(F.cosine_similarity(h, h0, dim=-1).median())
+ per_block.append(ratios); rel_dev.append(rel); cos_dev.append(cos)
+ seed_rows[str(seed)] = {'per_block_ratio': ratios, 'rel_dev': rel, 'cos': cos}
+ print(f"seed {seed}: per-block ||f||/||h|| = "
+ f"{['%.4f' % r for r in ratios]} "
+ f"||h_L-h_0||/||h_0|| = {rel:.3f} cos(h_L,h_0) = {cos:.4f}", flush=True)
+
+ pb = np.array(per_block)
+ summary = {
+ 'config': {'d_hidden': d_hidden, 'L': L, 'num_classes': C, 'batch': n,
+ 'dataset': 'cifar10-test', 'seeds': seeds},
+ 'per_seed': seed_rows,
+ 'per_block_ratio_mean': pb.mean(0).tolist(),
+ 'per_block_ratio_grand_mean': float(pb.mean()),
+ 'rel_dev_mean': float(np.mean(rel_dev)),
+ 'rel_dev_std': float(np.std(rel_dev, ddof=1)),
+ 'cos_mean': float(np.mean(cos_dev)),
+ 'cos_std': float(np.std(cos_dev, ddof=1)),
+ }
+ print(f"\nMEAN over {len(seeds)} seeds: "
+ f"per-block ratio ≈ {summary['per_block_ratio_grand_mean']:.3f}, "
+ f"||h_L-h_0||/||h_0|| = {summary['rel_dev_mean']:.3f} ± {summary['rel_dev_std']:.3f}, "
+ f"cos = {summary['cos_mean']:.4f} ± {summary['cos_std']:.4f}", flush=True)
+
+ out = 'results/depth_ladder/frozen_init_identity.json'
+ os.makedirs(os.path.dirname(out), exist_ok=True)
+ with open(out, 'w') as f:
+ json.dump(summary, f, indent=2)
+ print(f"Saved -> {out}", flush=True)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/experiments/plot_depth_ladder.py b/experiments/plot_depth_ladder.py
new file mode 100644
index 0000000..a5709bf
--- /dev/null
+++ b/experiments/plot_depth_ladder.py
@@ -0,0 +1,63 @@
+"""
+Plot the depth-utility ladder: test accuracy vs number of trainable blocks k,
+one curve per method (BP / FA / DFA), one panel per architecture.
+
+Usage:
+ python experiments/plot_depth_ladder.py
+"""
+import os, sys, json
+import numpy as np
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+
+CONFIGS = [
+ ('results/depth_ladder/ladder_d256_L4_cifar10.json', 'ResMLP d=256, L=4', 4),
+ ('results/depth_ladder/ladder_d512_L2_cifar10.json', 'ResMLP d=512, L=2', 2),
+]
+METHODS = [('bp', 'BP', 'tab:green', 'o'),
+ ('fa', 'FA', 'tab:orange', 's'),
+ ('dfa', 'DFA', 'tab:red', '^')]
+
+
+def agg(path, L):
+ d = json.load(open(path))['results']
+ out = {}
+ for m, _, _, _ in METHODS:
+ ks, mu, sd = [], [], []
+ for k in range(L + 1):
+ a = [v['final_acc'] for v in d[m][str(k)].values()]
+ ks.append(k); mu.append(np.mean(a))
+ sd.append(np.std(a, ddof=1) if len(a) > 1 else 0.0)
+ out[m] = (np.array(ks), np.array(mu), np.array(sd))
+ return out
+
+
+def main():
+ fig, axes = plt.subplots(1, len(CONFIGS), figsize=(11, 4.2))
+ if len(CONFIGS) == 1:
+ axes = [axes]
+ for ax, (path, title, L) in zip(axes, CONFIGS):
+ data = agg(path, L)
+ for m, label, color, mk in METHODS:
+ ks, mu, sd = data[m]
+ ax.errorbar(ks, mu, yerr=sd, marker=mk, color=color, label=label,
+ capsize=3, lw=2, ms=7)
+ # frozen baseline reference (k=0, averaged across methods is ~chance-of-readout)
+ ax.axhline(0.10, ls=':', color='gray', lw=1)
+ ax.text(0.02, 0.105, 'chance', color='gray', fontsize=8, transform=ax.get_yaxis_transform())
+ ax.set_xlabel('trainable blocks $k$ (last $k$ of $L$)')
+ ax.set_ylabel('CIFAR-10 test accuracy')
+ ax.set_title(title)
+ ax.set_xticks(range(L + 1))
+ ax.grid(alpha=0.3)
+ ax.legend(loc='center right')
+ fig.suptitle('Depth-utility ladder: does training deeper blocks raise accuracy?', y=1.02)
+ fig.tight_layout()
+ out = 'results/depth_ladder/depth_ladder.png'
+ fig.savefig(out, dpi=150, bbox_inches='tight')
+ print(f"Saved -> {out}")
+
+
+if __name__ == '__main__':
+ main()