summaryrefslogtreecommitdiff
path: root/protocol
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 23:12:50 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 23:12:50 -0500
commit0886902ab7decd7702c9764e8fb2c3e10a528e45 (patch)
tree82dad1b7154f2b2541aa5afdc10afbeca22aa053 /protocol
parent89fff0048c04bdc4c8beb6d11f8d5564d75cbb0c (diff)
Add threshold sensitivity analysis: (a) 63x gap, (b) 24338x gap
For each diagnostic, sweeps threshold across orders of magnitude on the 3-seed audit data and reports the verdict at each value. Key calibration findings (3 seeds): Diagnostic (a) max per-block growth: Healthy max (BP/EP): 11.0 Degenerate min (DFA/SB/CB): 694 Separation gap: 63x Default threshold 50 sits comfortably in the middle. Any threshold in [12, 693] gives the same verdicts. Diagnostic (b) ||g_L|| at floor: Healthy min (BP/EP): 1.02e-4 Degenerate max (DFA/SB/CB): 4.18e-9 Separation gap: 24,338x Default threshold 1e-7 sits comfortably in the middle. Any threshold in [4.2e-9, 1.0e-4] gives the same verdicts. Diagnostic (c) cross-batch stability: NOT a clean binary discriminator across seeds. BP s456=0.114 near threshold; DFA s42=0.047 (noise sub-mode) doesn't fire; SB s456=0.035 (noise sub-mode) doesn't fire. (c) is for sub-mode interpretation, not binary detection. This is the calibration evidence answering the E&D reviewer question 'why these specific thresholds?'.
Diffstat (limited to 'protocol')
-rw-r--r--protocol/examples/threshold_sensitivity.py158
1 files changed, 158 insertions, 0 deletions
diff --git a/protocol/examples/threshold_sensitivity.py b/protocol/examples/threshold_sensitivity.py
new file mode 100644
index 0000000..fdf0f6c
--- /dev/null
+++ b/protocol/examples/threshold_sensitivity.py
@@ -0,0 +1,158 @@
+"""
+Threshold sensitivity analysis: how does the protocol's verdict change as
+the diagnostic thresholds are varied? E&D reviewers will reasonably ask
+"what makes 50× per-block growth and 1e-7 g-norm-floor the right cutoffs?"
+
+This script sweeps each threshold independently across two orders of
+magnitude and reports, for the 5-method audit table:
+ - the verdict at each threshold
+ - the boundary value at which verdicts change for each method
+
+If a method's verdict is robust to a wide range of thresholds, the
+protocol is well-calibrated for that method. If verdicts flip near the
+chosen threshold, the calibration is fragile and the threshold needs
+explicit defense in the paper.
+
+Run:
+ python -m protocol.examples.threshold_sensitivity
+"""
+import os
+import sys
+import json
+
+REPO_ROOT = os.path.dirname(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+)
+sys.path.insert(0, REPO_ROOT)
+
+AUDIT_PATH = os.path.join(REPO_ROOT, "results/protocol_audit/audit_table_s42_s123_s456.json")
+
+
+def max_per_block_growth(h):
+ if len(h) < 2:
+ return 1.0
+ return max(h[i + 1] / max(h[i], 1e-30) for i in range(len(h) - 1))
+
+
+def main():
+ with open(AUDIT_PATH) as f:
+ data = json.load(f)
+ rows = data["summary"]
+
+ # Compute per-row diagnostic raw values
+ metrics = []
+ for r in rows:
+ report = data["reports"][f"{r['method']}_s{r['seed']}"]
+ metrics.append({
+ "method": r["method"],
+ "seed": r["seed"],
+ "max_per_block_growth": max_per_block_growth(report["residual_norms"]),
+ "g_L": report["bp_grad_norms"][-1],
+ "stability": report["cross_batch_stability"],
+ "frozen_acc": r.get("frozen_acc"),
+ "acc": r["acc"],
+ })
+
+ # ----- (a) per-block growth threshold sensitivity ----- #
+ print("=" * 88)
+ print("DIAGNOSTIC (a) per-block growth: sensitivity over threshold")
+ print("=" * 88)
+ print(f"{'method':<16}{'seed':>6}{'value':>14}", end="")
+ a_thresholds = [5, 10, 20, 50, 100, 500, 1000, 5000]
+ for t in a_thresholds:
+ print(f"{'>'+str(t)+'×':>10}", end="")
+ print()
+ for m in metrics:
+ print(f"{m['method']:<16}{m['seed']:>6}{m['max_per_block_growth']:>14.2e}", end="")
+ for t in a_thresholds:
+ fired = "FIRE" if m["max_per_block_growth"] > t else "ok"
+ print(f"{fired:>10}", end="")
+ print()
+ print()
+ print("Reading: BP/EP rows should be 'ok' across the entire row (the whole")
+ print("threshold range is healthy for them). DFA/SB/CB rows should be 'FIRE'")
+ print("at the chosen threshold and have a comfortable margin on either side.")
+ print()
+
+ # ----- (b) g-norm floor sensitivity ----- #
+ print("=" * 88)
+ print("DIAGNOSTIC (b) g-norm floor: sensitivity over threshold")
+ print("=" * 88)
+ print(f"{'method':<16}{'seed':>6}{'value':>14}", end="")
+ b_thresholds = [1e-9, 1e-8, 1e-7, 1e-6, 1e-5]
+ for t in b_thresholds:
+ print(f"{'<'+f'{t:.0e}':>10}", end="")
+ print()
+ for m in metrics:
+ print(f"{m['method']:<16}{m['seed']:>6}{m['g_L']:>14.2e}", end="")
+ for t in b_thresholds:
+ fired = "FIRE" if m["g_L"] < t else "ok"
+ print(f"{fired:>10}", end="")
+ print()
+ print()
+
+ # ----- (c) stability ceiling sensitivity ----- #
+ print("=" * 88)
+ print("DIAGNOSTIC (c) stability ceiling: sensitivity over threshold")
+ print("=" * 88)
+ print(f"{'method':<16}{'seed':>6}{'value':>14}", end="")
+ c_thresholds = [0.10, 0.20, 0.30, 0.50, 0.70, 0.90]
+ for t in c_thresholds:
+ print(f"{'>'+f'{t:.2f}':>10}", end="")
+ print()
+ for m in metrics:
+ print(f"{m['method']:<16}{m['seed']:>6}{m['stability']:>14.3f}", end="")
+ for t in c_thresholds:
+ fired = "FIRE" if m["stability"] > t else "ok"
+ print(f"{fired:>10}", end="")
+ print()
+ print()
+
+ # ----- Aggregate verdict robustness ----- #
+ print("=" * 88)
+ print("VERDICT ROBUSTNESS: at what threshold does each verdict CHANGE?")
+ print("=" * 88)
+ # For each method × seed, find the (a)-threshold where the verdict flips
+ print(f"{'method':<16}{'seed':>6}{'(a) value':>14}{'(a) flip near':>16}{'(b) value':>14}{'(b) flip near':>16}")
+ for m in metrics:
+ # Find the largest threshold at which it still fires (i.e. where
+ # the verdict would change if you raised the threshold past that)
+ flip_a = m["max_per_block_growth"]
+ flip_b = m["g_L"]
+ print(f"{m['method']:<16}{m['seed']:>6}{m['max_per_block_growth']:>14.2e}{flip_a:>16.2e}{m['g_L']:>14.2e}{flip_b:>16.2e}")
+ print()
+ print("Interpretation:")
+ print(" - For DFA/SB/CB, the 'flip near' values are the diagnostic raw values.")
+ print(" Default (a) threshold 50 catches all if raw values > 50; default (b)")
+ print(" threshold 1e-7 catches all if raw values < 1e-7. Compare:")
+ bp_a = max(m["max_per_block_growth"] for m in metrics if m["method"] == "bp")
+ ep_a = max(m["max_per_block_growth"] for m in metrics if m["method"] == "ep")
+ dfa_a = min(m["max_per_block_growth"] for m in metrics if m["method"] == "dfa")
+ sb_a = min(m["max_per_block_growth"] for m in metrics if m["method"] == "state_bridge")
+ cb_a = min(m["max_per_block_growth"] for m in metrics if m["method"] == "credit_bridge")
+ bp_b = min(m["g_L"] for m in metrics if m["method"] == "bp")
+ ep_b = min(m["g_L"] for m in metrics if m["method"] == "ep")
+ dfa_b = max(m["g_L"] for m in metrics if m["method"] == "dfa")
+ sb_b = max(m["g_L"] for m in metrics if m["method"] == "state_bridge")
+ cb_b = max(m["g_L"] for m in metrics if m["method"] == "credit_bridge")
+ print(f" (a) max BP per-block growth across 3 seeds: {bp_a:.2e}")
+ print(f" (a) max EP per-block growth across 3 seeds: {ep_a:.2e}")
+ print(f" (a) min DFA per-block growth across 3 seeds: {dfa_a:.2e}")
+ print(f" (a) min SB per-block growth across 3 seeds: {sb_a:.2e}")
+ print(f" (a) min CB per-block growth across 3 seeds: {cb_a:.2e}")
+ print(f" -> separation gap: healthy max = {max(bp_a, ep_a):.2e},")
+ print(f" degenerate min = {min(dfa_a, sb_a, cb_a):.2e},")
+ print(f" gap factor = {min(dfa_a, sb_a, cb_a) / max(bp_a, ep_a):.0f}×")
+ print()
+ print(f" (b) min BP ‖g_L‖ across 3 seeds: {bp_b:.2e}")
+ print(f" (b) min EP ‖g_L‖ across 3 seeds: {ep_b:.2e}")
+ print(f" (b) max DFA ‖g_L‖ across 3 seeds: {dfa_b:.2e}")
+ print(f" (b) max SB ‖g_L‖ across 3 seeds: {sb_b:.2e}")
+ print(f" (b) max CB ‖g_L‖ across 3 seeds: {cb_b:.2e}")
+ print(f" -> separation gap: healthy min = {min(bp_b, ep_b):.2e},")
+ print(f" degenerate max = {max(dfa_b, sb_b, cb_b):.2e},")
+ print(f" gap factor = {min(bp_b, ep_b) / max(dfa_b, sb_b, cb_b):.0f}×")
+
+
+if __name__ == "__main__":
+ main()