summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 22:37:49 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 22:37:49 -0500
commit0c245f5683cceba448d20d9dfc2090adb3503f14 (patch)
tree0467408974b504888ae8cbb8551fbb206e3c2b53 /experiments
parent111bab56e2d49c9fb1f3bfb9e55ea2028da4d008 (diff)
Add DFA direction-quality direct test (codex round 13 option c)
Trains both vanilla DFA (lam=0) and penalized DFA (lam=1e-2) from the same seed, then directly measures the per-layer cosine between DFA's local credit signal e_T @ B_l^T and the BP gradient at hidden layers. Uses the training Bs (not fresh ones, per the Bs-specificity finding from earlier). The penalized run is the key measurement: in that condition the BP grad is ~10^-7 (well above the eps=1e-8 floor), so a near-zero cosine here would be the direct evidence of the second failure mode (direction-quality ceiling) that codex round 13 hypothesized. Pre-registered prediction: penalized cos(DFA, BP) ~ 0.01-0.05 -> direction quality is the second, separable failure mode. Saves the penalized checkpoint so the diagnostic protocol can be re-applied to it (where (a) and (b) should pass, (d) should still fail).
Diffstat (limited to 'experiments')
-rw-r--r--experiments/dfa_direction_quality_test.py308
1 files changed, 308 insertions, 0 deletions
diff --git a/experiments/dfa_direction_quality_test.py b/experiments/dfa_direction_quality_test.py
new file mode 100644
index 0000000..8df60c8
--- /dev/null
+++ b/experiments/dfa_direction_quality_test.py
@@ -0,0 +1,308 @@
+"""
+Direction-quality direct test (codex round 13's option (c), finally executed).
+
+After the residual-branch penalty experiment confirmed that the
+||f_l(h_l)||^2 penalty (1) contains the residual stream 4 OOM, (2) keeps the
+BP gradient at hidden layers ~10^-7 (well above the eps=1e-8 floor and
+~5e-7 above the fp32 underflow region), but (3) only rescues acc by +5.5 pp
+over vanilla DFA and only +1.4 pp over the shallow baseline, we hypothesized
+a SECOND failure mode: even when the BP gradient at hidden layers is
+well-resolved, DFA's local credit signal `e_T B_l^T` may not be aligned with
+it.
+
+This script answers that hypothesis directly:
+
+ 1. Train a 4-block d=256 ResMLP with DFA + residual-branch penalty
+ (lam = 1e-2, the first penalty value we validated). Save the checkpoint
+ when training is done.
+ 2. On the trained network, on a held-out eval batch, compute:
+ (a) the per-layer BP gradient `g_l = d L / d h_l` (this is what offline
+ Γ uses as a reference)
+ (b) the per-layer DFA local credit signal `a_l = e_T @ B_l^T` (the same
+ signal DFA's training rule uses)
+ (c) the per-layer cosine similarity `cos(a_l, g_l)`
+ (d) the same cosine on the *vanilla* DFA-trained checkpoint for
+ comparison (the network where g_l is at the floor — Γ should be
+ degenerate there but the cosine value itself can still be computed)
+
+ 3. Report side-by-side: vanilla-DFA cosine (degenerate-reference) vs
+ penalized-DFA cosine (healthy-reference). The penalized-DFA cosine is
+ the *direct measurement* of the second failure mode — it tells us
+ whether DFA's random feedback signal aligns with BP credit when the
+ scale is fixed.
+
+The pre-registered prediction (codex round 13): the penalized-DFA cosine
+will still be near zero (~0.01-0.05), confirming that the direction quality
+of DFA's signal is the second, *separable* failure mode.
+
+Run:
+ CUDA_VISIBLE_DEVICES=2 python experiments/dfa_direction_quality_test.py \
+ --seed 42 --epochs 100 --lam 1e-2
+"""
+import sys, os, argparse, json
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+import torchvision
+import torchvision.transforms as transforms
+from torch.utils.data import DataLoader
+import numpy as np
+
+from models.residual_mlp import ResidualMLP
+
+
+# --------------------------------------------------------------------------- #
+# Data
+# --------------------------------------------------------------------------- #
+
+def get_loaders(batch_size=128):
+ tv_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train)
+ te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv)
+ return (
+ DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2),
+ DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2),
+ )
+
+
+def evaluate(model, loader, dev):
+ model.eval()
+ n = c = 0
+ with torch.no_grad():
+ for x, y in loader:
+ x = x.view(x.size(0), -1).to(dev); y = y.to(dev)
+ preds = model(x).argmax(-1)
+ c += (preds == y).sum().item()
+ n += x.size(0)
+ return c / n
+
+
+# --------------------------------------------------------------------------- #
+# DFA training (vanilla and with residual-branch penalty)
+# --------------------------------------------------------------------------- #
+
+def train_dfa(model, train_loader, dev, epochs, lr, wd, lam, Bs):
+ """DFA training. lam=0 reproduces vanilla DFA."""
+ L = model.num_blocks
+ block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd)
+ head_opt = optim.AdamW(
+ list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=lr, weight_decay=wd
+ )
+ scheds = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + [
+ optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs),
+ ]
+
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(dev); y = y.to(dev)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1
+ hL_det = hiddens[-1].detach()
+ logits_out = model.out_head(model.out_ln(hL_det))
+ head_opt.zero_grad()
+ F.cross_entropy(logits_out, y).backward()
+ head_opt.step()
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a_dfa = (e_T @ Bs[l].T).detach()
+ rms = (a_dfa ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a_dfa / rms
+ f_l = model.blocks[l](h_l)
+ local_dfa = (f_l * a_norm).sum(-1).mean()
+ penalty = lam * (f_l ** 2).sum(-1).mean()
+ local_loss = local_dfa + penalty
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+ a_0 = (e_T @ Bs[0].T).detach()
+ rms_0 = (a_0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ h0_emb = model.embed(x)
+ embed_loss = (h0_emb * (a_0 / rms_0)).sum(-1).mean()
+ embed_opt.zero_grad()
+ embed_loss.backward()
+ embed_opt.step()
+ for s in scheds: s.step()
+
+
+# --------------------------------------------------------------------------- #
+# Direction-quality measurement
+# --------------------------------------------------------------------------- #
+
+def measure_direction_quality(model, Bs, x, y, dev):
+ """For each layer l, compute the per-sample cosine between:
+ DFA local credit a_l = e_T @ B_l^T
+ BP grad at h_l g_l = d L / d h_l
+ Return per-layer mean cosine, plus the magnitudes of both signals.
+ """
+ L = model.num_blocks
+
+ # 1) Forward pass with hidden states retained for BP grad computation.
+ model.eval()
+ with torch.enable_grad():
+ h = model.embed(x)
+ hiddens = [h]
+ for block in model.blocks:
+ h = h + block(h)
+ hiddens.append(h)
+ logits = model.out_head(model.out_ln(h))
+ loss = F.cross_entropy(logits, y)
+ grads = torch.autograd.grad(loss, hiddens)
+ # grads[l] is d L / d h_l (per-sample, scaled by 1/N from the mean reduction)
+
+ # 2) DFA local credit signal: e_T @ B_l^T using the model's trained Bs and
+ # the SAME forward we just did
+ with torch.no_grad():
+ N = x.size(0)
+ # The DFA signal uses softmax(logits) - one_hot(y) (the "error" e_T).
+ e_T = F.softmax(logits.detach(), dim=-1)
+ e_T[torch.arange(N), y] -= 1 # (N, C)
+
+ out: dict = {}
+ for l in range(L + 1):
+ g_l = grads[l].detach() # (N, d)
+ # DFA's local credit signal at layer l is e_T @ B_{min(l, L-1)}^T
+ # (the embedding update uses Bs[0]; block l update uses Bs[l]; for
+ # the deepest hidden state h_L there is no block beyond it, so we
+ # report Bs[L-1] which is the closest comparator)
+ b_idx = min(l, L - 1)
+ a_l = (e_T @ Bs[b_idx].T).detach() # (N, d)
+
+ # Per-sample cosines, then mean
+ eps = 1e-30 # NOT torch's default 1e-8 — we want the true cosine
+ ag = (a_l * g_l).sum(dim=-1)
+ an = a_l.norm(dim=-1)
+ gn = g_l.norm(dim=-1)
+ cos = ag / (an * gn + eps)
+ out[f"layer_{l}"] = {
+ "cos_mean": float(cos.mean().item()),
+ "cos_std": float(cos.std().item()),
+ "cos_median": float(cos.median().item()),
+ "g_norm_median": float(gn.median().item()),
+ "a_norm_median": float(an.median().item()),
+ }
+ return out
+
+
+# --------------------------------------------------------------------------- #
+# Main
+# --------------------------------------------------------------------------- #
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--seed', type=int, default=42)
+ p.add_argument('--epochs', type=int, default=100)
+ p.add_argument('--lr', type=float, default=1e-3)
+ p.add_argument('--wd', type=float, default=0.01)
+ p.add_argument('--lam', type=float, default=1e-2)
+ p.add_argument('--output_dir', type=str, default='results/dfa_direction_quality')
+ args = p.parse_args()
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ dev = torch.device('cuda:0')
+ print(f"DFA direction-quality direct test: seed={args.seed}, lam={args.lam}", flush=True)
+ train_loader, test_loader = get_loaders(batch_size=128)
+
+ # Eval batch for direction-quality measurement
+ xs, ys = [], []
+ for x, y in test_loader:
+ xs.append(x.view(x.size(0), -1)); ys.append(y)
+ if sum(xb.size(0) for xb in xs) >= 1024:
+ break
+ x_eval = torch.cat(xs)[:1024].to(dev)
+ y_eval = torch.cat(ys)[:1024].to(dev)
+
+ # ----- VANILLA DFA (lam=0) ----- #
+ print("\n=== Vanilla DFA (lam=0) ===")
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ m_vanilla = ResidualMLP(3072, 256, 10, 4).to(dev)
+ Bs_vanilla = [torch.randn(256, 10, device=dev) / np.sqrt(10) for _ in range(4)]
+ train_dfa(m_vanilla, train_loader, dev, args.epochs, args.lr, args.wd, lam=0.0, Bs=Bs_vanilla)
+ acc_vanilla = evaluate(m_vanilla, test_loader, dev)
+ print(f" vanilla DFA test acc: {acc_vanilla:.4f}")
+ quality_vanilla = measure_direction_quality(m_vanilla, Bs_vanilla, x_eval, y_eval, dev)
+ print(" vanilla DFA per-layer DFA-credit vs BP-grad cosine:")
+ for k, v in quality_vanilla.items():
+ print(f" {k}: cos_mean={v['cos_mean']:+.4f} ||g||={v['g_norm_median']:.2e} ||a||={v['a_norm_median']:.2e}")
+
+ # ----- PENALIZED DFA (lam>0) ----- #
+ print(f"\n=== Penalized DFA (lam={args.lam}) ===")
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ m_pen = ResidualMLP(3072, 256, 10, 4).to(dev)
+ Bs_pen = [torch.randn(256, 10, device=dev) / np.sqrt(10) for _ in range(4)]
+ train_dfa(m_pen, train_loader, dev, args.epochs, args.lr, args.wd, lam=args.lam, Bs=Bs_pen)
+ acc_pen = evaluate(m_pen, test_loader, dev)
+ print(f" penalized DFA test acc: {acc_pen:.4f}")
+ quality_pen = measure_direction_quality(m_pen, Bs_pen, x_eval, y_eval, dev)
+ print(" penalized DFA per-layer DFA-credit vs BP-grad cosine:")
+ for k, v in quality_pen.items():
+ print(f" {k}: cos_mean={v['cos_mean']:+.4f} ||g||={v['g_norm_median']:.2e} ||a||={v['a_norm_median']:.2e}")
+
+ # Save results
+ out = {
+ "config": vars(args),
+ "vanilla": {
+ "test_acc": acc_vanilla,
+ "direction_quality": quality_vanilla,
+ },
+ "penalized": {
+ "test_acc": acc_pen,
+ "direction_quality": quality_pen,
+ },
+ }
+ out_path = os.path.join(args.output_dir, f'direction_quality_lam{args.lam}_s{args.seed}.json')
+ with open(out_path, 'w') as f:
+ json.dump(out, f, indent=2)
+
+ # Save the penalized checkpoint so the protocol can later be re-applied
+ ckpt_path = os.path.join(args.output_dir, f'penalized_dfa_lam{args.lam}_s{args.seed}.pt')
+ torch.save({
+ "state_dict": m_pen.state_dict(),
+ "Bs": [b.cpu() for b in Bs_pen],
+ "config": vars(args),
+ "test_acc": acc_pen,
+ }, ckpt_path)
+ print(f"\nSaved {out_path}")
+ print(f"Saved {ckpt_path}")
+
+ # Pre-registered interpretation summary
+ print("\n" + "=" * 72)
+ print("INTERPRETATION (vs codex round 13's pre-registered prediction)")
+ print("=" * 72)
+ g_vanilla = quality_vanilla["layer_2"]["g_norm_median"]
+ g_pen = quality_pen["layer_2"]["g_norm_median"]
+ cos_vanilla = quality_vanilla["layer_2"]["cos_mean"]
+ cos_pen = quality_pen["layer_2"]["cos_mean"]
+ print(f" vanilla DFA: ||g_2||={g_vanilla:.2e} cos(DFA, BP)={cos_vanilla:+.4f} -> reference at floor")
+ print(f" penalty DFA: ||g_2||={g_pen:.2e} cos(DFA, BP)={cos_pen:+.4f} -> reference healthy")
+ if g_pen > 1e-7:
+ if abs(cos_pen) < 0.05:
+ print(" -> Direction quality is POOR even with healthy reference. Second failure mode CONFIRMED.")
+ elif abs(cos_pen) < 0.20:
+ print(" -> Direction quality is mediocre with healthy reference. Second failure mode partially supported.")
+ else:
+ print(" -> Direction quality is reasonable with healthy reference. Second failure mode REJECTED — DFA's signal is OK, the gap to BP must come from something else.")
+ else:
+ print(" -> WARNING: penalized BP grad still below 1e-7; reference is not healthy. Try larger lam.")
+
+
+if __name__ == '__main__':
+ main()