""" Codex round 11's decisive validation: train DFA on 4-block d=256 ResMLP with an explicit residual-branch penalty `λ ||f_l(h_l)||^2` added to each block's local loss. Tests whether constraining the block output magnitude is sufficient to rescue DFA from the residual-stream-explosion → BP grad collapse → active harm failure mode. Conditions: - DFA-vanilla (λ=0): baseline, expected to reproduce 30.8% acc + ||h_L||~4e8 - DFA-penalized (λ=1e-3, 1e-2, 1e-1): different penalty strengths Three outcomes: (A) ||h_L|| bounded AND BP grad healthy AND acc > shallow baseline (34.7%) → mechanism chain causally validated (B) ||h_L|| bounded AND BP grad healthy BUT acc still ≤ shallow baseline → mechanism is necessary but not sufficient; other factor at play (C) ||h_L|| stays exploded under the penalty → penalty is too weak or wrong target Usage: CUDA_VISIBLE_DEVICES=2 python experiments/dfa_residual_penalty_test.py --seed 42 --lam 1e-2 """ import sys, os, argparse, json sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader import torchvision import torchvision.transforms as transforms import numpy as np from models.residual_mlp import ResidualMLP def get_loaders(batch_size=128): tv_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) tv = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) return ( DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2), ) def evaluate(model, loader, dev): model.eval() n = c = 0 with torch.no_grad(): for x, y in loader: x = x.view(x.size(0), -1).to(dev); y = y.to(dev) preds = model(x).argmax(-1) c += (preds == y).sum().item() n += x.size(0) return c / n def diagnose(model, x_eval, y_eval, dev): """Compute ||h_L||, ||BP grad at h_2||, and acc on a fixed eval batch.""" model.eval() with torch.no_grad(): _, hi = model(x_eval, return_hidden=True) h_L_norm = hi[-1].norm(dim=-1).median().item() h0 = model.embed(x_eval.detach()) hs = [h0.clone().requires_grad_(True)] for b in model.blocks: hs.append(hs[-1] + b(hs[-1])) lo = model.out_head(model.out_ln(hs[-1])) loss = F.cross_entropy(lo, y_eval) gs = torch.autograd.grad(loss, hs) g_2_norm = gs[2].norm(dim=-1).median().item() acc = (lo.argmax(-1) == y_eval).float().mean().item() return h_L_norm, g_2_norm, acc def train_dfa_with_penalty(model, train_loader, test_loader, x_eval, y_eval, dev, epochs, lr, wd, lam): """DFA training with residual-branch penalty `lam * ||f_l(h_l)||^2` added to each block's local loss.""" d_hidden = model.d_hidden L = model.num_blocks C = 10 Bs = [torch.randn(d_hidden, C, device=dev) / np.sqrt(C) for _ in range(L)] block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) head_opt = optim.AdamW( list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=lr, weight_decay=wd ) all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \ [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)] log = [] h0, g0, a0 = diagnose(model, x_eval, y_eval, dev) log.append({'epoch': 0, 'h_L_norm': h0, 'g_2_norm': g0, 'acc_eval': a0}) print(f" ep 0: ||h_L||={h0:.3e} ||g_2||={g0:.3e} acc={a0:.4f}", flush=True) for ep in range(1, epochs + 1): model.train() for x, y in train_loader: x = x.view(x.size(0), -1).to(dev); y = y.to(dev) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1 hL_det = hiddens[-1].detach() # Head update via true CE on out_ln(h_L) logits_out = model.out_head(model.out_ln(hL_det)) head_opt.zero_grad() F.cross_entropy(logits_out, y).backward() head_opt.step() # Block updates via DFA local credit + residual-branch penalty for l in range(L): h_l = hiddens[l].detach() a_dfa = (e_T @ Bs[l].T).detach() rms = (a_dfa ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 a_norm = a_dfa / rms f_l = model.blocks[l](h_l) # Original DFA local loss local_dfa = (f_l * a_norm).sum(-1).mean() # Residual-branch penalty (codex round 11): λ * mean(||f_l||²) penalty = lam * (f_l ** 2).sum(-1).mean() local_loss = local_dfa + penalty block_opts[l].zero_grad() local_loss.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) block_opts[l].step() # Embed update via DFA-style on h_0 a_0 = (e_T @ Bs[0].T).detach() rms_0 = (a_0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 h0_emb = model.embed(x) embed_loss = (h0_emb * (a_0 / rms_0)).sum(-1).mean() embed_opt.zero_grad() embed_loss.backward() embed_opt.step() for s in all_sch: s.step() if ep % 10 == 0 or ep == 1 or ep == epochs: h, g, a = diagnose(model, x_eval, y_eval, dev) log.append({'epoch': ep, 'h_L_norm': h, 'g_2_norm': g, 'acc_eval': a}) test_acc = evaluate(model, test_loader, dev) print(f" ep {ep}: ||h_L||={h:.3e} ||g_2||={g:.3e} eval_acc={a:.4f} test_acc={test_acc:.4f}", flush=True) return log def main(): p = argparse.ArgumentParser() p.add_argument('--seed', type=int, default=42) p.add_argument('--epochs', type=int, default=100) p.add_argument('--lr', type=float, default=1e-3) p.add_argument('--wd', type=float, default=0.01) p.add_argument('--lam', type=float, default=1e-2, help='residual-branch penalty strength λ for ||f_l(h_l)||²') p.add_argument('--output_dir', type=str, default='results/dfa_residual_penalty') args = p.parse_args() os.makedirs(args.output_dir, exist_ok=True) dev = torch.device('cuda:0') print(f"DFA + residual-branch penalty test: seed={args.seed}, lam={args.lam}", flush=True) train_loader, test_loader = get_loaders(batch_size=128) # Fixed eval buffer xs, ys = [], [] for x, y in test_loader: xs.append(x.view(x.size(0), -1)); ys.append(y) if sum(xb.size(0) for xb in xs) >= 1024: break x_eval = torch.cat(xs)[:1024].to(dev) y_eval = torch.cat(ys)[:1024].to(dev) L, d, C = 4, 256, 10 torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) m = ResidualMLP(3072, d, C, L).to(dev) log = train_dfa_with_penalty(m, train_loader, test_loader, x_eval, y_eval, dev, args.epochs, args.lr, args.wd, args.lam) final_test = evaluate(m, test_loader, dev) print(f"\nFINAL test acc: {final_test:.4f}") print(f"Compare to (matched 30-epoch 3-seed values, see paper v2.32):") print(f" DFA-vanilla 30ep (3-seed): 0.301 ± 0.005") print(f" DFA-shallow / DFA-frozen: 0.349 ± 0.002") print(f" BP-trainable no-pen 30ep: 0.585 ± 0.001") print(f" BP+pen lam=1e-2 30ep: 0.532 ± 0.006") out = {'config': vars(args), 'final_test_acc': final_test, 'log': log} out_path = os.path.join(args.output_dir, f'dfa_pen_lam{args.lam}_s{args.seed}.json') with open(out_path, 'w') as f: json.dump(out, f, indent=2) print(f"Saved {out_path}") # Round 18: save checkpoint AND Bs for post-hoc protocol application # (was missing — caused us to need a separate direction-quality experiment) ckpt_path = os.path.join(args.output_dir, f'dfa_pen_lam{args.lam}_s{args.seed}.pt') # Reconstruct the Bs sequence the way train_dfa_with_penalty did torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) _ = ResidualMLP(3072, d, C, L) # consume RNG draws to match training Bs = [torch.randn(d, C, device=dev) / np.sqrt(C) for _ in range(L)] torch.save({ "state_dict": m.state_dict(), "Bs": [b.cpu() for b in Bs], "config": vars(args), "test_acc": final_test, }, ckpt_path) print(f"Saved {ckpt_path}") if __name__ == '__main__': main()