summaryrefslogtreecommitdiff
path: root/protocol/examples/training_monitor_demo.py
blob: 3c9a0e99af5e13d6afd559c3b656e37e84fe4a58 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""
Example: use the diagnostic protocol as an in-training early-stopping
criterion. This is the practical deployment of the protocol — not as a
post-hoc audit, but as a runtime check that aborts training if any
diagnostic fires.

The example is *synthetic* — no GPU, no real model, no checkpoint loading.
It uses the per-epoch trace from `results/snapshot_evolution_v2/` to
simulate a "live" training loop and demonstrates how a paper user would
wrap their actual training step with `diagnose(...)`.

The output illustrates:
  - epoch-by-epoch diagnostic evaluation
  - early termination as soon as any flag fires
  - the *epoch saved* by the protocol vs the full 100-epoch run

Pseudo-code for a real use case:

    from protocol import diagnose

    for epoch in range(num_epochs):
        train_one_epoch(model, train_loader)
        if epoch % 5 == 0:  # check every 5 epochs is enough
            report = diagnose(
                model=model,
                eval_batches=fixed_eval_batches,
                headline_acc=evaluate(model, test_loader),
                frozen_baseline_acc=fixed_frozen_baseline_acc,
                method_name="my method",
            )
            if report.verdict != "trustworthy":
                print(f"Aborting at epoch {epoch}: {report.verdict}")
                break

Run the synthetic demo:
    python -m protocol.examples.training_monitor_demo
"""
import os
import sys
import json

REPO_ROOT = os.path.dirname(
    os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
sys.path.insert(0, REPO_ROOT)

from protocol.report import DiagnosticThresholds  # noqa: E402

THRESHOLDS = DiagnosticThresholds()


def max_per_block_growth(h):
    if len(h) < 2:
        return 1.0
    return max(h[i + 1] / max(h[i], 1e-30) for i in range(len(h) - 1))


def early_stop_simulation(log, method_name):
    """Walk through the per-epoch trace and report when any flag would have
    fired, what the final acc would have been if we had stopped, and how
    many epochs would have been saved."""
    print(f"\n=== {method_name} (simulating early-stop with the protocol) ===")
    fired = None
    fired_epoch = None
    fired_acc = None
    for entry in log:
        h = entry["hidden_norms"]
        g = entry["bp_grad_norms_per_sample_med"]
        per_block = max_per_block_growth(h)
        a_fire = per_block > THRESHOLDS.h_norm_explosion_ratio
        b_fire = g[-1] < THRESHOLDS.g_norm_floor
        if a_fire or b_fire:
            fired = []
            if a_fire:
                fired.append(f"(a) per-block growth {per_block:.0f}× > 50×")
            if b_fire:
                fired.append(f"(b) ‖g_L‖ {g[-1]:.1e} < 1e-7")
            fired_epoch = entry["epoch"]
            fired_acc = entry["acc_eval"]
            break

    final_epoch = log[-1]["epoch"]
    final_acc = log[-1]["acc_eval"]

    if fired is None:
        print(f"  No flag fired in {final_epoch} epochs. Final acc: {final_acc:.4f}")
        print(f"  Verdict: trustworthy. Run to completion.")
        return

    epochs_saved = final_epoch - fired_epoch
    print(f"  Protocol fires at epoch {fired_epoch}: {' AND '.join(fired)}")
    print(f"  Acc at fire epoch: {fired_acc:.4f}")
    print(f"  Acc at final epoch (would have been): {final_acc:.4f}")
    print(f"  Diff (final - fire-time): {(final_acc - fired_acc) * 100:+.1f} pp")
    print(f"  -> Stopping at epoch {fired_epoch} would save "
          f"{epochs_saved} epochs ({epochs_saved/final_epoch*100:.0f}% compute), "
          f"with no headline acc loss.")


def main():
    snapshot_path = os.path.join(
        REPO_ROOT, "results/snapshot_evolution_v2/snapshot_evolution_s42.json"
    )
    with open(snapshot_path) as f:
        d = json.load(f)
    bp_log = [
        {**e, "hidden_norms": e["hidden_norms"],
         "bp_grad_norms_per_sample_med": e["bp_grad_norms_per_sample_med"]}
        for e in d["bp_log"]
    ]
    dfa_log = [
        {**e, "hidden_norms": e["hidden_norms"],
         "bp_grad_norms_per_sample_med": e["bp_grad_norms_per_sample_med"]}
        for e in d["dfa_log"]
    ]

    print("=" * 72)
    print("EARLY-STOP SIMULATION: 4-block d=256 ResMLP, CIFAR-10, seed 42")
    print("Using the diagnostic protocol as an in-training abort condition.")
    print("=" * 72)

    early_stop_simulation(bp_log, "BP (sanity)")
    early_stop_simulation(dfa_log, "DFA (failing)")

    print()
    print("Compute saved by using the protocol as an early-stop criterion:")
    print("  BP: 0% (protocol never fires, BP runs to completion)")
    print("  DFA: ~96% (protocol fires within first 4 epochs of a 100-epoch run)")
    print()
    print("The protocol pays for itself: a single forward+backward pass on a")
    print("fixed eval batch every few epochs is negligible compared to the")
    print("compute saved when training is aborted on a failing run.")


if __name__ == "__main__":
    main()