1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
|
"""
Protocol decision-utility ablation: starting from the 5-method audit table,
evaluate what each *subset* of the protocol would catch.
For each method we ask: under each evaluation strategy, would a reviewer
have walked back the headline accuracy claim?
Strategies considered:
S0: Headline accuracy only (the conventional reporting)
S1: Headline accuracy + Γ (the field's standard FA evaluation)
S2: + diagnostic (a) per-layer ‖h_l‖ — catches scale pathology
S3: + diagnostic (b) per-layer ‖g_l‖ — catches reference at floor
S4: + diagnostic (c) cross-batch dir stability — catches drift dominance
S5: + diagnostic (d) frozen-blocks baseline — catches passive blocks
S_full: full protocol (a)+(b)+(c)+(d)
S1 corresponds to the field's status quo. S_full is what this paper proposes.
The "decision utility" of the protocol is the set of cases where S_full
flags but S1 does not.
Run:
CUDA_VISIBLE_DEVICES=2 python -m protocol.examples.ablation_decision_utility
"""
import os
import sys
import json
REPO_ROOT = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
AUDIT_PATH = os.path.join(REPO_ROOT, "results/protocol_audit/audit_table_s42.json")
def gamma_proxy(method: str, g_norm: float) -> str:
"""Return what S1 (headline acc + Γ) would conclude for a given method.
Γ is "high-ish" for DFA/SB/CB at the noise floor (0.10 / 0.005 / 0.07)
and ~1.0 for BP and ~0.008 for EP — but in all cases the value LOOKS
plausible to a reviewer who is not also looking at ‖g‖. The point of S1
is that it gives no walk-back signal."""
return "no walk-back (looks fine)"
def max_per_block_growth(h):
if len(h) < 2:
return 1.0
return max(h[i + 1] / max(h[i], 1e-30) for i in range(len(h) - 1))
def evaluate_strategy(strategy: str, method: str, report: dict, headline_acc: float) -> str:
"""Return whether the strategy would have walked back the claim, and why."""
h_exploded = (
max_per_block_growth(report["residual_norms"])
> report["thresholds"]["h_norm_explosion_ratio"]
)
g_at_floor = report["bp_grad_norms"][-1] < report["thresholds"]["g_norm_floor"]
drift = report["cross_batch_stability"] > report["thresholds"]["stability_drift_ceiling"]
frozen = report.get("frozen_baseline_acc")
undercut = (
frozen is not None
and (headline_acc - frozen) * 100 < report["thresholds"]["frozen_acc_margin_pp"]
)
flags = []
if strategy in ("S2", "S_full") and h_exploded:
flags.append("(a)scale")
if strategy in ("S3", "S_full") and g_at_floor:
flags.append("(b)floor")
if strategy in ("S4", "S_full") and drift:
flags.append("(c)drift")
if strategy in ("S5", "S_full") and undercut:
flags.append("(d)passive")
if strategy == "S0":
return "no walk-back (acc only)"
if strategy == "S1":
return gamma_proxy(method, report["bp_grad_norms"][-1])
if not flags:
return "no walk-back"
return "WALK-BACK: " + " + ".join(flags)
def main():
with open(AUDIT_PATH) as f:
data = json.load(f)
methods = ["bp", "dfa", "state_bridge", "credit_bridge", "ep"]
method_acc = {}
for row in data["summary"]:
# Use seed-42 row for the ablation; the table is single-seed by design
if row.get("seed", 42) == 42:
method_acc[row["method"]] = row["acc"]
strategies = ["S0", "S1", "S2", "S3", "S4", "S5", "S_full"]
strategy_label = {
"S0": "headline acc only",
"S1": "+ Γ (field standard)",
"S2": "+ diagnostic (a) ‖h_l‖",
"S3": "+ diagnostic (b) ‖g_l‖",
"S4": "+ diagnostic (c) stability",
"S5": "+ diagnostic (d) frozen baseline",
"S_full": "full protocol",
}
print("=" * 100)
print("Protocol decision-utility ablation (4-block d=256 ResMLP, CIFAR-10, seed 42)")
print("=" * 100)
# Audit JSON keys reports as either "bp" (legacy) or "bp_s42" (current).
table = {}
for method in methods:
if method in data["reports"]:
report = data["reports"][method]
elif f"{method}_s42" in data["reports"]:
report = data["reports"][f"{method}_s42"]
else:
print(f" SKIPPED (no report): {method}")
continue
acc = method_acc[method]
table[method] = {}
for s in strategies:
verdict = evaluate_strategy(s, method, report, acc)
table[method][s] = verdict
# Print row-by-method
for method in methods:
print(f"\n## {method.upper()} (acc {method_acc[method]:.4f})")
for s in strategies:
print(f" {s:<8} ({strategy_label[s]:<30}): {table[method][s]}")
# Decision utility = methods caught by S_full but missed by S1
print("\n" + "=" * 100)
print("DECISION UTILITY: methods walked back by S_full but NOT by S1 (status quo)")
print("=" * 100)
for method in methods:
s1 = table[method]["S1"]
sf = table[method]["S_full"]
if "WALK-BACK" in sf and "WALK-BACK" not in s1:
print(f" {method.upper():<16} S1='{s1}' -> S_full='{sf}'")
# Per-diagnostic recall: which method does each diagnostic catch alone?
print("\n" + "=" * 100)
print("PER-DIAGNOSTIC RECALL: which methods does each single diagnostic catch?")
print("=" * 100)
diag_strats = {"S2": "(a) ‖h_l‖", "S3": "(b) ‖g_l‖", "S4": "(c) stability", "S5": "(d) frozen"}
for s, name in diag_strats.items():
caught = []
for method in methods:
if "WALK-BACK" in table[method][s]:
caught.append(method)
print(f" {name:<16}: catches {caught}")
# Save
out = {
"table": table,
"strategies": strategy_label,
"summary": {
"missed_by_S1": [m for m in methods
if "WALK-BACK" in table[m]["S_full"] and "WALK-BACK" not in table[m]["S1"]],
"trustworthy_by_S_full": [m for m in methods if "WALK-BACK" not in table[m]["S_full"]],
},
}
out_path = os.path.join(REPO_ROOT, "results/protocol_audit/ablation_decision_utility.json")
with open(out_path, "w") as f:
json.dump(out, f, indent=2)
print(f"\nSaved {out_path}")
if __name__ == "__main__":
main()
|