""" Codex round 19's #4 control experiment: train BP with the same λ ‖f_l(h_l)‖² penalty that's used in the DFA penalty rescue. If BP + penalty still clears the frozen baseline by a wide margin (e.g., ~25 pp like normal BP): → the penalty itself is not the reason penalized DFA's depth utilization is capped at +1.4 pp; the cap is intrinsic to DFA's random-feedback credit signal quality → mode 2 (intrinsic credit quality) is real If BP + penalty drops to ~+1.4 pp margin too: → the penalty is the reason for the cap, not credit quality → mode 2 might be a regularization artifact, not a real failure mode → would need to walk back walk-back #7 (back to "one unified mode") Run: CUDA_VISIBLE_DEVICES=2 python experiments/bp_with_penalty_control.py --seed 42 --epochs 30 --lam 1e-2 """ import os import sys import argparse import json import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.residual_mlp import ResidualMLP def get_loaders(batch_size=128): tv_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) tv = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) return ( DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2), ) def evaluate(model, loader, dev): model.eval() n = c = 0 with torch.no_grad(): for x, y in loader: x = x.view(x.size(0), -1).to(dev); y = y.to(dev) preds = model(x).argmax(-1) c += (preds == y).sum().item() n += x.size(0) return c / n def train_bp_with_penalty(model, train_loader, test_loader, dev, epochs, lr, wd, lam): """End-to-end BP training with `lam * sum_l ||f_l(h_l)||^2` added to the cross-entropy loss. The penalty is applied to the residual branch outputs of every block.""" opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) log = [] for ep in range(1, epochs + 1): model.train() for x, y in train_loader: x = x.view(x.size(0), -1).to(dev); y = y.to(dev) # Forward, capturing per-block residual outputs h = model.embed(x) penalty = torch.zeros((), device=dev) for block in model.blocks: f = block(h) penalty = penalty + (f ** 2).sum(-1).mean() h = h + f logits = model.out_head(model.out_ln(h)) ce = F.cross_entropy(logits, y) loss = ce + lam * penalty opt.zero_grad() loss.backward() opt.step() sch.step() if ep % 5 == 0 or ep == 1 or ep == epochs: acc = evaluate(model, test_loader, dev) log.append({"epoch": ep, "test_acc": acc}) print(f" ep {ep}: test_acc={acc:.4f}", flush=True) return log def main(): p = argparse.ArgumentParser() p.add_argument("--seed", type=int, default=42) p.add_argument("--epochs", type=int, default=30) p.add_argument("--lr", type=float, default=1e-3) p.add_argument("--wd", type=float, default=0.01) p.add_argument("--lam", type=float, default=1e-2) p.add_argument("--output_dir", type=str, default="results/bp_with_penalty") args = p.parse_args() os.makedirs(args.output_dir, exist_ok=True) dev = torch.device("cuda:0") print(f"BP + ‖f‖² penalty: seed={args.seed}, lam={args.lam}, epochs={args.epochs}", flush=True) train_loader, test_loader = get_loaders(batch_size=128) torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) m = ResidualMLP(3072, 256, 10, 4).to(dev) log = train_bp_with_penalty(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, args.lam) final_acc = evaluate(m, test_loader, dev) print(f"\nFINAL test acc: {final_acc:.4f}", flush=True) print(f"Compare to (matched 30-epoch 3-seed values, see paper v2.32):") print(f" BP-trainable no-pen (3-seed): 0.585 ± 0.001") print(f" Penalized DFA lam=1e-2: 0.360 ± 0.001") print(f" DFA-shallow (frozen blocks): 0.349 ± 0.002") margin = (final_acc - 0.349) * 100 print(f"\nMargin vs DFA-shallow baseline: {margin:+.2f} pp") if margin > 25: print(" → BP+penalty still clears shallow by >25 pp") print(" → mode 2 (intrinsic random-feedback alignment) is REAL") print(" → walk-back #7 confirmed: two distinct failure modes") elif margin < 5: print(" → BP+penalty drops to a tiny margin like penalized DFA") print(" → the penalty itself capped depth utilization") print(" → mode 2 might be a regularization artifact") print(" → consider walking back walk-back #7") else: print(" → BP+penalty intermediate; partial capacity loss + residual mode 2") out = {"config": vars(args), "final_acc": final_acc, "log": log, "margin_pp": margin} out_path = os.path.join(args.output_dir, f"bp_pen_lam{args.lam}_s{args.seed}.json") with open(out_path, "w") as f: json.dump(out, f, indent=2) print(f"Saved {out_path}") if __name__ == "__main__": main()