diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 23:00:54 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 23:00:54 -0500 |
| commit | 31ddecc9eb646b15c4ac5960c7de9346c8f7be68 (patch) | |
| tree | eb3d7784aa24dbcd0aca348c0239df609ba3fbf5 /protocol/examples/ablation_decision_utility.py | |
| parent | ede7cca3e4f9048e3fc6d99077f8842e9b598ff4 (diff) | |
Protocol diagnostic (a): use max per-block growth, not max/min ratio
Old metric: max(||h||) / max(||h_0||, eps). False-positives on ViT-style
architectures because the cls token at layer 0 (right after patch_embed)
has anomalously small magnitude (~0.3-1.5), inflating the ratio even on
healthy BP-trained ViTs.
New metric: max_l(||h_{l+1}|| / ||h_l||) — the largest single-block
residual amplification. Architecture-invariant.
Calibration:
- BP-trained, late training: <5x per block
- BP ViT, early epochs (cls token resolving): 13-25x max
- DFA-trained ResMLP/ViT: 100-4000x per block
Threshold raised from 10 to 50 to sit cleanly between healthy-early-
training (max 25) and failure-regime (min 100).
Re-verifications:
- smoke test (BP/DFA/EP): all 3 verdicts unchanged
- random init (3 seeds): trustworthy on all 3
- 5-method audit table single-seed: identical verdicts
- decision-utility ablation: identical (still 0/5 by S1, 3/5 by S_full)
- temporal evolution 3-seed: (b) now fires first at ep 3-4, (a) at ep
8-11. Both well before training ends. The 'protocol fires ~92 epochs
early' story still holds.
- ViT temporal evolution: BP no longer false-fires; DFA fires (a) ep 1,
(b) ep 3 — protocol works on the second architecture.
Diffstat (limited to 'protocol/examples/ablation_decision_utility.py')
| -rw-r--r-- | protocol/examples/ablation_decision_utility.py | 23 |
1 files changed, 20 insertions, 3 deletions
diff --git a/protocol/examples/ablation_decision_utility.py b/protocol/examples/ablation_decision_utility.py index df82b7e..17aa6c5 100644 --- a/protocol/examples/ablation_decision_utility.py +++ b/protocol/examples/ablation_decision_utility.py @@ -40,10 +40,16 @@ def gamma_proxy(method: str, g_norm: float) -> str: return "no walk-back (looks fine)" +def max_per_block_growth(h): + if len(h) < 2: + return 1.0 + return max(h[i + 1] / max(h[i], 1e-30) for i in range(len(h) - 1)) + + def evaluate_strategy(strategy: str, method: str, report: dict, headline_acc: float) -> str: """Return whether the strategy would have walked back the claim, and why.""" h_exploded = ( - max(report["residual_norms"]) / max(report["residual_norms"][0], 1e-30) + max_per_block_growth(report["residual_norms"]) > report["thresholds"]["h_norm_explosion_ratio"] ) g_at_floor = report["bp_grad_norms"][-1] < report["thresholds"]["g_norm_floor"] @@ -78,7 +84,11 @@ def main(): data = json.load(f) methods = ["bp", "dfa", "state_bridge", "credit_bridge", "ep"] - method_acc = {row["method"]: row["acc"] for row in data["summary"]} + method_acc = {} + for row in data["summary"]: + # Use seed-42 row for the ablation; the table is single-seed by design + if row.get("seed", 42) == 42: + method_acc[row["method"]] = row["acc"] strategies = ["S0", "S1", "S2", "S3", "S4", "S5", "S_full"] strategy_label = { @@ -95,9 +105,16 @@ def main(): print("Protocol decision-utility ablation (4-block d=256 ResMLP, CIFAR-10, seed 42)") print("=" * 100) + # Audit JSON keys reports as either "bp" (legacy) or "bp_s42" (current). table = {} for method in methods: - report = data["reports"][method] + if method in data["reports"]: + report = data["reports"][method] + elif f"{method}_s42" in data["reports"]: + report = data["reports"][f"{method}_s42"] + else: + print(f" SKIPPED (no report): {method}") + continue acc = method_acc[method] table[method] = {} for s in strategies: |
