""" DiagnosticReport: structured output of the FA evaluation protocol. Holds the per-layer numbers from the four diagnostics and emits a verdict for each one based on the published thresholds. The verdict is intentionally binary ("trustworthy" / "needs walk-back"); fine-grained reading is the caller's job. """ from __future__ import annotations from dataclasses import dataclass, field from typing import List, Optional @dataclass class DiagnosticThresholds: """Degeneracy thresholds (defaults match the paper). g_norm_floor: BP gradient norms below this are considered to be at the numerical floor (Γ measured against this is not interpretable as credit alignment). Default 1e-7 — well above fp32 floor (~1e-38) and well above F.cosine_similarity's eps=1e-8 clamp, but several orders below healthy networks (~1e-5). h_norm_explosion_ratio: residual stream norm growth (relative to layer 0) above this is considered "exploded". Default 50× — BP-trained networks are ~1-3× per layer; failure modes show ~10^5-10^6×. stability_drift_ceiling: cross-batch direction cosine above this is considered to be drift-dominated (reference vector is sample- invariant). Default 0.30 — BP-trained / EP-trained networks are below 0.20, failure modes are 0.5-0.99. frozen_acc_margin_pp: minimum acc gain (in percentage points) of the trainable-blocks variant over the frozen-random-blocks baseline for the deep blocks to be considered "actually contributing". Default 2.0 pp. """ g_norm_floor: float = 1e-7 # Per-block residual growth ratio threshold. The diagnostic is # `max_l(||h_{l+1}|| / ||h_l||)` — the largest single-block residual # amplification. We avoided `max(||h||) / ||h_0||` because it false- # positives on ViT-style architectures where the cls token at layer 0 # is anomalously small after patch_embed. # # Calibration on observed data: # - BP-trained, late training: <5× per block (steady state) # - BP ViT, early training (epoch 1-5): 13-25× max (cls token still # resolving from its small init magnitude) # - DFA-trained ResMLP / ViT: 100-4000× max per block # Threshold 50 sits cleanly between healthy-early-training (max 25) and # failure-regime (min 100), with margin on both sides. h_norm_explosion_ratio: float = 50.0 stability_drift_ceiling: float = 0.30 frozen_acc_margin_pp: float = 2.0 @dataclass class DiagnosticReport: """Result of running the protocol on one trained network.""" method_name: str notes: str residual_norms: List[float] bp_grad_norms: List[float] stability_layer: int cross_batch_stability: float headline_acc: float frozen_baseline_acc: Optional[float] thresholds: DiagnosticThresholds = field(default_factory=DiagnosticThresholds) # ------------------------------------------------------------------ # # Per-diagnostic verdicts # ------------------------------------------------------------------ # @property def max_per_block_growth(self) -> float: """max_l (||h_{l+1}|| / ||h_l||) — the largest residual-stream amplification by any single block. Healthy BP/EP networks have all per-block growth < 5×; pathological networks (DFA/SB/CB on pre-LN residuals) have at least one block with growth > 100×.""" if len(self.residual_norms) < 2: return 1.0 ratios = [] for i in range(len(self.residual_norms) - 1): denom = max(self.residual_norms[i], 1e-30) ratios.append(self.residual_norms[i + 1] / denom) return max(ratios) @property def residual_stream_exploded(self) -> bool: return self.max_per_block_growth > self.thresholds.h_norm_explosion_ratio @property def bp_grad_at_floor(self) -> bool: if not self.bp_grad_norms: return False # Check the *deepest* hidden layer's BP grad — that's where Γ is # typically reported and where LN-driven collapse hits hardest. return self.bp_grad_norms[-1] < self.thresholds.g_norm_floor @property def reference_drift_dominated(self) -> bool: return self.cross_batch_stability > self.thresholds.stability_drift_ceiling @property def frozen_baseline_undercut(self) -> Optional[bool]: """True if the trainable-blocks acc fails to clear the frozen baseline by `frozen_acc_margin_pp`. None if no frozen baseline supplied. """ if self.frozen_baseline_acc is None: return None margin_pp = (self.headline_acc - self.frozen_baseline_acc) * 100 return margin_pp < self.thresholds.frozen_acc_margin_pp # ------------------------------------------------------------------ # # Aggregate verdict # ------------------------------------------------------------------ # @property def verdict(self) -> str: flags = [ ("residual stream exploded", self.residual_stream_exploded), ("BP grad at numerical floor", self.bp_grad_at_floor), ("BP grad direction is drift-dominated", self.reference_drift_dominated), ] if self.frozen_baseline_undercut is True: flags.append(("deep blocks fail to beat frozen-random baseline", True)) flagged = [name for name, val in flags if val] if not flagged: return "trustworthy" return "needs walk-back: " + "; ".join(flagged) # ------------------------------------------------------------------ # # Pretty-print # ------------------------------------------------------------------ # def __str__(self) -> str: L = len(self.residual_norms) lines: List[str] = [] lines.append("=" * 72) lines.append(f"FA Diagnostic Protocol Report — method: {self.method_name}") if self.notes: lines.append(f"Notes: {self.notes}") lines.append("=" * 72) # (a) Residual stream norms lines.append("(a) Residual stream norms ||h_l||_2 (median over batch):") for l in range(L): lines.append(f" h_{l}: {self.residual_norms[l]:.3e}") if self.residual_stream_exploded: lines.append( f" FLAG: max per-block growth ‖h_{{l+1}}‖/‖h_l‖ = " f"{self.max_per_block_growth:.2e} " f"> threshold {self.thresholds.h_norm_explosion_ratio}× — " "residual stream exploded." ) # (b) BP grad norms lines.append("") lines.append("(b) BP gradient norms ||g_l||_2 (median over batch):") for l in range(L): lines.append(f" g_{l}: {self.bp_grad_norms[l]:.3e}") if self.bp_grad_at_floor: lines.append( f" FLAG: deepest ||g_L|| = {self.bp_grad_norms[-1]:.2e} " f"< floor {self.thresholds.g_norm_floor:.0e} — Γ measured " "against numerical floor; not interpretable as credit alignment." ) # (c) Cross-batch direction stability lines.append("") lines.append( f"(c) Cross-batch direction stability at layer {self.stability_layer}: " f"{self.cross_batch_stability:.3f}" ) if self.reference_drift_dominated: lines.append( f" FLAG: stability {self.cross_batch_stability:.3f} > " f"ceiling {self.thresholds.stability_drift_ceiling} — " "reference vector is sample-invariant drift, not per-sample credit." ) # (d) Frozen baseline comparison lines.append("") if self.frozen_baseline_acc is None: lines.append("(d) Frozen-blocks baseline: NOT PROVIDED — diagnostic skipped.") else: margin_pp = (self.headline_acc - self.frozen_baseline_acc) * 100 lines.append( f"(d) Headline acc: {self.headline_acc:.4f}, " f"frozen-blocks baseline: {self.frozen_baseline_acc:.4f}, " f"margin: {margin_pp:+.2f} pp" ) if self.frozen_baseline_undercut: lines.append( f" FLAG: margin {margin_pp:+.2f} pp < " f"required {self.thresholds.frozen_acc_margin_pp} pp — " "deep blocks are not contributing over a random untrained baseline." ) # Verdict lines.append("") lines.append(f"VERDICT: {self.verdict}") lines.append("=" * 72) return "\n".join(lines) def to_dict(self) -> dict: return { "method_name": self.method_name, "notes": self.notes, "residual_norms": list(self.residual_norms), "bp_grad_norms": list(self.bp_grad_norms), "stability_layer": self.stability_layer, "cross_batch_stability": self.cross_batch_stability, "headline_acc": self.headline_acc, "frozen_baseline_acc": self.frozen_baseline_acc, "verdict": self.verdict, "thresholds": { "g_norm_floor": self.thresholds.g_norm_floor, "h_norm_explosion_ratio": self.thresholds.h_norm_explosion_ratio, "stability_drift_ceiling": self.thresholds.stability_drift_ceiling, "frozen_acc_margin_pp": self.thresholds.frozen_acc_margin_pp, }, }