1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
|
"""
Sanity check: apply the diagnostic protocol to a *randomly initialized*
4-block d=256 ResMLP, with no training. The protocol should report
"trustworthy" — random init is not in the failure regime, even though it
also has trivial accuracy. The protocol is supposed to flag *trained-into-
pathology* networks, not weak networks per se.
This is the answer to the likely reviewer question: "is your protocol just
flagging anything that doesn't perform well?" Answer: no. Random init
performs at chance (10%) and the protocol passes it. The 3 walked-back
trained methods are walked back because of the *measurements*, not the
accuracy.
Run:
CUDA_VISIBLE_DEVICES=2 python -m protocol.examples.random_init_sanity
"""
import os
import sys
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
REPO_ROOT = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
sys.path.insert(0, REPO_ROOT)
from models.residual_mlp import ResidualMLP # noqa: E402
from protocol import diagnose # noqa: E402
def get_eval(n_batches, batch_size, device):
tv = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
loader = DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=0)
batches = []
for x, y in loader:
x = x.view(x.size(0), -1).to(device)
y = y.to(device)
batches.append((x, y))
if len(batches) >= n_batches:
break
return batches
def chance_acc(model, device):
tv = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0)
model.eval()
correct = total = 0
with torch.no_grad():
for x, y in loader:
x = x.view(x.size(0), -1).to(device)
y = y.to(device)
preds = model(x).argmax(-1)
correct += (preds == y).sum().item()
total += x.size(0)
return correct / total
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
eval_batches = get_eval(n_batches=10, batch_size=128, device=device)
print("=" * 72)
print("RANDOM-INIT SANITY CHECK (4-block d=256 ResMLP, CIFAR-10)")
print("=" * 72)
for seed in [42, 123, 456]:
torch.manual_seed(seed)
model = ResidualMLP(3072, 256, 10, 4).to(device)
# Skip BatchNorm-style update by not doing forward pass — just use
# random init parameters directly.
acc = chance_acc(model, device)
report = diagnose(
model=model,
eval_batches=eval_batches,
headline_acc=acc,
frozen_baseline_acc=None,
method_name=f"RANDOM_INIT seed {seed}",
notes="untrained, no parameter updates",
)
print(f"\n=== seed {seed} (chance acc: {acc:.4f}) ===")
print(report)
if __name__ == "__main__":
main()
|