summaryrefslogtreecommitdiff
path: root/protocol
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 22:49:53 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 22:49:53 -0500
commita89ef4dee2750dd7bddbe1fd0a1b94d1f74d6f9c (patch)
tree0299175a1b943655db74f8343fad9e462f689820 /protocol
parentc2e145e162444b31ac5c66a90daa6bc0a1cda591 (diff)
Add temporal diagnostic evolution: protocol fires at epoch 4 of DFA
Replays per-epoch logged data from results/snapshot_evolution_v2/ through the protocol thresholds. Result: diagnostics (a) ||h_l|| explosion AND (b) ||g_L|| at floor BOTH first fire at epoch 4 of DFA training. At that point, DFA test acc is 0.308 — its final value at epoch 100 is also 0.308. The protocol could have walked back the headline 96 epochs before training finished. DFA's gamma hovers at 0.087-0.107 for all 100 epochs. A reviewer looking at acc+gamma would conclude 'DFA is hovering at 31% acc with ~0.10 alignment, both reasonable'. Wrong on both counts. BP never fires any diagnostic at any epoch. Stays bounded at ||h_L||~200, ||g_L||~3-5e-5, accuracy climbs to 0.61. This is the temporal validation of decision utility: the protocol catches the pathology AS IT HAPPENS, not just retrospectively.
Diffstat (limited to 'protocol')
-rw-r--r--protocol/examples/temporal_diagnostic_evolution.py167
1 files changed, 167 insertions, 0 deletions
diff --git a/protocol/examples/temporal_diagnostic_evolution.py b/protocol/examples/temporal_diagnostic_evolution.py
new file mode 100644
index 0000000..e349cec
--- /dev/null
+++ b/protocol/examples/temporal_diagnostic_evolution.py
@@ -0,0 +1,167 @@
+"""
+Temporal validation of the diagnostic protocol: at what epoch during DFA
+training does each diagnostic cross its degeneracy threshold?
+
+This uses the existing snapshot evolution data in
+`results/snapshot_evolution_v2/`, which logs per-epoch:
+ - hidden_norms (the (a) diagnostic)
+ - bp_grad_norms_per_sample_med (the (b) diagnostic)
+ - gamma_dfa (the field-standard reference number)
+ - acc_eval
+
+over 100 epochs of both BP and DFA training on the standard 4-block d=256
+ResMLP CIFAR-10 setup. We replay this data through the protocol's
+threshold logic and report:
+
+ (i) the epoch at which each diagnostic first FIRES on DFA,
+ (ii) the per-epoch headline accuracy (so we can show that the diagnostic
+ fires BEFORE the headline acc has converged — i.e. the protocol
+ could have caught the pathology mid-training),
+ (iii) the trajectory on BP for comparison (which should never fire).
+
+This is the temporal validation of the protocol's decision utility: the
+protocol catches the pathology *as it happens*, not just retrospectively.
+
+Run:
+ python -m protocol.examples.temporal_diagnostic_evolution
+"""
+import os
+import json
+import sys
+
+REPO_ROOT = os.path.dirname(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+)
+sys.path.insert(0, REPO_ROOT)
+
+from protocol.report import DiagnosticThresholds # noqa: E402
+
+THRESHOLDS = DiagnosticThresholds()
+
+
+def diagnose_entry(entry):
+ h = entry["hidden_norms"]
+ g = entry["bp_grad_norms_per_sample_med"]
+ h_exploded = (max(h) / max(h[0], 1e-30)) > THRESHOLDS.h_norm_explosion_ratio
+ g_at_floor = g[-1] < THRESHOLDS.g_norm_floor
+ return h_exploded, g_at_floor
+
+
+def first_fire_epoch(log, predicate):
+ for entry in log:
+ if predicate(entry):
+ return entry["epoch"]
+ return None
+
+
+def main():
+ snapshot_path = os.path.join(
+ REPO_ROOT, "results/snapshot_evolution_v2/snapshot_evolution_s42.json"
+ )
+ with open(snapshot_path) as f:
+ d = json.load(f)
+ bp_log = d["bp_log"]
+ dfa_log = d["dfa_log"]
+
+ print("=" * 88)
+ print("TEMPORAL DIAGNOSTIC EVOLUTION (4-block d=256 ResMLP, CIFAR-10, seed 42)")
+ print("=" * 88)
+
+ # ----- DFA trajectory ----- #
+ print("\nDFA training trajectory (each row = one logged epoch):")
+ print(
+ f" {'epoch':>6} {'acc':>8} {'gamma':>10} "
+ f"{'||h_L||':>14} {'||g_L||':>14} {'(a)':>5} {'(b)':>5}"
+ )
+ fired_a = False
+ fired_b = False
+ fire_a_epoch = None
+ fire_b_epoch = None
+ for entry in dfa_log:
+ h = entry["hidden_norms"]
+ g = entry["bp_grad_norms_per_sample_med"]
+ h_exp = (max(h) / max(h[0], 1e-30)) > THRESHOLDS.h_norm_explosion_ratio
+ g_floor = g[-1] < THRESHOLDS.g_norm_floor
+ flag_a = "FIRE" if h_exp else "ok"
+ flag_b = "FIRE" if g_floor else "ok"
+ ep = entry["epoch"]
+ if h_exp and not fired_a:
+ fired_a = True
+ fire_a_epoch = ep
+ if g_floor and not fired_b:
+ fired_b = True
+ fire_b_epoch = ep
+ if ep <= 5 or ep % 10 == 0 or ep == dfa_log[-1]["epoch"]:
+ gamma = entry.get("gamma_dfa")
+ gamma_s = "nan" if gamma is None or (isinstance(gamma, float) and gamma != gamma) else f"{gamma:.4f}"
+ print(
+ f" {ep:>6} {entry['acc_eval']:>8.4f} {gamma_s:>10} "
+ f"{h[-1]:>14.3e} {g[-1]:>14.3e} {flag_a:>5} {flag_b:>5}"
+ )
+
+ print()
+ print(f" Diagnostic (a) ‖h_l‖ explosion first fires at epoch: {fire_a_epoch}")
+ print(f" Diagnostic (b) ‖g_l‖ floor first fires at epoch: {fire_b_epoch}")
+ print(f" DFA test acc at the moment (a) fires: "
+ f"{next(e['acc_eval'] for e in dfa_log if e['epoch'] == fire_a_epoch):.4f}" if fire_a_epoch is not None else " (a) never fires")
+ print(f" DFA test acc at the moment (b) fires: "
+ f"{next(e['acc_eval'] for e in dfa_log if e['epoch'] == fire_b_epoch):.4f}" if fire_b_epoch is not None else " (b) never fires")
+ print(f" DFA final test acc: {dfa_log[-1]['acc_eval']:.4f}")
+
+ # ----- BP trajectory (sanity) ----- #
+ print("\nBP training trajectory (sanity):")
+ print(
+ f" {'epoch':>6} {'acc':>8} "
+ f"{'||h_L||':>14} {'||g_L||':>14} {'(a)':>5} {'(b)':>5}"
+ )
+ bp_fired = False
+ for entry in bp_log:
+ h = entry["hidden_norms"]
+ g = entry["bp_grad_norms_per_sample_med"]
+ h_exp = (max(h) / max(h[0], 1e-30)) > THRESHOLDS.h_norm_explosion_ratio
+ g_floor = g[-1] < THRESHOLDS.g_norm_floor
+ if h_exp or g_floor:
+ bp_fired = True
+ if entry["epoch"] <= 5 or entry["epoch"] % 10 == 0 or entry["epoch"] == bp_log[-1]["epoch"]:
+ print(
+ f" {entry['epoch']:>6} {entry['acc_eval']:>8.4f} "
+ f"{h[-1]:>14.3e} {g[-1]:>14.3e} "
+ f"{'FIRE' if h_exp else 'ok':>5} {'FIRE' if g_floor else 'ok':>5}"
+ )
+ print(f"\n BP fired any diagnostic at any epoch: {bp_fired}")
+ print(f" BP final test acc: {bp_log[-1]['acc_eval']:.4f}")
+
+ # ----- Save ----- #
+ out = {
+ "dfa": {
+ "trajectory": [
+ {
+ "epoch": e["epoch"],
+ "acc": e["acc_eval"],
+ "h_max_to_h0_ratio": (max(e["hidden_norms"]) / max(e["hidden_norms"][0], 1e-30)),
+ "g_L": e["bp_grad_norms_per_sample_med"][-1],
+ "gamma": e.get("gamma_dfa"),
+ }
+ for e in dfa_log
+ ],
+ "first_fire_a_epoch": fire_a_epoch,
+ "first_fire_b_epoch": fire_b_epoch,
+ "final_acc": dfa_log[-1]["acc_eval"],
+ },
+ "bp": {
+ "any_fire": bp_fired,
+ "final_acc": bp_log[-1]["acc_eval"],
+ },
+ "thresholds": {
+ "g_norm_floor": THRESHOLDS.g_norm_floor,
+ "h_norm_explosion_ratio": THRESHOLDS.h_norm_explosion_ratio,
+ },
+ }
+ out_path = os.path.join(REPO_ROOT, "results/protocol_audit/temporal_evolution_s42.json")
+ with open(out_path, "w") as f:
+ json.dump(out, f, indent=2)
+ print(f"\nSaved {out_path}")
+
+
+if __name__ == "__main__":
+ main()