""" Direction-quality direct test (codex round 13's option (c), finally executed). After the residual-branch penalty experiment confirmed that the ||f_l(h_l)||^2 penalty (1) contains the residual stream 4 OOM, (2) keeps the BP gradient at hidden layers ~10^-7 (well above the eps=1e-8 floor and ~5e-7 above the fp32 underflow region), but (3) only rescues acc by +5.5 pp over vanilla DFA and only +1.4 pp over the shallow baseline, we hypothesized a SECOND failure mode: even when the BP gradient at hidden layers is well-resolved, DFA's local credit signal `e_T B_l^T` may not be aligned with it. This script answers that hypothesis directly: 1. Train a 4-block d=256 ResMLP with DFA + residual-branch penalty (lam = 1e-2, the first penalty value we validated). Save the checkpoint when training is done. 2. On the trained network, on a held-out eval batch, compute: (a) the per-layer BP gradient `g_l = d L / d h_l` (this is what offline Γ uses as a reference) (b) the per-layer DFA local credit signal `a_l = e_T @ B_l^T` (the same signal DFA's training rule uses) (c) the per-layer cosine similarity `cos(a_l, g_l)` (d) the same cosine on the *vanilla* DFA-trained checkpoint for comparison (the network where g_l is at the floor — Γ should be degenerate there but the cosine value itself can still be computed) 3. Report side-by-side: vanilla-DFA cosine (degenerate-reference) vs penalized-DFA cosine (healthy-reference). The penalized-DFA cosine is the *direct measurement* of the second failure mode — it tells us whether DFA's random feedback signal aligns with BP credit when the scale is fixed. The pre-registered prediction (codex round 13): the penalized-DFA cosine will still be near zero (~0.01-0.05), confirming that the direction quality of DFA's signal is the second, *separable* failure mode. Run: CUDA_VISIBLE_DEVICES=2 python experiments/dfa_direction_quality_test.py \ --seed 42 --epochs 100 --lam 1e-2 """ import sys, os, argparse, json sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader import numpy as np from models.residual_mlp import ResidualMLP # --------------------------------------------------------------------------- # # Data # --------------------------------------------------------------------------- # def get_loaders(batch_size=128): tv_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) tv = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) return ( DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2), ) def evaluate(model, loader, dev): model.eval() n = c = 0 with torch.no_grad(): for x, y in loader: x = x.view(x.size(0), -1).to(dev); y = y.to(dev) preds = model(x).argmax(-1) c += (preds == y).sum().item() n += x.size(0) return c / n # --------------------------------------------------------------------------- # # DFA training (vanilla and with residual-branch penalty) # --------------------------------------------------------------------------- # def train_dfa(model, train_loader, dev, epochs, lr, wd, lam, Bs): """DFA training. lam=0 reproduces vanilla DFA.""" L = model.num_blocks block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) head_opt = optim.AdamW( list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=lr, weight_decay=wd ) scheds = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + [ optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs), ] for ep in range(1, epochs + 1): model.train() for x, y in train_loader: x = x.view(x.size(0), -1).to(dev); y = y.to(dev) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1 hL_det = hiddens[-1].detach() logits_out = model.out_head(model.out_ln(hL_det)) head_opt.zero_grad() F.cross_entropy(logits_out, y).backward() head_opt.step() for l in range(L): h_l = hiddens[l].detach() a_dfa = (e_T @ Bs[l].T).detach() rms = (a_dfa ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 a_norm = a_dfa / rms f_l = model.blocks[l](h_l) local_dfa = (f_l * a_norm).sum(-1).mean() penalty = lam * (f_l ** 2).sum(-1).mean() local_loss = local_dfa + penalty block_opts[l].zero_grad() local_loss.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) block_opts[l].step() a_0 = (e_T @ Bs[0].T).detach() rms_0 = (a_0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 h0_emb = model.embed(x) embed_loss = (h0_emb * (a_0 / rms_0)).sum(-1).mean() embed_opt.zero_grad() embed_loss.backward() embed_opt.step() for s in scheds: s.step() # --------------------------------------------------------------------------- # # Direction-quality measurement # --------------------------------------------------------------------------- # def measure_direction_quality(model, Bs, x, y, dev): """For each layer l, compute the per-sample cosine between: DFA local credit a_l = e_T @ B_l^T BP grad at h_l g_l = d L / d h_l Return per-layer mean cosine, plus the magnitudes of both signals. """ L = model.num_blocks # 1) Forward pass with hidden states retained for BP grad computation. model.eval() with torch.enable_grad(): h = model.embed(x) hiddens = [h] for block in model.blocks: h = h + block(h) hiddens.append(h) logits = model.out_head(model.out_ln(h)) loss = F.cross_entropy(logits, y) grads = torch.autograd.grad(loss, hiddens) # grads[l] is d L / d h_l (per-sample, scaled by 1/N from the mean reduction) # 2) DFA local credit signal: e_T @ B_l^T using the model's trained Bs and # the SAME forward we just did with torch.no_grad(): N = x.size(0) # The DFA signal uses softmax(logits) - one_hot(y) (the "error" e_T). e_T = F.softmax(logits.detach(), dim=-1) e_T[torch.arange(N), y] -= 1 # (N, C) out: dict = {} for l in range(L + 1): g_l = grads[l].detach() # (N, d) # DFA's local credit signal at layer l is e_T @ B_{min(l, L-1)}^T # (the embedding update uses Bs[0]; block l update uses Bs[l]; for # the deepest hidden state h_L there is no block beyond it, so we # report Bs[L-1] which is the closest comparator) b_idx = min(l, L - 1) a_l = (e_T @ Bs[b_idx].T).detach() # (N, d) # Per-sample cosines, then mean eps = 1e-30 # NOT torch's default 1e-8 — we want the true cosine ag = (a_l * g_l).sum(dim=-1) an = a_l.norm(dim=-1) gn = g_l.norm(dim=-1) cos = ag / (an * gn + eps) out[f"layer_{l}"] = { "cos_mean": float(cos.mean().item()), "cos_std": float(cos.std().item()), "cos_median": float(cos.median().item()), "g_norm_median": float(gn.median().item()), "a_norm_median": float(an.median().item()), } return out # --------------------------------------------------------------------------- # # Main # --------------------------------------------------------------------------- # def main(): p = argparse.ArgumentParser() p.add_argument('--seed', type=int, default=42) p.add_argument('--epochs', type=int, default=100) p.add_argument('--lr', type=float, default=1e-3) p.add_argument('--wd', type=float, default=0.01) p.add_argument('--lam', type=float, default=1e-2) p.add_argument('--output_dir', type=str, default='results/dfa_direction_quality') args = p.parse_args() os.makedirs(args.output_dir, exist_ok=True) dev = torch.device('cuda:0') print(f"DFA direction-quality direct test: seed={args.seed}, lam={args.lam}", flush=True) train_loader, test_loader = get_loaders(batch_size=128) # Eval batch for direction-quality measurement xs, ys = [], [] for x, y in test_loader: xs.append(x.view(x.size(0), -1)); ys.append(y) if sum(xb.size(0) for xb in xs) >= 1024: break x_eval = torch.cat(xs)[:1024].to(dev) y_eval = torch.cat(ys)[:1024].to(dev) # ----- VANILLA DFA (lam=0) ----- # print("\n=== Vanilla DFA (lam=0) ===") torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) m_vanilla = ResidualMLP(3072, 256, 10, 4).to(dev) Bs_vanilla = [torch.randn(256, 10, device=dev) / np.sqrt(10) for _ in range(4)] train_dfa(m_vanilla, train_loader, dev, args.epochs, args.lr, args.wd, lam=0.0, Bs=Bs_vanilla) acc_vanilla = evaluate(m_vanilla, test_loader, dev) print(f" vanilla DFA test acc: {acc_vanilla:.4f}") quality_vanilla = measure_direction_quality(m_vanilla, Bs_vanilla, x_eval, y_eval, dev) print(" vanilla DFA per-layer DFA-credit vs BP-grad cosine:") for k, v in quality_vanilla.items(): print(f" {k}: cos_mean={v['cos_mean']:+.4f} ||g||={v['g_norm_median']:.2e} ||a||={v['a_norm_median']:.2e}") # ----- PENALIZED DFA (lam>0) ----- # print(f"\n=== Penalized DFA (lam={args.lam}) ===") torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) m_pen = ResidualMLP(3072, 256, 10, 4).to(dev) Bs_pen = [torch.randn(256, 10, device=dev) / np.sqrt(10) for _ in range(4)] train_dfa(m_pen, train_loader, dev, args.epochs, args.lr, args.wd, lam=args.lam, Bs=Bs_pen) acc_pen = evaluate(m_pen, test_loader, dev) print(f" penalized DFA test acc: {acc_pen:.4f}") quality_pen = measure_direction_quality(m_pen, Bs_pen, x_eval, y_eval, dev) print(" penalized DFA per-layer DFA-credit vs BP-grad cosine:") for k, v in quality_pen.items(): print(f" {k}: cos_mean={v['cos_mean']:+.4f} ||g||={v['g_norm_median']:.2e} ||a||={v['a_norm_median']:.2e}") # Save results out = { "config": vars(args), "vanilla": { "test_acc": acc_vanilla, "direction_quality": quality_vanilla, }, "penalized": { "test_acc": acc_pen, "direction_quality": quality_pen, }, } out_path = os.path.join(args.output_dir, f'direction_quality_lam{args.lam}_s{args.seed}.json') with open(out_path, 'w') as f: json.dump(out, f, indent=2) # Save the penalized checkpoint so the protocol can later be re-applied ckpt_path = os.path.join(args.output_dir, f'penalized_dfa_lam{args.lam}_s{args.seed}.pt') torch.save({ "state_dict": m_pen.state_dict(), "Bs": [b.cpu() for b in Bs_pen], "config": vars(args), "test_acc": acc_pen, }, ckpt_path) print(f"\nSaved {out_path}") print(f"Saved {ckpt_path}") # Pre-registered interpretation summary print("\n" + "=" * 72) print("INTERPRETATION (vs codex round 13's pre-registered prediction)") print("=" * 72) g_vanilla = quality_vanilla["layer_2"]["g_norm_median"] g_pen = quality_pen["layer_2"]["g_norm_median"] cos_vanilla = quality_vanilla["layer_2"]["cos_mean"] cos_pen = quality_pen["layer_2"]["cos_mean"] print(f" vanilla DFA: ||g_2||={g_vanilla:.2e} cos(DFA, BP)={cos_vanilla:+.4f} -> reference at floor") print(f" penalty DFA: ||g_2||={g_pen:.2e} cos(DFA, BP)={cos_pen:+.4f} -> reference healthy") if g_pen > 1e-7: if abs(cos_pen) < 0.05: print(" -> Direction quality is POOR even with healthy reference. Second failure mode CONFIRMED.") elif abs(cos_pen) < 0.20: print(" -> Direction quality is mediocre with healthy reference. Second failure mode partially supported.") else: print(" -> Direction quality is reasonable with healthy reference. Second failure mode REJECTED — DFA's signal is OK, the gap to BP must come from something else.") else: print(" -> WARNING: penalized BP grad still below 1e-7; reference is not healthy. Try larger lam.") if __name__ == '__main__': main()