diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 22:49:53 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 22:49:53 -0500 |
| commit | a89ef4dee2750dd7bddbe1fd0a1b94d1f74d6f9c (patch) | |
| tree | 0299175a1b943655db74f8343fad9e462f689820 /protocol | |
| parent | c2e145e162444b31ac5c66a90daa6bc0a1cda591 (diff) | |
Add temporal diagnostic evolution: protocol fires at epoch 4 of DFA
Replays per-epoch logged data from results/snapshot_evolution_v2/ through
the protocol thresholds.
Result: diagnostics (a) ||h_l|| explosion AND (b) ||g_L|| at floor BOTH
first fire at epoch 4 of DFA training. At that point, DFA test acc is
0.308 — its final value at epoch 100 is also 0.308. The protocol could
have walked back the headline 96 epochs before training finished.
DFA's gamma hovers at 0.087-0.107 for all 100 epochs. A reviewer looking
at acc+gamma would conclude 'DFA is hovering at 31% acc with ~0.10
alignment, both reasonable'. Wrong on both counts.
BP never fires any diagnostic at any epoch. Stays bounded at ||h_L||~200,
||g_L||~3-5e-5, accuracy climbs to 0.61.
This is the temporal validation of decision utility: the protocol catches
the pathology AS IT HAPPENS, not just retrospectively.
Diffstat (limited to 'protocol')
| -rw-r--r-- | protocol/examples/temporal_diagnostic_evolution.py | 167 |
1 files changed, 167 insertions, 0 deletions
diff --git a/protocol/examples/temporal_diagnostic_evolution.py b/protocol/examples/temporal_diagnostic_evolution.py new file mode 100644 index 0000000..e349cec --- /dev/null +++ b/protocol/examples/temporal_diagnostic_evolution.py @@ -0,0 +1,167 @@ +""" +Temporal validation of the diagnostic protocol: at what epoch during DFA +training does each diagnostic cross its degeneracy threshold? + +This uses the existing snapshot evolution data in +`results/snapshot_evolution_v2/`, which logs per-epoch: + - hidden_norms (the (a) diagnostic) + - bp_grad_norms_per_sample_med (the (b) diagnostic) + - gamma_dfa (the field-standard reference number) + - acc_eval + +over 100 epochs of both BP and DFA training on the standard 4-block d=256 +ResMLP CIFAR-10 setup. We replay this data through the protocol's +threshold logic and report: + + (i) the epoch at which each diagnostic first FIRES on DFA, + (ii) the per-epoch headline accuracy (so we can show that the diagnostic + fires BEFORE the headline acc has converged — i.e. the protocol + could have caught the pathology mid-training), + (iii) the trajectory on BP for comparison (which should never fire). + +This is the temporal validation of the protocol's decision utility: the +protocol catches the pathology *as it happens*, not just retrospectively. + +Run: + python -m protocol.examples.temporal_diagnostic_evolution +""" +import os +import json +import sys + +REPO_ROOT = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) +sys.path.insert(0, REPO_ROOT) + +from protocol.report import DiagnosticThresholds # noqa: E402 + +THRESHOLDS = DiagnosticThresholds() + + +def diagnose_entry(entry): + h = entry["hidden_norms"] + g = entry["bp_grad_norms_per_sample_med"] + h_exploded = (max(h) / max(h[0], 1e-30)) > THRESHOLDS.h_norm_explosion_ratio + g_at_floor = g[-1] < THRESHOLDS.g_norm_floor + return h_exploded, g_at_floor + + +def first_fire_epoch(log, predicate): + for entry in log: + if predicate(entry): + return entry["epoch"] + return None + + +def main(): + snapshot_path = os.path.join( + REPO_ROOT, "results/snapshot_evolution_v2/snapshot_evolution_s42.json" + ) + with open(snapshot_path) as f: + d = json.load(f) + bp_log = d["bp_log"] + dfa_log = d["dfa_log"] + + print("=" * 88) + print("TEMPORAL DIAGNOSTIC EVOLUTION (4-block d=256 ResMLP, CIFAR-10, seed 42)") + print("=" * 88) + + # ----- DFA trajectory ----- # + print("\nDFA training trajectory (each row = one logged epoch):") + print( + f" {'epoch':>6} {'acc':>8} {'gamma':>10} " + f"{'||h_L||':>14} {'||g_L||':>14} {'(a)':>5} {'(b)':>5}" + ) + fired_a = False + fired_b = False + fire_a_epoch = None + fire_b_epoch = None + for entry in dfa_log: + h = entry["hidden_norms"] + g = entry["bp_grad_norms_per_sample_med"] + h_exp = (max(h) / max(h[0], 1e-30)) > THRESHOLDS.h_norm_explosion_ratio + g_floor = g[-1] < THRESHOLDS.g_norm_floor + flag_a = "FIRE" if h_exp else "ok" + flag_b = "FIRE" if g_floor else "ok" + ep = entry["epoch"] + if h_exp and not fired_a: + fired_a = True + fire_a_epoch = ep + if g_floor and not fired_b: + fired_b = True + fire_b_epoch = ep + if ep <= 5 or ep % 10 == 0 or ep == dfa_log[-1]["epoch"]: + gamma = entry.get("gamma_dfa") + gamma_s = "nan" if gamma is None or (isinstance(gamma, float) and gamma != gamma) else f"{gamma:.4f}" + print( + f" {ep:>6} {entry['acc_eval']:>8.4f} {gamma_s:>10} " + f"{h[-1]:>14.3e} {g[-1]:>14.3e} {flag_a:>5} {flag_b:>5}" + ) + + print() + print(f" Diagnostic (a) ‖h_l‖ explosion first fires at epoch: {fire_a_epoch}") + print(f" Diagnostic (b) ‖g_l‖ floor first fires at epoch: {fire_b_epoch}") + print(f" DFA test acc at the moment (a) fires: " + f"{next(e['acc_eval'] for e in dfa_log if e['epoch'] == fire_a_epoch):.4f}" if fire_a_epoch is not None else " (a) never fires") + print(f" DFA test acc at the moment (b) fires: " + f"{next(e['acc_eval'] for e in dfa_log if e['epoch'] == fire_b_epoch):.4f}" if fire_b_epoch is not None else " (b) never fires") + print(f" DFA final test acc: {dfa_log[-1]['acc_eval']:.4f}") + + # ----- BP trajectory (sanity) ----- # + print("\nBP training trajectory (sanity):") + print( + f" {'epoch':>6} {'acc':>8} " + f"{'||h_L||':>14} {'||g_L||':>14} {'(a)':>5} {'(b)':>5}" + ) + bp_fired = False + for entry in bp_log: + h = entry["hidden_norms"] + g = entry["bp_grad_norms_per_sample_med"] + h_exp = (max(h) / max(h[0], 1e-30)) > THRESHOLDS.h_norm_explosion_ratio + g_floor = g[-1] < THRESHOLDS.g_norm_floor + if h_exp or g_floor: + bp_fired = True + if entry["epoch"] <= 5 or entry["epoch"] % 10 == 0 or entry["epoch"] == bp_log[-1]["epoch"]: + print( + f" {entry['epoch']:>6} {entry['acc_eval']:>8.4f} " + f"{h[-1]:>14.3e} {g[-1]:>14.3e} " + f"{'FIRE' if h_exp else 'ok':>5} {'FIRE' if g_floor else 'ok':>5}" + ) + print(f"\n BP fired any diagnostic at any epoch: {bp_fired}") + print(f" BP final test acc: {bp_log[-1]['acc_eval']:.4f}") + + # ----- Save ----- # + out = { + "dfa": { + "trajectory": [ + { + "epoch": e["epoch"], + "acc": e["acc_eval"], + "h_max_to_h0_ratio": (max(e["hidden_norms"]) / max(e["hidden_norms"][0], 1e-30)), + "g_L": e["bp_grad_norms_per_sample_med"][-1], + "gamma": e.get("gamma_dfa"), + } + for e in dfa_log + ], + "first_fire_a_epoch": fire_a_epoch, + "first_fire_b_epoch": fire_b_epoch, + "final_acc": dfa_log[-1]["acc_eval"], + }, + "bp": { + "any_fire": bp_fired, + "final_acc": bp_log[-1]["acc_eval"], + }, + "thresholds": { + "g_norm_floor": THRESHOLDS.g_norm_floor, + "h_norm_explosion_ratio": THRESHOLDS.h_norm_explosion_ratio, + }, + } + out_path = os.path.join(REPO_ROOT, "results/protocol_audit/temporal_evolution_s42.json") + with open(out_path, "w") as f: + json.dump(out, f, indent=2) + print(f"\nSaved {out_path}") + + +if __name__ == "__main__": + main() |
