diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 23:06:01 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 23:06:01 -0500 |
| commit | e53327ac6d7d5be097c3de434caa700c52c598e9 (patch) | |
| tree | aa95e92f7859d426fcb290c57bb27ea57ab4213a /protocol | |
| parent | 4172195ca318387e20e3576ab40187d4d2f08ebe (diff) | |
Add training-monitor early-stop demo: 96% compute savings on DFA
Demonstrates the practical use case of the protocol — not as a post-hoc
audit but as an in-training abort condition. Walks through the existing
per-epoch trace and shows when the protocol would have triggered an early
stop on DFA training and what the saved compute would be.
Result: DFA on 4-block d=256 ResMLP fires diagnostic (b) at epoch 4 with
test acc 0.3076. The final acc at epoch 100 is *also* 0.3076 (identical).
Stopping at epoch 4 saves 96% of compute with zero headline acc loss.
Diffstat (limited to 'protocol')
| -rw-r--r-- | protocol/examples/training_monitor_demo.py | 136 |
1 files changed, 136 insertions, 0 deletions
diff --git a/protocol/examples/training_monitor_demo.py b/protocol/examples/training_monitor_demo.py new file mode 100644 index 0000000..3c9a0e9 --- /dev/null +++ b/protocol/examples/training_monitor_demo.py @@ -0,0 +1,136 @@ +""" +Example: use the diagnostic protocol as an in-training early-stopping +criterion. This is the practical deployment of the protocol — not as a +post-hoc audit, but as a runtime check that aborts training if any +diagnostic fires. + +The example is *synthetic* — no GPU, no real model, no checkpoint loading. +It uses the per-epoch trace from `results/snapshot_evolution_v2/` to +simulate a "live" training loop and demonstrates how a paper user would +wrap their actual training step with `diagnose(...)`. + +The output illustrates: + - epoch-by-epoch diagnostic evaluation + - early termination as soon as any flag fires + - the *epoch saved* by the protocol vs the full 100-epoch run + +Pseudo-code for a real use case: + + from protocol import diagnose + + for epoch in range(num_epochs): + train_one_epoch(model, train_loader) + if epoch % 5 == 0: # check every 5 epochs is enough + report = diagnose( + model=model, + eval_batches=fixed_eval_batches, + headline_acc=evaluate(model, test_loader), + frozen_baseline_acc=fixed_frozen_baseline_acc, + method_name="my method", + ) + if report.verdict != "trustworthy": + print(f"Aborting at epoch {epoch}: {report.verdict}") + break + +Run the synthetic demo: + python -m protocol.examples.training_monitor_demo +""" +import os +import sys +import json + +REPO_ROOT = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) +sys.path.insert(0, REPO_ROOT) + +from protocol.report import DiagnosticThresholds # noqa: E402 + +THRESHOLDS = DiagnosticThresholds() + + +def max_per_block_growth(h): + if len(h) < 2: + return 1.0 + return max(h[i + 1] / max(h[i], 1e-30) for i in range(len(h) - 1)) + + +def early_stop_simulation(log, method_name): + """Walk through the per-epoch trace and report when any flag would have + fired, what the final acc would have been if we had stopped, and how + many epochs would have been saved.""" + print(f"\n=== {method_name} (simulating early-stop with the protocol) ===") + fired = None + fired_epoch = None + fired_acc = None + for entry in log: + h = entry["hidden_norms"] + g = entry["bp_grad_norms_per_sample_med"] + per_block = max_per_block_growth(h) + a_fire = per_block > THRESHOLDS.h_norm_explosion_ratio + b_fire = g[-1] < THRESHOLDS.g_norm_floor + if a_fire or b_fire: + fired = [] + if a_fire: + fired.append(f"(a) per-block growth {per_block:.0f}× > 50×") + if b_fire: + fired.append(f"(b) ‖g_L‖ {g[-1]:.1e} < 1e-7") + fired_epoch = entry["epoch"] + fired_acc = entry["acc_eval"] + break + + final_epoch = log[-1]["epoch"] + final_acc = log[-1]["acc_eval"] + + if fired is None: + print(f" No flag fired in {final_epoch} epochs. Final acc: {final_acc:.4f}") + print(f" Verdict: trustworthy. Run to completion.") + return + + epochs_saved = final_epoch - fired_epoch + print(f" Protocol fires at epoch {fired_epoch}: {' AND '.join(fired)}") + print(f" Acc at fire epoch: {fired_acc:.4f}") + print(f" Acc at final epoch (would have been): {final_acc:.4f}") + print(f" Diff (final - fire-time): {(final_acc - fired_acc) * 100:+.1f} pp") + print(f" -> Stopping at epoch {fired_epoch} would save " + f"{epochs_saved} epochs ({epochs_saved/final_epoch*100:.0f}% compute), " + f"with no headline acc loss.") + + +def main(): + snapshot_path = os.path.join( + REPO_ROOT, "results/snapshot_evolution_v2/snapshot_evolution_s42.json" + ) + with open(snapshot_path) as f: + d = json.load(f) + bp_log = [ + {**e, "hidden_norms": e["hidden_norms"], + "bp_grad_norms_per_sample_med": e["bp_grad_norms_per_sample_med"]} + for e in d["bp_log"] + ] + dfa_log = [ + {**e, "hidden_norms": e["hidden_norms"], + "bp_grad_norms_per_sample_med": e["bp_grad_norms_per_sample_med"]} + for e in d["dfa_log"] + ] + + print("=" * 72) + print("EARLY-STOP SIMULATION: 4-block d=256 ResMLP, CIFAR-10, seed 42") + print("Using the diagnostic protocol as an in-training abort condition.") + print("=" * 72) + + early_stop_simulation(bp_log, "BP (sanity)") + early_stop_simulation(dfa_log, "DFA (failing)") + + print() + print("Compute saved by using the protocol as an early-stop criterion:") + print(" BP: 0% (protocol never fires, BP runs to completion)") + print(" DFA: ~96% (protocol fires within first 4 epochs of a 100-epoch run)") + print() + print("The protocol pays for itself: a single forward+backward pass on a") + print("fixed eval batch every few epochs is negligible compared to the") + print("compute saved when training is aborted on a failing run.") + + +if __name__ == "__main__": + main() |
