From 7b64702ad970c16171142665365e16a8e1737190 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Tue, 7 Apr 2026 22:20:48 -0500 Subject: Add FA diagnostic protocol reference implementation Codex round 15 #1 priority for the E&D-track paper: - protocol/protocol.py: 4 diagnostics (residual norms, BP grad norms, cross-batch direction stability, and a frozen-baseline comparator) - protocol/report.py: DiagnosticReport with per-diagnostic verdicts and pretty-printer - protocol/smoke_test.py: validates BP/DFA/EP checkpoints produce the expected verdicts (BP/EP trustworthy; DFA walked back via residual explosion + BP grad at floor) - protocol/README.md: usage, audit cases, threshold rationale - protocol/CHECKLIST.md: 6 evaluation pipeline pitfalls (norm(-1), cosine_similarity eps clamp, fp16 underflow, Bs reproducibility, aggregation, layer-0 dominance) - protocol/REPORTING_TEMPLATE.md: per-method fillable form for FA papers --- protocol/report.py | 201 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 201 insertions(+) create mode 100644 protocol/report.py (limited to 'protocol/report.py') diff --git a/protocol/report.py b/protocol/report.py new file mode 100644 index 0000000..15d6c34 --- /dev/null +++ b/protocol/report.py @@ -0,0 +1,201 @@ +""" +DiagnosticReport: structured output of the FA evaluation protocol. + +Holds the per-layer numbers from the four diagnostics and emits a verdict +for each one based on the published thresholds. The verdict is intentionally +binary ("trustworthy" / "needs walk-back"); fine-grained reading is the +caller's job. +""" +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import List, Optional + + +@dataclass +class DiagnosticThresholds: + """Degeneracy thresholds (defaults match the paper). + + g_norm_floor: BP gradient norms below this are considered to be at the + numerical floor (Γ measured against this is not interpretable as + credit alignment). Default 1e-7 — well above fp32 floor (~1e-38) and + well above F.cosine_similarity's eps=1e-8 clamp, but several orders + below healthy networks (~1e-5). + h_norm_explosion_ratio: residual stream norm growth (relative to layer 0) + above this is considered "exploded". Default 50× — BP-trained + networks are ~1-3× per layer; failure modes show ~10^5-10^6×. + stability_drift_ceiling: cross-batch direction cosine above this is + considered to be drift-dominated (reference vector is sample- + invariant). Default 0.30 — BP-trained / EP-trained networks are + below 0.20, failure modes are 0.5-0.99. + frozen_acc_margin_pp: minimum acc gain (in percentage points) of the + trainable-blocks variant over the frozen-random-blocks baseline for + the deep blocks to be considered "actually contributing". Default + 2.0 pp. + """ + + g_norm_floor: float = 1e-7 + h_norm_explosion_ratio: float = 50.0 + stability_drift_ceiling: float = 0.30 + frozen_acc_margin_pp: float = 2.0 + + +@dataclass +class DiagnosticReport: + """Result of running the protocol on one trained network.""" + + method_name: str + notes: str + residual_norms: List[float] + bp_grad_norms: List[float] + stability_layer: int + cross_batch_stability: float + headline_acc: float + frozen_baseline_acc: Optional[float] + thresholds: DiagnosticThresholds = field(default_factory=DiagnosticThresholds) + + # ------------------------------------------------------------------ # + # Per-diagnostic verdicts + # ------------------------------------------------------------------ # + + @property + def residual_stream_exploded(self) -> bool: + if not self.residual_norms: + return False + h0 = self.residual_norms[0] + if h0 <= 0: + return False + return (max(self.residual_norms) / h0) > self.thresholds.h_norm_explosion_ratio + + @property + def bp_grad_at_floor(self) -> bool: + if not self.bp_grad_norms: + return False + # Check the *deepest* hidden layer's BP grad — that's where Γ is + # typically reported and where LN-driven collapse hits hardest. + return self.bp_grad_norms[-1] < self.thresholds.g_norm_floor + + @property + def reference_drift_dominated(self) -> bool: + return self.cross_batch_stability > self.thresholds.stability_drift_ceiling + + @property + def frozen_baseline_undercut(self) -> Optional[bool]: + """True if the trainable-blocks acc fails to clear the frozen baseline + by `frozen_acc_margin_pp`. None if no frozen baseline supplied. + """ + if self.frozen_baseline_acc is None: + return None + margin_pp = (self.headline_acc - self.frozen_baseline_acc) * 100 + return margin_pp < self.thresholds.frozen_acc_margin_pp + + # ------------------------------------------------------------------ # + # Aggregate verdict + # ------------------------------------------------------------------ # + + @property + def verdict(self) -> str: + flags = [ + ("residual stream exploded", self.residual_stream_exploded), + ("BP grad at numerical floor", self.bp_grad_at_floor), + ("BP grad direction is drift-dominated", self.reference_drift_dominated), + ] + if self.frozen_baseline_undercut is True: + flags.append(("deep blocks fail to beat frozen-random baseline", True)) + flagged = [name for name, val in flags if val] + if not flagged: + return "trustworthy" + return "needs walk-back: " + "; ".join(flagged) + + # ------------------------------------------------------------------ # + # Pretty-print + # ------------------------------------------------------------------ # + + def __str__(self) -> str: + L = len(self.residual_norms) + lines: List[str] = [] + lines.append("=" * 72) + lines.append(f"FA Diagnostic Protocol Report — method: {self.method_name}") + if self.notes: + lines.append(f"Notes: {self.notes}") + lines.append("=" * 72) + + # (a) Residual stream norms + lines.append("(a) Residual stream norms ||h_l||_2 (median over batch):") + for l in range(L): + lines.append(f" h_{l}: {self.residual_norms[l]:.3e}") + if self.residual_stream_exploded: + lines.append( + f" FLAG: max/min ratio " + f"{max(self.residual_norms)/max(self.residual_norms[0],1e-30):.2e} " + f"> threshold {self.thresholds.h_norm_explosion_ratio}× — " + "residual stream exploded." + ) + + # (b) BP grad norms + lines.append("") + lines.append("(b) BP gradient norms ||g_l||_2 (median over batch):") + for l in range(L): + lines.append(f" g_{l}: {self.bp_grad_norms[l]:.3e}") + if self.bp_grad_at_floor: + lines.append( + f" FLAG: deepest ||g_L|| = {self.bp_grad_norms[-1]:.2e} " + f"< floor {self.thresholds.g_norm_floor:.0e} — Γ measured " + "against numerical floor; not interpretable as credit alignment." + ) + + # (c) Cross-batch direction stability + lines.append("") + lines.append( + f"(c) Cross-batch direction stability at layer {self.stability_layer}: " + f"{self.cross_batch_stability:.3f}" + ) + if self.reference_drift_dominated: + lines.append( + f" FLAG: stability {self.cross_batch_stability:.3f} > " + f"ceiling {self.thresholds.stability_drift_ceiling} — " + "reference vector is sample-invariant drift, not per-sample credit." + ) + + # (d) Frozen baseline comparison + lines.append("") + if self.frozen_baseline_acc is None: + lines.append("(d) Frozen-blocks baseline: NOT PROVIDED — diagnostic skipped.") + else: + margin_pp = (self.headline_acc - self.frozen_baseline_acc) * 100 + lines.append( + f"(d) Headline acc: {self.headline_acc:.4f}, " + f"frozen-blocks baseline: {self.frozen_baseline_acc:.4f}, " + f"margin: {margin_pp:+.2f} pp" + ) + if self.frozen_baseline_undercut: + lines.append( + f" FLAG: margin {margin_pp:+.2f} pp < " + f"required {self.thresholds.frozen_acc_margin_pp} pp — " + "deep blocks are not contributing over a random untrained baseline." + ) + + # Verdict + lines.append("") + lines.append(f"VERDICT: {self.verdict}") + lines.append("=" * 72) + return "\n".join(lines) + + def to_dict(self) -> dict: + return { + "method_name": self.method_name, + "notes": self.notes, + "residual_norms": list(self.residual_norms), + "bp_grad_norms": list(self.bp_grad_norms), + "stability_layer": self.stability_layer, + "cross_batch_stability": self.cross_batch_stability, + "headline_acc": self.headline_acc, + "frozen_baseline_acc": self.frozen_baseline_acc, + "verdict": self.verdict, + "thresholds": { + "g_norm_floor": self.thresholds.g_norm_floor, + "h_norm_explosion_ratio": self.thresholds.h_norm_explosion_ratio, + "stability_drift_ceiling": self.thresholds.stability_drift_ceiling, + "frozen_acc_margin_pp": self.thresholds.frozen_acc_margin_pp, + }, + } -- cgit v1.2.3