""" Audit table on the d=512 ResMLP variant. The main paper uses d=256; the d=512 set provides a width-control. If the protocol's verdicts generalize across width, it's a meaningful generalization claim. Run: CUDA_VISIBLE_DEVICES=2 python -m protocol.examples.audit_d512 """ import os import sys import json import torch import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader REPO_ROOT = os.path.dirname( os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ) sys.path.insert(0, REPO_ROOT) from models.residual_mlp import ResidualMLP # noqa: E402 from protocol import diagnose # noqa: E402 D512_CKPT_DIR = os.path.join(REPO_ROOT, "results/confirmatory/cifar_d512") OUT_DIR = os.path.join(REPO_ROOT, "results/protocol_audit") os.makedirs(OUT_DIR, exist_ok=True) def load_eval_batches(n_batches=10, batch_size=128, device="cuda:0"): tv = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv) loader = DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=0) batches = [] for x, y in loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) batches.append((x, y)) if len(batches) >= n_batches: break return batches def evaluate(model, device): tv = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv) loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0) model.eval() correct = total = 0 with torch.no_grad(): for x, y in loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) preds = model(x).argmax(-1) correct += (preds == y).sum().item() total += x.size(0) return correct / total def load_d512(method, seed, device): path = os.path.join(D512_CKPT_DIR, f"{method}_s{seed}.pt") sd = torch.load(path, map_location=device, weights_only=False) if isinstance(sd, dict) and "state_dict" in sd: sd = sd["state_dict"] model = ResidualMLP(3072, 512, 10, 4).to(device) model.load_state_dict(sd) return model def main(): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") eval_batches = load_eval_batches(n_batches=10, batch_size=128, device=device) methods = ["bp", "dfa", "state_bridge", "credit_bridge"] rows = [] for seed in [42, 123, 456]: for method in methods: try: model = load_d512(method, seed, device) except FileNotFoundError: print(f" SKIPPED: {method}_s{seed} not found") continue acc = evaluate(model, device) report = diagnose( model=model, eval_batches=eval_batches, headline_acc=acc, frozen_baseline_acc=None, method_name=method.upper(), notes=f"4-block d=512 ResMLP, CIFAR-10, seed {seed}", ) rows.append({ "method": method, "seed": seed, "acc": acc, "h_L": report.residual_norms[-1], "g_L": report.bp_grad_norms[-1], "stability": report.cross_batch_stability, "max_per_block": report.max_per_block_growth, "verdict": report.verdict, }) print("=" * 110) print("d=512 ResMLP audit (3 seeds)") print("=" * 110) print(f"{'method':<16}{'seed':>6}{'acc':>8}{'||h_L||':>14}{'||g_L||':>14}" f"{'max/block':>12}{'stab':>10} verdict") print("-" * 110) for r in rows: print( f"{r['method']:<16}{r['seed']:>6}{r['acc']:>8.4f}{r['h_L']:>14.3e}" f"{r['g_L']:>14.3e}{r['max_per_block']:>12.2e}{r['stability']:>10.3f} " f"{r['verdict'][:60]}" ) out_path = os.path.join(OUT_DIR, "audit_d512_3seed.json") with open(out_path, "w") as f: json.dump(rows, f, indent=2) print(f"\nSaved {out_path}") if __name__ == "__main__": main()