summaryrefslogtreecommitdiff
path: root/protocol/examples/threshold_sensitivity.py
blob: fdf0f6cabe43389fee07357d777733821f6c8106 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""
Threshold sensitivity analysis: how does the protocol's verdict change as
the diagnostic thresholds are varied? E&D reviewers will reasonably ask
"what makes 50× per-block growth and 1e-7 g-norm-floor the right cutoffs?"

This script sweeps each threshold independently across two orders of
magnitude and reports, for the 5-method audit table:
  - the verdict at each threshold
  - the boundary value at which verdicts change for each method

If a method's verdict is robust to a wide range of thresholds, the
protocol is well-calibrated for that method. If verdicts flip near the
chosen threshold, the calibration is fragile and the threshold needs
explicit defense in the paper.

Run:
    python -m protocol.examples.threshold_sensitivity
"""
import os
import sys
import json

REPO_ROOT = os.path.dirname(
    os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
sys.path.insert(0, REPO_ROOT)

AUDIT_PATH = os.path.join(REPO_ROOT, "results/protocol_audit/audit_table_s42_s123_s456.json")


def max_per_block_growth(h):
    if len(h) < 2:
        return 1.0
    return max(h[i + 1] / max(h[i], 1e-30) for i in range(len(h) - 1))


def main():
    with open(AUDIT_PATH) as f:
        data = json.load(f)
    rows = data["summary"]

    # Compute per-row diagnostic raw values
    metrics = []
    for r in rows:
        report = data["reports"][f"{r['method']}_s{r['seed']}"]
        metrics.append({
            "method": r["method"],
            "seed": r["seed"],
            "max_per_block_growth": max_per_block_growth(report["residual_norms"]),
            "g_L": report["bp_grad_norms"][-1],
            "stability": report["cross_batch_stability"],
            "frozen_acc": r.get("frozen_acc"),
            "acc": r["acc"],
        })

    # ----- (a) per-block growth threshold sensitivity ----- #
    print("=" * 88)
    print("DIAGNOSTIC (a) per-block growth: sensitivity over threshold")
    print("=" * 88)
    print(f"{'method':<16}{'seed':>6}{'value':>14}", end="")
    a_thresholds = [5, 10, 20, 50, 100, 500, 1000, 5000]
    for t in a_thresholds:
        print(f"{'>'+str(t)+'×':>10}", end="")
    print()
    for m in metrics:
        print(f"{m['method']:<16}{m['seed']:>6}{m['max_per_block_growth']:>14.2e}", end="")
        for t in a_thresholds:
            fired = "FIRE" if m["max_per_block_growth"] > t else "ok"
            print(f"{fired:>10}", end="")
        print()
    print()
    print("Reading: BP/EP rows should be 'ok' across the entire row (the whole")
    print("threshold range is healthy for them). DFA/SB/CB rows should be 'FIRE'")
    print("at the chosen threshold and have a comfortable margin on either side.")
    print()

    # ----- (b) g-norm floor sensitivity ----- #
    print("=" * 88)
    print("DIAGNOSTIC (b) g-norm floor: sensitivity over threshold")
    print("=" * 88)
    print(f"{'method':<16}{'seed':>6}{'value':>14}", end="")
    b_thresholds = [1e-9, 1e-8, 1e-7, 1e-6, 1e-5]
    for t in b_thresholds:
        print(f"{'<'+f'{t:.0e}':>10}", end="")
    print()
    for m in metrics:
        print(f"{m['method']:<16}{m['seed']:>6}{m['g_L']:>14.2e}", end="")
        for t in b_thresholds:
            fired = "FIRE" if m["g_L"] < t else "ok"
            print(f"{fired:>10}", end="")
        print()
    print()

    # ----- (c) stability ceiling sensitivity ----- #
    print("=" * 88)
    print("DIAGNOSTIC (c) stability ceiling: sensitivity over threshold")
    print("=" * 88)
    print(f"{'method':<16}{'seed':>6}{'value':>14}", end="")
    c_thresholds = [0.10, 0.20, 0.30, 0.50, 0.70, 0.90]
    for t in c_thresholds:
        print(f"{'>'+f'{t:.2f}':>10}", end="")
    print()
    for m in metrics:
        print(f"{m['method']:<16}{m['seed']:>6}{m['stability']:>14.3f}", end="")
        for t in c_thresholds:
            fired = "FIRE" if m["stability"] > t else "ok"
            print(f"{fired:>10}", end="")
        print()
    print()

    # ----- Aggregate verdict robustness ----- #
    print("=" * 88)
    print("VERDICT ROBUSTNESS: at what threshold does each verdict CHANGE?")
    print("=" * 88)
    # For each method × seed, find the (a)-threshold where the verdict flips
    print(f"{'method':<16}{'seed':>6}{'(a) value':>14}{'(a) flip near':>16}{'(b) value':>14}{'(b) flip near':>16}")
    for m in metrics:
        # Find the largest threshold at which it still fires (i.e. where
        # the verdict would change if you raised the threshold past that)
        flip_a = m["max_per_block_growth"]
        flip_b = m["g_L"]
        print(f"{m['method']:<16}{m['seed']:>6}{m['max_per_block_growth']:>14.2e}{flip_a:>16.2e}{m['g_L']:>14.2e}{flip_b:>16.2e}")
    print()
    print("Interpretation:")
    print("  - For DFA/SB/CB, the 'flip near' values are the diagnostic raw values.")
    print("    Default (a) threshold 50 catches all if raw values > 50; default (b)")
    print("    threshold 1e-7 catches all if raw values < 1e-7. Compare:")
    bp_a = max(m["max_per_block_growth"] for m in metrics if m["method"] == "bp")
    ep_a = max(m["max_per_block_growth"] for m in metrics if m["method"] == "ep")
    dfa_a = min(m["max_per_block_growth"] for m in metrics if m["method"] == "dfa")
    sb_a = min(m["max_per_block_growth"] for m in metrics if m["method"] == "state_bridge")
    cb_a = min(m["max_per_block_growth"] for m in metrics if m["method"] == "credit_bridge")
    bp_b = min(m["g_L"] for m in metrics if m["method"] == "bp")
    ep_b = min(m["g_L"] for m in metrics if m["method"] == "ep")
    dfa_b = max(m["g_L"] for m in metrics if m["method"] == "dfa")
    sb_b = max(m["g_L"] for m in metrics if m["method"] == "state_bridge")
    cb_b = max(m["g_L"] for m in metrics if m["method"] == "credit_bridge")
    print(f"    (a) max BP per-block growth across 3 seeds: {bp_a:.2e}")
    print(f"    (a) max EP per-block growth across 3 seeds: {ep_a:.2e}")
    print(f"    (a) min DFA per-block growth across 3 seeds: {dfa_a:.2e}")
    print(f"    (a) min SB per-block growth across 3 seeds: {sb_a:.2e}")
    print(f"    (a) min CB per-block growth across 3 seeds: {cb_a:.2e}")
    print(f"    -> separation gap: healthy max = {max(bp_a, ep_a):.2e},")
    print(f"       degenerate min = {min(dfa_a, sb_a, cb_a):.2e},")
    print(f"       gap factor = {min(dfa_a, sb_a, cb_a) / max(bp_a, ep_a):.0f}×")
    print()
    print(f"    (b) min BP ‖g_L‖ across 3 seeds: {bp_b:.2e}")
    print(f"    (b) min EP ‖g_L‖ across 3 seeds: {ep_b:.2e}")
    print(f"    (b) max DFA ‖g_L‖ across 3 seeds: {dfa_b:.2e}")
    print(f"    (b) max SB ‖g_L‖ across 3 seeds: {sb_b:.2e}")
    print(f"    (b) max CB ‖g_L‖ across 3 seeds: {cb_b:.2e}")
    print(f"    -> separation gap: healthy min = {min(bp_b, ep_b):.2e},")
    print(f"       degenerate max = {max(dfa_b, sb_b, cb_b):.2e},")
    print(f"       gap factor = {min(bp_b, ep_b) / max(dfa_b, sb_b, cb_b):.0f}×")


if __name__ == "__main__":
    main()