summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 22:29:00 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 22:29:00 -0500
commit111bab56e2d49c9fb1f3bfb9e55ea2028da4d008 (patch)
tree5963a8171c383023b3bd19ed3a86e460ebe99615
parent7b64702ad970c16171142665365e16a8e1737190 (diff)
Add audit table example: protocol applied to BP/DFA/SB/CB/EP
5-method audit table on 4-block d=256 ResMLP CIFAR-10 seed 42: - BP: trustworthy (acc 0.615, h_L=2e2, g_L=4e-4, stab 0.099) - DFA: walked back via (a)+(b)+(d) — h_L=4e8, g_L=4e-9, undercuts frozen - State Bridge: walked back via all 4 diagnostics — stability 0.992 is the cleanest possible drift-dominated case - Credit Bridge: walked back via all 4 — stability 0.352, also drift mode - EP: trustworthy (acc 0.359, h_L=3e3, g_L=2e-4, stab -0.036) — paper's internal control case This is the §2 audit evidence for the main-track paper. Confirms that standard headline acc + Γ silently fails on 3 of 5 methods on this architecture, while the 4-diagnostic protocol catches all three.
-rw-r--r--protocol/examples/__init__.py0
-rw-r--r--protocol/examples/audit_table.py162
-rw-r--r--results/protocol_audit/audit_table_s42.json196
3 files changed, 358 insertions, 0 deletions
diff --git a/protocol/examples/__init__.py b/protocol/examples/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/protocol/examples/__init__.py
diff --git a/protocol/examples/audit_table.py b/protocol/examples/audit_table.py
new file mode 100644
index 0000000..1a75d96
--- /dev/null
+++ b/protocol/examples/audit_table.py
@@ -0,0 +1,162 @@
+"""
+Reproduce the §2 audit table: apply the diagnostic protocol to BP / DFA /
+State Bridge / Credit Bridge / EP checkpoints on the 4-block d=256 ResMLP /
+CIFAR-10 setup. Single seed 42 for the table; the paper uses 3-seed means
+elsewhere.
+
+Output is a per-method tabular summary that lists, for each diagnostic,
+the per-layer values and the verdict. This is the audit evidence behind the
+paper claim *"standard FA evaluation reports headline accuracy + Γ as
+evidence of training, but on modern pre-LN residual networks both signals
+silently fail for non-BP methods."*
+
+Run:
+ CUDA_VISIBLE_DEVICES=2 python -m protocol.examples.audit_table
+"""
+import os
+import sys
+import json
+
+import torch
+import torchvision
+import torchvision.transforms as transforms
+from torch.utils.data import DataLoader
+
+REPO_ROOT = os.path.dirname(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+)
+sys.path.insert(0, REPO_ROOT)
+
+from models.residual_mlp import ResidualMLP # noqa: E402
+from protocol import diagnose # noqa: E402
+
+CHECKPOINT_DIR = os.path.join(REPO_ROOT, "results/confirmatory/checkpoints_A2")
+EP_CHECKPOINT_DIR = os.path.join(REPO_ROOT, "results/ep_baseline")
+OUT_DIR = os.path.join(REPO_ROOT, "results/protocol_audit")
+os.makedirs(OUT_DIR, exist_ok=True)
+
+
+def load_eval_batches(n_batches=10, batch_size=128, device="cuda:0"):
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
+ loader = DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=0)
+ batches = []
+ for x, y in loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batches.append((x, y))
+ if len(batches) >= n_batches:
+ break
+ return batches
+
+
+def evaluate(model, device):
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
+ loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0)
+ model.eval()
+ correct = total = 0
+ with torch.no_grad():
+ for x, y in loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ preds = model(x).argmax(-1)
+ correct += (preds == y).sum().item()
+ total += x.size(0)
+ return correct / total
+
+
+def load_model(method: str, seed: int, device):
+ if method == "ep":
+ path = os.path.join(EP_CHECKPOINT_DIR, f"ep_s{seed}.pt")
+ else:
+ path = os.path.join(CHECKPOINT_DIR, f"{method}_s{seed}.pt")
+ ckpt = torch.load(path, map_location=device, weights_only=False)
+ sd = ckpt if not hasattr(ckpt, "state_dict") else ckpt.state_dict()
+ if isinstance(sd, dict) and "state_dict" in sd:
+ sd = sd["state_dict"]
+ model = ResidualMLP(input_dim=3072, d_hidden=256, num_classes=10, num_blocks=4).to(device)
+ model.load_state_dict(sd)
+ return model
+
+
+# 3-seed mean shallow / frozen baseline accuracies (from
+# project_resmlp_walkback_dfa_destroys_value memory entry — these are the
+# same number for the DFA condition by design: the "deep blocks frozen at
+# random init" is informationally equivalent to "no deep blocks").
+FROZEN_BASELINE_ACC = {
+ "bp": None, # BP-frozen is 34.6%; not the right comparator for BP-trainable
+ "dfa": 0.349, # DFA-frozen / DFA-shallow 3-seed mean
+ "state_bridge": 0.349, # uses the same architecture-matched control
+ "credit_bridge": 0.349,
+ "ep": None, # EP frozen-control not run yet
+}
+
+
+def main():
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ print(f"Device: {device}")
+ eval_batches = load_eval_batches(n_batches=10, batch_size=128, device=device)
+
+ methods = ["bp", "dfa", "state_bridge", "credit_bridge", "ep"]
+ rows = []
+ reports = {}
+ for method in methods:
+ print(f"\n### {method.upper()} (seed 42)")
+ model = load_model(method, 42, device)
+ acc = evaluate(model, device)
+ report = diagnose(
+ model=model,
+ eval_batches=eval_batches,
+ headline_acc=acc,
+ frozen_baseline_acc=FROZEN_BASELINE_ACC.get(method),
+ method_name=method.upper(),
+ notes="4-block d=256 ResMLP, CIFAR-10, seed 42",
+ )
+ print(report)
+ reports[method] = report.to_dict()
+ rows.append({
+ "method": method,
+ "acc": acc,
+ "h_L": report.residual_norms[-1],
+ "g_L": report.bp_grad_norms[-1],
+ "stability": report.cross_batch_stability,
+ "frozen_acc": report.frozen_baseline_acc,
+ "verdict": report.verdict,
+ })
+
+ # Compact summary table
+ print("\n\n" + "=" * 100)
+ print("AUDIT SUMMARY (single seed 42, 4-block d=256 ResMLP, CIFAR-10)")
+ print("=" * 100)
+ header = (
+ f"{'method':<16}{'acc':>8}{'||h_L||':>14}{'||g_L||':>14}"
+ f"{'stab(L/2)':>12}{'frozen':>10} verdict"
+ )
+ print(header)
+ print("-" * 100)
+ for r in rows:
+ frozen = "n/a" if r["frozen_acc"] is None else f"{r['frozen_acc']:.4f}"
+ print(
+ f"{r['method']:<16}"
+ f"{r['acc']:>8.4f}"
+ f"{r['h_L']:>14.3e}"
+ f"{r['g_L']:>14.3e}"
+ f"{r['stability']:>12.3f}"
+ f"{frozen:>10} {r['verdict']}"
+ )
+
+ out_path = os.path.join(OUT_DIR, "audit_table_s42.json")
+ with open(out_path, "w") as f:
+ json.dump({"reports": reports, "summary": rows}, f, indent=2)
+ print(f"\nSaved {out_path}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/results/protocol_audit/audit_table_s42.json b/results/protocol_audit/audit_table_s42.json
new file mode 100644
index 0000000..d1c1f84
--- /dev/null
+++ b/results/protocol_audit/audit_table_s42.json
@@ -0,0 +1,196 @@
+{
+ "reports": {
+ "bp": {
+ "method_name": "BP",
+ "notes": "4-block d=256 ResMLP, CIFAR-10, seed 42",
+ "residual_norms": [
+ 251.83087158203125,
+ 226.57342529296875,
+ 212.16461181640625,
+ 205.60723876953125,
+ 205.75946044921875
+ ],
+ "bp_grad_norms": [
+ 0.0004396044823806733,
+ 0.0004709330096375197,
+ 0.0004792391264345497,
+ 0.00045345001854002476,
+ 0.0003701267414726317
+ ],
+ "stability_layer": 2,
+ "cross_batch_stability": 0.09898398886952135,
+ "headline_acc": 0.6149,
+ "frozen_baseline_acc": null,
+ "verdict": "trustworthy",
+ "thresholds": {
+ "g_norm_floor": 1e-07,
+ "h_norm_explosion_ratio": 50.0,
+ "stability_drift_ceiling": 0.3,
+ "frozen_acc_margin_pp": 2.0
+ }
+ },
+ "dfa": {
+ "method_name": "DFA",
+ "notes": "4-block d=256 ResMLP, CIFAR-10, seed 42",
+ "residual_norms": [
+ 35824.796875,
+ 73202040.0,
+ 174312304.0,
+ 339040960.0,
+ 435299520.0
+ ],
+ "bp_grad_norms": [
+ 4.39066155877299e-07,
+ 4.1912620041273385e-09,
+ 4.183721813433294e-09,
+ 4.174094847542165e-09,
+ 4.174704582027289e-09
+ ],
+ "stability_layer": 2,
+ "cross_batch_stability": 0.047060725092887876,
+ "headline_acc": 0.3107,
+ "frozen_baseline_acc": 0.349,
+ "verdict": "needs walk-back: residual stream exploded; BP grad at numerical floor; deep blocks fail to beat frozen-random baseline",
+ "thresholds": {
+ "g_norm_floor": 1e-07,
+ "h_norm_explosion_ratio": 50.0,
+ "stability_drift_ceiling": 0.3,
+ "frozen_acc_margin_pp": 2.0
+ }
+ },
+ "state_bridge": {
+ "method_name": "STATE_BRIDGE",
+ "notes": "4-block d=256 ResMLP, CIFAR-10, seed 42",
+ "residual_norms": [
+ 906.3201293945312,
+ 11583499.0,
+ 34872504.0,
+ 208111168.0,
+ 228665568.0
+ ],
+ "bp_grad_norms": [
+ 8.369566785404459e-06,
+ 1.996277365634569e-09,
+ 1.9812380624983916e-09,
+ 1.8405569290891322e-09,
+ 1.8411722146893794e-09
+ ],
+ "stability_layer": 2,
+ "cross_batch_stability": 0.99180050028695,
+ "headline_acc": 0.1695,
+ "frozen_baseline_acc": 0.349,
+ "verdict": "needs walk-back: residual stream exploded; BP grad at numerical floor; BP grad direction is drift-dominated; deep blocks fail to beat frozen-random baseline",
+ "thresholds": {
+ "g_norm_floor": 1e-07,
+ "h_norm_explosion_ratio": 50.0,
+ "stability_drift_ceiling": 0.3,
+ "frozen_acc_margin_pp": 2.0
+ }
+ },
+ "credit_bridge": {
+ "method_name": "CREDIT_BRIDGE",
+ "notes": "4-block d=256 ResMLP, CIFAR-10, seed 42",
+ "residual_norms": [
+ 13249.662109375,
+ 24119914.0,
+ 554824896.0,
+ 548816832.0,
+ 606231552.0
+ ],
+ "bp_grad_norms": [
+ 7.185065555859182e-07,
+ 1.1024462454045647e-09,
+ 9.061909000962487e-10,
+ 9.013046420314197e-10,
+ 9.011226209665324e-10
+ ],
+ "stability_layer": 2,
+ "cross_batch_stability": 0.3516695586343606,
+ "headline_acc": 0.2562,
+ "frozen_baseline_acc": 0.349,
+ "verdict": "needs walk-back: residual stream exploded; BP grad at numerical floor; BP grad direction is drift-dominated; deep blocks fail to beat frozen-random baseline",
+ "thresholds": {
+ "g_norm_floor": 1e-07,
+ "h_norm_explosion_ratio": 50.0,
+ "stability_drift_ceiling": 0.3,
+ "frozen_acc_margin_pp": 2.0
+ }
+ },
+ "ep": {
+ "method_name": "EP",
+ "notes": "4-block d=256 ResMLP, CIFAR-10, seed 42",
+ "residual_norms": [
+ 518.3867797851562,
+ 579.6542358398438,
+ 680.764892578125,
+ 1145.8692626953125,
+ 3286.841064453125
+ ],
+ "bp_grad_norms": [
+ 0.00022257285309024155,
+ 0.00022327345504891127,
+ 0.00021209640544839203,
+ 0.00021204684162512422,
+ 0.00016422539192717522
+ ],
+ "stability_layer": 2,
+ "cross_batch_stability": -0.03589460700750351,
+ "headline_acc": 0.359,
+ "frozen_baseline_acc": null,
+ "verdict": "trustworthy",
+ "thresholds": {
+ "g_norm_floor": 1e-07,
+ "h_norm_explosion_ratio": 50.0,
+ "stability_drift_ceiling": 0.3,
+ "frozen_acc_margin_pp": 2.0
+ }
+ }
+ },
+ "summary": [
+ {
+ "method": "bp",
+ "acc": 0.6149,
+ "h_L": 205.75946044921875,
+ "g_L": 0.0003701267414726317,
+ "stability": 0.09898398886952135,
+ "frozen_acc": null,
+ "verdict": "trustworthy"
+ },
+ {
+ "method": "dfa",
+ "acc": 0.3107,
+ "h_L": 435299520.0,
+ "g_L": 4.174704582027289e-09,
+ "stability": 0.047060725092887876,
+ "frozen_acc": 0.349,
+ "verdict": "needs walk-back: residual stream exploded; BP grad at numerical floor; deep blocks fail to beat frozen-random baseline"
+ },
+ {
+ "method": "state_bridge",
+ "acc": 0.1695,
+ "h_L": 228665568.0,
+ "g_L": 1.8411722146893794e-09,
+ "stability": 0.99180050028695,
+ "frozen_acc": 0.349,
+ "verdict": "needs walk-back: residual stream exploded; BP grad at numerical floor; BP grad direction is drift-dominated; deep blocks fail to beat frozen-random baseline"
+ },
+ {
+ "method": "credit_bridge",
+ "acc": 0.2562,
+ "h_L": 606231552.0,
+ "g_L": 9.011226209665324e-10,
+ "stability": 0.3516695586343606,
+ "frozen_acc": 0.349,
+ "verdict": "needs walk-back: residual stream exploded; BP grad at numerical floor; BP grad direction is drift-dominated; deep blocks fail to beat frozen-random baseline"
+ },
+ {
+ "method": "ep",
+ "acc": 0.359,
+ "h_L": 3286.841064453125,
+ "g_L": 0.00016422539192717522,
+ "stability": -0.03589460700750351,
+ "frozen_acc": null,
+ "verdict": "trustworthy"
+ }
+ ]
+} \ No newline at end of file