summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-08 00:07:39 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-08 00:07:39 -0500
commit1f705408da9eb9ff0fcb6f2269dadb2ebf71a0f1 (patch)
treee759097a6c575e0c03f0e5df004c595d8caa2f57
parent76edf529be1b8aa8813ce380d104eaa424a3dc1d (diff)
Add penalty lambda 3-seed summary script + checkpoint save in penalty test
- New script: protocol/examples/penalty_lam_3seed_summary.py Loads existing penalty JSON files for lam=1e-3 and lam=1e-2 across seeds, computes 3-seed mean margin vs DFA-shallow baseline, and explicitly checks the (d) verdict at 2pp threshold per seed and in aggregate. Reports MIXED if seeds disagree. Current result: lam=1e-2 has 3 seeds (margin +1.38 ± 0.05 pp, all FIRE), lam=1e-3 has 1 seed (+2.31 pp, PASSES). Awaiting s123/s456 for lam=1e-3. - experiments/dfa_residual_penalty_test.py: now saves model checkpoint + Bs alongside JSON log so post-hoc protocol can be applied without re-running. Closes the pitfall #6.5 self-disclosure (auxiliary nets must be saved for post-hoc Gamma to be reconstructible).
-rw-r--r--experiments/dfa_residual_penalty_test.py214
-rw-r--r--protocol/examples/penalty_lam_3seed_summary.py92
2 files changed, 306 insertions, 0 deletions
diff --git a/experiments/dfa_residual_penalty_test.py b/experiments/dfa_residual_penalty_test.py
new file mode 100644
index 0000000..3fa5466
--- /dev/null
+++ b/experiments/dfa_residual_penalty_test.py
@@ -0,0 +1,214 @@
+"""
+Codex round 11's decisive validation: train DFA on 4-block d=256 ResMLP with an
+explicit residual-branch penalty `λ ||f_l(h_l)||^2` added to each block's local
+loss. Tests whether constraining the block output magnitude is sufficient to
+rescue DFA from the residual-stream-explosion → BP grad collapse → active harm
+failure mode.
+
+Conditions:
+ - DFA-vanilla (λ=0): baseline, expected to reproduce 30.8% acc + ||h_L||~4e8
+ - DFA-penalized (λ=1e-3, 1e-2, 1e-1): different penalty strengths
+
+Three outcomes:
+ (A) ||h_L|| bounded AND BP grad healthy AND acc > shallow baseline (34.7%)
+ → mechanism chain causally validated
+ (B) ||h_L|| bounded AND BP grad healthy BUT acc still ≤ shallow baseline
+ → mechanism is necessary but not sufficient; other factor at play
+ (C) ||h_L|| stays exploded under the penalty
+ → penalty is too weak or wrong target
+
+Usage:
+ CUDA_VISIBLE_DEVICES=2 python experiments/dfa_residual_penalty_test.py --seed 42 --lam 1e-2
+"""
+import sys, os, argparse, json
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+import numpy as np
+
+from models.residual_mlp import ResidualMLP
+
+
+def get_loaders(batch_size=128):
+ tv_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train)
+ te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv)
+ return (
+ DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2),
+ DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2),
+ )
+
+
+def evaluate(model, loader, dev):
+ model.eval()
+ n = c = 0
+ with torch.no_grad():
+ for x, y in loader:
+ x = x.view(x.size(0), -1).to(dev); y = y.to(dev)
+ preds = model(x).argmax(-1)
+ c += (preds == y).sum().item()
+ n += x.size(0)
+ return c / n
+
+
+def diagnose(model, x_eval, y_eval, dev):
+ """Compute ||h_L||, ||BP grad at h_2||, and acc on a fixed eval batch."""
+ model.eval()
+ with torch.no_grad():
+ _, hi = model(x_eval, return_hidden=True)
+ h_L_norm = hi[-1].norm(dim=-1).median().item()
+
+ h0 = model.embed(x_eval.detach())
+ hs = [h0.clone().requires_grad_(True)]
+ for b in model.blocks: hs.append(hs[-1] + b(hs[-1]))
+ lo = model.out_head(model.out_ln(hs[-1]))
+ loss = F.cross_entropy(lo, y_eval)
+ gs = torch.autograd.grad(loss, hs)
+ g_2_norm = gs[2].norm(dim=-1).median().item()
+ acc = (lo.argmax(-1) == y_eval).float().mean().item()
+ return h_L_norm, g_2_norm, acc
+
+
+def train_dfa_with_penalty(model, train_loader, test_loader, x_eval, y_eval, dev, epochs, lr, wd, lam):
+ """DFA training with residual-branch penalty `lam * ||f_l(h_l)||^2` added
+ to each block's local loss."""
+ d_hidden = model.d_hidden
+ L = model.num_blocks
+ C = 10
+ Bs = [torch.randn(d_hidden, C, device=dev) / np.sqrt(C) for _ in range(L)]
+ block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd)
+ head_opt = optim.AdamW(
+ list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=lr, weight_decay=wd
+ )
+ all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \
+ [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]
+ log = []
+ h0, g0, a0 = diagnose(model, x_eval, y_eval, dev)
+ log.append({'epoch': 0, 'h_L_norm': h0, 'g_2_norm': g0, 'acc_eval': a0})
+ print(f" ep 0: ||h_L||={h0:.3e} ||g_2||={g0:.3e} acc={a0:.4f}", flush=True)
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(dev); y = y.to(dev)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1
+ hL_det = hiddens[-1].detach()
+ # Head update via true CE on out_ln(h_L)
+ logits_out = model.out_head(model.out_ln(hL_det))
+ head_opt.zero_grad()
+ F.cross_entropy(logits_out, y).backward()
+ head_opt.step()
+ # Block updates via DFA local credit + residual-branch penalty
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a_dfa = (e_T @ Bs[l].T).detach()
+ rms = (a_dfa ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a_dfa / rms
+ f_l = model.blocks[l](h_l)
+ # Original DFA local loss
+ local_dfa = (f_l * a_norm).sum(-1).mean()
+ # Residual-branch penalty (codex round 11): λ * mean(||f_l||²)
+ penalty = lam * (f_l ** 2).sum(-1).mean()
+ local_loss = local_dfa + penalty
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+ # Embed update via DFA-style on h_0
+ a_0 = (e_T @ Bs[0].T).detach()
+ rms_0 = (a_0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ h0_emb = model.embed(x)
+ embed_loss = (h0_emb * (a_0 / rms_0)).sum(-1).mean()
+ embed_opt.zero_grad()
+ embed_loss.backward()
+ embed_opt.step()
+ for s in all_sch: s.step()
+ if ep % 10 == 0 or ep == 1 or ep == epochs:
+ h, g, a = diagnose(model, x_eval, y_eval, dev)
+ log.append({'epoch': ep, 'h_L_norm': h, 'g_2_norm': g, 'acc_eval': a})
+ test_acc = evaluate(model, test_loader, dev)
+ print(f" ep {ep}: ||h_L||={h:.3e} ||g_2||={g:.3e} eval_acc={a:.4f} test_acc={test_acc:.4f}", flush=True)
+ return log
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--seed', type=int, default=42)
+ p.add_argument('--epochs', type=int, default=100)
+ p.add_argument('--lr', type=float, default=1e-3)
+ p.add_argument('--wd', type=float, default=0.01)
+ p.add_argument('--lam', type=float, default=1e-2,
+ help='residual-branch penalty strength λ for ||f_l(h_l)||²')
+ p.add_argument('--output_dir', type=str, default='results/dfa_residual_penalty')
+ args = p.parse_args()
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ dev = torch.device('cuda:0')
+ print(f"DFA + residual-branch penalty test: seed={args.seed}, lam={args.lam}", flush=True)
+ train_loader, test_loader = get_loaders(batch_size=128)
+
+ # Fixed eval buffer
+ xs, ys = [], []
+ for x, y in test_loader:
+ xs.append(x.view(x.size(0), -1)); ys.append(y)
+ if sum(xb.size(0) for xb in xs) >= 1024:
+ break
+ x_eval = torch.cat(xs)[:1024].to(dev)
+ y_eval = torch.cat(ys)[:1024].to(dev)
+
+ L, d, C = 4, 256, 10
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ m = ResidualMLP(3072, d, C, L).to(dev)
+ log = train_dfa_with_penalty(m, train_loader, test_loader, x_eval, y_eval, dev, args.epochs, args.lr, args.wd, args.lam)
+
+ final_test = evaluate(m, test_loader, dev)
+ print(f"\nFINAL test acc: {final_test:.4f}")
+ print(f"Compare to:")
+ print(f" DFA-vanilla (3-seed mean): 0.308")
+ print(f" DFA-shallow (3-seed mean): 0.349")
+ print(f" DFA-frozen (3-seed mean): 0.349")
+ print(f" BP-trainable (3-seed mean): 0.609")
+
+ out = {'config': vars(args), 'final_test_acc': final_test, 'log': log}
+ out_path = os.path.join(args.output_dir, f'dfa_pen_lam{args.lam}_s{args.seed}.json')
+ with open(out_path, 'w') as f:
+ json.dump(out, f, indent=2)
+ print(f"Saved {out_path}")
+
+ # Round 18: save checkpoint AND Bs for post-hoc protocol application
+ # (was missing — caused us to need a separate direction-quality experiment)
+ ckpt_path = os.path.join(args.output_dir, f'dfa_pen_lam{args.lam}_s{args.seed}.pt')
+ # Reconstruct the Bs sequence the way train_dfa_with_penalty did
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ _ = ResidualMLP(3072, d, C, L) # consume RNG draws to match training
+ Bs = [torch.randn(d, C, device=dev) / np.sqrt(C) for _ in range(L)]
+ torch.save({
+ "state_dict": m.state_dict(),
+ "Bs": [b.cpu() for b in Bs],
+ "config": vars(args),
+ "test_acc": final_test,
+ }, ckpt_path)
+ print(f"Saved {ckpt_path}")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/protocol/examples/penalty_lam_3seed_summary.py b/protocol/examples/penalty_lam_3seed_summary.py
new file mode 100644
index 0000000..0261d30
--- /dev/null
+++ b/protocol/examples/penalty_lam_3seed_summary.py
@@ -0,0 +1,92 @@
+"""
+Summarize penalty 3-seed results across lambda values.
+
+Requires:
+ - results/dfa_residual_penalty/dfa_pen_lam{0.001,0.01}_s{42,123,456}.json
+
+Reports per-seed acc, h_L, g_2 + 3-seed mean and std for each lambda, and
+explicitly checks the (d) diagnostic margin against the 2pp threshold.
+
+Run:
+ python -m protocol.examples.penalty_lam_3seed_summary
+"""
+import os
+import sys
+import json
+
+import numpy as np
+
+REPO_ROOT = os.path.dirname(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+)
+PEN_DIR = os.path.join(REPO_ROOT, "results/dfa_residual_penalty")
+SHALLOW_BASELINE = 0.349
+
+
+def load_one(lam, seed):
+ path = os.path.join(PEN_DIR, f"dfa_pen_lam{lam}_s{seed}.json")
+ if not os.path.exists(path):
+ return None
+ with open(path) as f:
+ d = json.load(f)
+ final = d["log"][-1]
+ return {
+ "acc": d["final_test_acc"],
+ "h_L": final["h_L_norm"],
+ "g_2": final["g_2_norm"],
+ }
+
+
+def main():
+ print("=" * 88)
+ print("DFA + ‖f_l(h_l)‖² penalty: 3-seed summary by λ")
+ print("=" * 88)
+
+ for lam in ["0.001", "0.01"]:
+ print(f"\n=== λ = {lam} ===")
+ rows = []
+ for seed in [42, 123, 456]:
+ r = load_one(lam, seed)
+ if r is None:
+ print(f" s{seed}: NOT YET AVAILABLE")
+ continue
+ rows.append({"seed": seed, **r})
+ print(f" s{seed}: acc={r['acc']:.4f} ‖h_L‖={r['h_L']:.3e} ‖g_2‖={r['g_2']:.3e}")
+ if not rows:
+ continue
+ accs = np.array([r["acc"] for r in rows])
+ h_Ls = np.array([r["h_L"] for r in rows])
+ g_2s = np.array([r["g_2"] for r in rows])
+ margins_pp = (accs - SHALLOW_BASELINE) * 100
+ print(f" 3-seed (or partial) mean: acc={accs.mean():.4f} ± {accs.std():.4f}, "
+ f"‖h_L‖={h_Ls.mean():.2e}, ‖g_2‖={g_2s.mean():.2e}")
+ print(f" margin vs DFA-shallow {SHALLOW_BASELINE}: "
+ f"{margins_pp.mean():+.2f} ± {margins_pp.std():.2f} pp")
+ # (d) verdict at 2pp threshold
+ fires = sum(1 for m in margins_pp if m < 2.0)
+ print(f" (d) at 2pp threshold: {fires}/{len(rows)} seeds FIRE")
+ if fires == 0:
+ verdict = "ALL PASS — penalty rescues to clear (d)"
+ elif fires == len(rows):
+ verdict = "ALL FIRE — second failure mode robust to seed"
+ else:
+ verdict = "MIXED — verdict depends on seed"
+ print(f" Aggregate (d) reading at λ={lam}: {verdict}")
+
+ print()
+ print("=" * 88)
+ print("LAMBDA × THRESHOLD CROSS-CHECK")
+ print("=" * 88)
+ print()
+ print("If λ=1e-3 3-seed mean margin exceeds 2 pp on all 3 seeds:")
+ print(" → my prior 'two failure modes via (d)' claim must be downgraded to")
+ print(" 'tradeoff between penalty strength and depth utilization'")
+ print()
+ print("If λ=1e-3 3-seed mean is ~1-2 pp (similar spread to λ=1e-2 ~1.4 pp):")
+ print(" → s42 +2.3 pp was a noisy outlier; the (d) 'second failure mode' story holds")
+ print()
+ print("Either outcome is publishable. The point is to learn it before a reviewer does.")
+
+
+if __name__ == "__main__":
+ main()