summaryrefslogtreecommitdiff
path: root/protocol/examples/audit_d512.py
blob: 0c9fb26880fe15d2a6e10d6aa504773bbcca6cbb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
Audit table on the d=512 ResMLP variant. The main paper uses d=256;
the d=512 set provides a width-control. If the protocol's verdicts
generalize across width, it's a meaningful generalization claim.

Run:
    CUDA_VISIBLE_DEVICES=2 python -m protocol.examples.audit_d512
"""
import os
import sys
import json

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

REPO_ROOT = os.path.dirname(
    os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
sys.path.insert(0, REPO_ROOT)

from models.residual_mlp import ResidualMLP  # noqa: E402
from protocol import diagnose  # noqa: E402

D512_CKPT_DIR = os.path.join(REPO_ROOT, "results/confirmatory/cifar_d512")
OUT_DIR = os.path.join(REPO_ROOT, "results/protocol_audit")
os.makedirs(OUT_DIR, exist_ok=True)


def load_eval_batches(n_batches=10, batch_size=128, device="cuda:0"):
    tv = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])
    te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
    loader = DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=0)
    batches = []
    for x, y in loader:
        x = x.view(x.size(0), -1).to(device)
        y = y.to(device)
        batches.append((x, y))
        if len(batches) >= n_batches:
            break
    return batches


def evaluate(model, device):
    tv = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])
    te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
    loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0)
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for x, y in loader:
            x = x.view(x.size(0), -1).to(device)
            y = y.to(device)
            preds = model(x).argmax(-1)
            correct += (preds == y).sum().item()
            total += x.size(0)
    return correct / total


def load_d512(method, seed, device):
    path = os.path.join(D512_CKPT_DIR, f"{method}_s{seed}.pt")
    sd = torch.load(path, map_location=device, weights_only=False)
    if isinstance(sd, dict) and "state_dict" in sd:
        sd = sd["state_dict"]
    model = ResidualMLP(3072, 512, 10, 4).to(device)
    model.load_state_dict(sd)
    return model


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    eval_batches = load_eval_batches(n_batches=10, batch_size=128, device=device)
    methods = ["bp", "dfa", "state_bridge", "credit_bridge"]
    rows = []
    for seed in [42, 123, 456]:
        for method in methods:
            try:
                model = load_d512(method, seed, device)
            except FileNotFoundError:
                print(f"  SKIPPED: {method}_s{seed} not found")
                continue
            acc = evaluate(model, device)
            report = diagnose(
                model=model,
                eval_batches=eval_batches,
                headline_acc=acc,
                frozen_baseline_acc=None,
                method_name=method.upper(),
                notes=f"4-block d=512 ResMLP, CIFAR-10, seed {seed}",
            )
            rows.append({
                "method": method,
                "seed": seed,
                "acc": acc,
                "h_L": report.residual_norms[-1],
                "g_L": report.bp_grad_norms[-1],
                "stability": report.cross_batch_stability,
                "max_per_block": report.max_per_block_growth,
                "verdict": report.verdict,
            })

    print("=" * 110)
    print("d=512 ResMLP audit (3 seeds)")
    print("=" * 110)
    print(f"{'method':<16}{'seed':>6}{'acc':>8}{'||h_L||':>14}{'||g_L||':>14}"
          f"{'max/block':>12}{'stab':>10}  verdict")
    print("-" * 110)
    for r in rows:
        print(
            f"{r['method']:<16}{r['seed']:>6}{r['acc']:>8.4f}{r['h_L']:>14.3e}"
            f"{r['g_L']:>14.3e}{r['max_per_block']:>12.2e}{r['stability']:>10.3f}  "
            f"{r['verdict'][:60]}"
        )

    out_path = os.path.join(OUT_DIR, "audit_d512_3seed.json")
    with open(out_path, "w") as f:
        json.dump(rows, f, indent=2)
    print(f"\nSaved {out_path}")


if __name__ == "__main__":
    main()