summaryrefslogtreecommitdiff
path: root/protocol
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 22:38:57 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 22:38:57 -0500
commit44614df2f4382e567b986bc6dbe5b3091072461e (patch)
tree21e2763e93a09d223c03f0cc43aa33481b4f35d1 /protocol
parent0c245f5683cceba448d20d9dfc2090adb3503f14 (diff)
Add protocol decision-utility ablation table
Builds on the 5-method audit JSON. For each method, evaluates 7 reporting strategies (S0=acc only, S1=+Γ field standard, S2-S5=+single diagnostic, S_full=full protocol), and emits the verdict each strategy would have reached. Result: 3 of 5 methods (DFA/SB/CB) are walked back by S_full but NOT by S1. Each of (a)scale, (b)floor, (d)frozen is independently sufficient for binary detection of those 3 failures. Diagnostic (c)stability adds sub-mode discrimination (drift vs noise) but not new positive detections. This is the §3 protocol decision-utility evidence.
Diffstat (limited to 'protocol')
-rw-r--r--protocol/examples/ablation_decision_utility.py152
1 files changed, 152 insertions, 0 deletions
diff --git a/protocol/examples/ablation_decision_utility.py b/protocol/examples/ablation_decision_utility.py
new file mode 100644
index 0000000..df82b7e
--- /dev/null
+++ b/protocol/examples/ablation_decision_utility.py
@@ -0,0 +1,152 @@
+"""
+Protocol decision-utility ablation: starting from the 5-method audit table,
+evaluate what each *subset* of the protocol would catch.
+
+For each method we ask: under each evaluation strategy, would a reviewer
+have walked back the headline accuracy claim?
+
+Strategies considered:
+ S0: Headline accuracy only (the conventional reporting)
+ S1: Headline accuracy + Γ (the field's standard FA evaluation)
+ S2: + diagnostic (a) per-layer ‖h_l‖ — catches scale pathology
+ S3: + diagnostic (b) per-layer ‖g_l‖ — catches reference at floor
+ S4: + diagnostic (c) cross-batch dir stability — catches drift dominance
+ S5: + diagnostic (d) frozen-blocks baseline — catches passive blocks
+ S_full: full protocol (a)+(b)+(c)+(d)
+
+S1 corresponds to the field's status quo. S_full is what this paper proposes.
+The "decision utility" of the protocol is the set of cases where S_full
+flags but S1 does not.
+
+Run:
+ CUDA_VISIBLE_DEVICES=2 python -m protocol.examples.ablation_decision_utility
+"""
+import os
+import sys
+import json
+
+REPO_ROOT = os.path.dirname(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+)
+AUDIT_PATH = os.path.join(REPO_ROOT, "results/protocol_audit/audit_table_s42.json")
+
+
+def gamma_proxy(method: str, g_norm: float) -> str:
+ """Return what S1 (headline acc + Γ) would conclude for a given method.
+ Γ is "high-ish" for DFA/SB/CB at the noise floor (0.10 / 0.005 / 0.07)
+ and ~1.0 for BP and ~0.008 for EP — but in all cases the value LOOKS
+ plausible to a reviewer who is not also looking at ‖g‖. The point of S1
+ is that it gives no walk-back signal."""
+ return "no walk-back (looks fine)"
+
+
+def evaluate_strategy(strategy: str, method: str, report: dict, headline_acc: float) -> str:
+ """Return whether the strategy would have walked back the claim, and why."""
+ h_exploded = (
+ max(report["residual_norms"]) / max(report["residual_norms"][0], 1e-30)
+ > report["thresholds"]["h_norm_explosion_ratio"]
+ )
+ g_at_floor = report["bp_grad_norms"][-1] < report["thresholds"]["g_norm_floor"]
+ drift = report["cross_batch_stability"] > report["thresholds"]["stability_drift_ceiling"]
+ frozen = report.get("frozen_baseline_acc")
+ undercut = (
+ frozen is not None
+ and (headline_acc - frozen) * 100 < report["thresholds"]["frozen_acc_margin_pp"]
+ )
+
+ flags = []
+ if strategy in ("S2", "S_full") and h_exploded:
+ flags.append("(a)scale")
+ if strategy in ("S3", "S_full") and g_at_floor:
+ flags.append("(b)floor")
+ if strategy in ("S4", "S_full") and drift:
+ flags.append("(c)drift")
+ if strategy in ("S5", "S_full") and undercut:
+ flags.append("(d)passive")
+
+ if strategy == "S0":
+ return "no walk-back (acc only)"
+ if strategy == "S1":
+ return gamma_proxy(method, report["bp_grad_norms"][-1])
+ if not flags:
+ return "no walk-back"
+ return "WALK-BACK: " + " + ".join(flags)
+
+
+def main():
+ with open(AUDIT_PATH) as f:
+ data = json.load(f)
+
+ methods = ["bp", "dfa", "state_bridge", "credit_bridge", "ep"]
+ method_acc = {row["method"]: row["acc"] for row in data["summary"]}
+
+ strategies = ["S0", "S1", "S2", "S3", "S4", "S5", "S_full"]
+ strategy_label = {
+ "S0": "headline acc only",
+ "S1": "+ Γ (field standard)",
+ "S2": "+ diagnostic (a) ‖h_l‖",
+ "S3": "+ diagnostic (b) ‖g_l‖",
+ "S4": "+ diagnostic (c) stability",
+ "S5": "+ diagnostic (d) frozen baseline",
+ "S_full": "full protocol",
+ }
+
+ print("=" * 100)
+ print("Protocol decision-utility ablation (4-block d=256 ResMLP, CIFAR-10, seed 42)")
+ print("=" * 100)
+
+ table = {}
+ for method in methods:
+ report = data["reports"][method]
+ acc = method_acc[method]
+ table[method] = {}
+ for s in strategies:
+ verdict = evaluate_strategy(s, method, report, acc)
+ table[method][s] = verdict
+
+ # Print row-by-method
+ for method in methods:
+ print(f"\n## {method.upper()} (acc {method_acc[method]:.4f})")
+ for s in strategies:
+ print(f" {s:<8} ({strategy_label[s]:<30}): {table[method][s]}")
+
+ # Decision utility = methods caught by S_full but missed by S1
+ print("\n" + "=" * 100)
+ print("DECISION UTILITY: methods walked back by S_full but NOT by S1 (status quo)")
+ print("=" * 100)
+ for method in methods:
+ s1 = table[method]["S1"]
+ sf = table[method]["S_full"]
+ if "WALK-BACK" in sf and "WALK-BACK" not in s1:
+ print(f" {method.upper():<16} S1='{s1}' -> S_full='{sf}'")
+
+ # Per-diagnostic recall: which method does each diagnostic catch alone?
+ print("\n" + "=" * 100)
+ print("PER-DIAGNOSTIC RECALL: which methods does each single diagnostic catch?")
+ print("=" * 100)
+ diag_strats = {"S2": "(a) ‖h_l‖", "S3": "(b) ‖g_l‖", "S4": "(c) stability", "S5": "(d) frozen"}
+ for s, name in diag_strats.items():
+ caught = []
+ for method in methods:
+ if "WALK-BACK" in table[method][s]:
+ caught.append(method)
+ print(f" {name:<16}: catches {caught}")
+
+ # Save
+ out = {
+ "table": table,
+ "strategies": strategy_label,
+ "summary": {
+ "missed_by_S1": [m for m in methods
+ if "WALK-BACK" in table[m]["S_full"] and "WALK-BACK" not in table[m]["S1"]],
+ "trustworthy_by_S_full": [m for m in methods if "WALK-BACK" not in table[m]["S_full"]],
+ },
+ }
+ out_path = os.path.join(REPO_ROOT, "results/protocol_audit/ablation_decision_utility.json")
+ with open(out_path, "w") as f:
+ json.dump(out, f, indent=2)
+ print(f"\nSaved {out_path}")
+
+
+if __name__ == "__main__":
+ main()