summaryrefslogtreecommitdiff
path: root/protocol/examples/temporal_diagnostic_evolution.py
blob: 6a2c04233320976c1491664f8376d27268aebbc8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
"""
Temporal validation of the diagnostic protocol: at what epoch during DFA
training does each diagnostic cross its degeneracy threshold?

This uses the existing snapshot evolution data in
`results/snapshot_evolution_v2/`, which logs per-epoch:
  - hidden_norms (the (a) diagnostic)
  - bp_grad_norms_per_sample_med (the (b) diagnostic)
  - gamma_dfa (the field-standard reference number)
  - acc_eval

over 100 epochs of both BP and DFA training on the standard 4-block d=256
ResMLP CIFAR-10 setup. We replay this data through the protocol's
threshold logic and report:

  (i) the epoch at which each diagnostic first FIRES on DFA,
  (ii) the per-epoch headline accuracy (so we can show that the diagnostic
       fires BEFORE the headline acc has converged — i.e. the protocol
       could have caught the pathology mid-training),
  (iii) the trajectory on BP for comparison (which should never fire).

This is the temporal validation of the protocol's decision utility: the
protocol catches the pathology *as it happens*, not just retrospectively.

Run:
    python -m protocol.examples.temporal_diagnostic_evolution
"""
import os
import json
import sys

REPO_ROOT = os.path.dirname(
    os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
sys.path.insert(0, REPO_ROOT)

from protocol.report import DiagnosticThresholds  # noqa: E402

THRESHOLDS = DiagnosticThresholds()


def diagnose_entry(entry):
    h = entry["hidden_norms"]
    g = entry["bp_grad_norms_per_sample_med"]
    h_exploded = (max(h) / max(h[0], 1e-30)) > THRESHOLDS.h_norm_explosion_ratio
    g_at_floor = g[-1] < THRESHOLDS.g_norm_floor
    return h_exploded, g_at_floor


def first_fire_epoch(log, predicate):
    for entry in log:
        if predicate(entry):
            return entry["epoch"]
    return None


def main():
    import argparse
    p = argparse.ArgumentParser()
    p.add_argument("--seed", type=int, default=42)
    args = p.parse_args()
    snapshot_path = os.path.join(
        REPO_ROOT, f"results/snapshot_evolution_v2/snapshot_evolution_s{args.seed}.json"
    )
    if not os.path.exists(snapshot_path):
        print(f"snapshot not found: {snapshot_path}")
        return
    with open(snapshot_path) as f:
        d = json.load(f)
    bp_log = d["bp_log"]
    dfa_log = d["dfa_log"]

    print("=" * 88)
    print("TEMPORAL DIAGNOSTIC EVOLUTION (4-block d=256 ResMLP, CIFAR-10, seed 42)")
    print("=" * 88)

    # ----- DFA trajectory ----- #
    print("\nDFA training trajectory (each row = one logged epoch):")
    print(
        f"  {'epoch':>6} {'acc':>8} {'gamma':>10} "
        f"{'||h_L||':>14} {'||g_L||':>14} {'(a)':>5} {'(b)':>5}"
    )
    fired_a = False
    fired_b = False
    fire_a_epoch = None
    fire_b_epoch = None
    for entry in dfa_log:
        h = entry["hidden_norms"]
        g = entry["bp_grad_norms_per_sample_med"]
        h_exp = (max(h) / max(h[0], 1e-30)) > THRESHOLDS.h_norm_explosion_ratio
        g_floor = g[-1] < THRESHOLDS.g_norm_floor
        flag_a = "FIRE" if h_exp else "ok"
        flag_b = "FIRE" if g_floor else "ok"
        ep = entry["epoch"]
        if h_exp and not fired_a:
            fired_a = True
            fire_a_epoch = ep
        if g_floor and not fired_b:
            fired_b = True
            fire_b_epoch = ep
        if ep <= 5 or ep % 10 == 0 or ep == dfa_log[-1]["epoch"]:
            gamma = entry.get("gamma_dfa")
            gamma_s = "nan" if gamma is None or (isinstance(gamma, float) and gamma != gamma) else f"{gamma:.4f}"
            print(
                f"  {ep:>6} {entry['acc_eval']:>8.4f} {gamma_s:>10} "
                f"{h[-1]:>14.3e} {g[-1]:>14.3e} {flag_a:>5} {flag_b:>5}"
            )

    print()
    print(f"  Diagnostic (a) ‖h_l‖ explosion first fires at epoch: {fire_a_epoch}")
    print(f"  Diagnostic (b) ‖g_l‖ floor    first fires at epoch: {fire_b_epoch}")
    print(f"  DFA test acc at the moment (a) fires: "
          f"{next(e['acc_eval'] for e in dfa_log if e['epoch'] == fire_a_epoch):.4f}" if fire_a_epoch is not None else "  (a) never fires")
    print(f"  DFA test acc at the moment (b) fires: "
          f"{next(e['acc_eval'] for e in dfa_log if e['epoch'] == fire_b_epoch):.4f}" if fire_b_epoch is not None else "  (b) never fires")
    print(f"  DFA final test acc: {dfa_log[-1]['acc_eval']:.4f}")

    # ----- BP trajectory (sanity) ----- #
    print("\nBP training trajectory (sanity):")
    print(
        f"  {'epoch':>6} {'acc':>8} "
        f"{'||h_L||':>14} {'||g_L||':>14} {'(a)':>5} {'(b)':>5}"
    )
    bp_fired = False
    for entry in bp_log:
        h = entry["hidden_norms"]
        g = entry["bp_grad_norms_per_sample_med"]
        h_exp = (max(h) / max(h[0], 1e-30)) > THRESHOLDS.h_norm_explosion_ratio
        g_floor = g[-1] < THRESHOLDS.g_norm_floor
        if h_exp or g_floor:
            bp_fired = True
        if entry["epoch"] <= 5 or entry["epoch"] % 10 == 0 or entry["epoch"] == bp_log[-1]["epoch"]:
            print(
                f"  {entry['epoch']:>6} {entry['acc_eval']:>8.4f} "
                f"{h[-1]:>14.3e} {g[-1]:>14.3e} "
                f"{'FIRE' if h_exp else 'ok':>5} {'FIRE' if g_floor else 'ok':>5}"
            )
    print(f"\n  BP fired any diagnostic at any epoch: {bp_fired}")
    print(f"  BP final test acc: {bp_log[-1]['acc_eval']:.4f}")

    # ----- Save ----- #
    out = {
        "dfa": {
            "trajectory": [
                {
                    "epoch": e["epoch"],
                    "acc": e["acc_eval"],
                    "h_max_to_h0_ratio": (max(e["hidden_norms"]) / max(e["hidden_norms"][0], 1e-30)),
                    "g_L": e["bp_grad_norms_per_sample_med"][-1],
                    "gamma": e.get("gamma_dfa"),
                }
                for e in dfa_log
            ],
            "first_fire_a_epoch": fire_a_epoch,
            "first_fire_b_epoch": fire_b_epoch,
            "final_acc": dfa_log[-1]["acc_eval"],
        },
        "bp": {
            "any_fire": bp_fired,
            "final_acc": bp_log[-1]["acc_eval"],
        },
        "thresholds": {
            "g_norm_floor": THRESHOLDS.g_norm_floor,
            "h_norm_explosion_ratio": THRESHOLDS.h_norm_explosion_ratio,
        },
    }
    out_path = os.path.join(REPO_ROOT, f"results/protocol_audit/temporal_evolution_s{args.seed}.json")
    with open(out_path, "w") as f:
        json.dump(out, f, indent=2)
    print(f"\nSaved {out_path}")


if __name__ == "__main__":
    main()