From 31ddecc9eb646b15c4ac5960c7de9346c8f7be68 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Tue, 7 Apr 2026 23:00:54 -0500 Subject: Protocol diagnostic (a): use max per-block growth, not max/min ratio MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Old metric: max(||h||) / max(||h_0||, eps). False-positives on ViT-style architectures because the cls token at layer 0 (right after patch_embed) has anomalously small magnitude (~0.3-1.5), inflating the ratio even on healthy BP-trained ViTs. New metric: max_l(||h_{l+1}|| / ||h_l||) — the largest single-block residual amplification. Architecture-invariant. Calibration: - BP-trained, late training: <5x per block - BP ViT, early epochs (cls token resolving): 13-25x max - DFA-trained ResMLP/ViT: 100-4000x per block Threshold raised from 10 to 50 to sit cleanly between healthy-early- training (max 25) and failure-regime (min 100). Re-verifications: - smoke test (BP/DFA/EP): all 3 verdicts unchanged - random init (3 seeds): trustworthy on all 3 - 5-method audit table single-seed: identical verdicts - decision-utility ablation: identical (still 0/5 by S1, 3/5 by S_full) - temporal evolution 3-seed: (b) now fires first at ep 3-4, (a) at ep 8-11. Both well before training ends. The 'protocol fires ~92 epochs early' story still holds. - ViT temporal evolution: BP no longer false-fires; DFA fires (a) ep 1, (b) ep 3 — protocol works on the second architecture. --- protocol/examples/ablation_decision_utility.py | 23 +++++++++++++-- protocol/examples/temporal_diagnostic_evolution.py | 34 ++++++++++++++++------ 2 files changed, 45 insertions(+), 12 deletions(-) (limited to 'protocol/examples') diff --git a/protocol/examples/ablation_decision_utility.py b/protocol/examples/ablation_decision_utility.py index df82b7e..17aa6c5 100644 --- a/protocol/examples/ablation_decision_utility.py +++ b/protocol/examples/ablation_decision_utility.py @@ -40,10 +40,16 @@ def gamma_proxy(method: str, g_norm: float) -> str: return "no walk-back (looks fine)" +def max_per_block_growth(h): + if len(h) < 2: + return 1.0 + return max(h[i + 1] / max(h[i], 1e-30) for i in range(len(h) - 1)) + + def evaluate_strategy(strategy: str, method: str, report: dict, headline_acc: float) -> str: """Return whether the strategy would have walked back the claim, and why.""" h_exploded = ( - max(report["residual_norms"]) / max(report["residual_norms"][0], 1e-30) + max_per_block_growth(report["residual_norms"]) > report["thresholds"]["h_norm_explosion_ratio"] ) g_at_floor = report["bp_grad_norms"][-1] < report["thresholds"]["g_norm_floor"] @@ -78,7 +84,11 @@ def main(): data = json.load(f) methods = ["bp", "dfa", "state_bridge", "credit_bridge", "ep"] - method_acc = {row["method"]: row["acc"] for row in data["summary"]} + method_acc = {} + for row in data["summary"]: + # Use seed-42 row for the ablation; the table is single-seed by design + if row.get("seed", 42) == 42: + method_acc[row["method"]] = row["acc"] strategies = ["S0", "S1", "S2", "S3", "S4", "S5", "S_full"] strategy_label = { @@ -95,9 +105,16 @@ def main(): print("Protocol decision-utility ablation (4-block d=256 ResMLP, CIFAR-10, seed 42)") print("=" * 100) + # Audit JSON keys reports as either "bp" (legacy) or "bp_s42" (current). table = {} for method in methods: - report = data["reports"][method] + if method in data["reports"]: + report = data["reports"][method] + elif f"{method}_s42" in data["reports"]: + report = data["reports"][f"{method}_s42"] + else: + print(f" SKIPPED (no report): {method}") + continue acc = method_acc[method] table[method] = {} for s in strategies: diff --git a/protocol/examples/temporal_diagnostic_evolution.py b/protocol/examples/temporal_diagnostic_evolution.py index 6a2c042..35cf720 100644 --- a/protocol/examples/temporal_diagnostic_evolution.py +++ b/protocol/examples/temporal_diagnostic_evolution.py @@ -39,10 +39,16 @@ from protocol.report import DiagnosticThresholds # noqa: E402 THRESHOLDS = DiagnosticThresholds() +def max_per_block_growth(h): + if len(h) < 2: + return 1.0 + return max(h[i + 1] / max(h[i], 1e-30) for i in range(len(h) - 1)) + + def diagnose_entry(entry): h = entry["hidden_norms"] g = entry["bp_grad_norms_per_sample_med"] - h_exploded = (max(h) / max(h[0], 1e-30)) > THRESHOLDS.h_norm_explosion_ratio + h_exploded = max_per_block_growth(h) > THRESHOLDS.h_norm_explosion_ratio g_at_floor = g[-1] < THRESHOLDS.g_norm_floor return h_exploded, g_at_floor @@ -58,17 +64,27 @@ def main(): import argparse p = argparse.ArgumentParser() p.add_argument("--seed", type=int, default=42) + p.add_argument("--arch", type=str, default="resmlp", choices=["resmlp", "vit"]) args = p.parse_args() - snapshot_path = os.path.join( - REPO_ROOT, f"results/snapshot_evolution_v2/snapshot_evolution_s{args.seed}.json" - ) + if args.arch == "resmlp": + snapshot_path = os.path.join( + REPO_ROOT, f"results/snapshot_evolution_v2/snapshot_evolution_s{args.seed}.json" + ) + h_key = "hidden_norms" + g_key = "bp_grad_norms_per_sample_med" + else: + snapshot_path = os.path.join( + REPO_ROOT, f"results/snapshot_vit_v1/snapshot_vit_s{args.seed}.json" + ) + h_key = "hidden_norms_cls" + g_key = "bp_grad_per_sample_l2_med" if not os.path.exists(snapshot_path): print(f"snapshot not found: {snapshot_path}") return with open(snapshot_path) as f: d = json.load(f) - bp_log = d["bp_log"] - dfa_log = d["dfa_log"] + bp_log = [{**e, "hidden_norms": e[h_key], "bp_grad_norms_per_sample_med": e[g_key]} for e in d["bp_log"]] + dfa_log = [{**e, "hidden_norms": e[h_key], "bp_grad_norms_per_sample_med": e[g_key]} for e in d["dfa_log"]] print("=" * 88) print("TEMPORAL DIAGNOSTIC EVOLUTION (4-block d=256 ResMLP, CIFAR-10, seed 42)") @@ -87,7 +103,7 @@ def main(): for entry in dfa_log: h = entry["hidden_norms"] g = entry["bp_grad_norms_per_sample_med"] - h_exp = (max(h) / max(h[0], 1e-30)) > THRESHOLDS.h_norm_explosion_ratio + h_exp = max_per_block_growth(h) > THRESHOLDS.h_norm_explosion_ratio g_floor = g[-1] < THRESHOLDS.g_norm_floor flag_a = "FIRE" if h_exp else "ok" flag_b = "FIRE" if g_floor else "ok" @@ -125,7 +141,7 @@ def main(): for entry in bp_log: h = entry["hidden_norms"] g = entry["bp_grad_norms_per_sample_med"] - h_exp = (max(h) / max(h[0], 1e-30)) > THRESHOLDS.h_norm_explosion_ratio + h_exp = max_per_block_growth(h) > THRESHOLDS.h_norm_explosion_ratio g_floor = g[-1] < THRESHOLDS.g_norm_floor if h_exp or g_floor: bp_fired = True @@ -145,7 +161,7 @@ def main(): { "epoch": e["epoch"], "acc": e["acc_eval"], - "h_max_to_h0_ratio": (max(e["hidden_norms"]) / max(e["hidden_norms"][0], 1e-30)), + "max_per_block_growth": max_per_block_growth(e["hidden_norms"]), "g_L": e["bp_grad_norms_per_sample_med"][-1], "gamma": e.get("gamma_dfa"), } -- cgit v1.2.3