summaryrefslogtreecommitdiff
path: root/protocol
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 23:26:32 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 23:26:32 -0500
commit665c9bb4ab3a5126c6fc191eecf42be7b703eb0c (patch)
tree9b1863521c05573925de77eff57eb2a4d2dee9ff /protocol
parent8f67bdeebac543961871b9896a62cd07b7a5be26 (diff)
Add d=512 ResMLP audit table (3 seeds): cross-width validation
Same protocol applied to the 4-block d=512 ResMLP variant (vs the d=256 default). 4 methods × 3 seeds = 12 conditions: BP @ d=512: trustworthy on all 3 seeds (acc 0.60-0.61) DFA @ d=512: walked back on all 3 seeds via (a)+(b) State Bridge @ d=512: walked back on all 3 seeds via (a)+(b), with drift sub-mode on s123 (stability 0.879) Credit Bridge @ d=512: walked back on all 3 seeds via (a)+(b) Width effect: max-per-block growth is HIGHER at d=512 (6e3-7e4) than at d=256 (~1e3). Larger width amplifies the explosion. The protocol verdicts are robust to this — same binary outcome, more extreme quantitative numbers. This is the cross-width validation: the protocol's findings are not d=256-specific. The §3 audit results generalize across the width dimension.
Diffstat (limited to 'protocol')
-rw-r--r--protocol/examples/audit_d512.py129
1 files changed, 129 insertions, 0 deletions
diff --git a/protocol/examples/audit_d512.py b/protocol/examples/audit_d512.py
new file mode 100644
index 0000000..0c9fb26
--- /dev/null
+++ b/protocol/examples/audit_d512.py
@@ -0,0 +1,129 @@
+"""
+Audit table on the d=512 ResMLP variant. The main paper uses d=256;
+the d=512 set provides a width-control. If the protocol's verdicts
+generalize across width, it's a meaningful generalization claim.
+
+Run:
+ CUDA_VISIBLE_DEVICES=2 python -m protocol.examples.audit_d512
+"""
+import os
+import sys
+import json
+
+import torch
+import torchvision
+import torchvision.transforms as transforms
+from torch.utils.data import DataLoader
+
+REPO_ROOT = os.path.dirname(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+)
+sys.path.insert(0, REPO_ROOT)
+
+from models.residual_mlp import ResidualMLP # noqa: E402
+from protocol import diagnose # noqa: E402
+
+D512_CKPT_DIR = os.path.join(REPO_ROOT, "results/confirmatory/cifar_d512")
+OUT_DIR = os.path.join(REPO_ROOT, "results/protocol_audit")
+os.makedirs(OUT_DIR, exist_ok=True)
+
+
+def load_eval_batches(n_batches=10, batch_size=128, device="cuda:0"):
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
+ loader = DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=0)
+ batches = []
+ for x, y in loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batches.append((x, y))
+ if len(batches) >= n_batches:
+ break
+ return batches
+
+
+def evaluate(model, device):
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
+ loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0)
+ model.eval()
+ correct = total = 0
+ with torch.no_grad():
+ for x, y in loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ preds = model(x).argmax(-1)
+ correct += (preds == y).sum().item()
+ total += x.size(0)
+ return correct / total
+
+
+def load_d512(method, seed, device):
+ path = os.path.join(D512_CKPT_DIR, f"{method}_s{seed}.pt")
+ sd = torch.load(path, map_location=device, weights_only=False)
+ if isinstance(sd, dict) and "state_dict" in sd:
+ sd = sd["state_dict"]
+ model = ResidualMLP(3072, 512, 10, 4).to(device)
+ model.load_state_dict(sd)
+ return model
+
+
+def main():
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ eval_batches = load_eval_batches(n_batches=10, batch_size=128, device=device)
+ methods = ["bp", "dfa", "state_bridge", "credit_bridge"]
+ rows = []
+ for seed in [42, 123, 456]:
+ for method in methods:
+ try:
+ model = load_d512(method, seed, device)
+ except FileNotFoundError:
+ print(f" SKIPPED: {method}_s{seed} not found")
+ continue
+ acc = evaluate(model, device)
+ report = diagnose(
+ model=model,
+ eval_batches=eval_batches,
+ headline_acc=acc,
+ frozen_baseline_acc=None,
+ method_name=method.upper(),
+ notes=f"4-block d=512 ResMLP, CIFAR-10, seed {seed}",
+ )
+ rows.append({
+ "method": method,
+ "seed": seed,
+ "acc": acc,
+ "h_L": report.residual_norms[-1],
+ "g_L": report.bp_grad_norms[-1],
+ "stability": report.cross_batch_stability,
+ "max_per_block": report.max_per_block_growth,
+ "verdict": report.verdict,
+ })
+
+ print("=" * 110)
+ print("d=512 ResMLP audit (3 seeds)")
+ print("=" * 110)
+ print(f"{'method':<16}{'seed':>6}{'acc':>8}{'||h_L||':>14}{'||g_L||':>14}"
+ f"{'max/block':>12}{'stab':>10} verdict")
+ print("-" * 110)
+ for r in rows:
+ print(
+ f"{r['method']:<16}{r['seed']:>6}{r['acc']:>8.4f}{r['h_L']:>14.3e}"
+ f"{r['g_L']:>14.3e}{r['max_per_block']:>12.2e}{r['stability']:>10.3f} "
+ f"{r['verdict'][:60]}"
+ )
+
+ out_path = os.path.join(OUT_DIR, "audit_d512_3seed.json")
+ with open(out_path, "w") as f:
+ json.dump(rows, f, indent=2)
+ print(f"\nSaved {out_path}")
+
+
+if __name__ == "__main__":
+ main()