summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-08 01:33:52 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-08 01:33:52 -0500
commitf92ebda7ec2ac155c1a4241b057801c863834e56 (patch)
tree7f2595f6ad1c3ad476292b7dad3584fb7acda9ee /experiments
parent2ca87f2bd4449b1d4ac715d8cf4fb5f20b7afdd8 (diff)
Add BP+penalty control (round 19's #4 critical experiment)
Trains end-to-end BP with the same lambda*||f_l(h_l)||^2 penalty used in the DFA penalty rescue. Tests whether the penalty's depth utilization loss in penalized DFA is intrinsic to DFA's random-feedback credit quality (mode 2) or due to penalty-induced capacity regularization. Decision rule: BP+pen margin > 25 pp -> mode 2 confirmed (penalty is not the cap) BP+pen margin < 5 pp -> penalty itself caps depth (capacity loss) intermediate -> both effects present
Diffstat (limited to 'experiments')
-rw-r--r--experiments/bp_with_penalty_control.py146
1 files changed, 146 insertions, 0 deletions
diff --git a/experiments/bp_with_penalty_control.py b/experiments/bp_with_penalty_control.py
new file mode 100644
index 0000000..b986dee
--- /dev/null
+++ b/experiments/bp_with_penalty_control.py
@@ -0,0 +1,146 @@
+"""
+Codex round 19's #4 control experiment: train BP with the same
+λ ‖f_l(h_l)‖² penalty that's used in the DFA penalty rescue.
+
+If BP + penalty still clears the frozen baseline by a wide margin
+(e.g., ~25 pp like normal BP):
+ → the penalty itself is not the reason penalized DFA's depth
+ utilization is capped at +1.4 pp; the cap is intrinsic to DFA's
+ random-feedback credit signal quality
+ → mode 2 (intrinsic credit quality) is real
+
+If BP + penalty drops to ~+1.4 pp margin too:
+ → the penalty is the reason for the cap, not credit quality
+ → mode 2 might be a regularization artifact, not a real failure mode
+ → would need to walk back walk-back #7 (back to "one unified mode")
+
+Run:
+ CUDA_VISIBLE_DEVICES=2 python experiments/bp_with_penalty_control.py --seed 42 --epochs 30 --lam 1e-2
+"""
+import os
+import sys
+import argparse
+import json
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+import torchvision
+import torchvision.transforms as transforms
+from torch.utils.data import DataLoader
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from models.residual_mlp import ResidualMLP
+
+
+def get_loaders(batch_size=128):
+ tv_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train)
+ te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv)
+ return (
+ DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2),
+ DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2),
+ )
+
+
+def evaluate(model, loader, dev):
+ model.eval()
+ n = c = 0
+ with torch.no_grad():
+ for x, y in loader:
+ x = x.view(x.size(0), -1).to(dev); y = y.to(dev)
+ preds = model(x).argmax(-1)
+ c += (preds == y).sum().item()
+ n += x.size(0)
+ return c / n
+
+
+def train_bp_with_penalty(model, train_loader, test_loader, dev, epochs, lr, wd, lam):
+ """End-to-end BP training with `lam * sum_l ||f_l(h_l)||^2` added to the
+ cross-entropy loss. The penalty is applied to the residual branch outputs
+ of every block."""
+ opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
+ sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
+ log = []
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(dev); y = y.to(dev)
+ # Forward, capturing per-block residual outputs
+ h = model.embed(x)
+ penalty = torch.zeros((), device=dev)
+ for block in model.blocks:
+ f = block(h)
+ penalty = penalty + (f ** 2).sum(-1).mean()
+ h = h + f
+ logits = model.out_head(model.out_ln(h))
+ ce = F.cross_entropy(logits, y)
+ loss = ce + lam * penalty
+ opt.zero_grad()
+ loss.backward()
+ opt.step()
+ sch.step()
+ if ep % 5 == 0 or ep == 1 or ep == epochs:
+ acc = evaluate(model, test_loader, dev)
+ log.append({"epoch": ep, "test_acc": acc})
+ print(f" ep {ep}: test_acc={acc:.4f}", flush=True)
+ return log
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument("--seed", type=int, default=42)
+ p.add_argument("--epochs", type=int, default=30)
+ p.add_argument("--lr", type=float, default=1e-3)
+ p.add_argument("--wd", type=float, default=0.01)
+ p.add_argument("--lam", type=float, default=1e-2)
+ p.add_argument("--output_dir", type=str, default="results/bp_with_penalty")
+ args = p.parse_args()
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ dev = torch.device("cuda:0")
+ print(f"BP + ‖f‖² penalty: seed={args.seed}, lam={args.lam}, epochs={args.epochs}", flush=True)
+ train_loader, test_loader = get_loaders(batch_size=128)
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ m = ResidualMLP(3072, 256, 10, 4).to(dev)
+ log = train_bp_with_penalty(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, args.lam)
+ final_acc = evaluate(m, test_loader, dev)
+ print(f"\nFINAL test acc: {final_acc:.4f}", flush=True)
+ print(f"Compare to:")
+ print(f" BP-trainable (3-seed mean): 0.609")
+ print(f" Penalized DFA lam=1e-2: 0.363")
+ print(f" DFA-shallow: 0.349")
+ margin = (final_acc - 0.349) * 100
+ print(f"\nMargin vs DFA-shallow baseline: {margin:+.2f} pp")
+ if margin > 25:
+ print(" → BP+penalty still clears shallow by >25 pp")
+ print(" → mode 2 (intrinsic random-feedback alignment) is REAL")
+ print(" → walk-back #7 confirmed: two distinct failure modes")
+ elif margin < 5:
+ print(" → BP+penalty drops to a tiny margin like penalized DFA")
+ print(" → the penalty itself capped depth utilization")
+ print(" → mode 2 might be a regularization artifact")
+ print(" → consider walking back walk-back #7")
+ else:
+ print(" → BP+penalty intermediate; partial capacity loss + residual mode 2")
+
+ out = {"config": vars(args), "final_acc": final_acc, "log": log, "margin_pp": margin}
+ out_path = os.path.join(args.output_dir, f"bp_pen_lam{args.lam}_s{args.seed}.json")
+ with open(out_path, "w") as f:
+ json.dump(out, f, indent=2)
+ print(f"Saved {out_path}")
+
+
+if __name__ == "__main__":
+ main()