summaryrefslogtreecommitdiff
path: root/experiments/depth_utility_ladder.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-06-14 20:32:31 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-06-14 20:32:31 -0500
commit1118b7457c261de36ead6103503c00c321c75f9b (patch)
tree7ea76b32f070cb58458caaa2897a5d8133561f48 /experiments/depth_utility_ladder.py
parentaa73718eb6427d7da3b9cb416275802d90c4b2ed (diff)
Depth-utility ladder: trainable-block sweep (BP/FA/DFA) on ResMLP CIFAR-10HEADmaster
Appendix experiment triangulating the depth-utility diagnostic (D3) by varying the number of trainable residual blocks k (last-k trainable, first L-k frozen at init; embed/LN/head always trained). - d=256 L=4 and d=512 L=2, 3 seeds, recipe identical to the main audit. - BP climbs monotonically (+22-23pp); DFA peaks at the frozen baseline (k=0) and declines once any deep block is trained; FA shows partial/no net depth utility. - Cross-checks reproduce existing anchors (BP 0.617, DFA 0.301, FA 0.402, frozen 0.349). - frozen_init_identity_check quantifies frozen stack as a near-norm-preserving random feature map (per-block ||f||/||h||~0.10, stack cos 0.981), explaining the above-chance k=0 rung. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments/depth_utility_ladder.py')
-rw-r--r--experiments/depth_utility_ladder.py317
1 files changed, 317 insertions, 0 deletions
diff --git a/experiments/depth_utility_ladder.py b/experiments/depth_utility_ladder.py
new file mode 100644
index 0000000..c9de9e9
--- /dev/null
+++ b/experiments/depth_utility_ladder.py
@@ -0,0 +1,317 @@
+"""
+Depth-utility ladder (appendix experiment for the FA-evaluation E&D paper).
+
+Turns the binary frozen-vs-trained block comparison into a CURVE: vary the number
+of trainable residual blocks k, training the LAST k blocks (output side) and
+freezing the first L-k at random init. Embedding / out_ln / out_head are ALWAYS
+trained. Credit still propagates through frozen blocks (forward + FA feedback
+matrices unchanged); only their weights stay at init.
+
+Question. As more blocks are made trainable, does test accuracy rise?
+ - BP (positive control): should climb monotonically with k.
+ - FA (Lillicrap vanilla): modest climb where depth is usable, flat where not.
+ - DFA (direct FA): flat at / below the frozen baseline (deep credit
+ is non-functional -> the D3 failure at every k).
+
+Output-side-first is deliberate: the deepest block receives the most direct
+credit (FA's last block sees the exact output gradient), so it is the BEST case
+for the method. If even these blocks add nothing, depth is unused.
+
+Recipe is identical to the main CIFAR audit (cifar_resmlp.py): AdamW, lr 1e-3,
+wd 0.01, cosine, batch 128, 100 epochs, per-block independent optimizers and
+rms-normalized local surrogate losses.
+
+k=0 reproduces the frozen-blocks baseline; k=L reproduces the full audit.
+
+Usage:
+ CUDA_VISIBLE_DEVICES=2 python experiments/depth_utility_ladder.py \
+ --d_hidden 256 --num_blocks 4 --dataset cifar10 \
+ --methods bp fa dfa --k_values 0 1 2 3 4 --seeds 42 123 456 \
+ --epochs 100 --output_dir results/depth_ladder
+"""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from models.residual_mlp import ResidualMLP
+
+
+# ---------------------------------------------------------------------------
+# Data / eval
+# ---------------------------------------------------------------------------
+def get_data(dataset, batch_size=128):
+ if dataset == 'cifar100':
+ mean, std = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)
+ DatasetClass, num_classes, input_dim = torchvision.datasets.CIFAR100, 100, 32 * 32 * 3
+ else:
+ mean, std = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)
+ DatasetClass, num_classes, input_dim = torchvision.datasets.CIFAR10, 10, 32 * 32 * 3
+ tf_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize(mean, std),
+ ])
+ tf_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
+ tr = DatasetClass('./data', True, download=True, transform=tf_train)
+ te = DatasetClass('./data', False, download=True, transform=tf_test)
+ return (
+ DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True),
+ DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True),
+ input_dim, num_classes,
+ )
+
+
+def evaluate(model, loader, dev):
+ model.eval()
+ c = n = 0
+ with torch.no_grad():
+ for x, y in loader:
+ x = x.view(x.size(0), -1).to(dev); y = y.to(dev)
+ c += (model(x).argmax(-1) == y).sum().item()
+ n += x.size(0)
+ return c / n
+
+
+def freeze_first(model, k):
+ """Freeze the first L-k blocks (indices 0 .. L-k-1); leave the last k trainable.
+ Returns the set of trainable block indices."""
+ L = model.num_blocks
+ n_frozen = L - k
+ trainable = set(range(n_frozen, L))
+ for l, block in enumerate(model.blocks):
+ req = l in trainable
+ for p in block.parameters():
+ p.requires_grad_(req)
+ return trainable
+
+
+# ---------------------------------------------------------------------------
+# Trainers (freeze-aware ports of cifar_resmlp.py)
+# ---------------------------------------------------------------------------
+def train_bp(model, train_loader, test_loader, dev, args, trainable):
+ """End-to-end BP; optimizer filters to requires_grad params (frozen blocks excluded).
+ Gradients still flow THROUGH frozen blocks to reach trainable blocks / embed."""
+ opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
+ lr=args.lr, weight_decay=args.wd)
+ sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=args.epochs)
+ curve = []
+ for ep in range(1, args.epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(dev); y = y.to(dev)
+ loss = F.cross_entropy(model(x), y)
+ opt.zero_grad(); loss.backward(); opt.step()
+ sch.step()
+ if ep % 10 == 0 or ep == 1 or ep == args.epochs:
+ acc = evaluate(model, test_loader, dev)
+ curve.append((ep, acc))
+ print(f" [BP k] ep {ep}: test={acc:.4f}", flush=True)
+ return curve
+
+
+def train_dfa(model, train_loader, test_loader, dev, args, trainable):
+ """DFA: each block reads output error directly via B_l (no sequential propagation).
+ Only TRAINABLE blocks are updated; embed / out_ln / out_head always trained."""
+ d, C, L = model.d_hidden, args.num_classes, model.num_blocks
+ Bs = [torch.randn(d, C, device=dev) / np.sqrt(C) for _ in range(L)]
+
+ block_opts = {l: optim.AdamW(model.blocks[l].parameters(), lr=args.lr, weight_decay=args.wd)
+ for l in sorted(trainable)}
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd)
+ head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=args.lr, weight_decay=args.wd)
+ scheds = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs)
+ for o in list(block_opts.values()) + [embed_opt, head_opt]]
+
+ curve = []
+ for ep in range(1, args.epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(dev); y = y.to(dev)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1
+
+ # head: exact CE, h_L detached
+ hL = hiddens[-1].detach()
+ head_opt.zero_grad()
+ F.cross_entropy(model.out_head(model.out_ln(hL)), y).backward()
+ head_opt.step()
+
+ # trainable blocks: DFA local surrogate
+ for l in sorted(trainable):
+ a = (e_T @ Bs[l].T).detach()
+ a = a / ((a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6)
+ f_l = model.blocks[l](hiddens[l].detach())
+ local = (f_l * a).sum(-1).mean()
+ block_opts[l].zero_grad(); local.backward(); block_opts[l].step()
+
+ # embed: DFA credit at h_0
+ a0 = (e_T @ Bs[0].T).detach()
+ a0 = a0 / ((a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6)
+ embed_loss = (model.embed(x) * a0).sum(-1).mean()
+ embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step()
+
+ for s in scheds:
+ s.step()
+ if ep % 10 == 0 or ep == 1 or ep == args.epochs:
+ acc = evaluate(model, test_loader, dev)
+ curve.append((ep, acc))
+ print(f" [DFA k] ep {ep}: test={acc:.4f}", flush=True)
+ return curve
+
+
+def train_fa(model, train_loader, test_loader, dev, args, trainable):
+ """Vanilla FA: credit propagates sequentially backward via fixed d×d B_l.
+ Frozen blocks STILL propagate credit (a_credit = a_credit @ B_l) so trainable
+ blocks / embed downstream receive it; only their weight update is skipped."""
+ d, C, L = model.d_hidden, args.num_classes, model.num_blocks
+ Bs = [torch.randn(d, d, device=dev) / np.sqrt(d) for _ in range(L)]
+
+ block_opts = {l: optim.AdamW(model.blocks[l].parameters(), lr=args.lr, weight_decay=args.wd)
+ for l in sorted(trainable)}
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd)
+ head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=args.lr, weight_decay=args.wd)
+ scheds = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs)
+ for o in list(block_opts.values()) + [embed_opt, head_opt]]
+
+ curve = []
+ for ep in range(1, args.epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(dev); y = y.to(dev)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+
+ # head: exact CE; a_credit = exact gradient at h_L (FA's starting credit)
+ hL = hiddens[-1].detach().requires_grad_(True)
+ head_opt.zero_grad()
+ F.cross_entropy(model.out_head(model.out_ln(hL)), y).backward()
+ head_opt.step()
+ a_credit = hL.grad.detach()
+
+ # blocks backward: update only trainable; ALWAYS propagate credit
+ for l in range(L - 1, -1, -1):
+ if l in trainable:
+ a = a_credit / ((a_credit ** 2).mean(-1, keepdim=True).sqrt() + 1e-6)
+ f_l = model.blocks[l](hiddens[l].detach())
+ local = (f_l * a).sum(-1).mean()
+ block_opts[l].zero_grad(); local.backward(); block_opts[l].step()
+ a_credit = (a_credit @ Bs[l]).detach()
+
+ # embed: FA credit at h_0
+ a0 = a_credit / ((a_credit ** 2).mean(-1, keepdim=True).sqrt() + 1e-6)
+ embed_loss = (model.embed(x) * a0).sum(-1).mean()
+ embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step()
+
+ for s in scheds:
+ s.step()
+ if ep % 10 == 0 or ep == 1 or ep == args.epochs:
+ acc = evaluate(model, test_loader, dev)
+ curve.append((ep, acc))
+ print(f" [FA k] ep {ep}: test={acc:.4f}", flush=True)
+ return curve
+
+
+TRAINERS = {'bp': train_bp, 'dfa': train_dfa, 'fa': train_fa}
+
+
+# ---------------------------------------------------------------------------
+# Driver
+# ---------------------------------------------------------------------------
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--d_hidden', type=int, default=256)
+ p.add_argument('--num_blocks', type=int, default=4)
+ p.add_argument('--dataset', type=str, default='cifar10')
+ p.add_argument('--methods', type=str, nargs='+', default=['bp', 'fa', 'dfa'])
+ p.add_argument('--k_values', type=int, nargs='+', default=[0, 1, 2, 3, 4])
+ p.add_argument('--seeds', type=int, nargs='+', default=[42, 123, 456])
+ p.add_argument('--epochs', type=int, default=100)
+ p.add_argument('--lr', type=float, default=1e-3)
+ p.add_argument('--wd', type=float, default=0.01)
+ p.add_argument('--batch_size', type=int, default=128)
+ p.add_argument('--gpu', type=int, default=0)
+ p.add_argument('--output_dir', type=str, default='results/depth_ladder')
+ args = p.parse_args()
+
+ dev = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ os.makedirs(args.output_dir, exist_ok=True)
+ L = args.num_blocks
+ tag = f"ladder_d{args.d_hidden}_L{L}_{args.dataset}"
+ out_path = os.path.join(args.output_dir, f"{tag}.json")
+ print(f"Device={dev} {tag} methods={args.methods} k={args.k_values} seeds={args.seeds} "
+ f"epochs={args.epochs}", flush=True)
+
+ # incremental results: results[method][k][seed] = {final_acc, curve}
+ results = {}
+ if os.path.exists(out_path):
+ with open(out_path) as f:
+ results = json.load(f).get('results', {})
+ print(f"Resuming; existing keys: "
+ f"{[(m, list(results[m].keys())) for m in results]}", flush=True)
+
+ def save():
+ with open(out_path, 'w') as f:
+ json.dump({'config': vars(args), 'results': results}, f, indent=2)
+
+ for method in args.methods:
+ results.setdefault(method, {})
+ for k in args.k_values:
+ if k > L:
+ continue
+ results[method].setdefault(str(k), {})
+ for seed in args.seeds:
+ if str(seed) in results[method][str(k)]:
+ print(f" skip {method} k={k} seed={seed} (done)", flush=True)
+ continue
+ print(f"\n=== {method.upper()} k={k} (last {k} of {L} trainable) "
+ f"seed={seed} ===", flush=True)
+ torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
+ train_loader, test_loader, input_dim, num_classes = get_data(args.dataset, args.batch_size)
+ args.num_classes = num_classes
+
+ model = ResidualMLP(input_dim, args.d_hidden, num_classes, L).to(dev)
+ trainable = freeze_first(model, k)
+ n_train = sum(pp.numel() for pp in model.parameters() if pp.requires_grad)
+ print(f" trainable blocks: {sorted(trainable)} "
+ f"trainable params: {n_train:,}", flush=True)
+
+ curve = TRAINERS[method](model, train_loader, test_loader, dev, args, trainable)
+ final_acc = evaluate(model, test_loader, dev)
+ results[method][str(k)][str(seed)] = {'final_acc': final_acc, 'curve': curve}
+ print(f" FINAL {method} k={k} seed={seed}: {final_acc:.4f}", flush=True)
+ save()
+
+ # summary table
+ print(f"\n{'='*60}\nSUMMARY {tag} (mean ± ddof-1 std over seeds)\n{'='*60}", flush=True)
+ for method in args.methods:
+ row = []
+ for k in args.k_values:
+ if k > L:
+ continue
+ accs = [v['final_acc'] for v in results[method][str(k)].values()]
+ if accs:
+ m = float(np.mean(accs)); s = float(np.std(accs, ddof=1)) if len(accs) > 1 else 0.0
+ row.append(f"k={k}: {m:.4f}±{s:.4f}")
+ print(f" {method.upper():4s} " + " ".join(row), flush=True)
+ save()
+ print(f"\nSaved -> {out_path}", flush=True)
+
+
+if __name__ == '__main__':
+ main()