summaryrefslogtreecommitdiff
path: root/protocol
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 23:06:01 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 23:06:01 -0500
commite53327ac6d7d5be097c3de434caa700c52c598e9 (patch)
treeaa95e92f7859d426fcb290c57bb27ea57ab4213a /protocol
parent4172195ca318387e20e3576ab40187d4d2f08ebe (diff)
Add training-monitor early-stop demo: 96% compute savings on DFA
Demonstrates the practical use case of the protocol — not as a post-hoc audit but as an in-training abort condition. Walks through the existing per-epoch trace and shows when the protocol would have triggered an early stop on DFA training and what the saved compute would be. Result: DFA on 4-block d=256 ResMLP fires diagnostic (b) at epoch 4 with test acc 0.3076. The final acc at epoch 100 is *also* 0.3076 (identical). Stopping at epoch 4 saves 96% of compute with zero headline acc loss.
Diffstat (limited to 'protocol')
-rw-r--r--protocol/examples/training_monitor_demo.py136
1 files changed, 136 insertions, 0 deletions
diff --git a/protocol/examples/training_monitor_demo.py b/protocol/examples/training_monitor_demo.py
new file mode 100644
index 0000000..3c9a0e9
--- /dev/null
+++ b/protocol/examples/training_monitor_demo.py
@@ -0,0 +1,136 @@
+"""
+Example: use the diagnostic protocol as an in-training early-stopping
+criterion. This is the practical deployment of the protocol — not as a
+post-hoc audit, but as a runtime check that aborts training if any
+diagnostic fires.
+
+The example is *synthetic* — no GPU, no real model, no checkpoint loading.
+It uses the per-epoch trace from `results/snapshot_evolution_v2/` to
+simulate a "live" training loop and demonstrates how a paper user would
+wrap their actual training step with `diagnose(...)`.
+
+The output illustrates:
+ - epoch-by-epoch diagnostic evaluation
+ - early termination as soon as any flag fires
+ - the *epoch saved* by the protocol vs the full 100-epoch run
+
+Pseudo-code for a real use case:
+
+ from protocol import diagnose
+
+ for epoch in range(num_epochs):
+ train_one_epoch(model, train_loader)
+ if epoch % 5 == 0: # check every 5 epochs is enough
+ report = diagnose(
+ model=model,
+ eval_batches=fixed_eval_batches,
+ headline_acc=evaluate(model, test_loader),
+ frozen_baseline_acc=fixed_frozen_baseline_acc,
+ method_name="my method",
+ )
+ if report.verdict != "trustworthy":
+ print(f"Aborting at epoch {epoch}: {report.verdict}")
+ break
+
+Run the synthetic demo:
+ python -m protocol.examples.training_monitor_demo
+"""
+import os
+import sys
+import json
+
+REPO_ROOT = os.path.dirname(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+)
+sys.path.insert(0, REPO_ROOT)
+
+from protocol.report import DiagnosticThresholds # noqa: E402
+
+THRESHOLDS = DiagnosticThresholds()
+
+
+def max_per_block_growth(h):
+ if len(h) < 2:
+ return 1.0
+ return max(h[i + 1] / max(h[i], 1e-30) for i in range(len(h) - 1))
+
+
+def early_stop_simulation(log, method_name):
+ """Walk through the per-epoch trace and report when any flag would have
+ fired, what the final acc would have been if we had stopped, and how
+ many epochs would have been saved."""
+ print(f"\n=== {method_name} (simulating early-stop with the protocol) ===")
+ fired = None
+ fired_epoch = None
+ fired_acc = None
+ for entry in log:
+ h = entry["hidden_norms"]
+ g = entry["bp_grad_norms_per_sample_med"]
+ per_block = max_per_block_growth(h)
+ a_fire = per_block > THRESHOLDS.h_norm_explosion_ratio
+ b_fire = g[-1] < THRESHOLDS.g_norm_floor
+ if a_fire or b_fire:
+ fired = []
+ if a_fire:
+ fired.append(f"(a) per-block growth {per_block:.0f}× > 50×")
+ if b_fire:
+ fired.append(f"(b) ‖g_L‖ {g[-1]:.1e} < 1e-7")
+ fired_epoch = entry["epoch"]
+ fired_acc = entry["acc_eval"]
+ break
+
+ final_epoch = log[-1]["epoch"]
+ final_acc = log[-1]["acc_eval"]
+
+ if fired is None:
+ print(f" No flag fired in {final_epoch} epochs. Final acc: {final_acc:.4f}")
+ print(f" Verdict: trustworthy. Run to completion.")
+ return
+
+ epochs_saved = final_epoch - fired_epoch
+ print(f" Protocol fires at epoch {fired_epoch}: {' AND '.join(fired)}")
+ print(f" Acc at fire epoch: {fired_acc:.4f}")
+ print(f" Acc at final epoch (would have been): {final_acc:.4f}")
+ print(f" Diff (final - fire-time): {(final_acc - fired_acc) * 100:+.1f} pp")
+ print(f" -> Stopping at epoch {fired_epoch} would save "
+ f"{epochs_saved} epochs ({epochs_saved/final_epoch*100:.0f}% compute), "
+ f"with no headline acc loss.")
+
+
+def main():
+ snapshot_path = os.path.join(
+ REPO_ROOT, "results/snapshot_evolution_v2/snapshot_evolution_s42.json"
+ )
+ with open(snapshot_path) as f:
+ d = json.load(f)
+ bp_log = [
+ {**e, "hidden_norms": e["hidden_norms"],
+ "bp_grad_norms_per_sample_med": e["bp_grad_norms_per_sample_med"]}
+ for e in d["bp_log"]
+ ]
+ dfa_log = [
+ {**e, "hidden_norms": e["hidden_norms"],
+ "bp_grad_norms_per_sample_med": e["bp_grad_norms_per_sample_med"]}
+ for e in d["dfa_log"]
+ ]
+
+ print("=" * 72)
+ print("EARLY-STOP SIMULATION: 4-block d=256 ResMLP, CIFAR-10, seed 42")
+ print("Using the diagnostic protocol as an in-training abort condition.")
+ print("=" * 72)
+
+ early_stop_simulation(bp_log, "BP (sanity)")
+ early_stop_simulation(dfa_log, "DFA (failing)")
+
+ print()
+ print("Compute saved by using the protocol as an early-stop criterion:")
+ print(" BP: 0% (protocol never fires, BP runs to completion)")
+ print(" DFA: ~96% (protocol fires within first 4 epochs of a 100-epoch run)")
+ print()
+ print("The protocol pays for itself: a single forward+backward pass on a")
+ print("fixed eval batch every few epochs is negligible compared to the")
+ print("compute saved when training is aborted on a failing run.")
+
+
+if __name__ == "__main__":
+ main()