1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
|
"""
Audit table on the d=512 ResMLP variant. The main paper uses d=256;
the d=512 set provides a width-control. If the protocol's verdicts
generalize across width, it's a meaningful generalization claim.
Run:
CUDA_VISIBLE_DEVICES=2 python -m protocol.examples.audit_d512
"""
import os
import sys
import json
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
REPO_ROOT = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
sys.path.insert(0, REPO_ROOT)
from models.residual_mlp import ResidualMLP # noqa: E402
from protocol import diagnose # noqa: E402
D512_CKPT_DIR = os.path.join(REPO_ROOT, "results/confirmatory/cifar_d512")
OUT_DIR = os.path.join(REPO_ROOT, "results/protocol_audit")
os.makedirs(OUT_DIR, exist_ok=True)
def load_eval_batches(n_batches=10, batch_size=128, device="cuda:0"):
tv = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
loader = DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=0)
batches = []
for x, y in loader:
x = x.view(x.size(0), -1).to(device)
y = y.to(device)
batches.append((x, y))
if len(batches) >= n_batches:
break
return batches
def evaluate(model, device):
tv = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0)
model.eval()
correct = total = 0
with torch.no_grad():
for x, y in loader:
x = x.view(x.size(0), -1).to(device)
y = y.to(device)
preds = model(x).argmax(-1)
correct += (preds == y).sum().item()
total += x.size(0)
return correct / total
def load_d512(method, seed, device):
path = os.path.join(D512_CKPT_DIR, f"{method}_s{seed}.pt")
sd = torch.load(path, map_location=device, weights_only=False)
if isinstance(sd, dict) and "state_dict" in sd:
sd = sd["state_dict"]
model = ResidualMLP(3072, 512, 10, 4).to(device)
model.load_state_dict(sd)
return model
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
eval_batches = load_eval_batches(n_batches=10, batch_size=128, device=device)
methods = ["bp", "dfa", "state_bridge", "credit_bridge"]
rows = []
for seed in [42, 123, 456]:
for method in methods:
try:
model = load_d512(method, seed, device)
except FileNotFoundError:
print(f" SKIPPED: {method}_s{seed} not found")
continue
acc = evaluate(model, device)
report = diagnose(
model=model,
eval_batches=eval_batches,
headline_acc=acc,
frozen_baseline_acc=None,
method_name=method.upper(),
notes=f"4-block d=512 ResMLP, CIFAR-10, seed {seed}",
)
rows.append({
"method": method,
"seed": seed,
"acc": acc,
"h_L": report.residual_norms[-1],
"g_L": report.bp_grad_norms[-1],
"stability": report.cross_batch_stability,
"max_per_block": report.max_per_block_growth,
"verdict": report.verdict,
})
print("=" * 110)
print("d=512 ResMLP audit (3 seeds)")
print("=" * 110)
print(f"{'method':<16}{'seed':>6}{'acc':>8}{'||h_L||':>14}{'||g_L||':>14}"
f"{'max/block':>12}{'stab':>10} verdict")
print("-" * 110)
for r in rows:
print(
f"{r['method']:<16}{r['seed']:>6}{r['acc']:>8.4f}{r['h_L']:>14.3e}"
f"{r['g_L']:>14.3e}{r['max_per_block']:>12.2e}{r['stability']:>10.3f} "
f"{r['verdict'][:60]}"
)
out_path = os.path.join(OUT_DIR, "audit_d512_3seed.json")
with open(out_path, "w") as f:
json.dump(rows, f, indent=2)
print(f"\nSaved {out_path}")
if __name__ == "__main__":
main()
|