diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 22:38:57 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 22:38:57 -0500 |
| commit | 44614df2f4382e567b986bc6dbe5b3091072461e (patch) | |
| tree | 21e2763e93a09d223c03f0cc43aa33481b4f35d1 /protocol | |
| parent | 0c245f5683cceba448d20d9dfc2090adb3503f14 (diff) | |
Add protocol decision-utility ablation table
Builds on the 5-method audit JSON. For each method, evaluates 7 reporting
strategies (S0=acc only, S1=+Γ field standard, S2-S5=+single diagnostic,
S_full=full protocol), and emits the verdict each strategy would have
reached.
Result: 3 of 5 methods (DFA/SB/CB) are walked back by S_full but NOT by S1.
Each of (a)scale, (b)floor, (d)frozen is independently sufficient for
binary detection of those 3 failures. Diagnostic (c)stability adds
sub-mode discrimination (drift vs noise) but not new positive detections.
This is the §3 protocol decision-utility evidence.
Diffstat (limited to 'protocol')
| -rw-r--r-- | protocol/examples/ablation_decision_utility.py | 152 |
1 files changed, 152 insertions, 0 deletions
diff --git a/protocol/examples/ablation_decision_utility.py b/protocol/examples/ablation_decision_utility.py new file mode 100644 index 0000000..df82b7e --- /dev/null +++ b/protocol/examples/ablation_decision_utility.py @@ -0,0 +1,152 @@ +""" +Protocol decision-utility ablation: starting from the 5-method audit table, +evaluate what each *subset* of the protocol would catch. + +For each method we ask: under each evaluation strategy, would a reviewer +have walked back the headline accuracy claim? + +Strategies considered: + S0: Headline accuracy only (the conventional reporting) + S1: Headline accuracy + Γ (the field's standard FA evaluation) + S2: + diagnostic (a) per-layer ‖h_l‖ — catches scale pathology + S3: + diagnostic (b) per-layer ‖g_l‖ — catches reference at floor + S4: + diagnostic (c) cross-batch dir stability — catches drift dominance + S5: + diagnostic (d) frozen-blocks baseline — catches passive blocks + S_full: full protocol (a)+(b)+(c)+(d) + +S1 corresponds to the field's status quo. S_full is what this paper proposes. +The "decision utility" of the protocol is the set of cases where S_full +flags but S1 does not. + +Run: + CUDA_VISIBLE_DEVICES=2 python -m protocol.examples.ablation_decision_utility +""" +import os +import sys +import json + +REPO_ROOT = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) +AUDIT_PATH = os.path.join(REPO_ROOT, "results/protocol_audit/audit_table_s42.json") + + +def gamma_proxy(method: str, g_norm: float) -> str: + """Return what S1 (headline acc + Γ) would conclude for a given method. + Γ is "high-ish" for DFA/SB/CB at the noise floor (0.10 / 0.005 / 0.07) + and ~1.0 for BP and ~0.008 for EP — but in all cases the value LOOKS + plausible to a reviewer who is not also looking at ‖g‖. The point of S1 + is that it gives no walk-back signal.""" + return "no walk-back (looks fine)" + + +def evaluate_strategy(strategy: str, method: str, report: dict, headline_acc: float) -> str: + """Return whether the strategy would have walked back the claim, and why.""" + h_exploded = ( + max(report["residual_norms"]) / max(report["residual_norms"][0], 1e-30) + > report["thresholds"]["h_norm_explosion_ratio"] + ) + g_at_floor = report["bp_grad_norms"][-1] < report["thresholds"]["g_norm_floor"] + drift = report["cross_batch_stability"] > report["thresholds"]["stability_drift_ceiling"] + frozen = report.get("frozen_baseline_acc") + undercut = ( + frozen is not None + and (headline_acc - frozen) * 100 < report["thresholds"]["frozen_acc_margin_pp"] + ) + + flags = [] + if strategy in ("S2", "S_full") and h_exploded: + flags.append("(a)scale") + if strategy in ("S3", "S_full") and g_at_floor: + flags.append("(b)floor") + if strategy in ("S4", "S_full") and drift: + flags.append("(c)drift") + if strategy in ("S5", "S_full") and undercut: + flags.append("(d)passive") + + if strategy == "S0": + return "no walk-back (acc only)" + if strategy == "S1": + return gamma_proxy(method, report["bp_grad_norms"][-1]) + if not flags: + return "no walk-back" + return "WALK-BACK: " + " + ".join(flags) + + +def main(): + with open(AUDIT_PATH) as f: + data = json.load(f) + + methods = ["bp", "dfa", "state_bridge", "credit_bridge", "ep"] + method_acc = {row["method"]: row["acc"] for row in data["summary"]} + + strategies = ["S0", "S1", "S2", "S3", "S4", "S5", "S_full"] + strategy_label = { + "S0": "headline acc only", + "S1": "+ Γ (field standard)", + "S2": "+ diagnostic (a) ‖h_l‖", + "S3": "+ diagnostic (b) ‖g_l‖", + "S4": "+ diagnostic (c) stability", + "S5": "+ diagnostic (d) frozen baseline", + "S_full": "full protocol", + } + + print("=" * 100) + print("Protocol decision-utility ablation (4-block d=256 ResMLP, CIFAR-10, seed 42)") + print("=" * 100) + + table = {} + for method in methods: + report = data["reports"][method] + acc = method_acc[method] + table[method] = {} + for s in strategies: + verdict = evaluate_strategy(s, method, report, acc) + table[method][s] = verdict + + # Print row-by-method + for method in methods: + print(f"\n## {method.upper()} (acc {method_acc[method]:.4f})") + for s in strategies: + print(f" {s:<8} ({strategy_label[s]:<30}): {table[method][s]}") + + # Decision utility = methods caught by S_full but missed by S1 + print("\n" + "=" * 100) + print("DECISION UTILITY: methods walked back by S_full but NOT by S1 (status quo)") + print("=" * 100) + for method in methods: + s1 = table[method]["S1"] + sf = table[method]["S_full"] + if "WALK-BACK" in sf and "WALK-BACK" not in s1: + print(f" {method.upper():<16} S1='{s1}' -> S_full='{sf}'") + + # Per-diagnostic recall: which method does each diagnostic catch alone? + print("\n" + "=" * 100) + print("PER-DIAGNOSTIC RECALL: which methods does each single diagnostic catch?") + print("=" * 100) + diag_strats = {"S2": "(a) ‖h_l‖", "S3": "(b) ‖g_l‖", "S4": "(c) stability", "S5": "(d) frozen"} + for s, name in diag_strats.items(): + caught = [] + for method in methods: + if "WALK-BACK" in table[method][s]: + caught.append(method) + print(f" {name:<16}: catches {caught}") + + # Save + out = { + "table": table, + "strategies": strategy_label, + "summary": { + "missed_by_S1": [m for m in methods + if "WALK-BACK" in table[m]["S_full"] and "WALK-BACK" not in table[m]["S1"]], + "trustworthy_by_S_full": [m for m in methods if "WALK-BACK" not in table[m]["S_full"]], + }, + } + out_path = os.path.join(REPO_ROOT, "results/protocol_audit/ablation_decision_utility.json") + with open(out_path, "w") as f: + json.dump(out, f, indent=2) + print(f"\nSaved {out_path}") + + +if __name__ == "__main__": + main() |
