1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
|
"""
Temporal validation of the diagnostic protocol: at what epoch during DFA
training does each diagnostic cross its degeneracy threshold?
This uses the existing snapshot evolution data in
`results/snapshot_evolution_v2/`, which logs per-epoch:
- hidden_norms (the (a) diagnostic)
- bp_grad_norms_per_sample_med (the (b) diagnostic)
- gamma_dfa (the field-standard reference number)
- acc_eval
over 100 epochs of both BP and DFA training on the standard 4-block d=256
ResMLP CIFAR-10 setup. We replay this data through the protocol's
threshold logic and report:
(i) the epoch at which each diagnostic first FIRES on DFA,
(ii) the per-epoch headline accuracy (so we can show that the diagnostic
fires BEFORE the headline acc has converged — i.e. the protocol
could have caught the pathology mid-training),
(iii) the trajectory on BP for comparison (which should never fire).
This is the temporal validation of the protocol's decision utility: the
protocol catches the pathology *as it happens*, not just retrospectively.
Run:
python -m protocol.examples.temporal_diagnostic_evolution
"""
import os
import json
import sys
REPO_ROOT = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
sys.path.insert(0, REPO_ROOT)
from protocol.report import DiagnosticThresholds # noqa: E402
THRESHOLDS = DiagnosticThresholds()
def max_per_block_growth(h):
if len(h) < 2:
return 1.0
return max(h[i + 1] / max(h[i], 1e-30) for i in range(len(h) - 1))
def diagnose_entry(entry):
h = entry["hidden_norms"]
g = entry["bp_grad_norms_per_sample_med"]
h_exploded = max_per_block_growth(h) > THRESHOLDS.h_norm_explosion_ratio
g_at_floor = g[-1] < THRESHOLDS.g_norm_floor
return h_exploded, g_at_floor
def first_fire_epoch(log, predicate):
for entry in log:
if predicate(entry):
return entry["epoch"]
return None
def main():
import argparse
p = argparse.ArgumentParser()
p.add_argument("--seed", type=int, default=42)
p.add_argument("--arch", type=str, default="resmlp", choices=["resmlp", "vit"])
args = p.parse_args()
if args.arch == "resmlp":
snapshot_path = os.path.join(
REPO_ROOT, f"results/snapshot_evolution_v2/snapshot_evolution_s{args.seed}.json"
)
h_key = "hidden_norms"
g_key = "bp_grad_norms_per_sample_med"
else:
snapshot_path = os.path.join(
REPO_ROOT, f"results/snapshot_vit_v1/snapshot_vit_s{args.seed}.json"
)
h_key = "hidden_norms_cls"
g_key = "bp_grad_per_sample_l2_med"
if not os.path.exists(snapshot_path):
print(f"snapshot not found: {snapshot_path}")
return
with open(snapshot_path) as f:
d = json.load(f)
bp_log = [{**e, "hidden_norms": e[h_key], "bp_grad_norms_per_sample_med": e[g_key]} for e in d["bp_log"]]
dfa_log = [{**e, "hidden_norms": e[h_key], "bp_grad_norms_per_sample_med": e[g_key]} for e in d["dfa_log"]]
print("=" * 88)
print("TEMPORAL DIAGNOSTIC EVOLUTION (4-block d=256 ResMLP, CIFAR-10, seed 42)")
print("=" * 88)
# ----- DFA trajectory ----- #
print("\nDFA training trajectory (each row = one logged epoch):")
print(
f" {'epoch':>6} {'acc':>8} {'gamma':>10} "
f"{'||h_L||':>14} {'||g_L||':>14} {'(a)':>5} {'(b)':>5}"
)
fired_a = False
fired_b = False
fire_a_epoch = None
fire_b_epoch = None
for entry in dfa_log:
h = entry["hidden_norms"]
g = entry["bp_grad_norms_per_sample_med"]
h_exp = max_per_block_growth(h) > THRESHOLDS.h_norm_explosion_ratio
g_floor = g[-1] < THRESHOLDS.g_norm_floor
flag_a = "FIRE" if h_exp else "ok"
flag_b = "FIRE" if g_floor else "ok"
ep = entry["epoch"]
if h_exp and not fired_a:
fired_a = True
fire_a_epoch = ep
if g_floor and not fired_b:
fired_b = True
fire_b_epoch = ep
if ep <= 5 or ep % 10 == 0 or ep == dfa_log[-1]["epoch"]:
gamma = entry.get("gamma_dfa")
gamma_s = "nan" if gamma is None or (isinstance(gamma, float) and gamma != gamma) else f"{gamma:.4f}"
print(
f" {ep:>6} {entry['acc_eval']:>8.4f} {gamma_s:>10} "
f"{h[-1]:>14.3e} {g[-1]:>14.3e} {flag_a:>5} {flag_b:>5}"
)
print()
print(f" Diagnostic (a) ‖h_l‖ explosion first fires at epoch: {fire_a_epoch}")
print(f" Diagnostic (b) ‖g_l‖ floor first fires at epoch: {fire_b_epoch}")
print(f" DFA test acc at the moment (a) fires: "
f"{next(e['acc_eval'] for e in dfa_log if e['epoch'] == fire_a_epoch):.4f}" if fire_a_epoch is not None else " (a) never fires")
print(f" DFA test acc at the moment (b) fires: "
f"{next(e['acc_eval'] for e in dfa_log if e['epoch'] == fire_b_epoch):.4f}" if fire_b_epoch is not None else " (b) never fires")
print(f" DFA final test acc: {dfa_log[-1]['acc_eval']:.4f}")
# ----- BP trajectory (sanity) ----- #
print("\nBP training trajectory (sanity):")
print(
f" {'epoch':>6} {'acc':>8} "
f"{'||h_L||':>14} {'||g_L||':>14} {'(a)':>5} {'(b)':>5}"
)
bp_fired = False
for entry in bp_log:
h = entry["hidden_norms"]
g = entry["bp_grad_norms_per_sample_med"]
h_exp = max_per_block_growth(h) > THRESHOLDS.h_norm_explosion_ratio
g_floor = g[-1] < THRESHOLDS.g_norm_floor
if h_exp or g_floor:
bp_fired = True
if entry["epoch"] <= 5 or entry["epoch"] % 10 == 0 or entry["epoch"] == bp_log[-1]["epoch"]:
print(
f" {entry['epoch']:>6} {entry['acc_eval']:>8.4f} "
f"{h[-1]:>14.3e} {g[-1]:>14.3e} "
f"{'FIRE' if h_exp else 'ok':>5} {'FIRE' if g_floor else 'ok':>5}"
)
print(f"\n BP fired any diagnostic at any epoch: {bp_fired}")
print(f" BP final test acc: {bp_log[-1]['acc_eval']:.4f}")
# ----- Save ----- #
out = {
"dfa": {
"trajectory": [
{
"epoch": e["epoch"],
"acc": e["acc_eval"],
"max_per_block_growth": max_per_block_growth(e["hidden_norms"]),
"g_L": e["bp_grad_norms_per_sample_med"][-1],
"gamma": e.get("gamma_dfa"),
}
for e in dfa_log
],
"first_fire_a_epoch": fire_a_epoch,
"first_fire_b_epoch": fire_b_epoch,
"final_acc": dfa_log[-1]["acc_eval"],
},
"bp": {
"any_fire": bp_fired,
"final_acc": bp_log[-1]["acc_eval"],
},
"thresholds": {
"g_norm_floor": THRESHOLDS.g_norm_floor,
"h_norm_explosion_ratio": THRESHOLDS.h_norm_explosion_ratio,
},
}
out_path = os.path.join(REPO_ROOT, f"results/protocol_audit/temporal_evolution_s{args.seed}.json")
with open(out_path, "w") as f:
json.dump(out, f, indent=2)
print(f"\nSaved {out_path}")
if __name__ == "__main__":
main()
|