diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 22:20:48 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 22:20:48 -0500 |
| commit | 7b64702ad970c16171142665365e16a8e1737190 (patch) | |
| tree | 30df9313fc195aa037f05c0598fdc0df8b32bbcc /protocol | |
| parent | 2dc8e7efb5f2ff827fbd97be87f0127aa5ab2757 (diff) | |
Add FA diagnostic protocol reference implementation
Codex round 15 #1 priority for the E&D-track paper:
- protocol/protocol.py: 4 diagnostics (residual norms, BP grad norms,
cross-batch direction stability, and a frozen-baseline comparator)
- protocol/report.py: DiagnosticReport with per-diagnostic verdicts and
pretty-printer
- protocol/smoke_test.py: validates BP/DFA/EP checkpoints produce the
expected verdicts (BP/EP trustworthy; DFA walked back via residual
explosion + BP grad at floor)
- protocol/README.md: usage, audit cases, threshold rationale
- protocol/CHECKLIST.md: 6 evaluation pipeline pitfalls (norm(-1),
cosine_similarity eps clamp, fp16 underflow, Bs reproducibility,
aggregation, layer-0 dominance)
- protocol/REPORTING_TEMPLATE.md: per-method fillable form for FA papers
Diffstat (limited to 'protocol')
| -rw-r--r-- | protocol/CHECKLIST.md | 96 | ||||
| -rw-r--r-- | protocol/README.md | 150 | ||||
| -rw-r--r-- | protocol/REPORTING_TEMPLATE.md | 69 | ||||
| -rw-r--r-- | protocol/__init__.py | 52 | ||||
| -rw-r--r-- | protocol/protocol.py | 290 | ||||
| -rw-r--r-- | protocol/report.py | 201 | ||||
| -rw-r--r-- | protocol/smoke_test.py | 136 |
7 files changed, 994 insertions, 0 deletions
diff --git a/protocol/CHECKLIST.md b/protocol/CHECKLIST.md new file mode 100644 index 0000000..3227642 --- /dev/null +++ b/protocol/CHECKLIST.md @@ -0,0 +1,96 @@ +# FA Evaluation Pitfalls Checklist + +Six evaluation-pipeline bugs we found in our own codebase while developing +this protocol. Each one silently corrupted reported numbers in a way that +would have been very hard for a reviewer to catch. Use this list as a +final-pass review for any FA evaluation paper. + +## 1. `tensor.norm(-1)` is NOT "L₂ along dim=-1" + +`tensor.norm(-1)` computes the L₋₁ "norm" (`1/sum(1/|x|)`) of the *entire* +flattened tensor, not the per-row L₂ norm. The correct call for "L₂ along +the feature dimension" is `tensor.norm(dim=-1)`. + +This single bug invalidated several months of our own gradient-norm +measurements; gradient sparsity tables, "support sparsity" analyses, and +the `naive_StateErr` metric in three of our trainers were all degenerate. + +**Check**: grep your codebase for `\.norm\(-1\)` and `\.norm\(\s*-1\s*\)`. +Replace with `.norm(dim=-1)`. + +## 2. `F.cosine_similarity` clamps small magnitudes via `eps=1e-8` + +`torch.nn.functional.cosine_similarity(a, b)` divides by `max(||a||·||b||, +eps)` with `eps=1e-8` by default. When the BP gradient at hidden layers is +~1e-10 (which it is on DFA-trained pre-LN ResMLPs), the divisor becomes +`||a|| · 1e-8` instead of `||a|| · 1e-10`, scaling the reported cosine by +`||g|| / eps` ≈ 0.01. The reported "Γ ≈ 0.10" is then off by ~100× from the +true cosine. + +**Check**: when computing Γ, either use a hand-rolled cosine that does not +clamp, or assert `||g|| > 1e-8` before calling `F.cosine_similarity`. If you +report Γ alongside ‖g‖ and ‖g‖ is at the floor, your Γ is uninterpretable — +the diagnostic protocol will catch this. + +## 3. fp16 mixed precision underflows BP grads at hidden layers + +Computing Γ in fp16 mixed precision on a DFA-trained ResMLP gives `nan` +because BP grads (~5e-10) are below fp16's smallest subnormal (~6e-8). +**bf16 works** because it has the same exponent range as fp32 (only the +mantissa is reduced). + +**Check**: if your reported metric depends on a tiny intermediate magnitude, +either compute it in fp32, switch to bf16, or rescale by a known constant +before the precision-sensitive step. + +## 4. Random feedback `Bs` are training-specific — Γ is not invariant to + reseeding them + +DFA at fixed seed s42 reports Γ ≈ 0.106 *with the random feedback projection +matrices `Bs` that were used during training*. With 20 fresh random `Bs` +draws on the same trained network, Γ ≈ 0 ± 0.005. The non-zero Γ is the +network adapting to the *specific* `Bs` it saw during training, not an +intrinsic property of the local credit signal. + +**Check**: when you report Γ for an FA-style method, specify exactly which +`Bs` were used. If you cannot reproduce them from the seed, your numbers +are not reproducible. + +## 5. Aggregation strategy across layers / samples / batches is rarely + specified, but determines the headline number + +We computed DFA's "headline Γ" with four equally valid aggregation +strategies: (median over layers) × (mean over samples), (mean over layers) × +(median over samples), and so on. The four strategies give Γ ∈ [0.085, +0.211] on the *same* trained DFA s42 network — a 2.5× spread. For credit +bridge it's 5× (Γ ∈ [0.057, 0.285]). + +**Check**: state explicitly how you aggregate Γ. Report per-layer values +and let the reader pick. Do not collapse to a single number without showing +the layer breakdown. + +## 6. Layer-0 dominates the headline Γ; deeper layers are ~0 + +For DFA on a 4-block ResMLP, the headline Γ ≈ 0.10 is driven almost entirely +by the embedding layer (Γ_layer0 ≈ 0.43). The block layers have Γ ≈ 0. A +"mean over layers" summary statistic hides this. The same pattern likely +holds for other FA-style methods. + +**Check**: always report per-layer Γ. A single average is misleading when +one layer dominates. + +## Suggested final-pass workflow + +1. Run `grep -nr '\.norm(-1)' YOUR_REPO` and fix all matches. +2. Run the diagnostic protocol on your trained model. If any flag fires, + fix the underlying issue or walk back the corresponding claim before + submission. +3. Inspect every place you compute Γ: + - Is the divisor clamped? + - Are you in fp16? + - Are you using the training `Bs` or fresh ones? + - How do you aggregate? + - Do you report per-layer? +4. Report `‖h_l‖`, `‖g_l‖`, and `Γ_l` *together* for every layer. The + metrics are entangled; reporting one without the others hides the + pathology. diff --git a/protocol/README.md b/protocol/README.md new file mode 100644 index 0000000..770d582 --- /dev/null +++ b/protocol/README.md @@ -0,0 +1,150 @@ +# FA Diagnostic Protocol — reference implementation + +A drop-in toolkit for evaluating Feedback Alignment (FA) and other local-credit +methods on modern residual architectures. Implements the four diagnostics +described in *Beyond Accuracy and Alignment: A Diagnostic Protocol for +Feedback Alignment Evaluation*. + +## What this fixes + +Standard FA evaluation reports two numbers: + +1. **Headline accuracy** of the trained network. +2. **Γ** — cosine alignment between the local credit signal and the + true BP gradient at hidden layers. + +Both can silently fail on modern pre-LayerNorm residual networks: + +- **Headline accuracy** fails because deep blocks may not be contributing at + all — the embedding + terminal-LN + head pathway alone can match (or + exceed) the trainable-blocks accuracy. We have measured DFA on a 4-block + d=256 ResMLP at 30.8% test acc, *below* the same architecture's + frozen-random-blocks baseline at 34.9%. The headline number reads as + "DFA learns" but the deep blocks are *actively destroying* value. +- **Γ** fails because residual-stream explosion drives the BP gradient at + hidden layers to ~10⁻¹⁰ via terminal-LayerNorm scaling. Γ then computes + the cosine of the local credit signal against a numerical-floor reference + vector. The reported value is mathematically well-defined but does not + measure credit-assignment quality. + +This protocol catches both failures and tells you whether your reported +numbers are interpretable. + +## The four diagnostics + +| | Diagnostic | What it catches | +|---|---|---| +| (a) | Per-layer residual stream norm ‖h_l‖₂ | Scale pathology — if ‖h_L‖ has grown ≫50× over depth, the network is in a degenerate regime. | +| (b) | Per-layer BP gradient norm ‖g_l‖₂ | Reference-vector floor — if ‖g_L‖ < 1e-7, Γ is being measured against the numerical floor and is not interpretable. | +| (c) | Cross-batch direction stability of g | Drift dominance — if the BP gradient direction is sample-invariant (cosine ≈ 1 across disjoint batches), the reference is a global drift artifact, not per-sample credit. | +| (d) | Frozen-blocks baseline accuracy | Useful blocks — if the trainable-blocks variant fails to clear an architecture-matched frozen-random-blocks baseline by ≥2 pp, the deep blocks are not contributing. | + +## Quick start + +```python +from protocol import diagnose + +report = diagnose( + model=trained_model, # any PyTorch module exposing + # `model(x, return_hidden=True)`, + # `model.embed`/`model.patch_embed`, + # `model.blocks`, and + # `model.out_ln`/`model.out_head` + # (or `model.norm`/`model.head`) + eval_batches=eval_batches, # ≥8 disjoint (x,y) batches of ~128 + headline_acc=test_acc, # the number you would normally publish + frozen_baseline_acc=frozen_acc, # acc of an architecture-matched + # frozen-random-blocks control + # (set None to skip diagnostic d) + method_name="DFA", + notes="4-block d=256 ResMLP, CIFAR-10", +) +print(report) +# emits a per-diagnostic table and either +# "VERDICT: trustworthy" +# or +# "VERDICT: needs walk-back: <list of failed diagnostics>" +``` + +## Reproducing the audit cases + +We ship a smoke test that loads BP, DFA, and EP checkpoints from +`results/confirmatory/checkpoints_A2/` and `results/ep_baseline/`, runs the +protocol on each, and checks that the verdicts match expectation: + +```bash +CUDA_VISIBLE_DEVICES=2 python -m protocol.smoke_test +``` + +Expected verdicts (4-block d=256 ResMLP, CIFAR-10, seed 42): + +| method | ‖h₄‖ | ‖g₄‖ | stability | verdict | +|---|---:|---:|---:|---| +| BP | 2.1e2 | 3.7e-4 | 0.099 | trustworthy | +| DFA | 4.4e8 | 4.2e-9 | 0.047 | walked back: explosion + BP grad at floor | +| EP | 3.3e3 | 1.6e-4 | -0.036 | trustworthy (internal control) | + +EP is the paper's central internal control: same architecture, same dataset, +same metric, but EP's training does not produce residual-stream explosion, +so the protocol passes for EP even though EP's headline accuracy is also low +(~32%). This isolates "the network is in a degenerate measurement regime" +from "the method underperforms" — they are different failure modes with +different evidence. + +## Duck-typed model interface + +The protocol works on any model exposing: + +```python +# Forward with hidden states +logits, hiddens = model(x, return_hidden=True) +# hiddens is a list [h_0, h_1, ..., h_L]; supports either (B, d) or (B, T, d) +# (the protocol pools the cls token via h[:, 0, :] in the latter case) + +# Building blocks for the manual forward (used by diagnostics b and c) +model.embed OR model.patch_embed +model.blocks # iterable of nn.Modules; each block(h) returns the + # *residual branch* f_l(h_l), NOT h_l + f_l(h_l) +model.out_ln, model.out_head OR model.norm, model.head +``` + +`models/residual_mlp.py` and `models/vit_mini.py` in this repo both satisfy +this interface as written. + +## Thresholds (and how to override them) + +```python +from protocol import diagnose, DiagnosticReport +from protocol.report import DiagnosticThresholds + +custom = DiagnosticThresholds( + g_norm_floor=1e-6, # tighter floor + h_norm_explosion_ratio=20.0, # tighter scale check + stability_drift_ceiling=0.30, + frozen_acc_margin_pp=2.0, +) +report = diagnose(..., thresholds=custom) +``` + +The defaults match the paper: + +- `g_norm_floor = 1e-7` — well above fp32's smallest subnormal (~1e-38) and + well above F.cosine_similarity's eps=1e-8 clamp, but several orders below + healthy networks (~1e-5). +- `h_norm_explosion_ratio = 50.0` — BP-trained networks are ~1-3× per layer + in our measurements; failure modes show ~10⁵-10⁶× from layer 0 to layer L. +- `stability_drift_ceiling = 0.30` — calibrated for ~128-sample batches. + BP/EP-trained networks: 0.05-0.18. DFA/SB drift mode: 0.43-0.99. +- `frozen_acc_margin_pp = 2.0` — DFA-on-ResMLP undercuts frozen baseline by + 4 pp; the threshold is loose enough to allow noise on healthy networks. + +## Pipeline pitfalls (the checklist) + +In addition to the diagnostics above, we ship `CHECKLIST.md` — six +evaluation-pipeline bugs we found in our own dogfood codebase that silently +corrupt FA evaluation results. We recommend a final pre-publication pass +through the checklist regardless of which diagnostic protocol you use. + +## Citation + +(TBD — points to the paper PDF when uploaded.) diff --git a/protocol/REPORTING_TEMPLATE.md b/protocol/REPORTING_TEMPLATE.md new file mode 100644 index 0000000..5d4a187 --- /dev/null +++ b/protocol/REPORTING_TEMPLATE.md @@ -0,0 +1,69 @@ +# FA Evaluation Reporting Template + +A minimal, fillable reporting template for any paper that evaluates a +local-credit (FA-style) method on a residual architecture. Reviewers can +use this as a checklist to verify whether the reported numbers are in the +metric's meaningful regime. + +Copy the table below into your paper appendix or supplementary material. +Fill in one row per `(method × architecture × dataset × seed)`. + +--- + +## Method × architecture identification + +| field | value | +|---|---| +| Method | _e.g._ DFA, EP, SB, CB, BP | +| Method local-credit signal | _formula_, _e.g._ `e_T B_l^T` for DFA | +| Architecture | _e.g._ 4-block d=256 pre-LN ResidualMLP | +| Has terminal LayerNorm before head | yes / no | +| Dataset | _e.g._ CIFAR-10 | +| Number of seeds reported | _e.g._ 3 | +| Random feedback `Bs` (if any): seed/spec | | + +## Headline numbers + +| field | value | +|---|---| +| Test accuracy (mean ± std over seeds) | | +| Headline Γ (cosine to BP gradient) | | +| Aggregation: how is Γ collapsed? | _e.g._ "mean over layers, then mean over samples" | +| Per-layer Γ table reported in §X | yes / no | + +## Diagnostic protocol numbers (this is the new content) + +| diagnostic | per-layer values | flag? | +|---|---|---| +| (a) ‖h_l‖₂ for l = 0 .. L | _e.g._ [255, 226, 211, 205, 203] | scale OK / **EXPLODED** | +| (b) ‖g_l‖₂ for l = 0 .. L | _e.g._ [4e-4, 4e-4, 4e-4, 4e-4, 3e-4] | grad OK / **AT FLOOR** | +| (c) cross-batch direction stability at layer L/2 | _e.g._ 0.099 (N=128, ≥8 batches) | OK / **DRIFT** | +| (d) frozen-blocks baseline acc | _e.g._ 0.349 | trainable acc clears it by ≥2 pp / **UNDERCUT** | + +If any of (a)-(d) fires a flag, headline accuracy and Γ should be walked +back or accompanied by an explicit caveat. We recommend using the bundled +`protocol.diagnose(...)` function and pasting its `__str__` output as the +appendix table. + +## Pipeline-pitfalls disclosure + +State explicitly: + +- [ ] We grepped for `tensor.norm(-1)` (the L₋₁ footgun) and confirmed all + gradient-norm computations use `tensor.norm(dim=-1)`. +- [ ] We did not compute Γ in fp16 (or, if we did, we verified no underflow + and reported the precision). +- [ ] We disclose the random feedback `Bs` used during training (if + applicable) and reproduced reported Γ with those exact `Bs`. +- [ ] We report per-layer Γ in addition to any aggregate. +- [ ] We report `‖h_l‖`, `‖g_l‖`, and `Γ_l` *together* for every layer. +- [ ] We ran an architecture-matched frozen-random-blocks baseline and + report its accuracy alongside the trainable-blocks variant. + +## Internal control (strongly recommended) + +If your method is *not* the only candidate on this architecture, run the +protocol on a second method that you have reason to believe operates in +the meaningful regime (e.g., BP itself, or EP). The protocol should pass +on the control. If it does not, you have an architecture problem rather +than a method problem, and the paper should reflect that. diff --git a/protocol/__init__.py b/protocol/__init__.py new file mode 100644 index 0000000..770bdda --- /dev/null +++ b/protocol/__init__.py @@ -0,0 +1,52 @@ +""" +FA Diagnostic Protocol — reference implementation. + +A drop-in toolkit for evaluating Feedback Alignment (and related local-credit) +methods on modern residual architectures. Implements the four diagnostics +described in *Beyond Accuracy and Alignment: A Diagnostic Protocol for +Feedback Alignment Evaluation*: + + (a) per-layer residual stream norm ||h_l||_2 + (b) per-layer BP gradient norm ||g_l||_2 + (c) cross-batch direction stability s_l + (d) frozen-blocks control accuracy acc_frozen + +The motivation is that headline accuracy and offline Γ (cosine alignment of +local credit to BP gradient) can both silently fail on modern pre-LayerNorm +residual networks: residual-stream explosion drives ||g|| ~10^-10 and turns Γ +into a measurement of cosine to a numerical-floor reference vector. This +protocol catches that regime change and tells you whether your reported +numbers are interpretable. + +Quick start: + + from protocol import diagnose, DiagnosticReport + + report = diagnose( + model=trained_model, + eval_batches=list_of_(x, y)_pairs, + frozen_baseline_acc=accuracy_of_random_blocks_baseline, + method_name="DFA", + notes="4-block d=256 ResMLP, CIFAR-10", + ) + print(report) + +The report flags any diagnostic that crossed a "degenerate regime" threshold +and emits a verdict ("trustworthy" / "needs walk-back"). +""" + +from .protocol import ( + residual_norms, + bp_grad_norms, + cross_batch_direction_stability, + diagnose, +) +from .report import DiagnosticReport + +__all__ = [ + "residual_norms", + "bp_grad_norms", + "cross_batch_direction_stability", + "diagnose", + "DiagnosticReport", +] diff --git a/protocol/protocol.py b/protocol/protocol.py new file mode 100644 index 0000000..4298756 --- /dev/null +++ b/protocol/protocol.py @@ -0,0 +1,290 @@ +""" +Core diagnostics for the FA evaluation protocol. + +The four diagnostics: + + (a) residual_norms — per-layer ||h_l||_2 on a fixed eval batch + (b) bp_grad_norms — per-layer ||∂CE/∂h_l||_2 on a fixed eval batch + (c) cross_batch_direction_stability — cosine of normalized BP-grad direction + across disjoint minibatches; high values + indicate the reference vector is dominated by + a sample-invariant global drift, NOT per-sample + credit + (d) frozen-blocks accuracy — caller-supplied; the protocol just compares it + to the headline accuracy of the trained network + +The protocol expects a model exposing a `return_hidden=True` mode that yields +the per-layer residual stream as a list `[h_0, h_1, ..., h_L]`, plus the +standard `out_ln` + `out_head` terminal stack. ResidualMLP and ViT-Mini in +this repo both satisfy that contract; see the docstring of `diagnose` for the +duck-typed interface. + +All metrics use `tensor.norm(dim=-1)` (L2 along the feature dim) — NEVER +`tensor.norm(-1)` (which is L_{-1} of the entire tensor and is the bug behind +several walk-backs in our project). +""" +from __future__ import annotations + +from typing import List, Sequence, Tuple, Optional + +import torch +import torch.nn.functional as F + +from .report import DiagnosticReport, DiagnosticThresholds + + +# --------------------------------------------------------------------------- # +# Diagnostic (a): residual stream norm +# --------------------------------------------------------------------------- # + +def residual_norms(model, x: torch.Tensor) -> List[float]: + """Per-layer median ||h_l||_2 on the eval batch. + + Args: + model: must support `model(x, return_hidden=True) -> (logits, hiddens)` + where hiddens is `[h_0, ..., h_L]` (each (B, ..., d_hidden)). + x: eval batch (already on the correct device, already in the input + shape the model wants). + + Returns: + list of L+1 floats — the median (over batch) of ||h_l||_2 per layer. + """ + model.eval() + with torch.no_grad(): + _, hiddens = model(x, return_hidden=True) + out: List[float] = [] + for h in hiddens: + # If h has token dim (B, T, d), pool tokens to a single vector per + # sample by taking the cls token (index 0) — that's what the head + # actually consumes in our ViT. + if h.dim() == 3: + h = h[:, 0, :] + # NOTE: dim=-1 is mandatory; `.norm(-1)` would compute L_{-1} of the + # whole tensor and is wrong (see project_norm_minus1_footgun). + out.append(h.norm(dim=-1).median().item()) + return out + + +# --------------------------------------------------------------------------- # +# Diagnostic (b): BP gradient norm at hidden layers +# --------------------------------------------------------------------------- # + +def bp_grad_norms( + model, + x: torch.Tensor, + y: torch.Tensor, + loss_fn=F.cross_entropy, +) -> List[float]: + """Per-layer median ||∂L/∂h_l||_2 on the eval batch. + + The BP gradient at hidden layers is the *reference vector* that offline + alignment metrics (Γ) compare against. When this is at the numerical + floor (~1e-10) on modern pre-LN residual nets, Γ measures alignment to a + noise vector and is no longer interpretable as "credit assignment + quality". + + Args: + model: same contract as for `residual_norms`. + x: eval batch. + y: labels for the eval batch. + loss_fn: scalar loss given (logits, y); cross_entropy by default. + + Returns: + list of L+1 floats — the median (over batch) of ||g_l||_2 per layer, + where g_l = ∂loss/∂h_l. + """ + model.eval() + # Forward through the canonical path (embed -> blocks with explicit + # residual addition -> terminal LN -> head), keeping every intermediate + # h_l as a non-leaf in the autograd graph. `torch.autograd.grad` accepts + # non-leaf tensors as `inputs` and computes ∂L/∂h_l via the chain rule. + # NOTE: do not clone+detach the hidden states — that severs the graph. + with torch.enable_grad(): + h = _embed(model, x) + hiddens: List[torch.Tensor] = [h] + for block in _iter_blocks(model): + f = block(h) + h = h + f + hiddens.append(h) + logits = _terminal(model, h) + loss = loss_fn(logits, y) + grads = torch.autograd.grad(loss, hiddens) + out: List[float] = [] + for g in grads: + if g.dim() == 3: + g = g[:, 0, :] + out.append(g.norm(dim=-1).median().item()) + return out + + +def _embed(model, x): + if hasattr(model, "embed"): + return model.embed(x) + if hasattr(model, "patch_embed"): + # ViT-style + return model.patch_embed(x) + raise AttributeError( + "model must expose `.embed(x)` or `.patch_embed(x)` for the protocol" + ) + + +def _iter_blocks(model): + if hasattr(model, "blocks"): + return list(model.blocks) + raise AttributeError("model must expose `.blocks` (a Module list)") + + +def _terminal(model, h): + """Apply terminal LN (if any) + classification head to hidden state.""" + if hasattr(model, "out_ln") and hasattr(model, "out_head"): + return model.out_head(model.out_ln(h)) + if hasattr(model, "norm") and hasattr(model, "head"): + # ViT-style naming + h = model.norm(h) + if h.dim() == 3: + h = h[:, 0, :] + return model.head(h) + raise AttributeError( + "model must expose `(out_ln, out_head)` or `(norm, head)` for the protocol" + ) + + +# --------------------------------------------------------------------------- # +# Diagnostic (c): cross-batch direction stability +# --------------------------------------------------------------------------- # + +def cross_batch_direction_stability( + model, + eval_batches: Sequence[Tuple[torch.Tensor, torch.Tensor]], + layer_index: int, + loss_fn=F.cross_entropy, +) -> float: + """Mean pairwise cosine of the BP-grad *direction* at layer `layer_index` + across disjoint minibatches. + + A value near 1.0 means: the BP gradient direction at that hidden layer is + the same regardless of which samples are in the batch. That is the + fingerprint of a **sample-invariant global drift** — i.e. the reference + vector against which Γ is computed is NOT per-sample credit, it is a + constant artifact of the trained network's geometry. On healthy + BP-trained or EP-trained networks this value is small (~0.05-0.18); on + DFA/SB/CB pre-LN ResMLPs we see ~0.5-0.99. + + Args: + model: same contract as the other diagnostics. + eval_batches: list of (x, y) pairs, each a *disjoint* minibatch on the + same device. Use ≥8 batches of ~128 samples each. The metric is + calibrated for this batch size: smaller batches let per-sample + noise dominate the batch-mean direction on healthy networks + (~0.10), giving the cleanest separation from drift-dominated + failure modes (~0.95). Larger batches average out per-sample + noise via CLT and shrink the gap. + layer_index: which layer's BP grad direction to compare. + loss_fn: scalar loss; cross_entropy by default. + + Returns: + mean pairwise cosine in [-1, 1]. + """ + if len(eval_batches) < 2: + raise ValueError("need ≥2 eval batches to measure stability") + directions: List[torch.Tensor] = [] + for x, y in eval_batches: + with torch.enable_grad(): + h = _embed(model, x) + hiddens: List[torch.Tensor] = [h] + for block in _iter_blocks(model): + f = block(h) + h = h + f + hiddens.append(h) + logits = _terminal(model, h) + loss = loss_fn(logits, y) + grads = torch.autograd.grad(loss, hiddens) + g = grads[layer_index] + if g.dim() == 3: + g = g[:, 0, :] + # Reduce to a single direction per batch by averaging across samples, + # then normalizing. This is the natural "direction the gradient is + # pointing on this batch" signal. + gbar = g.mean(dim=0) + n = gbar.norm() + if n < 1e-30: + directions.append(torch.zeros_like(gbar)) + else: + directions.append(gbar / n) + # Mean pairwise cosine + cos_vals: List[float] = [] + for i in range(len(directions)): + for j in range(i + 1, len(directions)): + cos_vals.append( + float(torch.dot(directions[i], directions[j]).item()) + ) + return sum(cos_vals) / len(cos_vals) + + +# --------------------------------------------------------------------------- # +# Diagnostic (d) is just a comparison: caller supplies frozen_baseline_acc +# --------------------------------------------------------------------------- # + + +# --------------------------------------------------------------------------- # +# Top-level convenience: run all four and produce a verdict +# --------------------------------------------------------------------------- # + +def diagnose( + model, + eval_batches: Sequence[Tuple[torch.Tensor, torch.Tensor]], + headline_acc: float, + frozen_baseline_acc: Optional[float] = None, + method_name: str = "method", + notes: str = "", + layer_for_stability: Optional[int] = None, + thresholds: Optional[DiagnosticThresholds] = None, +) -> DiagnosticReport: + """Run the full protocol on `model` and return a DiagnosticReport. + + Args: + model: trained network. Must satisfy the duck-typed interface + described at the top of `protocol.py`. + eval_batches: ≥4 disjoint minibatches `(x, y)`, on the model's device. + headline_acc: the accuracy you would otherwise have reported as the + method's result on this network. + frozen_baseline_acc: accuracy of an architecture-matched baseline + with the deep blocks **frozen at random initialization** and only + embed/LN/head trained. If None, the (d)-diagnostic is skipped. + See `protocol/examples/frozen_baseline.py` for a reference impl. + method_name: free-form label, only used in the printed report. + notes: free-form architecture/dataset description for the report. + layer_for_stability: which hidden layer index to use for the cross- + batch stability check. Defaults to the middle layer. + thresholds: degeneracy thresholds; defaults to those used in our + paper (see DiagnosticThresholds for details). + + Returns: + A DiagnosticReport summarizing each diagnostic and its verdict. + """ + if thresholds is None: + thresholds = DiagnosticThresholds() + if not eval_batches: + raise ValueError("eval_batches must not be empty") + x0, y0 = eval_batches[0] + + h_norms = residual_norms(model, x0) + g_norms = bp_grad_norms(model, x0, y0) + + if layer_for_stability is None: + layer_for_stability = max(1, len(h_norms) // 2) + stability = cross_batch_direction_stability( + model, eval_batches, layer_index=layer_for_stability + ) + + return DiagnosticReport( + method_name=method_name, + notes=notes, + residual_norms=h_norms, + bp_grad_norms=g_norms, + stability_layer=layer_for_stability, + cross_batch_stability=stability, + headline_acc=headline_acc, + frozen_baseline_acc=frozen_baseline_acc, + thresholds=thresholds, + ) diff --git a/protocol/report.py b/protocol/report.py new file mode 100644 index 0000000..15d6c34 --- /dev/null +++ b/protocol/report.py @@ -0,0 +1,201 @@ +""" +DiagnosticReport: structured output of the FA evaluation protocol. + +Holds the per-layer numbers from the four diagnostics and emits a verdict +for each one based on the published thresholds. The verdict is intentionally +binary ("trustworthy" / "needs walk-back"); fine-grained reading is the +caller's job. +""" +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import List, Optional + + +@dataclass +class DiagnosticThresholds: + """Degeneracy thresholds (defaults match the paper). + + g_norm_floor: BP gradient norms below this are considered to be at the + numerical floor (Γ measured against this is not interpretable as + credit alignment). Default 1e-7 — well above fp32 floor (~1e-38) and + well above F.cosine_similarity's eps=1e-8 clamp, but several orders + below healthy networks (~1e-5). + h_norm_explosion_ratio: residual stream norm growth (relative to layer 0) + above this is considered "exploded". Default 50× — BP-trained + networks are ~1-3× per layer; failure modes show ~10^5-10^6×. + stability_drift_ceiling: cross-batch direction cosine above this is + considered to be drift-dominated (reference vector is sample- + invariant). Default 0.30 — BP-trained / EP-trained networks are + below 0.20, failure modes are 0.5-0.99. + frozen_acc_margin_pp: minimum acc gain (in percentage points) of the + trainable-blocks variant over the frozen-random-blocks baseline for + the deep blocks to be considered "actually contributing". Default + 2.0 pp. + """ + + g_norm_floor: float = 1e-7 + h_norm_explosion_ratio: float = 50.0 + stability_drift_ceiling: float = 0.30 + frozen_acc_margin_pp: float = 2.0 + + +@dataclass +class DiagnosticReport: + """Result of running the protocol on one trained network.""" + + method_name: str + notes: str + residual_norms: List[float] + bp_grad_norms: List[float] + stability_layer: int + cross_batch_stability: float + headline_acc: float + frozen_baseline_acc: Optional[float] + thresholds: DiagnosticThresholds = field(default_factory=DiagnosticThresholds) + + # ------------------------------------------------------------------ # + # Per-diagnostic verdicts + # ------------------------------------------------------------------ # + + @property + def residual_stream_exploded(self) -> bool: + if not self.residual_norms: + return False + h0 = self.residual_norms[0] + if h0 <= 0: + return False + return (max(self.residual_norms) / h0) > self.thresholds.h_norm_explosion_ratio + + @property + def bp_grad_at_floor(self) -> bool: + if not self.bp_grad_norms: + return False + # Check the *deepest* hidden layer's BP grad — that's where Γ is + # typically reported and where LN-driven collapse hits hardest. + return self.bp_grad_norms[-1] < self.thresholds.g_norm_floor + + @property + def reference_drift_dominated(self) -> bool: + return self.cross_batch_stability > self.thresholds.stability_drift_ceiling + + @property + def frozen_baseline_undercut(self) -> Optional[bool]: + """True if the trainable-blocks acc fails to clear the frozen baseline + by `frozen_acc_margin_pp`. None if no frozen baseline supplied. + """ + if self.frozen_baseline_acc is None: + return None + margin_pp = (self.headline_acc - self.frozen_baseline_acc) * 100 + return margin_pp < self.thresholds.frozen_acc_margin_pp + + # ------------------------------------------------------------------ # + # Aggregate verdict + # ------------------------------------------------------------------ # + + @property + def verdict(self) -> str: + flags = [ + ("residual stream exploded", self.residual_stream_exploded), + ("BP grad at numerical floor", self.bp_grad_at_floor), + ("BP grad direction is drift-dominated", self.reference_drift_dominated), + ] + if self.frozen_baseline_undercut is True: + flags.append(("deep blocks fail to beat frozen-random baseline", True)) + flagged = [name for name, val in flags if val] + if not flagged: + return "trustworthy" + return "needs walk-back: " + "; ".join(flagged) + + # ------------------------------------------------------------------ # + # Pretty-print + # ------------------------------------------------------------------ # + + def __str__(self) -> str: + L = len(self.residual_norms) + lines: List[str] = [] + lines.append("=" * 72) + lines.append(f"FA Diagnostic Protocol Report — method: {self.method_name}") + if self.notes: + lines.append(f"Notes: {self.notes}") + lines.append("=" * 72) + + # (a) Residual stream norms + lines.append("(a) Residual stream norms ||h_l||_2 (median over batch):") + for l in range(L): + lines.append(f" h_{l}: {self.residual_norms[l]:.3e}") + if self.residual_stream_exploded: + lines.append( + f" FLAG: max/min ratio " + f"{max(self.residual_norms)/max(self.residual_norms[0],1e-30):.2e} " + f"> threshold {self.thresholds.h_norm_explosion_ratio}× — " + "residual stream exploded." + ) + + # (b) BP grad norms + lines.append("") + lines.append("(b) BP gradient norms ||g_l||_2 (median over batch):") + for l in range(L): + lines.append(f" g_{l}: {self.bp_grad_norms[l]:.3e}") + if self.bp_grad_at_floor: + lines.append( + f" FLAG: deepest ||g_L|| = {self.bp_grad_norms[-1]:.2e} " + f"< floor {self.thresholds.g_norm_floor:.0e} — Γ measured " + "against numerical floor; not interpretable as credit alignment." + ) + + # (c) Cross-batch direction stability + lines.append("") + lines.append( + f"(c) Cross-batch direction stability at layer {self.stability_layer}: " + f"{self.cross_batch_stability:.3f}" + ) + if self.reference_drift_dominated: + lines.append( + f" FLAG: stability {self.cross_batch_stability:.3f} > " + f"ceiling {self.thresholds.stability_drift_ceiling} — " + "reference vector is sample-invariant drift, not per-sample credit." + ) + + # (d) Frozen baseline comparison + lines.append("") + if self.frozen_baseline_acc is None: + lines.append("(d) Frozen-blocks baseline: NOT PROVIDED — diagnostic skipped.") + else: + margin_pp = (self.headline_acc - self.frozen_baseline_acc) * 100 + lines.append( + f"(d) Headline acc: {self.headline_acc:.4f}, " + f"frozen-blocks baseline: {self.frozen_baseline_acc:.4f}, " + f"margin: {margin_pp:+.2f} pp" + ) + if self.frozen_baseline_undercut: + lines.append( + f" FLAG: margin {margin_pp:+.2f} pp < " + f"required {self.thresholds.frozen_acc_margin_pp} pp — " + "deep blocks are not contributing over a random untrained baseline." + ) + + # Verdict + lines.append("") + lines.append(f"VERDICT: {self.verdict}") + lines.append("=" * 72) + return "\n".join(lines) + + def to_dict(self) -> dict: + return { + "method_name": self.method_name, + "notes": self.notes, + "residual_norms": list(self.residual_norms), + "bp_grad_norms": list(self.bp_grad_norms), + "stability_layer": self.stability_layer, + "cross_batch_stability": self.cross_batch_stability, + "headline_acc": self.headline_acc, + "frozen_baseline_acc": self.frozen_baseline_acc, + "verdict": self.verdict, + "thresholds": { + "g_norm_floor": self.thresholds.g_norm_floor, + "h_norm_explosion_ratio": self.thresholds.h_norm_explosion_ratio, + "stability_drift_ceiling": self.thresholds.stability_drift_ceiling, + "frozen_acc_margin_pp": self.thresholds.frozen_acc_margin_pp, + }, + } diff --git a/protocol/smoke_test.py b/protocol/smoke_test.py new file mode 100644 index 0000000..272a9e5 --- /dev/null +++ b/protocol/smoke_test.py @@ -0,0 +1,136 @@ +""" +Smoke test for the FA Diagnostic Protocol reference implementation. + +Loads BP-trained and DFA-trained ResMLP checkpoints from +`results/confirmatory/checkpoints_A2/` and applies the protocol to each. +The protocol should: + - return verdict="trustworthy" on the BP checkpoint (residual norms bounded, + BP grad ~5e-5 well above floor, low cross-batch direction stability), + - flag the DFA checkpoint as "needs walk-back" (residual stream exploded + by ~10^7×, BP grad at ~5e-10, drift-dominated direction). + +Run with: + CUDA_VISIBLE_DEVICES=2 python -m protocol.smoke_test +""" +import os +import sys + +import torch +import torchvision +import torchvision.transforms as transforms +from torch.utils.data import DataLoader + +# Make `protocol` and `models` importable when invoked from repo root +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.residual_mlp import ResidualMLP # noqa: E402 +from protocol import diagnose # noqa: E402 + + +REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +CHECKPOINT_DIR = os.path.join(REPO_ROOT, "results/confirmatory/checkpoints_A2") +EP_CHECKPOINT_DIR = os.path.join(REPO_ROOT, "results/ep_baseline") + + +def load_eval_batches(n_batches: int = 4, batch_size: int = 1024, device="cuda:0"): + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv) + loader = DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=0) + batches = [] + for x, y in loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + batches.append((x, y)) + if len(batches) >= n_batches: + break + return batches + + +def evaluate(model, loader, device): + model.eval() + correct = total = 0 + with torch.no_grad(): + for x, y in loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + preds = model(x).argmax(-1) + correct += (preds == y).sum().item() + total += x.size(0) + return correct / total + + +def load_model(method: str, seed: int, device): + if method == "ep": + path = os.path.join(EP_CHECKPOINT_DIR, f"ep_s{seed}.pt") + else: + path = os.path.join(CHECKPOINT_DIR, f"{method}_s{seed}.pt") + ckpt = torch.load(path, map_location=device, weights_only=False) + # Try several common checkpoint formats + if isinstance(ckpt, dict) and "state_dict" in ckpt: + sd = ckpt["state_dict"] + elif isinstance(ckpt, dict) and "model" in ckpt: + sd = ckpt["model"] + elif isinstance(ckpt, dict): + sd = ckpt + else: + sd = ckpt.state_dict() + model = ResidualMLP(input_dim=3072, d_hidden=256, num_classes=10, num_blocks=4).to(device) + model.load_state_dict(sd) + return model + + +def main(): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + print(f"Device: {device}") + + # Cross-batch stability is sensitive to batch size: with smaller batches + # the per-sample noise component dominates the batch-mean direction on + # healthy networks, giving the cleanest separation from the drift- + # dominated failure mode (~0.10 healthy vs ~0.95 drift at N=128). + eval_batches = load_eval_batches(n_batches=10, batch_size=128, device=device) + + # Build a single test loader for headline acc + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv) + test_loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0) + + # BP: trustworthy. DFA: walked back. EP: trustworthy (the internal control + # — same architecture/metric/dataset, but EP doesn't blow up the residual + # stream, so the diagnostic protocol passes for EP even though EP's + # accuracy is also low. This is the paper's central comparison.) + for method, expected in [ + ("bp", "trustworthy"), + ("dfa", "needs walk-back"), + ("ep", "trustworthy"), + ]: + print() + print(f"### {method.upper()} (seed 42)") + model = load_model(method, 42, device) + acc = evaluate(model, test_loader, device) + report = diagnose( + model=model, + eval_batches=eval_batches, + headline_acc=acc, + frozen_baseline_acc=None, + method_name=method.upper(), + notes="4-block d=256 ResMLP, CIFAR-10", + ) + print(report) + # Sanity check + verdict = report.verdict + if expected == "trustworthy" and verdict != "trustworthy": + print(f"!! UNEXPECTED: BP should be trustworthy, got '{verdict}'") + elif expected == "needs walk-back" and not verdict.startswith("needs walk-back"): + print(f"!! UNEXPECTED: DFA should need walk-back, got '{verdict}'") + else: + print(f"OK: verdict matches expectation ('{expected}')") + + +if __name__ == "__main__": + main() |
