summaryrefslogtreecommitdiff
path: root/protocol
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 22:29:00 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 22:29:00 -0500
commit111bab56e2d49c9fb1f3bfb9e55ea2028da4d008 (patch)
tree5963a8171c383023b3bd19ed3a86e460ebe99615 /protocol
parent7b64702ad970c16171142665365e16a8e1737190 (diff)
Add audit table example: protocol applied to BP/DFA/SB/CB/EP
5-method audit table on 4-block d=256 ResMLP CIFAR-10 seed 42: - BP: trustworthy (acc 0.615, h_L=2e2, g_L=4e-4, stab 0.099) - DFA: walked back via (a)+(b)+(d) — h_L=4e8, g_L=4e-9, undercuts frozen - State Bridge: walked back via all 4 diagnostics — stability 0.992 is the cleanest possible drift-dominated case - Credit Bridge: walked back via all 4 — stability 0.352, also drift mode - EP: trustworthy (acc 0.359, h_L=3e3, g_L=2e-4, stab -0.036) — paper's internal control case This is the §2 audit evidence for the main-track paper. Confirms that standard headline acc + Γ silently fails on 3 of 5 methods on this architecture, while the 4-diagnostic protocol catches all three.
Diffstat (limited to 'protocol')
-rw-r--r--protocol/examples/__init__.py0
-rw-r--r--protocol/examples/audit_table.py162
2 files changed, 162 insertions, 0 deletions
diff --git a/protocol/examples/__init__.py b/protocol/examples/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/protocol/examples/__init__.py
diff --git a/protocol/examples/audit_table.py b/protocol/examples/audit_table.py
new file mode 100644
index 0000000..1a75d96
--- /dev/null
+++ b/protocol/examples/audit_table.py
@@ -0,0 +1,162 @@
+"""
+Reproduce the §2 audit table: apply the diagnostic protocol to BP / DFA /
+State Bridge / Credit Bridge / EP checkpoints on the 4-block d=256 ResMLP /
+CIFAR-10 setup. Single seed 42 for the table; the paper uses 3-seed means
+elsewhere.
+
+Output is a per-method tabular summary that lists, for each diagnostic,
+the per-layer values and the verdict. This is the audit evidence behind the
+paper claim *"standard FA evaluation reports headline accuracy + Γ as
+evidence of training, but on modern pre-LN residual networks both signals
+silently fail for non-BP methods."*
+
+Run:
+ CUDA_VISIBLE_DEVICES=2 python -m protocol.examples.audit_table
+"""
+import os
+import sys
+import json
+
+import torch
+import torchvision
+import torchvision.transforms as transforms
+from torch.utils.data import DataLoader
+
+REPO_ROOT = os.path.dirname(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+)
+sys.path.insert(0, REPO_ROOT)
+
+from models.residual_mlp import ResidualMLP # noqa: E402
+from protocol import diagnose # noqa: E402
+
+CHECKPOINT_DIR = os.path.join(REPO_ROOT, "results/confirmatory/checkpoints_A2")
+EP_CHECKPOINT_DIR = os.path.join(REPO_ROOT, "results/ep_baseline")
+OUT_DIR = os.path.join(REPO_ROOT, "results/protocol_audit")
+os.makedirs(OUT_DIR, exist_ok=True)
+
+
+def load_eval_batches(n_batches=10, batch_size=128, device="cuda:0"):
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
+ loader = DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=0)
+ batches = []
+ for x, y in loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batches.append((x, y))
+ if len(batches) >= n_batches:
+ break
+ return batches
+
+
+def evaluate(model, device):
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
+ loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0)
+ model.eval()
+ correct = total = 0
+ with torch.no_grad():
+ for x, y in loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ preds = model(x).argmax(-1)
+ correct += (preds == y).sum().item()
+ total += x.size(0)
+ return correct / total
+
+
+def load_model(method: str, seed: int, device):
+ if method == "ep":
+ path = os.path.join(EP_CHECKPOINT_DIR, f"ep_s{seed}.pt")
+ else:
+ path = os.path.join(CHECKPOINT_DIR, f"{method}_s{seed}.pt")
+ ckpt = torch.load(path, map_location=device, weights_only=False)
+ sd = ckpt if not hasattr(ckpt, "state_dict") else ckpt.state_dict()
+ if isinstance(sd, dict) and "state_dict" in sd:
+ sd = sd["state_dict"]
+ model = ResidualMLP(input_dim=3072, d_hidden=256, num_classes=10, num_blocks=4).to(device)
+ model.load_state_dict(sd)
+ return model
+
+
+# 3-seed mean shallow / frozen baseline accuracies (from
+# project_resmlp_walkback_dfa_destroys_value memory entry — these are the
+# same number for the DFA condition by design: the "deep blocks frozen at
+# random init" is informationally equivalent to "no deep blocks").
+FROZEN_BASELINE_ACC = {
+ "bp": None, # BP-frozen is 34.6%; not the right comparator for BP-trainable
+ "dfa": 0.349, # DFA-frozen / DFA-shallow 3-seed mean
+ "state_bridge": 0.349, # uses the same architecture-matched control
+ "credit_bridge": 0.349,
+ "ep": None, # EP frozen-control not run yet
+}
+
+
+def main():
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ print(f"Device: {device}")
+ eval_batches = load_eval_batches(n_batches=10, batch_size=128, device=device)
+
+ methods = ["bp", "dfa", "state_bridge", "credit_bridge", "ep"]
+ rows = []
+ reports = {}
+ for method in methods:
+ print(f"\n### {method.upper()} (seed 42)")
+ model = load_model(method, 42, device)
+ acc = evaluate(model, device)
+ report = diagnose(
+ model=model,
+ eval_batches=eval_batches,
+ headline_acc=acc,
+ frozen_baseline_acc=FROZEN_BASELINE_ACC.get(method),
+ method_name=method.upper(),
+ notes="4-block d=256 ResMLP, CIFAR-10, seed 42",
+ )
+ print(report)
+ reports[method] = report.to_dict()
+ rows.append({
+ "method": method,
+ "acc": acc,
+ "h_L": report.residual_norms[-1],
+ "g_L": report.bp_grad_norms[-1],
+ "stability": report.cross_batch_stability,
+ "frozen_acc": report.frozen_baseline_acc,
+ "verdict": report.verdict,
+ })
+
+ # Compact summary table
+ print("\n\n" + "=" * 100)
+ print("AUDIT SUMMARY (single seed 42, 4-block d=256 ResMLP, CIFAR-10)")
+ print("=" * 100)
+ header = (
+ f"{'method':<16}{'acc':>8}{'||h_L||':>14}{'||g_L||':>14}"
+ f"{'stab(L/2)':>12}{'frozen':>10} verdict"
+ )
+ print(header)
+ print("-" * 100)
+ for r in rows:
+ frozen = "n/a" if r["frozen_acc"] is None else f"{r['frozen_acc']:.4f}"
+ print(
+ f"{r['method']:<16}"
+ f"{r['acc']:>8.4f}"
+ f"{r['h_L']:>14.3e}"
+ f"{r['g_L']:>14.3e}"
+ f"{r['stability']:>12.3f}"
+ f"{frozen:>10} {r['verdict']}"
+ )
+
+ out_path = os.path.join(OUT_DIR, "audit_table_s42.json")
+ with open(out_path, "w") as f:
+ json.dump({"reports": reports, "summary": rows}, f, indent=2)
+ print(f"\nSaved {out_path}")
+
+
+if __name__ == "__main__":
+ main()