""" Reproduce the §2 audit table: apply the diagnostic protocol to BP / DFA / State Bridge / Credit Bridge / EP checkpoints on the 4-block d=256 ResMLP / CIFAR-10 setup. Single seed 42 for the table; the paper uses 3-seed means elsewhere. Output is a per-method tabular summary that lists, for each diagnostic, the per-layer values and the verdict. This is the audit evidence behind the paper claim *"standard FA evaluation reports headline accuracy + Γ as evidence of training, but on modern pre-LN residual networks both signals silently fail for non-BP methods."* Run: CUDA_VISIBLE_DEVICES=2 python -m protocol.examples.audit_table """ import os import sys import json import torch import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader REPO_ROOT = os.path.dirname( os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ) sys.path.insert(0, REPO_ROOT) from models.residual_mlp import ResidualMLP # noqa: E402 from protocol import diagnose # noqa: E402 CHECKPOINT_DIR = os.path.join(REPO_ROOT, "results/confirmatory/checkpoints_A2") EP_CHECKPOINT_DIR = os.path.join(REPO_ROOT, "results/ep_baseline") OUT_DIR = os.path.join(REPO_ROOT, "results/protocol_audit") os.makedirs(OUT_DIR, exist_ok=True) def load_eval_batches(n_batches=10, batch_size=128, device="cuda:0"): tv = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv) loader = DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=0) batches = [] for x, y in loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) batches.append((x, y)) if len(batches) >= n_batches: break return batches def evaluate(model, device): tv = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv) loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0) model.eval() correct = total = 0 with torch.no_grad(): for x, y in loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) preds = model(x).argmax(-1) correct += (preds == y).sum().item() total += x.size(0) return correct / total def load_model(method: str, seed: int, device): if method == "ep": path = os.path.join(EP_CHECKPOINT_DIR, f"ep_s{seed}.pt") else: path = os.path.join(CHECKPOINT_DIR, f"{method}_s{seed}.pt") ckpt = torch.load(path, map_location=device, weights_only=False) sd = ckpt if not hasattr(ckpt, "state_dict") else ckpt.state_dict() if isinstance(sd, dict) and "state_dict" in sd: sd = sd["state_dict"] model = ResidualMLP(input_dim=3072, d_hidden=256, num_classes=10, num_blocks=4).to(device) model.load_state_dict(sd) return model # 3-seed mean shallow / frozen baseline accuracies (from # project_resmlp_walkback_dfa_destroys_value memory entry — these are the # same number for the DFA condition by design: the "deep blocks frozen at # random init" is informationally equivalent to "no deep blocks"). FROZEN_BASELINE_ACC = { "bp": None, # BP-frozen is 34.6%; not the right comparator for BP-trainable "dfa": 0.349, # DFA-frozen / DFA-shallow 3-seed mean "state_bridge": 0.349, # uses the same architecture-matched control "credit_bridge": 0.349, "ep": None, # EP frozen-control not run yet } def main(): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") eval_batches = load_eval_batches(n_batches=10, batch_size=128, device=device) methods = ["bp", "dfa", "state_bridge", "credit_bridge", "ep"] rows = [] reports = {} for method in methods: print(f"\n### {method.upper()} (seed 42)") model = load_model(method, 42, device) acc = evaluate(model, device) report = diagnose( model=model, eval_batches=eval_batches, headline_acc=acc, frozen_baseline_acc=FROZEN_BASELINE_ACC.get(method), method_name=method.upper(), notes="4-block d=256 ResMLP, CIFAR-10, seed 42", ) print(report) reports[method] = report.to_dict() rows.append({ "method": method, "acc": acc, "h_L": report.residual_norms[-1], "g_L": report.bp_grad_norms[-1], "stability": report.cross_batch_stability, "frozen_acc": report.frozen_baseline_acc, "verdict": report.verdict, }) # Compact summary table print("\n\n" + "=" * 100) print("AUDIT SUMMARY (single seed 42, 4-block d=256 ResMLP, CIFAR-10)") print("=" * 100) header = ( f"{'method':<16}{'acc':>8}{'||h_L||':>14}{'||g_L||':>14}" f"{'stab(L/2)':>12}{'frozen':>10} verdict" ) print(header) print("-" * 100) for r in rows: frozen = "n/a" if r["frozen_acc"] is None else f"{r['frozen_acc']:.4f}" print( f"{r['method']:<16}" f"{r['acc']:>8.4f}" f"{r['h_L']:>14.3e}" f"{r['g_L']:>14.3e}" f"{r['stability']:>12.3f}" f"{frozen:>10} {r['verdict']}" ) out_path = os.path.join(OUT_DIR, "audit_table_s42.json") with open(out_path, "w") as f: json.dump({"reports": reports, "summary": rows}, f, indent=2) print(f"\nSaved {out_path}") if __name__ == "__main__": main()