diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 23:12:50 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 23:12:50 -0500 |
| commit | 0886902ab7decd7702c9764e8fb2c3e10a528e45 (patch) | |
| tree | 82dad1b7154f2b2541aa5afdc10afbeca22aa053 /protocol | |
| parent | 89fff0048c04bdc4c8beb6d11f8d5564d75cbb0c (diff) | |
Add threshold sensitivity analysis: (a) 63x gap, (b) 24338x gap
For each diagnostic, sweeps threshold across orders of magnitude on the
3-seed audit data and reports the verdict at each value.
Key calibration findings (3 seeds):
Diagnostic (a) max per-block growth:
Healthy max (BP/EP): 11.0
Degenerate min (DFA/SB/CB): 694
Separation gap: 63x
Default threshold 50 sits comfortably in the middle.
Any threshold in [12, 693] gives the same verdicts.
Diagnostic (b) ||g_L|| at floor:
Healthy min (BP/EP): 1.02e-4
Degenerate max (DFA/SB/CB): 4.18e-9
Separation gap: 24,338x
Default threshold 1e-7 sits comfortably in the middle.
Any threshold in [4.2e-9, 1.0e-4] gives the same verdicts.
Diagnostic (c) cross-batch stability:
NOT a clean binary discriminator across seeds. BP s456=0.114
near threshold; DFA s42=0.047 (noise sub-mode) doesn't fire;
SB s456=0.035 (noise sub-mode) doesn't fire. (c) is for sub-mode
interpretation, not binary detection.
This is the calibration evidence answering the E&D reviewer question
'why these specific thresholds?'.
Diffstat (limited to 'protocol')
| -rw-r--r-- | protocol/examples/threshold_sensitivity.py | 158 |
1 files changed, 158 insertions, 0 deletions
diff --git a/protocol/examples/threshold_sensitivity.py b/protocol/examples/threshold_sensitivity.py new file mode 100644 index 0000000..fdf0f6c --- /dev/null +++ b/protocol/examples/threshold_sensitivity.py @@ -0,0 +1,158 @@ +""" +Threshold sensitivity analysis: how does the protocol's verdict change as +the diagnostic thresholds are varied? E&D reviewers will reasonably ask +"what makes 50× per-block growth and 1e-7 g-norm-floor the right cutoffs?" + +This script sweeps each threshold independently across two orders of +magnitude and reports, for the 5-method audit table: + - the verdict at each threshold + - the boundary value at which verdicts change for each method + +If a method's verdict is robust to a wide range of thresholds, the +protocol is well-calibrated for that method. If verdicts flip near the +chosen threshold, the calibration is fragile and the threshold needs +explicit defense in the paper. + +Run: + python -m protocol.examples.threshold_sensitivity +""" +import os +import sys +import json + +REPO_ROOT = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) +sys.path.insert(0, REPO_ROOT) + +AUDIT_PATH = os.path.join(REPO_ROOT, "results/protocol_audit/audit_table_s42_s123_s456.json") + + +def max_per_block_growth(h): + if len(h) < 2: + return 1.0 + return max(h[i + 1] / max(h[i], 1e-30) for i in range(len(h) - 1)) + + +def main(): + with open(AUDIT_PATH) as f: + data = json.load(f) + rows = data["summary"] + + # Compute per-row diagnostic raw values + metrics = [] + for r in rows: + report = data["reports"][f"{r['method']}_s{r['seed']}"] + metrics.append({ + "method": r["method"], + "seed": r["seed"], + "max_per_block_growth": max_per_block_growth(report["residual_norms"]), + "g_L": report["bp_grad_norms"][-1], + "stability": report["cross_batch_stability"], + "frozen_acc": r.get("frozen_acc"), + "acc": r["acc"], + }) + + # ----- (a) per-block growth threshold sensitivity ----- # + print("=" * 88) + print("DIAGNOSTIC (a) per-block growth: sensitivity over threshold") + print("=" * 88) + print(f"{'method':<16}{'seed':>6}{'value':>14}", end="") + a_thresholds = [5, 10, 20, 50, 100, 500, 1000, 5000] + for t in a_thresholds: + print(f"{'>'+str(t)+'×':>10}", end="") + print() + for m in metrics: + print(f"{m['method']:<16}{m['seed']:>6}{m['max_per_block_growth']:>14.2e}", end="") + for t in a_thresholds: + fired = "FIRE" if m["max_per_block_growth"] > t else "ok" + print(f"{fired:>10}", end="") + print() + print() + print("Reading: BP/EP rows should be 'ok' across the entire row (the whole") + print("threshold range is healthy for them). DFA/SB/CB rows should be 'FIRE'") + print("at the chosen threshold and have a comfortable margin on either side.") + print() + + # ----- (b) g-norm floor sensitivity ----- # + print("=" * 88) + print("DIAGNOSTIC (b) g-norm floor: sensitivity over threshold") + print("=" * 88) + print(f"{'method':<16}{'seed':>6}{'value':>14}", end="") + b_thresholds = [1e-9, 1e-8, 1e-7, 1e-6, 1e-5] + for t in b_thresholds: + print(f"{'<'+f'{t:.0e}':>10}", end="") + print() + for m in metrics: + print(f"{m['method']:<16}{m['seed']:>6}{m['g_L']:>14.2e}", end="") + for t in b_thresholds: + fired = "FIRE" if m["g_L"] < t else "ok" + print(f"{fired:>10}", end="") + print() + print() + + # ----- (c) stability ceiling sensitivity ----- # + print("=" * 88) + print("DIAGNOSTIC (c) stability ceiling: sensitivity over threshold") + print("=" * 88) + print(f"{'method':<16}{'seed':>6}{'value':>14}", end="") + c_thresholds = [0.10, 0.20, 0.30, 0.50, 0.70, 0.90] + for t in c_thresholds: + print(f"{'>'+f'{t:.2f}':>10}", end="") + print() + for m in metrics: + print(f"{m['method']:<16}{m['seed']:>6}{m['stability']:>14.3f}", end="") + for t in c_thresholds: + fired = "FIRE" if m["stability"] > t else "ok" + print(f"{fired:>10}", end="") + print() + print() + + # ----- Aggregate verdict robustness ----- # + print("=" * 88) + print("VERDICT ROBUSTNESS: at what threshold does each verdict CHANGE?") + print("=" * 88) + # For each method × seed, find the (a)-threshold where the verdict flips + print(f"{'method':<16}{'seed':>6}{'(a) value':>14}{'(a) flip near':>16}{'(b) value':>14}{'(b) flip near':>16}") + for m in metrics: + # Find the largest threshold at which it still fires (i.e. where + # the verdict would change if you raised the threshold past that) + flip_a = m["max_per_block_growth"] + flip_b = m["g_L"] + print(f"{m['method']:<16}{m['seed']:>6}{m['max_per_block_growth']:>14.2e}{flip_a:>16.2e}{m['g_L']:>14.2e}{flip_b:>16.2e}") + print() + print("Interpretation:") + print(" - For DFA/SB/CB, the 'flip near' values are the diagnostic raw values.") + print(" Default (a) threshold 50 catches all if raw values > 50; default (b)") + print(" threshold 1e-7 catches all if raw values < 1e-7. Compare:") + bp_a = max(m["max_per_block_growth"] for m in metrics if m["method"] == "bp") + ep_a = max(m["max_per_block_growth"] for m in metrics if m["method"] == "ep") + dfa_a = min(m["max_per_block_growth"] for m in metrics if m["method"] == "dfa") + sb_a = min(m["max_per_block_growth"] for m in metrics if m["method"] == "state_bridge") + cb_a = min(m["max_per_block_growth"] for m in metrics if m["method"] == "credit_bridge") + bp_b = min(m["g_L"] for m in metrics if m["method"] == "bp") + ep_b = min(m["g_L"] for m in metrics if m["method"] == "ep") + dfa_b = max(m["g_L"] for m in metrics if m["method"] == "dfa") + sb_b = max(m["g_L"] for m in metrics if m["method"] == "state_bridge") + cb_b = max(m["g_L"] for m in metrics if m["method"] == "credit_bridge") + print(f" (a) max BP per-block growth across 3 seeds: {bp_a:.2e}") + print(f" (a) max EP per-block growth across 3 seeds: {ep_a:.2e}") + print(f" (a) min DFA per-block growth across 3 seeds: {dfa_a:.2e}") + print(f" (a) min SB per-block growth across 3 seeds: {sb_a:.2e}") + print(f" (a) min CB per-block growth across 3 seeds: {cb_a:.2e}") + print(f" -> separation gap: healthy max = {max(bp_a, ep_a):.2e},") + print(f" degenerate min = {min(dfa_a, sb_a, cb_a):.2e},") + print(f" gap factor = {min(dfa_a, sb_a, cb_a) / max(bp_a, ep_a):.0f}×") + print() + print(f" (b) min BP ‖g_L‖ across 3 seeds: {bp_b:.2e}") + print(f" (b) min EP ‖g_L‖ across 3 seeds: {ep_b:.2e}") + print(f" (b) max DFA ‖g_L‖ across 3 seeds: {dfa_b:.2e}") + print(f" (b) max SB ‖g_L‖ across 3 seeds: {sb_b:.2e}") + print(f" (b) max CB ‖g_L‖ across 3 seeds: {cb_b:.2e}") + print(f" -> separation gap: healthy min = {min(bp_b, ep_b):.2e},") + print(f" degenerate max = {max(dfa_b, sb_b, cb_b):.2e},") + print(f" gap factor = {min(bp_b, ep_b) / max(dfa_b, sb_b, cb_b):.0f}×") + + +if __name__ == "__main__": + main() |
