""" Temporal validation of the diagnostic protocol: at what epoch during DFA training does each diagnostic cross its degeneracy threshold? This uses the existing snapshot evolution data in `results/snapshot_evolution_v2/`, which logs per-epoch: - hidden_norms (the (a) diagnostic) - bp_grad_norms_per_sample_med (the (b) diagnostic) - gamma_dfa (the field-standard reference number) - acc_eval over 100 epochs of both BP and DFA training on the standard 4-block d=256 ResMLP CIFAR-10 setup. We replay this data through the protocol's threshold logic and report: (i) the epoch at which each diagnostic first FIRES on DFA, (ii) the per-epoch headline accuracy (so we can show that the diagnostic fires BEFORE the headline acc has converged — i.e. the protocol could have caught the pathology mid-training), (iii) the trajectory on BP for comparison (which should never fire). This is the temporal validation of the protocol's decision utility: the protocol catches the pathology *as it happens*, not just retrospectively. Run: python -m protocol.examples.temporal_diagnostic_evolution """ import os import json import sys REPO_ROOT = os.path.dirname( os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ) sys.path.insert(0, REPO_ROOT) from protocol.report import DiagnosticThresholds # noqa: E402 THRESHOLDS = DiagnosticThresholds() def diagnose_entry(entry): h = entry["hidden_norms"] g = entry["bp_grad_norms_per_sample_med"] h_exploded = (max(h) / max(h[0], 1e-30)) > THRESHOLDS.h_norm_explosion_ratio g_at_floor = g[-1] < THRESHOLDS.g_norm_floor return h_exploded, g_at_floor def first_fire_epoch(log, predicate): for entry in log: if predicate(entry): return entry["epoch"] return None def main(): import argparse p = argparse.ArgumentParser() p.add_argument("--seed", type=int, default=42) args = p.parse_args() snapshot_path = os.path.join( REPO_ROOT, f"results/snapshot_evolution_v2/snapshot_evolution_s{args.seed}.json" ) if not os.path.exists(snapshot_path): print(f"snapshot not found: {snapshot_path}") return with open(snapshot_path) as f: d = json.load(f) bp_log = d["bp_log"] dfa_log = d["dfa_log"] print("=" * 88) print("TEMPORAL DIAGNOSTIC EVOLUTION (4-block d=256 ResMLP, CIFAR-10, seed 42)") print("=" * 88) # ----- DFA trajectory ----- # print("\nDFA training trajectory (each row = one logged epoch):") print( f" {'epoch':>6} {'acc':>8} {'gamma':>10} " f"{'||h_L||':>14} {'||g_L||':>14} {'(a)':>5} {'(b)':>5}" ) fired_a = False fired_b = False fire_a_epoch = None fire_b_epoch = None for entry in dfa_log: h = entry["hidden_norms"] g = entry["bp_grad_norms_per_sample_med"] h_exp = (max(h) / max(h[0], 1e-30)) > THRESHOLDS.h_norm_explosion_ratio g_floor = g[-1] < THRESHOLDS.g_norm_floor flag_a = "FIRE" if h_exp else "ok" flag_b = "FIRE" if g_floor else "ok" ep = entry["epoch"] if h_exp and not fired_a: fired_a = True fire_a_epoch = ep if g_floor and not fired_b: fired_b = True fire_b_epoch = ep if ep <= 5 or ep % 10 == 0 or ep == dfa_log[-1]["epoch"]: gamma = entry.get("gamma_dfa") gamma_s = "nan" if gamma is None or (isinstance(gamma, float) and gamma != gamma) else f"{gamma:.4f}" print( f" {ep:>6} {entry['acc_eval']:>8.4f} {gamma_s:>10} " f"{h[-1]:>14.3e} {g[-1]:>14.3e} {flag_a:>5} {flag_b:>5}" ) print() print(f" Diagnostic (a) ‖h_l‖ explosion first fires at epoch: {fire_a_epoch}") print(f" Diagnostic (b) ‖g_l‖ floor first fires at epoch: {fire_b_epoch}") print(f" DFA test acc at the moment (a) fires: " f"{next(e['acc_eval'] for e in dfa_log if e['epoch'] == fire_a_epoch):.4f}" if fire_a_epoch is not None else " (a) never fires") print(f" DFA test acc at the moment (b) fires: " f"{next(e['acc_eval'] for e in dfa_log if e['epoch'] == fire_b_epoch):.4f}" if fire_b_epoch is not None else " (b) never fires") print(f" DFA final test acc: {dfa_log[-1]['acc_eval']:.4f}") # ----- BP trajectory (sanity) ----- # print("\nBP training trajectory (sanity):") print( f" {'epoch':>6} {'acc':>8} " f"{'||h_L||':>14} {'||g_L||':>14} {'(a)':>5} {'(b)':>5}" ) bp_fired = False for entry in bp_log: h = entry["hidden_norms"] g = entry["bp_grad_norms_per_sample_med"] h_exp = (max(h) / max(h[0], 1e-30)) > THRESHOLDS.h_norm_explosion_ratio g_floor = g[-1] < THRESHOLDS.g_norm_floor if h_exp or g_floor: bp_fired = True if entry["epoch"] <= 5 or entry["epoch"] % 10 == 0 or entry["epoch"] == bp_log[-1]["epoch"]: print( f" {entry['epoch']:>6} {entry['acc_eval']:>8.4f} " f"{h[-1]:>14.3e} {g[-1]:>14.3e} " f"{'FIRE' if h_exp else 'ok':>5} {'FIRE' if g_floor else 'ok':>5}" ) print(f"\n BP fired any diagnostic at any epoch: {bp_fired}") print(f" BP final test acc: {bp_log[-1]['acc_eval']:.4f}") # ----- Save ----- # out = { "dfa": { "trajectory": [ { "epoch": e["epoch"], "acc": e["acc_eval"], "h_max_to_h0_ratio": (max(e["hidden_norms"]) / max(e["hidden_norms"][0], 1e-30)), "g_L": e["bp_grad_norms_per_sample_med"][-1], "gamma": e.get("gamma_dfa"), } for e in dfa_log ], "first_fire_a_epoch": fire_a_epoch, "first_fire_b_epoch": fire_b_epoch, "final_acc": dfa_log[-1]["acc_eval"], }, "bp": { "any_fire": bp_fired, "final_acc": bp_log[-1]["acc_eval"], }, "thresholds": { "g_norm_floor": THRESHOLDS.g_norm_floor, "h_norm_explosion_ratio": THRESHOLDS.h_norm_explosion_ratio, }, } out_path = os.path.join(REPO_ROOT, f"results/protocol_audit/temporal_evolution_s{args.seed}.json") with open(out_path, "w") as f: json.dump(out, f, indent=2) print(f"\nSaved {out_path}") if __name__ == "__main__": main()