""" Example: use the diagnostic protocol as an in-training early-stopping criterion. This is the practical deployment of the protocol — not as a post-hoc audit, but as a runtime check that aborts training if any diagnostic fires. The example is *synthetic* — no GPU, no real model, no checkpoint loading. It uses the per-epoch trace from `results/snapshot_evolution_v2/` to simulate a "live" training loop and demonstrates how a paper user would wrap their actual training step with `diagnose(...)`. The output illustrates: - epoch-by-epoch diagnostic evaluation - early termination as soon as any flag fires - the *epoch saved* by the protocol vs the full 100-epoch run Pseudo-code for a real use case: from protocol import diagnose for epoch in range(num_epochs): train_one_epoch(model, train_loader) if epoch % 5 == 0: # check every 5 epochs is enough report = diagnose( model=model, eval_batches=fixed_eval_batches, headline_acc=evaluate(model, test_loader), frozen_baseline_acc=fixed_frozen_baseline_acc, method_name="my method", ) if report.verdict != "trustworthy": print(f"Aborting at epoch {epoch}: {report.verdict}") break Run the synthetic demo: python -m protocol.examples.training_monitor_demo """ import os import sys import json REPO_ROOT = os.path.dirname( os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ) sys.path.insert(0, REPO_ROOT) from protocol.report import DiagnosticThresholds # noqa: E402 THRESHOLDS = DiagnosticThresholds() def max_per_block_growth(h): if len(h) < 2: return 1.0 return max(h[i + 1] / max(h[i], 1e-30) for i in range(len(h) - 1)) def early_stop_simulation(log, method_name): """Walk through the per-epoch trace and report when any flag would have fired, what the final acc would have been if we had stopped, and how many epochs would have been saved.""" print(f"\n=== {method_name} (simulating early-stop with the protocol) ===") fired = None fired_epoch = None fired_acc = None for entry in log: h = entry["hidden_norms"] g = entry["bp_grad_norms_per_sample_med"] per_block = max_per_block_growth(h) a_fire = per_block > THRESHOLDS.h_norm_explosion_ratio b_fire = g[-1] < THRESHOLDS.g_norm_floor if a_fire or b_fire: fired = [] if a_fire: fired.append(f"(a) per-block growth {per_block:.0f}× > 50×") if b_fire: fired.append(f"(b) ‖g_L‖ {g[-1]:.1e} < 1e-7") fired_epoch = entry["epoch"] fired_acc = entry["acc_eval"] break final_epoch = log[-1]["epoch"] final_acc = log[-1]["acc_eval"] if fired is None: print(f" No flag fired in {final_epoch} epochs. Final acc: {final_acc:.4f}") print(f" Verdict: trustworthy. Run to completion.") return epochs_saved = final_epoch - fired_epoch print(f" Protocol fires at epoch {fired_epoch}: {' AND '.join(fired)}") print(f" Acc at fire epoch: {fired_acc:.4f}") print(f" Acc at final epoch (would have been): {final_acc:.4f}") print(f" Diff (final - fire-time): {(final_acc - fired_acc) * 100:+.1f} pp") print(f" -> Stopping at epoch {fired_epoch} would save " f"{epochs_saved} epochs ({epochs_saved/final_epoch*100:.0f}% compute), " f"with no headline acc loss.") def main(): snapshot_path = os.path.join( REPO_ROOT, "results/snapshot_evolution_v2/snapshot_evolution_s42.json" ) with open(snapshot_path) as f: d = json.load(f) bp_log = [ {**e, "hidden_norms": e["hidden_norms"], "bp_grad_norms_per_sample_med": e["bp_grad_norms_per_sample_med"]} for e in d["bp_log"] ] dfa_log = [ {**e, "hidden_norms": e["hidden_norms"], "bp_grad_norms_per_sample_med": e["bp_grad_norms_per_sample_med"]} for e in d["dfa_log"] ] print("=" * 72) print("EARLY-STOP SIMULATION: 4-block d=256 ResMLP, CIFAR-10, seed 42") print("Using the diagnostic protocol as an in-training abort condition.") print("=" * 72) early_stop_simulation(bp_log, "BP (sanity)") early_stop_simulation(dfa_log, "DFA (failing)") print() print("Compute saved by using the protocol as an early-stop criterion:") print(" BP: 0% (protocol never fires, BP runs to completion)") print(" DFA: ~96% (protocol fires within first 4 epochs of a 100-epoch run)") print() print("The protocol pays for itself: a single forward+backward pass on a") print("fixed eval batch every few epochs is negligible compared to the") print("compute saved when training is aborted on a failing run.") if __name__ == "__main__": main()