summaryrefslogtreecommitdiff
path: root/protocol
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 23:00:54 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 23:00:54 -0500
commit31ddecc9eb646b15c4ac5960c7de9346c8f7be68 (patch)
treeeb3d7784aa24dbcd0aca348c0239df609ba3fbf5 /protocol
parentede7cca3e4f9048e3fc6d99077f8842e9b598ff4 (diff)
Protocol diagnostic (a): use max per-block growth, not max/min ratio
Old metric: max(||h||) / max(||h_0||, eps). False-positives on ViT-style architectures because the cls token at layer 0 (right after patch_embed) has anomalously small magnitude (~0.3-1.5), inflating the ratio even on healthy BP-trained ViTs. New metric: max_l(||h_{l+1}|| / ||h_l||) — the largest single-block residual amplification. Architecture-invariant. Calibration: - BP-trained, late training: <5x per block - BP ViT, early epochs (cls token resolving): 13-25x max - DFA-trained ResMLP/ViT: 100-4000x per block Threshold raised from 10 to 50 to sit cleanly between healthy-early- training (max 25) and failure-regime (min 100). Re-verifications: - smoke test (BP/DFA/EP): all 3 verdicts unchanged - random init (3 seeds): trustworthy on all 3 - 5-method audit table single-seed: identical verdicts - decision-utility ablation: identical (still 0/5 by S1, 3/5 by S_full) - temporal evolution 3-seed: (b) now fires first at ep 3-4, (a) at ep 8-11. Both well before training ends. The 'protocol fires ~92 epochs early' story still holds. - ViT temporal evolution: BP no longer false-fires; DFA fires (a) ep 1, (b) ep 3 — protocol works on the second architecture.
Diffstat (limited to 'protocol')
-rw-r--r--protocol/examples/ablation_decision_utility.py23
-rw-r--r--protocol/examples/temporal_diagnostic_evolution.py34
-rw-r--r--protocol/report.py38
3 files changed, 75 insertions, 20 deletions
diff --git a/protocol/examples/ablation_decision_utility.py b/protocol/examples/ablation_decision_utility.py
index df82b7e..17aa6c5 100644
--- a/protocol/examples/ablation_decision_utility.py
+++ b/protocol/examples/ablation_decision_utility.py
@@ -40,10 +40,16 @@ def gamma_proxy(method: str, g_norm: float) -> str:
return "no walk-back (looks fine)"
+def max_per_block_growth(h):
+ if len(h) < 2:
+ return 1.0
+ return max(h[i + 1] / max(h[i], 1e-30) for i in range(len(h) - 1))
+
+
def evaluate_strategy(strategy: str, method: str, report: dict, headline_acc: float) -> str:
"""Return whether the strategy would have walked back the claim, and why."""
h_exploded = (
- max(report["residual_norms"]) / max(report["residual_norms"][0], 1e-30)
+ max_per_block_growth(report["residual_norms"])
> report["thresholds"]["h_norm_explosion_ratio"]
)
g_at_floor = report["bp_grad_norms"][-1] < report["thresholds"]["g_norm_floor"]
@@ -78,7 +84,11 @@ def main():
data = json.load(f)
methods = ["bp", "dfa", "state_bridge", "credit_bridge", "ep"]
- method_acc = {row["method"]: row["acc"] for row in data["summary"]}
+ method_acc = {}
+ for row in data["summary"]:
+ # Use seed-42 row for the ablation; the table is single-seed by design
+ if row.get("seed", 42) == 42:
+ method_acc[row["method"]] = row["acc"]
strategies = ["S0", "S1", "S2", "S3", "S4", "S5", "S_full"]
strategy_label = {
@@ -95,9 +105,16 @@ def main():
print("Protocol decision-utility ablation (4-block d=256 ResMLP, CIFAR-10, seed 42)")
print("=" * 100)
+ # Audit JSON keys reports as either "bp" (legacy) or "bp_s42" (current).
table = {}
for method in methods:
- report = data["reports"][method]
+ if method in data["reports"]:
+ report = data["reports"][method]
+ elif f"{method}_s42" in data["reports"]:
+ report = data["reports"][f"{method}_s42"]
+ else:
+ print(f" SKIPPED (no report): {method}")
+ continue
acc = method_acc[method]
table[method] = {}
for s in strategies:
diff --git a/protocol/examples/temporal_diagnostic_evolution.py b/protocol/examples/temporal_diagnostic_evolution.py
index 6a2c042..35cf720 100644
--- a/protocol/examples/temporal_diagnostic_evolution.py
+++ b/protocol/examples/temporal_diagnostic_evolution.py
@@ -39,10 +39,16 @@ from protocol.report import DiagnosticThresholds # noqa: E402
THRESHOLDS = DiagnosticThresholds()
+def max_per_block_growth(h):
+ if len(h) < 2:
+ return 1.0
+ return max(h[i + 1] / max(h[i], 1e-30) for i in range(len(h) - 1))
+
+
def diagnose_entry(entry):
h = entry["hidden_norms"]
g = entry["bp_grad_norms_per_sample_med"]
- h_exploded = (max(h) / max(h[0], 1e-30)) > THRESHOLDS.h_norm_explosion_ratio
+ h_exploded = max_per_block_growth(h) > THRESHOLDS.h_norm_explosion_ratio
g_at_floor = g[-1] < THRESHOLDS.g_norm_floor
return h_exploded, g_at_floor
@@ -58,17 +64,27 @@ def main():
import argparse
p = argparse.ArgumentParser()
p.add_argument("--seed", type=int, default=42)
+ p.add_argument("--arch", type=str, default="resmlp", choices=["resmlp", "vit"])
args = p.parse_args()
- snapshot_path = os.path.join(
- REPO_ROOT, f"results/snapshot_evolution_v2/snapshot_evolution_s{args.seed}.json"
- )
+ if args.arch == "resmlp":
+ snapshot_path = os.path.join(
+ REPO_ROOT, f"results/snapshot_evolution_v2/snapshot_evolution_s{args.seed}.json"
+ )
+ h_key = "hidden_norms"
+ g_key = "bp_grad_norms_per_sample_med"
+ else:
+ snapshot_path = os.path.join(
+ REPO_ROOT, f"results/snapshot_vit_v1/snapshot_vit_s{args.seed}.json"
+ )
+ h_key = "hidden_norms_cls"
+ g_key = "bp_grad_per_sample_l2_med"
if not os.path.exists(snapshot_path):
print(f"snapshot not found: {snapshot_path}")
return
with open(snapshot_path) as f:
d = json.load(f)
- bp_log = d["bp_log"]
- dfa_log = d["dfa_log"]
+ bp_log = [{**e, "hidden_norms": e[h_key], "bp_grad_norms_per_sample_med": e[g_key]} for e in d["bp_log"]]
+ dfa_log = [{**e, "hidden_norms": e[h_key], "bp_grad_norms_per_sample_med": e[g_key]} for e in d["dfa_log"]]
print("=" * 88)
print("TEMPORAL DIAGNOSTIC EVOLUTION (4-block d=256 ResMLP, CIFAR-10, seed 42)")
@@ -87,7 +103,7 @@ def main():
for entry in dfa_log:
h = entry["hidden_norms"]
g = entry["bp_grad_norms_per_sample_med"]
- h_exp = (max(h) / max(h[0], 1e-30)) > THRESHOLDS.h_norm_explosion_ratio
+ h_exp = max_per_block_growth(h) > THRESHOLDS.h_norm_explosion_ratio
g_floor = g[-1] < THRESHOLDS.g_norm_floor
flag_a = "FIRE" if h_exp else "ok"
flag_b = "FIRE" if g_floor else "ok"
@@ -125,7 +141,7 @@ def main():
for entry in bp_log:
h = entry["hidden_norms"]
g = entry["bp_grad_norms_per_sample_med"]
- h_exp = (max(h) / max(h[0], 1e-30)) > THRESHOLDS.h_norm_explosion_ratio
+ h_exp = max_per_block_growth(h) > THRESHOLDS.h_norm_explosion_ratio
g_floor = g[-1] < THRESHOLDS.g_norm_floor
if h_exp or g_floor:
bp_fired = True
@@ -145,7 +161,7 @@ def main():
{
"epoch": e["epoch"],
"acc": e["acc_eval"],
- "h_max_to_h0_ratio": (max(e["hidden_norms"]) / max(e["hidden_norms"][0], 1e-30)),
+ "max_per_block_growth": max_per_block_growth(e["hidden_norms"]),
"g_L": e["bp_grad_norms_per_sample_med"][-1],
"gamma": e.get("gamma_dfa"),
}
diff --git a/protocol/report.py b/protocol/report.py
index 15d6c34..00640eb 100644
--- a/protocol/report.py
+++ b/protocol/report.py
@@ -35,6 +35,19 @@ class DiagnosticThresholds:
"""
g_norm_floor: float = 1e-7
+ # Per-block residual growth ratio threshold. The diagnostic is
+ # `max_l(||h_{l+1}|| / ||h_l||)` — the largest single-block residual
+ # amplification. We avoided `max(||h||) / ||h_0||` because it false-
+ # positives on ViT-style architectures where the cls token at layer 0
+ # is anomalously small after patch_embed.
+ #
+ # Calibration on observed data:
+ # - BP-trained, late training: <5× per block (steady state)
+ # - BP ViT, early training (epoch 1-5): 13-25× max (cls token still
+ # resolving from its small init magnitude)
+ # - DFA-trained ResMLP / ViT: 100-4000× max per block
+ # Threshold 50 sits cleanly between healthy-early-training (max 25) and
+ # failure-regime (min 100), with margin on both sides.
h_norm_explosion_ratio: float = 50.0
stability_drift_ceiling: float = 0.30
frozen_acc_margin_pp: float = 2.0
@@ -59,13 +72,22 @@ class DiagnosticReport:
# ------------------------------------------------------------------ #
@property
+ def max_per_block_growth(self) -> float:
+ """max_l (||h_{l+1}|| / ||h_l||) — the largest residual-stream
+ amplification by any single block. Healthy BP/EP networks have all
+ per-block growth < 5×; pathological networks (DFA/SB/CB on pre-LN
+ residuals) have at least one block with growth > 100×."""
+ if len(self.residual_norms) < 2:
+ return 1.0
+ ratios = []
+ for i in range(len(self.residual_norms) - 1):
+ denom = max(self.residual_norms[i], 1e-30)
+ ratios.append(self.residual_norms[i + 1] / denom)
+ return max(ratios)
+
+ @property
def residual_stream_exploded(self) -> bool:
- if not self.residual_norms:
- return False
- h0 = self.residual_norms[0]
- if h0 <= 0:
- return False
- return (max(self.residual_norms) / h0) > self.thresholds.h_norm_explosion_ratio
+ return self.max_per_block_growth > self.thresholds.h_norm_explosion_ratio
@property
def bp_grad_at_floor(self) -> bool:
@@ -126,8 +148,8 @@ class DiagnosticReport:
lines.append(f" h_{l}: {self.residual_norms[l]:.3e}")
if self.residual_stream_exploded:
lines.append(
- f" FLAG: max/min ratio "
- f"{max(self.residual_norms)/max(self.residual_norms[0],1e-30):.2e} "
+ f" FLAG: max per-block growth ‖h_{{l+1}}‖/‖h_l‖ = "
+ f"{self.max_per_block_growth:.2e} "
f"> threshold {self.thresholds.h_norm_explosion_ratio}× — "
"residual stream exploded."
)