summaryrefslogtreecommitdiff
path: root/protocol
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 22:20:48 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 22:20:48 -0500
commit7b64702ad970c16171142665365e16a8e1737190 (patch)
tree30df9313fc195aa037f05c0598fdc0df8b32bbcc /protocol
parent2dc8e7efb5f2ff827fbd97be87f0127aa5ab2757 (diff)
Add FA diagnostic protocol reference implementation
Codex round 15 #1 priority for the E&D-track paper: - protocol/protocol.py: 4 diagnostics (residual norms, BP grad norms, cross-batch direction stability, and a frozen-baseline comparator) - protocol/report.py: DiagnosticReport with per-diagnostic verdicts and pretty-printer - protocol/smoke_test.py: validates BP/DFA/EP checkpoints produce the expected verdicts (BP/EP trustworthy; DFA walked back via residual explosion + BP grad at floor) - protocol/README.md: usage, audit cases, threshold rationale - protocol/CHECKLIST.md: 6 evaluation pipeline pitfalls (norm(-1), cosine_similarity eps clamp, fp16 underflow, Bs reproducibility, aggregation, layer-0 dominance) - protocol/REPORTING_TEMPLATE.md: per-method fillable form for FA papers
Diffstat (limited to 'protocol')
-rw-r--r--protocol/CHECKLIST.md96
-rw-r--r--protocol/README.md150
-rw-r--r--protocol/REPORTING_TEMPLATE.md69
-rw-r--r--protocol/__init__.py52
-rw-r--r--protocol/protocol.py290
-rw-r--r--protocol/report.py201
-rw-r--r--protocol/smoke_test.py136
7 files changed, 994 insertions, 0 deletions
diff --git a/protocol/CHECKLIST.md b/protocol/CHECKLIST.md
new file mode 100644
index 0000000..3227642
--- /dev/null
+++ b/protocol/CHECKLIST.md
@@ -0,0 +1,96 @@
+# FA Evaluation Pitfalls Checklist
+
+Six evaluation-pipeline bugs we found in our own codebase while developing
+this protocol. Each one silently corrupted reported numbers in a way that
+would have been very hard for a reviewer to catch. Use this list as a
+final-pass review for any FA evaluation paper.
+
+## 1. `tensor.norm(-1)` is NOT "L₂ along dim=-1"
+
+`tensor.norm(-1)` computes the L₋₁ "norm" (`1/sum(1/|x|)`) of the *entire*
+flattened tensor, not the per-row L₂ norm. The correct call for "L₂ along
+the feature dimension" is `tensor.norm(dim=-1)`.
+
+This single bug invalidated several months of our own gradient-norm
+measurements; gradient sparsity tables, "support sparsity" analyses, and
+the `naive_StateErr` metric in three of our trainers were all degenerate.
+
+**Check**: grep your codebase for `\.norm\(-1\)` and `\.norm\(\s*-1\s*\)`.
+Replace with `.norm(dim=-1)`.
+
+## 2. `F.cosine_similarity` clamps small magnitudes via `eps=1e-8`
+
+`torch.nn.functional.cosine_similarity(a, b)` divides by `max(||a||·||b||,
+eps)` with `eps=1e-8` by default. When the BP gradient at hidden layers is
+~1e-10 (which it is on DFA-trained pre-LN ResMLPs), the divisor becomes
+`||a|| · 1e-8` instead of `||a|| · 1e-10`, scaling the reported cosine by
+`||g|| / eps` ≈ 0.01. The reported "Γ ≈ 0.10" is then off by ~100× from the
+true cosine.
+
+**Check**: when computing Γ, either use a hand-rolled cosine that does not
+clamp, or assert `||g|| > 1e-8` before calling `F.cosine_similarity`. If you
+report Γ alongside ‖g‖ and ‖g‖ is at the floor, your Γ is uninterpretable —
+the diagnostic protocol will catch this.
+
+## 3. fp16 mixed precision underflows BP grads at hidden layers
+
+Computing Γ in fp16 mixed precision on a DFA-trained ResMLP gives `nan`
+because BP grads (~5e-10) are below fp16's smallest subnormal (~6e-8).
+**bf16 works** because it has the same exponent range as fp32 (only the
+mantissa is reduced).
+
+**Check**: if your reported metric depends on a tiny intermediate magnitude,
+either compute it in fp32, switch to bf16, or rescale by a known constant
+before the precision-sensitive step.
+
+## 4. Random feedback `Bs` are training-specific — Γ is not invariant to
+ reseeding them
+
+DFA at fixed seed s42 reports Γ ≈ 0.106 *with the random feedback projection
+matrices `Bs` that were used during training*. With 20 fresh random `Bs`
+draws on the same trained network, Γ ≈ 0 ± 0.005. The non-zero Γ is the
+network adapting to the *specific* `Bs` it saw during training, not an
+intrinsic property of the local credit signal.
+
+**Check**: when you report Γ for an FA-style method, specify exactly which
+`Bs` were used. If you cannot reproduce them from the seed, your numbers
+are not reproducible.
+
+## 5. Aggregation strategy across layers / samples / batches is rarely
+ specified, but determines the headline number
+
+We computed DFA's "headline Γ" with four equally valid aggregation
+strategies: (median over layers) × (mean over samples), (mean over layers) ×
+(median over samples), and so on. The four strategies give Γ ∈ [0.085,
+0.211] on the *same* trained DFA s42 network — a 2.5× spread. For credit
+bridge it's 5× (Γ ∈ [0.057, 0.285]).
+
+**Check**: state explicitly how you aggregate Γ. Report per-layer values
+and let the reader pick. Do not collapse to a single number without showing
+the layer breakdown.
+
+## 6. Layer-0 dominates the headline Γ; deeper layers are ~0
+
+For DFA on a 4-block ResMLP, the headline Γ ≈ 0.10 is driven almost entirely
+by the embedding layer (Γ_layer0 ≈ 0.43). The block layers have Γ ≈ 0. A
+"mean over layers" summary statistic hides this. The same pattern likely
+holds for other FA-style methods.
+
+**Check**: always report per-layer Γ. A single average is misleading when
+one layer dominates.
+
+## Suggested final-pass workflow
+
+1. Run `grep -nr '\.norm(-1)' YOUR_REPO` and fix all matches.
+2. Run the diagnostic protocol on your trained model. If any flag fires,
+ fix the underlying issue or walk back the corresponding claim before
+ submission.
+3. Inspect every place you compute Γ:
+ - Is the divisor clamped?
+ - Are you in fp16?
+ - Are you using the training `Bs` or fresh ones?
+ - How do you aggregate?
+ - Do you report per-layer?
+4. Report `‖h_l‖`, `‖g_l‖`, and `Γ_l` *together* for every layer. The
+ metrics are entangled; reporting one without the others hides the
+ pathology.
diff --git a/protocol/README.md b/protocol/README.md
new file mode 100644
index 0000000..770d582
--- /dev/null
+++ b/protocol/README.md
@@ -0,0 +1,150 @@
+# FA Diagnostic Protocol — reference implementation
+
+A drop-in toolkit for evaluating Feedback Alignment (FA) and other local-credit
+methods on modern residual architectures. Implements the four diagnostics
+described in *Beyond Accuracy and Alignment: A Diagnostic Protocol for
+Feedback Alignment Evaluation*.
+
+## What this fixes
+
+Standard FA evaluation reports two numbers:
+
+1. **Headline accuracy** of the trained network.
+2. **Γ** — cosine alignment between the local credit signal and the
+ true BP gradient at hidden layers.
+
+Both can silently fail on modern pre-LayerNorm residual networks:
+
+- **Headline accuracy** fails because deep blocks may not be contributing at
+ all — the embedding + terminal-LN + head pathway alone can match (or
+ exceed) the trainable-blocks accuracy. We have measured DFA on a 4-block
+ d=256 ResMLP at 30.8% test acc, *below* the same architecture's
+ frozen-random-blocks baseline at 34.9%. The headline number reads as
+ "DFA learns" but the deep blocks are *actively destroying* value.
+- **Γ** fails because residual-stream explosion drives the BP gradient at
+ hidden layers to ~10⁻¹⁰ via terminal-LayerNorm scaling. Γ then computes
+ the cosine of the local credit signal against a numerical-floor reference
+ vector. The reported value is mathematically well-defined but does not
+ measure credit-assignment quality.
+
+This protocol catches both failures and tells you whether your reported
+numbers are interpretable.
+
+## The four diagnostics
+
+| | Diagnostic | What it catches |
+|---|---|---|
+| (a) | Per-layer residual stream norm ‖h_l‖₂ | Scale pathology — if ‖h_L‖ has grown ≫50× over depth, the network is in a degenerate regime. |
+| (b) | Per-layer BP gradient norm ‖g_l‖₂ | Reference-vector floor — if ‖g_L‖ < 1e-7, Γ is being measured against the numerical floor and is not interpretable. |
+| (c) | Cross-batch direction stability of g | Drift dominance — if the BP gradient direction is sample-invariant (cosine ≈ 1 across disjoint batches), the reference is a global drift artifact, not per-sample credit. |
+| (d) | Frozen-blocks baseline accuracy | Useful blocks — if the trainable-blocks variant fails to clear an architecture-matched frozen-random-blocks baseline by ≥2 pp, the deep blocks are not contributing. |
+
+## Quick start
+
+```python
+from protocol import diagnose
+
+report = diagnose(
+ model=trained_model, # any PyTorch module exposing
+ # `model(x, return_hidden=True)`,
+ # `model.embed`/`model.patch_embed`,
+ # `model.blocks`, and
+ # `model.out_ln`/`model.out_head`
+ # (or `model.norm`/`model.head`)
+ eval_batches=eval_batches, # ≥8 disjoint (x,y) batches of ~128
+ headline_acc=test_acc, # the number you would normally publish
+ frozen_baseline_acc=frozen_acc, # acc of an architecture-matched
+ # frozen-random-blocks control
+ # (set None to skip diagnostic d)
+ method_name="DFA",
+ notes="4-block d=256 ResMLP, CIFAR-10",
+)
+print(report)
+# emits a per-diagnostic table and either
+# "VERDICT: trustworthy"
+# or
+# "VERDICT: needs walk-back: <list of failed diagnostics>"
+```
+
+## Reproducing the audit cases
+
+We ship a smoke test that loads BP, DFA, and EP checkpoints from
+`results/confirmatory/checkpoints_A2/` and `results/ep_baseline/`, runs the
+protocol on each, and checks that the verdicts match expectation:
+
+```bash
+CUDA_VISIBLE_DEVICES=2 python -m protocol.smoke_test
+```
+
+Expected verdicts (4-block d=256 ResMLP, CIFAR-10, seed 42):
+
+| method | ‖h₄‖ | ‖g₄‖ | stability | verdict |
+|---|---:|---:|---:|---|
+| BP | 2.1e2 | 3.7e-4 | 0.099 | trustworthy |
+| DFA | 4.4e8 | 4.2e-9 | 0.047 | walked back: explosion + BP grad at floor |
+| EP | 3.3e3 | 1.6e-4 | -0.036 | trustworthy (internal control) |
+
+EP is the paper's central internal control: same architecture, same dataset,
+same metric, but EP's training does not produce residual-stream explosion,
+so the protocol passes for EP even though EP's headline accuracy is also low
+(~32%). This isolates "the network is in a degenerate measurement regime"
+from "the method underperforms" — they are different failure modes with
+different evidence.
+
+## Duck-typed model interface
+
+The protocol works on any model exposing:
+
+```python
+# Forward with hidden states
+logits, hiddens = model(x, return_hidden=True)
+# hiddens is a list [h_0, h_1, ..., h_L]; supports either (B, d) or (B, T, d)
+# (the protocol pools the cls token via h[:, 0, :] in the latter case)
+
+# Building blocks for the manual forward (used by diagnostics b and c)
+model.embed OR model.patch_embed
+model.blocks # iterable of nn.Modules; each block(h) returns the
+ # *residual branch* f_l(h_l), NOT h_l + f_l(h_l)
+model.out_ln, model.out_head OR model.norm, model.head
+```
+
+`models/residual_mlp.py` and `models/vit_mini.py` in this repo both satisfy
+this interface as written.
+
+## Thresholds (and how to override them)
+
+```python
+from protocol import diagnose, DiagnosticReport
+from protocol.report import DiagnosticThresholds
+
+custom = DiagnosticThresholds(
+ g_norm_floor=1e-6, # tighter floor
+ h_norm_explosion_ratio=20.0, # tighter scale check
+ stability_drift_ceiling=0.30,
+ frozen_acc_margin_pp=2.0,
+)
+report = diagnose(..., thresholds=custom)
+```
+
+The defaults match the paper:
+
+- `g_norm_floor = 1e-7` — well above fp32's smallest subnormal (~1e-38) and
+ well above F.cosine_similarity's eps=1e-8 clamp, but several orders below
+ healthy networks (~1e-5).
+- `h_norm_explosion_ratio = 50.0` — BP-trained networks are ~1-3× per layer
+ in our measurements; failure modes show ~10⁵-10⁶× from layer 0 to layer L.
+- `stability_drift_ceiling = 0.30` — calibrated for ~128-sample batches.
+ BP/EP-trained networks: 0.05-0.18. DFA/SB drift mode: 0.43-0.99.
+- `frozen_acc_margin_pp = 2.0` — DFA-on-ResMLP undercuts frozen baseline by
+ 4 pp; the threshold is loose enough to allow noise on healthy networks.
+
+## Pipeline pitfalls (the checklist)
+
+In addition to the diagnostics above, we ship `CHECKLIST.md` — six
+evaluation-pipeline bugs we found in our own dogfood codebase that silently
+corrupt FA evaluation results. We recommend a final pre-publication pass
+through the checklist regardless of which diagnostic protocol you use.
+
+## Citation
+
+(TBD — points to the paper PDF when uploaded.)
diff --git a/protocol/REPORTING_TEMPLATE.md b/protocol/REPORTING_TEMPLATE.md
new file mode 100644
index 0000000..5d4a187
--- /dev/null
+++ b/protocol/REPORTING_TEMPLATE.md
@@ -0,0 +1,69 @@
+# FA Evaluation Reporting Template
+
+A minimal, fillable reporting template for any paper that evaluates a
+local-credit (FA-style) method on a residual architecture. Reviewers can
+use this as a checklist to verify whether the reported numbers are in the
+metric's meaningful regime.
+
+Copy the table below into your paper appendix or supplementary material.
+Fill in one row per `(method × architecture × dataset × seed)`.
+
+---
+
+## Method × architecture identification
+
+| field | value |
+|---|---|
+| Method | _e.g._ DFA, EP, SB, CB, BP |
+| Method local-credit signal | _formula_, _e.g._ `e_T B_l^T` for DFA |
+| Architecture | _e.g._ 4-block d=256 pre-LN ResidualMLP |
+| Has terminal LayerNorm before head | yes / no |
+| Dataset | _e.g._ CIFAR-10 |
+| Number of seeds reported | _e.g._ 3 |
+| Random feedback `Bs` (if any): seed/spec | |
+
+## Headline numbers
+
+| field | value |
+|---|---|
+| Test accuracy (mean ± std over seeds) | |
+| Headline Γ (cosine to BP gradient) | |
+| Aggregation: how is Γ collapsed? | _e.g._ "mean over layers, then mean over samples" |
+| Per-layer Γ table reported in §X | yes / no |
+
+## Diagnostic protocol numbers (this is the new content)
+
+| diagnostic | per-layer values | flag? |
+|---|---|---|
+| (a) ‖h_l‖₂ for l = 0 .. L | _e.g._ [255, 226, 211, 205, 203] | scale OK / **EXPLODED** |
+| (b) ‖g_l‖₂ for l = 0 .. L | _e.g._ [4e-4, 4e-4, 4e-4, 4e-4, 3e-4] | grad OK / **AT FLOOR** |
+| (c) cross-batch direction stability at layer L/2 | _e.g._ 0.099 (N=128, ≥8 batches) | OK / **DRIFT** |
+| (d) frozen-blocks baseline acc | _e.g._ 0.349 | trainable acc clears it by ≥2 pp / **UNDERCUT** |
+
+If any of (a)-(d) fires a flag, headline accuracy and Γ should be walked
+back or accompanied by an explicit caveat. We recommend using the bundled
+`protocol.diagnose(...)` function and pasting its `__str__` output as the
+appendix table.
+
+## Pipeline-pitfalls disclosure
+
+State explicitly:
+
+- [ ] We grepped for `tensor.norm(-1)` (the L₋₁ footgun) and confirmed all
+ gradient-norm computations use `tensor.norm(dim=-1)`.
+- [ ] We did not compute Γ in fp16 (or, if we did, we verified no underflow
+ and reported the precision).
+- [ ] We disclose the random feedback `Bs` used during training (if
+ applicable) and reproduced reported Γ with those exact `Bs`.
+- [ ] We report per-layer Γ in addition to any aggregate.
+- [ ] We report `‖h_l‖`, `‖g_l‖`, and `Γ_l` *together* for every layer.
+- [ ] We ran an architecture-matched frozen-random-blocks baseline and
+ report its accuracy alongside the trainable-blocks variant.
+
+## Internal control (strongly recommended)
+
+If your method is *not* the only candidate on this architecture, run the
+protocol on a second method that you have reason to believe operates in
+the meaningful regime (e.g., BP itself, or EP). The protocol should pass
+on the control. If it does not, you have an architecture problem rather
+than a method problem, and the paper should reflect that.
diff --git a/protocol/__init__.py b/protocol/__init__.py
new file mode 100644
index 0000000..770bdda
--- /dev/null
+++ b/protocol/__init__.py
@@ -0,0 +1,52 @@
+"""
+FA Diagnostic Protocol — reference implementation.
+
+A drop-in toolkit for evaluating Feedback Alignment (and related local-credit)
+methods on modern residual architectures. Implements the four diagnostics
+described in *Beyond Accuracy and Alignment: A Diagnostic Protocol for
+Feedback Alignment Evaluation*:
+
+ (a) per-layer residual stream norm ||h_l||_2
+ (b) per-layer BP gradient norm ||g_l||_2
+ (c) cross-batch direction stability s_l
+ (d) frozen-blocks control accuracy acc_frozen
+
+The motivation is that headline accuracy and offline Γ (cosine alignment of
+local credit to BP gradient) can both silently fail on modern pre-LayerNorm
+residual networks: residual-stream explosion drives ||g|| ~10^-10 and turns Γ
+into a measurement of cosine to a numerical-floor reference vector. This
+protocol catches that regime change and tells you whether your reported
+numbers are interpretable.
+
+Quick start:
+
+ from protocol import diagnose, DiagnosticReport
+
+ report = diagnose(
+ model=trained_model,
+ eval_batches=list_of_(x, y)_pairs,
+ frozen_baseline_acc=accuracy_of_random_blocks_baseline,
+ method_name="DFA",
+ notes="4-block d=256 ResMLP, CIFAR-10",
+ )
+ print(report)
+
+The report flags any diagnostic that crossed a "degenerate regime" threshold
+and emits a verdict ("trustworthy" / "needs walk-back").
+"""
+
+from .protocol import (
+ residual_norms,
+ bp_grad_norms,
+ cross_batch_direction_stability,
+ diagnose,
+)
+from .report import DiagnosticReport
+
+__all__ = [
+ "residual_norms",
+ "bp_grad_norms",
+ "cross_batch_direction_stability",
+ "diagnose",
+ "DiagnosticReport",
+]
diff --git a/protocol/protocol.py b/protocol/protocol.py
new file mode 100644
index 0000000..4298756
--- /dev/null
+++ b/protocol/protocol.py
@@ -0,0 +1,290 @@
+"""
+Core diagnostics for the FA evaluation protocol.
+
+The four diagnostics:
+
+ (a) residual_norms — per-layer ||h_l||_2 on a fixed eval batch
+ (b) bp_grad_norms — per-layer ||∂CE/∂h_l||_2 on a fixed eval batch
+ (c) cross_batch_direction_stability — cosine of normalized BP-grad direction
+ across disjoint minibatches; high values
+ indicate the reference vector is dominated by
+ a sample-invariant global drift, NOT per-sample
+ credit
+ (d) frozen-blocks accuracy — caller-supplied; the protocol just compares it
+ to the headline accuracy of the trained network
+
+The protocol expects a model exposing a `return_hidden=True` mode that yields
+the per-layer residual stream as a list `[h_0, h_1, ..., h_L]`, plus the
+standard `out_ln` + `out_head` terminal stack. ResidualMLP and ViT-Mini in
+this repo both satisfy that contract; see the docstring of `diagnose` for the
+duck-typed interface.
+
+All metrics use `tensor.norm(dim=-1)` (L2 along the feature dim) — NEVER
+`tensor.norm(-1)` (which is L_{-1} of the entire tensor and is the bug behind
+several walk-backs in our project).
+"""
+from __future__ import annotations
+
+from typing import List, Sequence, Tuple, Optional
+
+import torch
+import torch.nn.functional as F
+
+from .report import DiagnosticReport, DiagnosticThresholds
+
+
+# --------------------------------------------------------------------------- #
+# Diagnostic (a): residual stream norm
+# --------------------------------------------------------------------------- #
+
+def residual_norms(model, x: torch.Tensor) -> List[float]:
+ """Per-layer median ||h_l||_2 on the eval batch.
+
+ Args:
+ model: must support `model(x, return_hidden=True) -> (logits, hiddens)`
+ where hiddens is `[h_0, ..., h_L]` (each (B, ..., d_hidden)).
+ x: eval batch (already on the correct device, already in the input
+ shape the model wants).
+
+ Returns:
+ list of L+1 floats — the median (over batch) of ||h_l||_2 per layer.
+ """
+ model.eval()
+ with torch.no_grad():
+ _, hiddens = model(x, return_hidden=True)
+ out: List[float] = []
+ for h in hiddens:
+ # If h has token dim (B, T, d), pool tokens to a single vector per
+ # sample by taking the cls token (index 0) — that's what the head
+ # actually consumes in our ViT.
+ if h.dim() == 3:
+ h = h[:, 0, :]
+ # NOTE: dim=-1 is mandatory; `.norm(-1)` would compute L_{-1} of the
+ # whole tensor and is wrong (see project_norm_minus1_footgun).
+ out.append(h.norm(dim=-1).median().item())
+ return out
+
+
+# --------------------------------------------------------------------------- #
+# Diagnostic (b): BP gradient norm at hidden layers
+# --------------------------------------------------------------------------- #
+
+def bp_grad_norms(
+ model,
+ x: torch.Tensor,
+ y: torch.Tensor,
+ loss_fn=F.cross_entropy,
+) -> List[float]:
+ """Per-layer median ||∂L/∂h_l||_2 on the eval batch.
+
+ The BP gradient at hidden layers is the *reference vector* that offline
+ alignment metrics (Γ) compare against. When this is at the numerical
+ floor (~1e-10) on modern pre-LN residual nets, Γ measures alignment to a
+ noise vector and is no longer interpretable as "credit assignment
+ quality".
+
+ Args:
+ model: same contract as for `residual_norms`.
+ x: eval batch.
+ y: labels for the eval batch.
+ loss_fn: scalar loss given (logits, y); cross_entropy by default.
+
+ Returns:
+ list of L+1 floats — the median (over batch) of ||g_l||_2 per layer,
+ where g_l = ∂loss/∂h_l.
+ """
+ model.eval()
+ # Forward through the canonical path (embed -> blocks with explicit
+ # residual addition -> terminal LN -> head), keeping every intermediate
+ # h_l as a non-leaf in the autograd graph. `torch.autograd.grad` accepts
+ # non-leaf tensors as `inputs` and computes ∂L/∂h_l via the chain rule.
+ # NOTE: do not clone+detach the hidden states — that severs the graph.
+ with torch.enable_grad():
+ h = _embed(model, x)
+ hiddens: List[torch.Tensor] = [h]
+ for block in _iter_blocks(model):
+ f = block(h)
+ h = h + f
+ hiddens.append(h)
+ logits = _terminal(model, h)
+ loss = loss_fn(logits, y)
+ grads = torch.autograd.grad(loss, hiddens)
+ out: List[float] = []
+ for g in grads:
+ if g.dim() == 3:
+ g = g[:, 0, :]
+ out.append(g.norm(dim=-1).median().item())
+ return out
+
+
+def _embed(model, x):
+ if hasattr(model, "embed"):
+ return model.embed(x)
+ if hasattr(model, "patch_embed"):
+ # ViT-style
+ return model.patch_embed(x)
+ raise AttributeError(
+ "model must expose `.embed(x)` or `.patch_embed(x)` for the protocol"
+ )
+
+
+def _iter_blocks(model):
+ if hasattr(model, "blocks"):
+ return list(model.blocks)
+ raise AttributeError("model must expose `.blocks` (a Module list)")
+
+
+def _terminal(model, h):
+ """Apply terminal LN (if any) + classification head to hidden state."""
+ if hasattr(model, "out_ln") and hasattr(model, "out_head"):
+ return model.out_head(model.out_ln(h))
+ if hasattr(model, "norm") and hasattr(model, "head"):
+ # ViT-style naming
+ h = model.norm(h)
+ if h.dim() == 3:
+ h = h[:, 0, :]
+ return model.head(h)
+ raise AttributeError(
+ "model must expose `(out_ln, out_head)` or `(norm, head)` for the protocol"
+ )
+
+
+# --------------------------------------------------------------------------- #
+# Diagnostic (c): cross-batch direction stability
+# --------------------------------------------------------------------------- #
+
+def cross_batch_direction_stability(
+ model,
+ eval_batches: Sequence[Tuple[torch.Tensor, torch.Tensor]],
+ layer_index: int,
+ loss_fn=F.cross_entropy,
+) -> float:
+ """Mean pairwise cosine of the BP-grad *direction* at layer `layer_index`
+ across disjoint minibatches.
+
+ A value near 1.0 means: the BP gradient direction at that hidden layer is
+ the same regardless of which samples are in the batch. That is the
+ fingerprint of a **sample-invariant global drift** — i.e. the reference
+ vector against which Γ is computed is NOT per-sample credit, it is a
+ constant artifact of the trained network's geometry. On healthy
+ BP-trained or EP-trained networks this value is small (~0.05-0.18); on
+ DFA/SB/CB pre-LN ResMLPs we see ~0.5-0.99.
+
+ Args:
+ model: same contract as the other diagnostics.
+ eval_batches: list of (x, y) pairs, each a *disjoint* minibatch on the
+ same device. Use ≥8 batches of ~128 samples each. The metric is
+ calibrated for this batch size: smaller batches let per-sample
+ noise dominate the batch-mean direction on healthy networks
+ (~0.10), giving the cleanest separation from drift-dominated
+ failure modes (~0.95). Larger batches average out per-sample
+ noise via CLT and shrink the gap.
+ layer_index: which layer's BP grad direction to compare.
+ loss_fn: scalar loss; cross_entropy by default.
+
+ Returns:
+ mean pairwise cosine in [-1, 1].
+ """
+ if len(eval_batches) < 2:
+ raise ValueError("need ≥2 eval batches to measure stability")
+ directions: List[torch.Tensor] = []
+ for x, y in eval_batches:
+ with torch.enable_grad():
+ h = _embed(model, x)
+ hiddens: List[torch.Tensor] = [h]
+ for block in _iter_blocks(model):
+ f = block(h)
+ h = h + f
+ hiddens.append(h)
+ logits = _terminal(model, h)
+ loss = loss_fn(logits, y)
+ grads = torch.autograd.grad(loss, hiddens)
+ g = grads[layer_index]
+ if g.dim() == 3:
+ g = g[:, 0, :]
+ # Reduce to a single direction per batch by averaging across samples,
+ # then normalizing. This is the natural "direction the gradient is
+ # pointing on this batch" signal.
+ gbar = g.mean(dim=0)
+ n = gbar.norm()
+ if n < 1e-30:
+ directions.append(torch.zeros_like(gbar))
+ else:
+ directions.append(gbar / n)
+ # Mean pairwise cosine
+ cos_vals: List[float] = []
+ for i in range(len(directions)):
+ for j in range(i + 1, len(directions)):
+ cos_vals.append(
+ float(torch.dot(directions[i], directions[j]).item())
+ )
+ return sum(cos_vals) / len(cos_vals)
+
+
+# --------------------------------------------------------------------------- #
+# Diagnostic (d) is just a comparison: caller supplies frozen_baseline_acc
+# --------------------------------------------------------------------------- #
+
+
+# --------------------------------------------------------------------------- #
+# Top-level convenience: run all four and produce a verdict
+# --------------------------------------------------------------------------- #
+
+def diagnose(
+ model,
+ eval_batches: Sequence[Tuple[torch.Tensor, torch.Tensor]],
+ headline_acc: float,
+ frozen_baseline_acc: Optional[float] = None,
+ method_name: str = "method",
+ notes: str = "",
+ layer_for_stability: Optional[int] = None,
+ thresholds: Optional[DiagnosticThresholds] = None,
+) -> DiagnosticReport:
+ """Run the full protocol on `model` and return a DiagnosticReport.
+
+ Args:
+ model: trained network. Must satisfy the duck-typed interface
+ described at the top of `protocol.py`.
+ eval_batches: ≥4 disjoint minibatches `(x, y)`, on the model's device.
+ headline_acc: the accuracy you would otherwise have reported as the
+ method's result on this network.
+ frozen_baseline_acc: accuracy of an architecture-matched baseline
+ with the deep blocks **frozen at random initialization** and only
+ embed/LN/head trained. If None, the (d)-diagnostic is skipped.
+ See `protocol/examples/frozen_baseline.py` for a reference impl.
+ method_name: free-form label, only used in the printed report.
+ notes: free-form architecture/dataset description for the report.
+ layer_for_stability: which hidden layer index to use for the cross-
+ batch stability check. Defaults to the middle layer.
+ thresholds: degeneracy thresholds; defaults to those used in our
+ paper (see DiagnosticThresholds for details).
+
+ Returns:
+ A DiagnosticReport summarizing each diagnostic and its verdict.
+ """
+ if thresholds is None:
+ thresholds = DiagnosticThresholds()
+ if not eval_batches:
+ raise ValueError("eval_batches must not be empty")
+ x0, y0 = eval_batches[0]
+
+ h_norms = residual_norms(model, x0)
+ g_norms = bp_grad_norms(model, x0, y0)
+
+ if layer_for_stability is None:
+ layer_for_stability = max(1, len(h_norms) // 2)
+ stability = cross_batch_direction_stability(
+ model, eval_batches, layer_index=layer_for_stability
+ )
+
+ return DiagnosticReport(
+ method_name=method_name,
+ notes=notes,
+ residual_norms=h_norms,
+ bp_grad_norms=g_norms,
+ stability_layer=layer_for_stability,
+ cross_batch_stability=stability,
+ headline_acc=headline_acc,
+ frozen_baseline_acc=frozen_baseline_acc,
+ thresholds=thresholds,
+ )
diff --git a/protocol/report.py b/protocol/report.py
new file mode 100644
index 0000000..15d6c34
--- /dev/null
+++ b/protocol/report.py
@@ -0,0 +1,201 @@
+"""
+DiagnosticReport: structured output of the FA evaluation protocol.
+
+Holds the per-layer numbers from the four diagnostics and emits a verdict
+for each one based on the published thresholds. The verdict is intentionally
+binary ("trustworthy" / "needs walk-back"); fine-grained reading is the
+caller's job.
+"""
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+from typing import List, Optional
+
+
+@dataclass
+class DiagnosticThresholds:
+ """Degeneracy thresholds (defaults match the paper).
+
+ g_norm_floor: BP gradient norms below this are considered to be at the
+ numerical floor (Γ measured against this is not interpretable as
+ credit alignment). Default 1e-7 — well above fp32 floor (~1e-38) and
+ well above F.cosine_similarity's eps=1e-8 clamp, but several orders
+ below healthy networks (~1e-5).
+ h_norm_explosion_ratio: residual stream norm growth (relative to layer 0)
+ above this is considered "exploded". Default 50× — BP-trained
+ networks are ~1-3× per layer; failure modes show ~10^5-10^6×.
+ stability_drift_ceiling: cross-batch direction cosine above this is
+ considered to be drift-dominated (reference vector is sample-
+ invariant). Default 0.30 — BP-trained / EP-trained networks are
+ below 0.20, failure modes are 0.5-0.99.
+ frozen_acc_margin_pp: minimum acc gain (in percentage points) of the
+ trainable-blocks variant over the frozen-random-blocks baseline for
+ the deep blocks to be considered "actually contributing". Default
+ 2.0 pp.
+ """
+
+ g_norm_floor: float = 1e-7
+ h_norm_explosion_ratio: float = 50.0
+ stability_drift_ceiling: float = 0.30
+ frozen_acc_margin_pp: float = 2.0
+
+
+@dataclass
+class DiagnosticReport:
+ """Result of running the protocol on one trained network."""
+
+ method_name: str
+ notes: str
+ residual_norms: List[float]
+ bp_grad_norms: List[float]
+ stability_layer: int
+ cross_batch_stability: float
+ headline_acc: float
+ frozen_baseline_acc: Optional[float]
+ thresholds: DiagnosticThresholds = field(default_factory=DiagnosticThresholds)
+
+ # ------------------------------------------------------------------ #
+ # Per-diagnostic verdicts
+ # ------------------------------------------------------------------ #
+
+ @property
+ def residual_stream_exploded(self) -> bool:
+ if not self.residual_norms:
+ return False
+ h0 = self.residual_norms[0]
+ if h0 <= 0:
+ return False
+ return (max(self.residual_norms) / h0) > self.thresholds.h_norm_explosion_ratio
+
+ @property
+ def bp_grad_at_floor(self) -> bool:
+ if not self.bp_grad_norms:
+ return False
+ # Check the *deepest* hidden layer's BP grad — that's where Γ is
+ # typically reported and where LN-driven collapse hits hardest.
+ return self.bp_grad_norms[-1] < self.thresholds.g_norm_floor
+
+ @property
+ def reference_drift_dominated(self) -> bool:
+ return self.cross_batch_stability > self.thresholds.stability_drift_ceiling
+
+ @property
+ def frozen_baseline_undercut(self) -> Optional[bool]:
+ """True if the trainable-blocks acc fails to clear the frozen baseline
+ by `frozen_acc_margin_pp`. None if no frozen baseline supplied.
+ """
+ if self.frozen_baseline_acc is None:
+ return None
+ margin_pp = (self.headline_acc - self.frozen_baseline_acc) * 100
+ return margin_pp < self.thresholds.frozen_acc_margin_pp
+
+ # ------------------------------------------------------------------ #
+ # Aggregate verdict
+ # ------------------------------------------------------------------ #
+
+ @property
+ def verdict(self) -> str:
+ flags = [
+ ("residual stream exploded", self.residual_stream_exploded),
+ ("BP grad at numerical floor", self.bp_grad_at_floor),
+ ("BP grad direction is drift-dominated", self.reference_drift_dominated),
+ ]
+ if self.frozen_baseline_undercut is True:
+ flags.append(("deep blocks fail to beat frozen-random baseline", True))
+ flagged = [name for name, val in flags if val]
+ if not flagged:
+ return "trustworthy"
+ return "needs walk-back: " + "; ".join(flagged)
+
+ # ------------------------------------------------------------------ #
+ # Pretty-print
+ # ------------------------------------------------------------------ #
+
+ def __str__(self) -> str:
+ L = len(self.residual_norms)
+ lines: List[str] = []
+ lines.append("=" * 72)
+ lines.append(f"FA Diagnostic Protocol Report — method: {self.method_name}")
+ if self.notes:
+ lines.append(f"Notes: {self.notes}")
+ lines.append("=" * 72)
+
+ # (a) Residual stream norms
+ lines.append("(a) Residual stream norms ||h_l||_2 (median over batch):")
+ for l in range(L):
+ lines.append(f" h_{l}: {self.residual_norms[l]:.3e}")
+ if self.residual_stream_exploded:
+ lines.append(
+ f" FLAG: max/min ratio "
+ f"{max(self.residual_norms)/max(self.residual_norms[0],1e-30):.2e} "
+ f"> threshold {self.thresholds.h_norm_explosion_ratio}× — "
+ "residual stream exploded."
+ )
+
+ # (b) BP grad norms
+ lines.append("")
+ lines.append("(b) BP gradient norms ||g_l||_2 (median over batch):")
+ for l in range(L):
+ lines.append(f" g_{l}: {self.bp_grad_norms[l]:.3e}")
+ if self.bp_grad_at_floor:
+ lines.append(
+ f" FLAG: deepest ||g_L|| = {self.bp_grad_norms[-1]:.2e} "
+ f"< floor {self.thresholds.g_norm_floor:.0e} — Γ measured "
+ "against numerical floor; not interpretable as credit alignment."
+ )
+
+ # (c) Cross-batch direction stability
+ lines.append("")
+ lines.append(
+ f"(c) Cross-batch direction stability at layer {self.stability_layer}: "
+ f"{self.cross_batch_stability:.3f}"
+ )
+ if self.reference_drift_dominated:
+ lines.append(
+ f" FLAG: stability {self.cross_batch_stability:.3f} > "
+ f"ceiling {self.thresholds.stability_drift_ceiling} — "
+ "reference vector is sample-invariant drift, not per-sample credit."
+ )
+
+ # (d) Frozen baseline comparison
+ lines.append("")
+ if self.frozen_baseline_acc is None:
+ lines.append("(d) Frozen-blocks baseline: NOT PROVIDED — diagnostic skipped.")
+ else:
+ margin_pp = (self.headline_acc - self.frozen_baseline_acc) * 100
+ lines.append(
+ f"(d) Headline acc: {self.headline_acc:.4f}, "
+ f"frozen-blocks baseline: {self.frozen_baseline_acc:.4f}, "
+ f"margin: {margin_pp:+.2f} pp"
+ )
+ if self.frozen_baseline_undercut:
+ lines.append(
+ f" FLAG: margin {margin_pp:+.2f} pp < "
+ f"required {self.thresholds.frozen_acc_margin_pp} pp — "
+ "deep blocks are not contributing over a random untrained baseline."
+ )
+
+ # Verdict
+ lines.append("")
+ lines.append(f"VERDICT: {self.verdict}")
+ lines.append("=" * 72)
+ return "\n".join(lines)
+
+ def to_dict(self) -> dict:
+ return {
+ "method_name": self.method_name,
+ "notes": self.notes,
+ "residual_norms": list(self.residual_norms),
+ "bp_grad_norms": list(self.bp_grad_norms),
+ "stability_layer": self.stability_layer,
+ "cross_batch_stability": self.cross_batch_stability,
+ "headline_acc": self.headline_acc,
+ "frozen_baseline_acc": self.frozen_baseline_acc,
+ "verdict": self.verdict,
+ "thresholds": {
+ "g_norm_floor": self.thresholds.g_norm_floor,
+ "h_norm_explosion_ratio": self.thresholds.h_norm_explosion_ratio,
+ "stability_drift_ceiling": self.thresholds.stability_drift_ceiling,
+ "frozen_acc_margin_pp": self.thresholds.frozen_acc_margin_pp,
+ },
+ }
diff --git a/protocol/smoke_test.py b/protocol/smoke_test.py
new file mode 100644
index 0000000..272a9e5
--- /dev/null
+++ b/protocol/smoke_test.py
@@ -0,0 +1,136 @@
+"""
+Smoke test for the FA Diagnostic Protocol reference implementation.
+
+Loads BP-trained and DFA-trained ResMLP checkpoints from
+`results/confirmatory/checkpoints_A2/` and applies the protocol to each.
+The protocol should:
+ - return verdict="trustworthy" on the BP checkpoint (residual norms bounded,
+ BP grad ~5e-5 well above floor, low cross-batch direction stability),
+ - flag the DFA checkpoint as "needs walk-back" (residual stream exploded
+ by ~10^7×, BP grad at ~5e-10, drift-dominated direction).
+
+Run with:
+ CUDA_VISIBLE_DEVICES=2 python -m protocol.smoke_test
+"""
+import os
+import sys
+
+import torch
+import torchvision
+import torchvision.transforms as transforms
+from torch.utils.data import DataLoader
+
+# Make `protocol` and `models` importable when invoked from repo root
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.residual_mlp import ResidualMLP # noqa: E402
+from protocol import diagnose # noqa: E402
+
+
+REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+CHECKPOINT_DIR = os.path.join(REPO_ROOT, "results/confirmatory/checkpoints_A2")
+EP_CHECKPOINT_DIR = os.path.join(REPO_ROOT, "results/ep_baseline")
+
+
+def load_eval_batches(n_batches: int = 4, batch_size: int = 1024, device="cuda:0"):
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
+ loader = DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=0)
+ batches = []
+ for x, y in loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batches.append((x, y))
+ if len(batches) >= n_batches:
+ break
+ return batches
+
+
+def evaluate(model, loader, device):
+ model.eval()
+ correct = total = 0
+ with torch.no_grad():
+ for x, y in loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ preds = model(x).argmax(-1)
+ correct += (preds == y).sum().item()
+ total += x.size(0)
+ return correct / total
+
+
+def load_model(method: str, seed: int, device):
+ if method == "ep":
+ path = os.path.join(EP_CHECKPOINT_DIR, f"ep_s{seed}.pt")
+ else:
+ path = os.path.join(CHECKPOINT_DIR, f"{method}_s{seed}.pt")
+ ckpt = torch.load(path, map_location=device, weights_only=False)
+ # Try several common checkpoint formats
+ if isinstance(ckpt, dict) and "state_dict" in ckpt:
+ sd = ckpt["state_dict"]
+ elif isinstance(ckpt, dict) and "model" in ckpt:
+ sd = ckpt["model"]
+ elif isinstance(ckpt, dict):
+ sd = ckpt
+ else:
+ sd = ckpt.state_dict()
+ model = ResidualMLP(input_dim=3072, d_hidden=256, num_classes=10, num_blocks=4).to(device)
+ model.load_state_dict(sd)
+ return model
+
+
+def main():
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ print(f"Device: {device}")
+
+ # Cross-batch stability is sensitive to batch size: with smaller batches
+ # the per-sample noise component dominates the batch-mean direction on
+ # healthy networks, giving the cleanest separation from the drift-
+ # dominated failure mode (~0.10 healthy vs ~0.95 drift at N=128).
+ eval_batches = load_eval_batches(n_batches=10, batch_size=128, device=device)
+
+ # Build a single test loader for headline acc
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
+ test_loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0)
+
+ # BP: trustworthy. DFA: walked back. EP: trustworthy (the internal control
+ # — same architecture/metric/dataset, but EP doesn't blow up the residual
+ # stream, so the diagnostic protocol passes for EP even though EP's
+ # accuracy is also low. This is the paper's central comparison.)
+ for method, expected in [
+ ("bp", "trustworthy"),
+ ("dfa", "needs walk-back"),
+ ("ep", "trustworthy"),
+ ]:
+ print()
+ print(f"### {method.upper()} (seed 42)")
+ model = load_model(method, 42, device)
+ acc = evaluate(model, test_loader, device)
+ report = diagnose(
+ model=model,
+ eval_batches=eval_batches,
+ headline_acc=acc,
+ frozen_baseline_acc=None,
+ method_name=method.upper(),
+ notes="4-block d=256 ResMLP, CIFAR-10",
+ )
+ print(report)
+ # Sanity check
+ verdict = report.verdict
+ if expected == "trustworthy" and verdict != "trustworthy":
+ print(f"!! UNEXPECTED: BP should be trustworthy, got '{verdict}'")
+ elif expected == "needs walk-back" and not verdict.startswith("needs walk-back"):
+ print(f"!! UNEXPECTED: DFA should need walk-back, got '{verdict}'")
+ else:
+ print(f"OK: verdict matches expectation ('{expected}')")
+
+
+if __name__ == "__main__":
+ main()