diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-08 01:33:52 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-08 01:33:52 -0500 |
| commit | f92ebda7ec2ac155c1a4241b057801c863834e56 (patch) | |
| tree | 7f2595f6ad1c3ad476292b7dad3584fb7acda9ee /experiments | |
| parent | 2ca87f2bd4449b1d4ac715d8cf4fb5f20b7afdd8 (diff) | |
Add BP+penalty control (round 19's #4 critical experiment)
Trains end-to-end BP with the same lambda*||f_l(h_l)||^2 penalty used in
the DFA penalty rescue. Tests whether the penalty's depth utilization
loss in penalized DFA is intrinsic to DFA's random-feedback credit
quality (mode 2) or due to penalty-induced capacity regularization.
Decision rule:
BP+pen margin > 25 pp -> mode 2 confirmed (penalty is not the cap)
BP+pen margin < 5 pp -> penalty itself caps depth (capacity loss)
intermediate -> both effects present
Diffstat (limited to 'experiments')
| -rw-r--r-- | experiments/bp_with_penalty_control.py | 146 |
1 files changed, 146 insertions, 0 deletions
diff --git a/experiments/bp_with_penalty_control.py b/experiments/bp_with_penalty_control.py new file mode 100644 index 0000000..b986dee --- /dev/null +++ b/experiments/bp_with_penalty_control.py @@ -0,0 +1,146 @@ +""" +Codex round 19's #4 control experiment: train BP with the same +λ ‖f_l(h_l)‖² penalty that's used in the DFA penalty rescue. + +If BP + penalty still clears the frozen baseline by a wide margin +(e.g., ~25 pp like normal BP): + → the penalty itself is not the reason penalized DFA's depth + utilization is capped at +1.4 pp; the cap is intrinsic to DFA's + random-feedback credit signal quality + → mode 2 (intrinsic credit quality) is real + +If BP + penalty drops to ~+1.4 pp margin too: + → the penalty is the reason for the cap, not credit quality + → mode 2 might be a regularization artifact, not a real failure mode + → would need to walk back walk-back #7 (back to "one unified mode") + +Run: + CUDA_VISIBLE_DEVICES=2 python experiments/bp_with_penalty_control.py --seed 42 --epochs 30 --lam 1e-2 +""" +import os +import sys +import argparse +import json + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torchvision +import torchvision.transforms as transforms +from torch.utils.data import DataLoader + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from models.residual_mlp import ResidualMLP + + +def get_loaders(batch_size=128): + tv_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) + te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) + return ( + DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), + DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2), + ) + + +def evaluate(model, loader, dev): + model.eval() + n = c = 0 + with torch.no_grad(): + for x, y in loader: + x = x.view(x.size(0), -1).to(dev); y = y.to(dev) + preds = model(x).argmax(-1) + c += (preds == y).sum().item() + n += x.size(0) + return c / n + + +def train_bp_with_penalty(model, train_loader, test_loader, dev, epochs, lr, wd, lam): + """End-to-end BP training with `lam * sum_l ||f_l(h_l)||^2` added to the + cross-entropy loss. The penalty is applied to the residual branch outputs + of every block.""" + opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) + sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) + log = [] + for ep in range(1, epochs + 1): + model.train() + for x, y in train_loader: + x = x.view(x.size(0), -1).to(dev); y = y.to(dev) + # Forward, capturing per-block residual outputs + h = model.embed(x) + penalty = torch.zeros((), device=dev) + for block in model.blocks: + f = block(h) + penalty = penalty + (f ** 2).sum(-1).mean() + h = h + f + logits = model.out_head(model.out_ln(h)) + ce = F.cross_entropy(logits, y) + loss = ce + lam * penalty + opt.zero_grad() + loss.backward() + opt.step() + sch.step() + if ep % 5 == 0 or ep == 1 or ep == epochs: + acc = evaluate(model, test_loader, dev) + log.append({"epoch": ep, "test_acc": acc}) + print(f" ep {ep}: test_acc={acc:.4f}", flush=True) + return log + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--seed", type=int, default=42) + p.add_argument("--epochs", type=int, default=30) + p.add_argument("--lr", type=float, default=1e-3) + p.add_argument("--wd", type=float, default=0.01) + p.add_argument("--lam", type=float, default=1e-2) + p.add_argument("--output_dir", type=str, default="results/bp_with_penalty") + args = p.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + dev = torch.device("cuda:0") + print(f"BP + ‖f‖² penalty: seed={args.seed}, lam={args.lam}, epochs={args.epochs}", flush=True) + train_loader, test_loader = get_loaders(batch_size=128) + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + m = ResidualMLP(3072, 256, 10, 4).to(dev) + log = train_bp_with_penalty(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, args.lam) + final_acc = evaluate(m, test_loader, dev) + print(f"\nFINAL test acc: {final_acc:.4f}", flush=True) + print(f"Compare to:") + print(f" BP-trainable (3-seed mean): 0.609") + print(f" Penalized DFA lam=1e-2: 0.363") + print(f" DFA-shallow: 0.349") + margin = (final_acc - 0.349) * 100 + print(f"\nMargin vs DFA-shallow baseline: {margin:+.2f} pp") + if margin > 25: + print(" → BP+penalty still clears shallow by >25 pp") + print(" → mode 2 (intrinsic random-feedback alignment) is REAL") + print(" → walk-back #7 confirmed: two distinct failure modes") + elif margin < 5: + print(" → BP+penalty drops to a tiny margin like penalized DFA") + print(" → the penalty itself capped depth utilization") + print(" → mode 2 might be a regularization artifact") + print(" → consider walking back walk-back #7") + else: + print(" → BP+penalty intermediate; partial capacity loss + residual mode 2") + + out = {"config": vars(args), "final_acc": final_acc, "log": log, "margin_pp": margin} + out_path = os.path.join(args.output_dir, f"bp_pen_lam{args.lam}_s{args.seed}.json") + with open(out_path, "w") as f: + json.dump(out, f, indent=2) + print(f"Saved {out_path}") + + +if __name__ == "__main__": + main() |
