diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 23:26:32 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 23:26:32 -0500 |
| commit | 665c9bb4ab3a5126c6fc191eecf42be7b703eb0c (patch) | |
| tree | 9b1863521c05573925de77eff57eb2a4d2dee9ff /protocol | |
| parent | 8f67bdeebac543961871b9896a62cd07b7a5be26 (diff) | |
Add d=512 ResMLP audit table (3 seeds): cross-width validation
Same protocol applied to the 4-block d=512 ResMLP variant (vs the d=256
default). 4 methods × 3 seeds = 12 conditions:
BP @ d=512: trustworthy on all 3 seeds (acc 0.60-0.61)
DFA @ d=512: walked back on all 3 seeds via (a)+(b)
State Bridge @ d=512: walked back on all 3 seeds via (a)+(b), with
drift sub-mode on s123 (stability 0.879)
Credit Bridge @ d=512: walked back on all 3 seeds via (a)+(b)
Width effect: max-per-block growth is HIGHER at d=512 (6e3-7e4) than at
d=256 (~1e3). Larger width amplifies the explosion. The protocol
verdicts are robust to this — same binary outcome, more extreme
quantitative numbers.
This is the cross-width validation: the protocol's findings are not
d=256-specific. The §3 audit results generalize across the width
dimension.
Diffstat (limited to 'protocol')
| -rw-r--r-- | protocol/examples/audit_d512.py | 129 |
1 files changed, 129 insertions, 0 deletions
diff --git a/protocol/examples/audit_d512.py b/protocol/examples/audit_d512.py new file mode 100644 index 0000000..0c9fb26 --- /dev/null +++ b/protocol/examples/audit_d512.py @@ -0,0 +1,129 @@ +""" +Audit table on the d=512 ResMLP variant. The main paper uses d=256; +the d=512 set provides a width-control. If the protocol's verdicts +generalize across width, it's a meaningful generalization claim. + +Run: + CUDA_VISIBLE_DEVICES=2 python -m protocol.examples.audit_d512 +""" +import os +import sys +import json + +import torch +import torchvision +import torchvision.transforms as transforms +from torch.utils.data import DataLoader + +REPO_ROOT = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) +sys.path.insert(0, REPO_ROOT) + +from models.residual_mlp import ResidualMLP # noqa: E402 +from protocol import diagnose # noqa: E402 + +D512_CKPT_DIR = os.path.join(REPO_ROOT, "results/confirmatory/cifar_d512") +OUT_DIR = os.path.join(REPO_ROOT, "results/protocol_audit") +os.makedirs(OUT_DIR, exist_ok=True) + + +def load_eval_batches(n_batches=10, batch_size=128, device="cuda:0"): + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv) + loader = DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=0) + batches = [] + for x, y in loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + batches.append((x, y)) + if len(batches) >= n_batches: + break + return batches + + +def evaluate(model, device): + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv) + loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0) + model.eval() + correct = total = 0 + with torch.no_grad(): + for x, y in loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + preds = model(x).argmax(-1) + correct += (preds == y).sum().item() + total += x.size(0) + return correct / total + + +def load_d512(method, seed, device): + path = os.path.join(D512_CKPT_DIR, f"{method}_s{seed}.pt") + sd = torch.load(path, map_location=device, weights_only=False) + if isinstance(sd, dict) and "state_dict" in sd: + sd = sd["state_dict"] + model = ResidualMLP(3072, 512, 10, 4).to(device) + model.load_state_dict(sd) + return model + + +def main(): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + eval_batches = load_eval_batches(n_batches=10, batch_size=128, device=device) + methods = ["bp", "dfa", "state_bridge", "credit_bridge"] + rows = [] + for seed in [42, 123, 456]: + for method in methods: + try: + model = load_d512(method, seed, device) + except FileNotFoundError: + print(f" SKIPPED: {method}_s{seed} not found") + continue + acc = evaluate(model, device) + report = diagnose( + model=model, + eval_batches=eval_batches, + headline_acc=acc, + frozen_baseline_acc=None, + method_name=method.upper(), + notes=f"4-block d=512 ResMLP, CIFAR-10, seed {seed}", + ) + rows.append({ + "method": method, + "seed": seed, + "acc": acc, + "h_L": report.residual_norms[-1], + "g_L": report.bp_grad_norms[-1], + "stability": report.cross_batch_stability, + "max_per_block": report.max_per_block_growth, + "verdict": report.verdict, + }) + + print("=" * 110) + print("d=512 ResMLP audit (3 seeds)") + print("=" * 110) + print(f"{'method':<16}{'seed':>6}{'acc':>8}{'||h_L||':>14}{'||g_L||':>14}" + f"{'max/block':>12}{'stab':>10} verdict") + print("-" * 110) + for r in rows: + print( + f"{r['method']:<16}{r['seed']:>6}{r['acc']:>8.4f}{r['h_L']:>14.3e}" + f"{r['g_L']:>14.3e}{r['max_per_block']:>12.2e}{r['stability']:>10.3f} " + f"{r['verdict'][:60]}" + ) + + out_path = os.path.join(OUT_DIR, "audit_d512_3seed.json") + with open(out_path, "w") as f: + json.dump(rows, f, indent=2) + print(f"\nSaved {out_path}") + + +if __name__ == "__main__": + main() |
