summaryrefslogtreecommitdiff
path: root/protocol/examples/ablation_decision_utility.py
diff options
context:
space:
mode:
Diffstat (limited to 'protocol/examples/ablation_decision_utility.py')
-rw-r--r--protocol/examples/ablation_decision_utility.py23
1 files changed, 20 insertions, 3 deletions
diff --git a/protocol/examples/ablation_decision_utility.py b/protocol/examples/ablation_decision_utility.py
index df82b7e..17aa6c5 100644
--- a/protocol/examples/ablation_decision_utility.py
+++ b/protocol/examples/ablation_decision_utility.py
@@ -40,10 +40,16 @@ def gamma_proxy(method: str, g_norm: float) -> str:
return "no walk-back (looks fine)"
+def max_per_block_growth(h):
+ if len(h) < 2:
+ return 1.0
+ return max(h[i + 1] / max(h[i], 1e-30) for i in range(len(h) - 1))
+
+
def evaluate_strategy(strategy: str, method: str, report: dict, headline_acc: float) -> str:
"""Return whether the strategy would have walked back the claim, and why."""
h_exploded = (
- max(report["residual_norms"]) / max(report["residual_norms"][0], 1e-30)
+ max_per_block_growth(report["residual_norms"])
> report["thresholds"]["h_norm_explosion_ratio"]
)
g_at_floor = report["bp_grad_norms"][-1] < report["thresholds"]["g_norm_floor"]
@@ -78,7 +84,11 @@ def main():
data = json.load(f)
methods = ["bp", "dfa", "state_bridge", "credit_bridge", "ep"]
- method_acc = {row["method"]: row["acc"] for row in data["summary"]}
+ method_acc = {}
+ for row in data["summary"]:
+ # Use seed-42 row for the ablation; the table is single-seed by design
+ if row.get("seed", 42) == 42:
+ method_acc[row["method"]] = row["acc"]
strategies = ["S0", "S1", "S2", "S3", "S4", "S5", "S_full"]
strategy_label = {
@@ -95,9 +105,16 @@ def main():
print("Protocol decision-utility ablation (4-block d=256 ResMLP, CIFAR-10, seed 42)")
print("=" * 100)
+ # Audit JSON keys reports as either "bp" (legacy) or "bp_s42" (current).
table = {}
for method in methods:
- report = data["reports"][method]
+ if method in data["reports"]:
+ report = data["reports"][method]
+ elif f"{method}_s42" in data["reports"]:
+ report = data["reports"][f"{method}_s42"]
+ else:
+ print(f" SKIPPED (no report): {method}")
+ continue
acc = method_acc[method]
table[method] = {}
for s in strategies: