summaryrefslogtreecommitdiff
path: root/protocol/examples/ablation_decision_utility.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 23:00:54 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 23:00:54 -0500
commit31ddecc9eb646b15c4ac5960c7de9346c8f7be68 (patch)
treeeb3d7784aa24dbcd0aca348c0239df609ba3fbf5 /protocol/examples/ablation_decision_utility.py
parentede7cca3e4f9048e3fc6d99077f8842e9b598ff4 (diff)
Protocol diagnostic (a): use max per-block growth, not max/min ratio
Old metric: max(||h||) / max(||h_0||, eps). False-positives on ViT-style architectures because the cls token at layer 0 (right after patch_embed) has anomalously small magnitude (~0.3-1.5), inflating the ratio even on healthy BP-trained ViTs. New metric: max_l(||h_{l+1}|| / ||h_l||) — the largest single-block residual amplification. Architecture-invariant. Calibration: - BP-trained, late training: <5x per block - BP ViT, early epochs (cls token resolving): 13-25x max - DFA-trained ResMLP/ViT: 100-4000x per block Threshold raised from 10 to 50 to sit cleanly between healthy-early- training (max 25) and failure-regime (min 100). Re-verifications: - smoke test (BP/DFA/EP): all 3 verdicts unchanged - random init (3 seeds): trustworthy on all 3 - 5-method audit table single-seed: identical verdicts - decision-utility ablation: identical (still 0/5 by S1, 3/5 by S_full) - temporal evolution 3-seed: (b) now fires first at ep 3-4, (a) at ep 8-11. Both well before training ends. The 'protocol fires ~92 epochs early' story still holds. - ViT temporal evolution: BP no longer false-fires; DFA fires (a) ep 1, (b) ep 3 — protocol works on the second architecture.
Diffstat (limited to 'protocol/examples/ablation_decision_utility.py')
-rw-r--r--protocol/examples/ablation_decision_utility.py23
1 files changed, 20 insertions, 3 deletions
diff --git a/protocol/examples/ablation_decision_utility.py b/protocol/examples/ablation_decision_utility.py
index df82b7e..17aa6c5 100644
--- a/protocol/examples/ablation_decision_utility.py
+++ b/protocol/examples/ablation_decision_utility.py
@@ -40,10 +40,16 @@ def gamma_proxy(method: str, g_norm: float) -> str:
return "no walk-back (looks fine)"
+def max_per_block_growth(h):
+ if len(h) < 2:
+ return 1.0
+ return max(h[i + 1] / max(h[i], 1e-30) for i in range(len(h) - 1))
+
+
def evaluate_strategy(strategy: str, method: str, report: dict, headline_acc: float) -> str:
"""Return whether the strategy would have walked back the claim, and why."""
h_exploded = (
- max(report["residual_norms"]) / max(report["residual_norms"][0], 1e-30)
+ max_per_block_growth(report["residual_norms"])
> report["thresholds"]["h_norm_explosion_ratio"]
)
g_at_floor = report["bp_grad_norms"][-1] < report["thresholds"]["g_norm_floor"]
@@ -78,7 +84,11 @@ def main():
data = json.load(f)
methods = ["bp", "dfa", "state_bridge", "credit_bridge", "ep"]
- method_acc = {row["method"]: row["acc"] for row in data["summary"]}
+ method_acc = {}
+ for row in data["summary"]:
+ # Use seed-42 row for the ablation; the table is single-seed by design
+ if row.get("seed", 42) == 42:
+ method_acc[row["method"]] = row["acc"]
strategies = ["S0", "S1", "S2", "S3", "S4", "S5", "S_full"]
strategy_label = {
@@ -95,9 +105,16 @@ def main():
print("Protocol decision-utility ablation (4-block d=256 ResMLP, CIFAR-10, seed 42)")
print("=" * 100)
+ # Audit JSON keys reports as either "bp" (legacy) or "bp_s42" (current).
table = {}
for method in methods:
- report = data["reports"][method]
+ if method in data["reports"]:
+ report = data["reports"][method]
+ elif f"{method}_s42" in data["reports"]:
+ report = data["reports"][f"{method}_s42"]
+ else:
+ print(f" SKIPPED (no report): {method}")
+ continue
acc = method_acc[method]
table[method] = {}
for s in strategies: