""" Protocol decision-utility ablation: starting from the 5-method audit table, evaluate what each *subset* of the protocol would catch. For each method we ask: under each evaluation strategy, would a reviewer have walked back the headline accuracy claim? Strategies considered: S0: Headline accuracy only (the conventional reporting) S1: Headline accuracy + Γ (the field's standard FA evaluation) S2: + diagnostic (a) per-layer ‖h_l‖ — catches scale pathology S3: + diagnostic (b) per-layer ‖g_l‖ — catches reference at floor S4: + diagnostic (c) cross-batch dir stability — catches drift dominance S5: + diagnostic (d) frozen-blocks baseline — catches passive blocks S_full: full protocol (a)+(b)+(c)+(d) S1 corresponds to the field's status quo. S_full is what this paper proposes. The "decision utility" of the protocol is the set of cases where S_full flags but S1 does not. Run: CUDA_VISIBLE_DEVICES=2 python -m protocol.examples.ablation_decision_utility """ import os import sys import json REPO_ROOT = os.path.dirname( os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ) AUDIT_PATH = os.path.join(REPO_ROOT, "results/protocol_audit/audit_table_s42.json") def gamma_proxy(method: str, g_norm: float) -> str: """Return what S1 (headline acc + Γ) would conclude for a given method. Γ is "high-ish" for DFA/SB/CB at the noise floor (0.10 / 0.005 / 0.07) and ~1.0 for BP and ~0.008 for EP — but in all cases the value LOOKS plausible to a reviewer who is not also looking at ‖g‖. The point of S1 is that it gives no walk-back signal.""" return "no walk-back (looks fine)" def max_per_block_growth(h): if len(h) < 2: return 1.0 return max(h[i + 1] / max(h[i], 1e-30) for i in range(len(h) - 1)) def evaluate_strategy(strategy: str, method: str, report: dict, headline_acc: float) -> str: """Return whether the strategy would have walked back the claim, and why.""" h_exploded = ( max_per_block_growth(report["residual_norms"]) > report["thresholds"]["h_norm_explosion_ratio"] ) g_at_floor = report["bp_grad_norms"][-1] < report["thresholds"]["g_norm_floor"] drift = report["cross_batch_stability"] > report["thresholds"]["stability_drift_ceiling"] frozen = report.get("frozen_baseline_acc") undercut = ( frozen is not None and (headline_acc - frozen) * 100 < report["thresholds"]["frozen_acc_margin_pp"] ) flags = [] if strategy in ("S2", "S_full") and h_exploded: flags.append("(a)scale") if strategy in ("S3", "S_full") and g_at_floor: flags.append("(b)floor") if strategy in ("S4", "S_full") and drift: flags.append("(c)drift") if strategy in ("S5", "S_full") and undercut: flags.append("(d)passive") if strategy == "S0": return "no walk-back (acc only)" if strategy == "S1": return gamma_proxy(method, report["bp_grad_norms"][-1]) if not flags: return "no walk-back" return "WALK-BACK: " + " + ".join(flags) def main(): with open(AUDIT_PATH) as f: data = json.load(f) methods = ["bp", "dfa", "state_bridge", "credit_bridge", "ep"] method_acc = {} for row in data["summary"]: # Use seed-42 row for the ablation; the table is single-seed by design if row.get("seed", 42) == 42: method_acc[row["method"]] = row["acc"] strategies = ["S0", "S1", "S2", "S3", "S4", "S5", "S_full"] strategy_label = { "S0": "headline acc only", "S1": "+ Γ (field standard)", "S2": "+ diagnostic (a) ‖h_l‖", "S3": "+ diagnostic (b) ‖g_l‖", "S4": "+ diagnostic (c) stability", "S5": "+ diagnostic (d) frozen baseline", "S_full": "full protocol", } print("=" * 100) print("Protocol decision-utility ablation (4-block d=256 ResMLP, CIFAR-10, seed 42)") print("=" * 100) # Audit JSON keys reports as either "bp" (legacy) or "bp_s42" (current). table = {} for method in methods: if method in data["reports"]: report = data["reports"][method] elif f"{method}_s42" in data["reports"]: report = data["reports"][f"{method}_s42"] else: print(f" SKIPPED (no report): {method}") continue acc = method_acc[method] table[method] = {} for s in strategies: verdict = evaluate_strategy(s, method, report, acc) table[method][s] = verdict # Print row-by-method for method in methods: print(f"\n## {method.upper()} (acc {method_acc[method]:.4f})") for s in strategies: print(f" {s:<8} ({strategy_label[s]:<30}): {table[method][s]}") # Decision utility = methods caught by S_full but missed by S1 print("\n" + "=" * 100) print("DECISION UTILITY: methods walked back by S_full but NOT by S1 (status quo)") print("=" * 100) for method in methods: s1 = table[method]["S1"] sf = table[method]["S_full"] if "WALK-BACK" in sf and "WALK-BACK" not in s1: print(f" {method.upper():<16} S1='{s1}' -> S_full='{sf}'") # Per-diagnostic recall: which method does each diagnostic catch alone? print("\n" + "=" * 100) print("PER-DIAGNOSTIC RECALL: which methods does each single diagnostic catch?") print("=" * 100) diag_strats = {"S2": "(a) ‖h_l‖", "S3": "(b) ‖g_l‖", "S4": "(c) stability", "S5": "(d) frozen"} for s, name in diag_strats.items(): caught = [] for method in methods: if "WALK-BACK" in table[method][s]: caught.append(method) print(f" {name:<16}: catches {caught}") # Save out = { "table": table, "strategies": strategy_label, "summary": { "missed_by_S1": [m for m in methods if "WALK-BACK" in table[m]["S_full"] and "WALK-BACK" not in table[m]["S1"]], "trustworthy_by_S_full": [m for m in methods if "WALK-BACK" not in table[m]["S_full"]], }, } out_path = os.path.join(REPO_ROOT, "results/protocol_audit/ablation_decision_utility.json") with open(out_path, "w") as f: json.dump(out, f, indent=2) print(f"\nSaved {out_path}") if __name__ == "__main__": main()