1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
|
"""
Reproduce the §2 audit table: apply the diagnostic protocol to BP / DFA /
State Bridge / Credit Bridge / EP checkpoints on the 4-block d=256 ResMLP /
CIFAR-10 setup. Single seed 42 for the table; the paper uses 3-seed means
elsewhere.
Output is a per-method tabular summary that lists, for each diagnostic,
the per-layer values and the verdict. This is the audit evidence behind the
paper claim *"standard FA evaluation reports headline accuracy + Γ as
evidence of training, but on modern pre-LN residual networks both signals
silently fail for non-BP methods."*
Run:
CUDA_VISIBLE_DEVICES=2 python -m protocol.examples.audit_table
"""
import os
import sys
import json
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
REPO_ROOT = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
sys.path.insert(0, REPO_ROOT)
from models.residual_mlp import ResidualMLP # noqa: E402
from protocol import diagnose # noqa: E402
CHECKPOINT_DIR = os.path.join(REPO_ROOT, "results/confirmatory/checkpoints_A2")
EP_CHECKPOINT_DIR = os.path.join(REPO_ROOT, "results/ep_baseline")
OUT_DIR = os.path.join(REPO_ROOT, "results/protocol_audit")
os.makedirs(OUT_DIR, exist_ok=True)
def load_eval_batches(n_batches=10, batch_size=128, device="cuda:0"):
tv = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
loader = DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=0)
batches = []
for x, y in loader:
x = x.view(x.size(0), -1).to(device)
y = y.to(device)
batches.append((x, y))
if len(batches) >= n_batches:
break
return batches
def evaluate(model, device):
tv = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0)
model.eval()
correct = total = 0
with torch.no_grad():
for x, y in loader:
x = x.view(x.size(0), -1).to(device)
y = y.to(device)
preds = model(x).argmax(-1)
correct += (preds == y).sum().item()
total += x.size(0)
return correct / total
def load_model(method: str, seed: int, device):
if method == "ep":
path = os.path.join(EP_CHECKPOINT_DIR, f"ep_s{seed}.pt")
else:
path = os.path.join(CHECKPOINT_DIR, f"{method}_s{seed}.pt")
ckpt = torch.load(path, map_location=device, weights_only=False)
sd = ckpt if not hasattr(ckpt, "state_dict") else ckpt.state_dict()
if isinstance(sd, dict) and "state_dict" in sd:
sd = sd["state_dict"]
model = ResidualMLP(input_dim=3072, d_hidden=256, num_classes=10, num_blocks=4).to(device)
model.load_state_dict(sd)
return model
# 3-seed mean shallow / frozen baseline accuracies (from
# project_resmlp_walkback_dfa_destroys_value memory entry — these are the
# same number for the DFA condition by design: the "deep blocks frozen at
# random init" is informationally equivalent to "no deep blocks").
FROZEN_BASELINE_ACC = {
"bp": None, # BP-frozen is 34.6%; not the right comparator for BP-trainable
"dfa": 0.349, # DFA-frozen / DFA-shallow 3-seed mean
"state_bridge": 0.349, # uses the same architecture-matched control
"credit_bridge": 0.349,
"ep": None, # EP frozen-control not run yet
}
def main():
import argparse
p = argparse.ArgumentParser()
p.add_argument("--seeds", type=int, nargs="+", default=[42])
args = p.parse_args()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
eval_batches = load_eval_batches(n_batches=10, batch_size=128, device=device)
methods = ["bp", "dfa", "state_bridge", "credit_bridge", "ep"]
rows = []
reports = {}
for seed in args.seeds:
for method in methods:
print(f"\n### {method.upper()} (seed {seed})")
try:
model = load_model(method, seed, device)
except FileNotFoundError as e:
print(f" SKIPPED: checkpoint not found ({e})")
continue
acc = evaluate(model, device)
report = diagnose(
model=model,
eval_batches=eval_batches,
headline_acc=acc,
frozen_baseline_acc=FROZEN_BASELINE_ACC.get(method),
method_name=method.upper(),
notes=f"4-block d=256 ResMLP, CIFAR-10, seed {seed}",
)
print(report)
reports[f"{method}_s{seed}"] = report.to_dict()
rows.append({
"method": method,
"seed": seed,
"acc": acc,
"h_L": report.residual_norms[-1],
"g_L": report.bp_grad_norms[-1],
"stability": report.cross_batch_stability,
"frozen_acc": report.frozen_baseline_acc,
"verdict": report.verdict,
})
# Compact summary table
print("\n\n" + "=" * 110)
print(f"AUDIT SUMMARY (seeds={args.seeds}, 4-block d=256 ResMLP, CIFAR-10)")
print("=" * 110)
header = (
f"{'method':<16}{'seed':>6}{'acc':>8}{'||h_L||':>14}{'||g_L||':>14}"
f"{'stab(L/2)':>12}{'frozen':>10} verdict"
)
print(header)
print("-" * 110)
for r in rows:
frozen = "n/a" if r["frozen_acc"] is None else f"{r['frozen_acc']:.4f}"
print(
f"{r['method']:<16}"
f"{r['seed']:>6}"
f"{r['acc']:>8.4f}"
f"{r['h_L']:>14.3e}"
f"{r['g_L']:>14.3e}"
f"{r['stability']:>12.3f}"
f"{frozen:>10} {r['verdict']}"
)
seeds_tag = "_".join(f"s{s}" for s in args.seeds)
out_path = os.path.join(OUT_DIR, f"audit_table_{seeds_tag}.json")
with open(out_path, "w") as f:
json.dump({"reports": reports, "summary": rows}, f, indent=2)
print(f"\nSaved {out_path}")
if __name__ == "__main__":
main()
|