1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
|
"""
Smoke test for the FA Diagnostic Protocol reference implementation.
Loads BP-trained and DFA-trained ResMLP checkpoints from
`results/confirmatory/checkpoints_A2/` and applies the protocol to each.
The protocol should:
- return verdict="trustworthy" on the BP checkpoint (residual norms bounded,
BP grad ~5e-5 well above floor, low cross-batch direction stability),
- flag the DFA checkpoint as "needs walk-back" (residual stream exploded
by ~10^7×, BP grad at ~5e-10, drift-dominated direction).
Run with:
CUDA_VISIBLE_DEVICES=2 python -m protocol.smoke_test
"""
import os
import sys
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# Make `protocol` and `models` importable when invoked from repo root
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.residual_mlp import ResidualMLP # noqa: E402
from protocol import diagnose # noqa: E402
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
CHECKPOINT_DIR = os.path.join(REPO_ROOT, "results/confirmatory/checkpoints_A2")
EP_CHECKPOINT_DIR = os.path.join(REPO_ROOT, "results/ep_baseline")
def load_eval_batches(n_batches: int = 4, batch_size: int = 1024, device="cuda:0"):
tv = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
loader = DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=0)
batches = []
for x, y in loader:
x = x.view(x.size(0), -1).to(device)
y = y.to(device)
batches.append((x, y))
if len(batches) >= n_batches:
break
return batches
def evaluate(model, loader, device):
model.eval()
correct = total = 0
with torch.no_grad():
for x, y in loader:
x = x.view(x.size(0), -1).to(device)
y = y.to(device)
preds = model(x).argmax(-1)
correct += (preds == y).sum().item()
total += x.size(0)
return correct / total
def load_model(method: str, seed: int, device):
if method == "ep":
path = os.path.join(EP_CHECKPOINT_DIR, f"ep_s{seed}.pt")
else:
path = os.path.join(CHECKPOINT_DIR, f"{method}_s{seed}.pt")
ckpt = torch.load(path, map_location=device, weights_only=False)
# Try several common checkpoint formats
if isinstance(ckpt, dict) and "state_dict" in ckpt:
sd = ckpt["state_dict"]
elif isinstance(ckpt, dict) and "model" in ckpt:
sd = ckpt["model"]
elif isinstance(ckpt, dict):
sd = ckpt
else:
sd = ckpt.state_dict()
model = ResidualMLP(input_dim=3072, d_hidden=256, num_classes=10, num_blocks=4).to(device)
model.load_state_dict(sd)
return model
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# Cross-batch stability is sensitive to batch size: with smaller batches
# the per-sample noise component dominates the batch-mean direction on
# healthy networks, giving the cleanest separation from the drift-
# dominated failure mode (~0.10 healthy vs ~0.95 drift at N=128).
eval_batches = load_eval_batches(n_batches=10, batch_size=128, device=device)
# Build a single test loader for headline acc
tv = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
test_loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0)
# BP: trustworthy. DFA: walked back. EP: trustworthy (the internal control
# — same architecture/metric/dataset, but EP doesn't blow up the residual
# stream, so the diagnostic protocol passes for EP even though EP's
# accuracy is also low. This is the paper's central comparison.)
for method, expected in [
("bp", "trustworthy"),
("dfa", "needs walk-back"),
("ep", "trustworthy"),
]:
print()
print(f"### {method.upper()} (seed 42)")
model = load_model(method, 42, device)
acc = evaluate(model, test_loader, device)
report = diagnose(
model=model,
eval_batches=eval_batches,
headline_acc=acc,
frozen_baseline_acc=None,
method_name=method.upper(),
notes="4-block d=256 ResMLP, CIFAR-10",
)
print(report)
# Sanity check
verdict = report.verdict
if expected == "trustworthy" and verdict != "trustworthy":
print(f"!! UNEXPECTED: BP should be trustworthy, got '{verdict}'")
elif expected == "needs walk-back" and not verdict.startswith("needs walk-back"):
print(f"!! UNEXPECTED: DFA should need walk-back, got '{verdict}'")
else:
print(f"OK: verdict matches expectation ('{expected}')")
if __name__ == "__main__":
main()
|