summaryrefslogtreecommitdiff
path: root/protocol/examples/random_init_sanity.py
blob: c4981fdc78bd4f1cd63d7ccc9f3d765441511370 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""
Sanity check: apply the diagnostic protocol to a *randomly initialized*
4-block d=256 ResMLP, with no training. The protocol should report
"trustworthy" — random init is not in the failure regime, even though it
also has trivial accuracy. The protocol is supposed to flag *trained-into-
pathology* networks, not weak networks per se.

This is the answer to the likely reviewer question: "is your protocol just
flagging anything that doesn't perform well?" Answer: no. Random init
performs at chance (10%) and the protocol passes it. The 3 walked-back
trained methods are walked back because of the *measurements*, not the
accuracy.

Run:
    CUDA_VISIBLE_DEVICES=2 python -m protocol.examples.random_init_sanity
"""
import os
import sys

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

REPO_ROOT = os.path.dirname(
    os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
sys.path.insert(0, REPO_ROOT)

from models.residual_mlp import ResidualMLP  # noqa: E402
from protocol import diagnose  # noqa: E402


def get_eval(n_batches, batch_size, device):
    tv = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])
    te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
    loader = DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=0)
    batches = []
    for x, y in loader:
        x = x.view(x.size(0), -1).to(device)
        y = y.to(device)
        batches.append((x, y))
        if len(batches) >= n_batches:
            break
    return batches


def chance_acc(model, device):
    tv = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])
    te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
    loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0)
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for x, y in loader:
            x = x.view(x.size(0), -1).to(device)
            y = y.to(device)
            preds = model(x).argmax(-1)
            correct += (preds == y).sum().item()
            total += x.size(0)
    return correct / total


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    eval_batches = get_eval(n_batches=10, batch_size=128, device=device)

    print("=" * 72)
    print("RANDOM-INIT SANITY CHECK (4-block d=256 ResMLP, CIFAR-10)")
    print("=" * 72)
    for seed in [42, 123, 456]:
        torch.manual_seed(seed)
        model = ResidualMLP(3072, 256, 10, 4).to(device)
        # Skip BatchNorm-style update by not doing forward pass — just use
        # random init parameters directly.
        acc = chance_acc(model, device)
        report = diagnose(
            model=model,
            eval_batches=eval_batches,
            headline_acc=acc,
            frozen_baseline_acc=None,
            method_name=f"RANDOM_INIT seed {seed}",
            notes="untrained, no parameter updates",
        )
        print(f"\n=== seed {seed} (chance acc: {acc:.4f}) ===")
        print(report)


if __name__ == "__main__":
    main()