diff options
Diffstat (limited to 'protocol/examples/ablation_decision_utility.py')
| -rw-r--r-- | protocol/examples/ablation_decision_utility.py | 23 |
1 files changed, 20 insertions, 3 deletions
diff --git a/protocol/examples/ablation_decision_utility.py b/protocol/examples/ablation_decision_utility.py index df82b7e..17aa6c5 100644 --- a/protocol/examples/ablation_decision_utility.py +++ b/protocol/examples/ablation_decision_utility.py @@ -40,10 +40,16 @@ def gamma_proxy(method: str, g_norm: float) -> str: return "no walk-back (looks fine)" +def max_per_block_growth(h): + if len(h) < 2: + return 1.0 + return max(h[i + 1] / max(h[i], 1e-30) for i in range(len(h) - 1)) + + def evaluate_strategy(strategy: str, method: str, report: dict, headline_acc: float) -> str: """Return whether the strategy would have walked back the claim, and why.""" h_exploded = ( - max(report["residual_norms"]) / max(report["residual_norms"][0], 1e-30) + max_per_block_growth(report["residual_norms"]) > report["thresholds"]["h_norm_explosion_ratio"] ) g_at_floor = report["bp_grad_norms"][-1] < report["thresholds"]["g_norm_floor"] @@ -78,7 +84,11 @@ def main(): data = json.load(f) methods = ["bp", "dfa", "state_bridge", "credit_bridge", "ep"] - method_acc = {row["method"]: row["acc"] for row in data["summary"]} + method_acc = {} + for row in data["summary"]: + # Use seed-42 row for the ablation; the table is single-seed by design + if row.get("seed", 42) == 42: + method_acc[row["method"]] = row["acc"] strategies = ["S0", "S1", "S2", "S3", "S4", "S5", "S_full"] strategy_label = { @@ -95,9 +105,16 @@ def main(): print("Protocol decision-utility ablation (4-block d=256 ResMLP, CIFAR-10, seed 42)") print("=" * 100) + # Audit JSON keys reports as either "bp" (legacy) or "bp_s42" (current). table = {} for method in methods: - report = data["reports"][method] + if method in data["reports"]: + report = data["reports"][method] + elif f"{method}_s42" in data["reports"]: + report = data["reports"][f"{method}_s42"] + else: + print(f" SKIPPED (no report): {method}") + continue acc = method_acc[method] table[method] = {} for s in strategies: |
