""" Threshold sensitivity analysis: how does the protocol's verdict change as the diagnostic thresholds are varied? E&D reviewers will reasonably ask "what makes 50× per-block growth and 1e-7 g-norm-floor the right cutoffs?" This script sweeps each threshold independently across two orders of magnitude and reports, for the 5-method audit table: - the verdict at each threshold - the boundary value at which verdicts change for each method If a method's verdict is robust to a wide range of thresholds, the protocol is well-calibrated for that method. If verdicts flip near the chosen threshold, the calibration is fragile and the threshold needs explicit defense in the paper. Run: python -m protocol.examples.threshold_sensitivity """ import os import sys import json REPO_ROOT = os.path.dirname( os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ) sys.path.insert(0, REPO_ROOT) AUDIT_PATH = os.path.join(REPO_ROOT, "results/protocol_audit/audit_table_s42_s123_s456.json") def max_per_block_growth(h): if len(h) < 2: return 1.0 return max(h[i + 1] / max(h[i], 1e-30) for i in range(len(h) - 1)) def main(): with open(AUDIT_PATH) as f: data = json.load(f) rows = data["summary"] # Compute per-row diagnostic raw values metrics = [] for r in rows: report = data["reports"][f"{r['method']}_s{r['seed']}"] metrics.append({ "method": r["method"], "seed": r["seed"], "max_per_block_growth": max_per_block_growth(report["residual_norms"]), "g_L": report["bp_grad_norms"][-1], "stability": report["cross_batch_stability"], "frozen_acc": r.get("frozen_acc"), "acc": r["acc"], }) # ----- (a) per-block growth threshold sensitivity ----- # print("=" * 88) print("DIAGNOSTIC (a) per-block growth: sensitivity over threshold") print("=" * 88) print(f"{'method':<16}{'seed':>6}{'value':>14}", end="") a_thresholds = [5, 10, 20, 50, 100, 500, 1000, 5000] for t in a_thresholds: print(f"{'>'+str(t)+'×':>10}", end="") print() for m in metrics: print(f"{m['method']:<16}{m['seed']:>6}{m['max_per_block_growth']:>14.2e}", end="") for t in a_thresholds: fired = "FIRE" if m["max_per_block_growth"] > t else "ok" print(f"{fired:>10}", end="") print() print() print("Reading: BP/EP rows should be 'ok' across the entire row (the whole") print("threshold range is healthy for them). DFA/SB/CB rows should be 'FIRE'") print("at the chosen threshold and have a comfortable margin on either side.") print() # ----- (b) g-norm floor sensitivity ----- # print("=" * 88) print("DIAGNOSTIC (b) g-norm floor: sensitivity over threshold") print("=" * 88) print(f"{'method':<16}{'seed':>6}{'value':>14}", end="") b_thresholds = [1e-9, 1e-8, 1e-7, 1e-6, 1e-5] for t in b_thresholds: print(f"{'<'+f'{t:.0e}':>10}", end="") print() for m in metrics: print(f"{m['method']:<16}{m['seed']:>6}{m['g_L']:>14.2e}", end="") for t in b_thresholds: fired = "FIRE" if m["g_L"] < t else "ok" print(f"{fired:>10}", end="") print() print() # ----- (c) stability ceiling sensitivity ----- # print("=" * 88) print("DIAGNOSTIC (c) stability ceiling: sensitivity over threshold") print("=" * 88) print(f"{'method':<16}{'seed':>6}{'value':>14}", end="") c_thresholds = [0.10, 0.20, 0.30, 0.50, 0.70, 0.90] for t in c_thresholds: print(f"{'>'+f'{t:.2f}':>10}", end="") print() for m in metrics: print(f"{m['method']:<16}{m['seed']:>6}{m['stability']:>14.3f}", end="") for t in c_thresholds: fired = "FIRE" if m["stability"] > t else "ok" print(f"{fired:>10}", end="") print() print() # ----- Aggregate verdict robustness ----- # print("=" * 88) print("VERDICT ROBUSTNESS: at what threshold does each verdict CHANGE?") print("=" * 88) # For each method × seed, find the (a)-threshold where the verdict flips print(f"{'method':<16}{'seed':>6}{'(a) value':>14}{'(a) flip near':>16}{'(b) value':>14}{'(b) flip near':>16}") for m in metrics: # Find the largest threshold at which it still fires (i.e. where # the verdict would change if you raised the threshold past that) flip_a = m["max_per_block_growth"] flip_b = m["g_L"] print(f"{m['method']:<16}{m['seed']:>6}{m['max_per_block_growth']:>14.2e}{flip_a:>16.2e}{m['g_L']:>14.2e}{flip_b:>16.2e}") print() print("Interpretation:") print(" - For DFA/SB/CB, the 'flip near' values are the diagnostic raw values.") print(" Default (a) threshold 50 catches all if raw values > 50; default (b)") print(" threshold 1e-7 catches all if raw values < 1e-7. Compare:") bp_a = max(m["max_per_block_growth"] for m in metrics if m["method"] == "bp") ep_a = max(m["max_per_block_growth"] for m in metrics if m["method"] == "ep") dfa_a = min(m["max_per_block_growth"] for m in metrics if m["method"] == "dfa") sb_a = min(m["max_per_block_growth"] for m in metrics if m["method"] == "state_bridge") cb_a = min(m["max_per_block_growth"] for m in metrics if m["method"] == "credit_bridge") bp_b = min(m["g_L"] for m in metrics if m["method"] == "bp") ep_b = min(m["g_L"] for m in metrics if m["method"] == "ep") dfa_b = max(m["g_L"] for m in metrics if m["method"] == "dfa") sb_b = max(m["g_L"] for m in metrics if m["method"] == "state_bridge") cb_b = max(m["g_L"] for m in metrics if m["method"] == "credit_bridge") print(f" (a) max BP per-block growth across 3 seeds: {bp_a:.2e}") print(f" (a) max EP per-block growth across 3 seeds: {ep_a:.2e}") print(f" (a) min DFA per-block growth across 3 seeds: {dfa_a:.2e}") print(f" (a) min SB per-block growth across 3 seeds: {sb_a:.2e}") print(f" (a) min CB per-block growth across 3 seeds: {cb_a:.2e}") print(f" -> separation gap: healthy max = {max(bp_a, ep_a):.2e},") print(f" degenerate min = {min(dfa_a, sb_a, cb_a):.2e},") print(f" gap factor = {min(dfa_a, sb_a, cb_a) / max(bp_a, ep_a):.0f}×") print() print(f" (b) min BP ‖g_L‖ across 3 seeds: {bp_b:.2e}") print(f" (b) min EP ‖g_L‖ across 3 seeds: {ep_b:.2e}") print(f" (b) max DFA ‖g_L‖ across 3 seeds: {dfa_b:.2e}") print(f" (b) max SB ‖g_L‖ across 3 seeds: {sb_b:.2e}") print(f" (b) max CB ‖g_L‖ across 3 seeds: {cb_b:.2e}") print(f" -> separation gap: healthy min = {min(bp_b, ep_b):.2e},") print(f" degenerate max = {max(dfa_b, sb_b, cb_b):.2e},") print(f" gap factor = {min(bp_b, ep_b) / max(dfa_b, sb_b, cb_b):.0f}×") if __name__ == "__main__": main()