summaryrefslogtreecommitdiff
path: root/protocol/report.py
diff options
context:
space:
mode:
Diffstat (limited to 'protocol/report.py')
-rw-r--r--protocol/report.py201
1 files changed, 201 insertions, 0 deletions
diff --git a/protocol/report.py b/protocol/report.py
new file mode 100644
index 0000000..15d6c34
--- /dev/null
+++ b/protocol/report.py
@@ -0,0 +1,201 @@
+"""
+DiagnosticReport: structured output of the FA evaluation protocol.
+
+Holds the per-layer numbers from the four diagnostics and emits a verdict
+for each one based on the published thresholds. The verdict is intentionally
+binary ("trustworthy" / "needs walk-back"); fine-grained reading is the
+caller's job.
+"""
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+from typing import List, Optional
+
+
+@dataclass
+class DiagnosticThresholds:
+ """Degeneracy thresholds (defaults match the paper).
+
+ g_norm_floor: BP gradient norms below this are considered to be at the
+ numerical floor (Γ measured against this is not interpretable as
+ credit alignment). Default 1e-7 — well above fp32 floor (~1e-38) and
+ well above F.cosine_similarity's eps=1e-8 clamp, but several orders
+ below healthy networks (~1e-5).
+ h_norm_explosion_ratio: residual stream norm growth (relative to layer 0)
+ above this is considered "exploded". Default 50× — BP-trained
+ networks are ~1-3× per layer; failure modes show ~10^5-10^6×.
+ stability_drift_ceiling: cross-batch direction cosine above this is
+ considered to be drift-dominated (reference vector is sample-
+ invariant). Default 0.30 — BP-trained / EP-trained networks are
+ below 0.20, failure modes are 0.5-0.99.
+ frozen_acc_margin_pp: minimum acc gain (in percentage points) of the
+ trainable-blocks variant over the frozen-random-blocks baseline for
+ the deep blocks to be considered "actually contributing". Default
+ 2.0 pp.
+ """
+
+ g_norm_floor: float = 1e-7
+ h_norm_explosion_ratio: float = 50.0
+ stability_drift_ceiling: float = 0.30
+ frozen_acc_margin_pp: float = 2.0
+
+
+@dataclass
+class DiagnosticReport:
+ """Result of running the protocol on one trained network."""
+
+ method_name: str
+ notes: str
+ residual_norms: List[float]
+ bp_grad_norms: List[float]
+ stability_layer: int
+ cross_batch_stability: float
+ headline_acc: float
+ frozen_baseline_acc: Optional[float]
+ thresholds: DiagnosticThresholds = field(default_factory=DiagnosticThresholds)
+
+ # ------------------------------------------------------------------ #
+ # Per-diagnostic verdicts
+ # ------------------------------------------------------------------ #
+
+ @property
+ def residual_stream_exploded(self) -> bool:
+ if not self.residual_norms:
+ return False
+ h0 = self.residual_norms[0]
+ if h0 <= 0:
+ return False
+ return (max(self.residual_norms) / h0) > self.thresholds.h_norm_explosion_ratio
+
+ @property
+ def bp_grad_at_floor(self) -> bool:
+ if not self.bp_grad_norms:
+ return False
+ # Check the *deepest* hidden layer's BP grad — that's where Γ is
+ # typically reported and where LN-driven collapse hits hardest.
+ return self.bp_grad_norms[-1] < self.thresholds.g_norm_floor
+
+ @property
+ def reference_drift_dominated(self) -> bool:
+ return self.cross_batch_stability > self.thresholds.stability_drift_ceiling
+
+ @property
+ def frozen_baseline_undercut(self) -> Optional[bool]:
+ """True if the trainable-blocks acc fails to clear the frozen baseline
+ by `frozen_acc_margin_pp`. None if no frozen baseline supplied.
+ """
+ if self.frozen_baseline_acc is None:
+ return None
+ margin_pp = (self.headline_acc - self.frozen_baseline_acc) * 100
+ return margin_pp < self.thresholds.frozen_acc_margin_pp
+
+ # ------------------------------------------------------------------ #
+ # Aggregate verdict
+ # ------------------------------------------------------------------ #
+
+ @property
+ def verdict(self) -> str:
+ flags = [
+ ("residual stream exploded", self.residual_stream_exploded),
+ ("BP grad at numerical floor", self.bp_grad_at_floor),
+ ("BP grad direction is drift-dominated", self.reference_drift_dominated),
+ ]
+ if self.frozen_baseline_undercut is True:
+ flags.append(("deep blocks fail to beat frozen-random baseline", True))
+ flagged = [name for name, val in flags if val]
+ if not flagged:
+ return "trustworthy"
+ return "needs walk-back: " + "; ".join(flagged)
+
+ # ------------------------------------------------------------------ #
+ # Pretty-print
+ # ------------------------------------------------------------------ #
+
+ def __str__(self) -> str:
+ L = len(self.residual_norms)
+ lines: List[str] = []
+ lines.append("=" * 72)
+ lines.append(f"FA Diagnostic Protocol Report — method: {self.method_name}")
+ if self.notes:
+ lines.append(f"Notes: {self.notes}")
+ lines.append("=" * 72)
+
+ # (a) Residual stream norms
+ lines.append("(a) Residual stream norms ||h_l||_2 (median over batch):")
+ for l in range(L):
+ lines.append(f" h_{l}: {self.residual_norms[l]:.3e}")
+ if self.residual_stream_exploded:
+ lines.append(
+ f" FLAG: max/min ratio "
+ f"{max(self.residual_norms)/max(self.residual_norms[0],1e-30):.2e} "
+ f"> threshold {self.thresholds.h_norm_explosion_ratio}× — "
+ "residual stream exploded."
+ )
+
+ # (b) BP grad norms
+ lines.append("")
+ lines.append("(b) BP gradient norms ||g_l||_2 (median over batch):")
+ for l in range(L):
+ lines.append(f" g_{l}: {self.bp_grad_norms[l]:.3e}")
+ if self.bp_grad_at_floor:
+ lines.append(
+ f" FLAG: deepest ||g_L|| = {self.bp_grad_norms[-1]:.2e} "
+ f"< floor {self.thresholds.g_norm_floor:.0e} — Γ measured "
+ "against numerical floor; not interpretable as credit alignment."
+ )
+
+ # (c) Cross-batch direction stability
+ lines.append("")
+ lines.append(
+ f"(c) Cross-batch direction stability at layer {self.stability_layer}: "
+ f"{self.cross_batch_stability:.3f}"
+ )
+ if self.reference_drift_dominated:
+ lines.append(
+ f" FLAG: stability {self.cross_batch_stability:.3f} > "
+ f"ceiling {self.thresholds.stability_drift_ceiling} — "
+ "reference vector is sample-invariant drift, not per-sample credit."
+ )
+
+ # (d) Frozen baseline comparison
+ lines.append("")
+ if self.frozen_baseline_acc is None:
+ lines.append("(d) Frozen-blocks baseline: NOT PROVIDED — diagnostic skipped.")
+ else:
+ margin_pp = (self.headline_acc - self.frozen_baseline_acc) * 100
+ lines.append(
+ f"(d) Headline acc: {self.headline_acc:.4f}, "
+ f"frozen-blocks baseline: {self.frozen_baseline_acc:.4f}, "
+ f"margin: {margin_pp:+.2f} pp"
+ )
+ if self.frozen_baseline_undercut:
+ lines.append(
+ f" FLAG: margin {margin_pp:+.2f} pp < "
+ f"required {self.thresholds.frozen_acc_margin_pp} pp — "
+ "deep blocks are not contributing over a random untrained baseline."
+ )
+
+ # Verdict
+ lines.append("")
+ lines.append(f"VERDICT: {self.verdict}")
+ lines.append("=" * 72)
+ return "\n".join(lines)
+
+ def to_dict(self) -> dict:
+ return {
+ "method_name": self.method_name,
+ "notes": self.notes,
+ "residual_norms": list(self.residual_norms),
+ "bp_grad_norms": list(self.bp_grad_norms),
+ "stability_layer": self.stability_layer,
+ "cross_batch_stability": self.cross_batch_stability,
+ "headline_acc": self.headline_acc,
+ "frozen_baseline_acc": self.frozen_baseline_acc,
+ "verdict": self.verdict,
+ "thresholds": {
+ "g_norm_floor": self.thresholds.g_norm_floor,
+ "h_norm_explosion_ratio": self.thresholds.h_norm_explosion_ratio,
+ "stability_drift_ceiling": self.thresholds.stability_drift_ceiling,
+ "frozen_acc_margin_pp": self.thresholds.frozen_acc_margin_pp,
+ },
+ }