# FA Diagnostic Protocol — reference implementation A drop-in toolkit for evaluating Feedback Alignment (FA) and other local-credit methods on modern residual architectures. Implements the four diagnostics described in *Beyond Accuracy and Alignment: A Diagnostic Protocol for Feedback Alignment Evaluation*. ## What this fixes Standard FA evaluation reports two numbers: 1. **Headline accuracy** of the trained network. 2. **Γ** — cosine alignment between the local credit signal and the true BP gradient at hidden layers. Both can silently fail on modern pre-LayerNorm residual networks: - **Headline accuracy** fails because deep blocks may not be contributing at all — the embedding + terminal-LN + head pathway alone can match (or exceed) the trainable-blocks accuracy. We have measured DFA on a 4-block d=256 ResMLP at 30.8% test acc, *below* the same architecture's frozen-random-blocks baseline at 34.9%. The headline number reads as "DFA learns" but the deep blocks are *actively destroying* value. - **Γ** fails because residual-stream explosion drives the BP gradient at hidden layers to ~10⁻¹⁰ via terminal-LayerNorm scaling. Γ then computes the cosine of the local credit signal against a numerical-floor reference vector. The reported value is mathematically well-defined but does not measure credit-assignment quality. This protocol catches both failures and tells you whether your reported numbers are interpretable. ## The four diagnostics | | Diagnostic | What it catches | |---|---|---| | (a) | Per-layer residual stream norm ‖h_l‖₂ | Scale pathology — if ‖h_L‖ has grown ≫50× over depth, the network is in a degenerate regime. | | (b) | Per-layer BP gradient norm ‖g_l‖₂ | Reference-vector floor — if ‖g_L‖ < 1e-7, Γ is being measured against the numerical floor and is not interpretable. | | (c) | Cross-batch direction stability of g | Drift dominance — if the BP gradient direction is sample-invariant (cosine ≈ 1 across disjoint batches), the reference is a global drift artifact, not per-sample credit. | | (d) | Frozen-blocks baseline accuracy | Useful blocks — if the trainable-blocks variant fails to clear an architecture-matched frozen-random-blocks baseline by ≥2 pp, the deep blocks are not contributing. | ## Quick start ```python from protocol import diagnose report = diagnose( model=trained_model, # any PyTorch module exposing # `model(x, return_hidden=True)`, # `model.embed`/`model.patch_embed`, # `model.blocks`, and # `model.out_ln`/`model.out_head` # (or `model.norm`/`model.head`) eval_batches=eval_batches, # ≥8 disjoint (x,y) batches of ~128 headline_acc=test_acc, # the number you would normally publish frozen_baseline_acc=frozen_acc, # acc of an architecture-matched # frozen-random-blocks control # (set None to skip diagnostic d) method_name="DFA", notes="4-block d=256 ResMLP, CIFAR-10", ) print(report) # emits a per-diagnostic table and either # "VERDICT: trustworthy" # or # "VERDICT: needs walk-back: " ``` ## Reproducing the audit cases We ship a smoke test that loads BP, DFA, and EP checkpoints from `results/confirmatory/checkpoints_A2/` and `results/ep_baseline/`, runs the protocol on each, and checks that the verdicts match expectation: ```bash CUDA_VISIBLE_DEVICES=2 python -m protocol.smoke_test ``` Expected verdicts (4-block d=256 ResMLP, CIFAR-10, seed 42): | method | ‖h₄‖ | ‖g₄‖ | stability | verdict | |---|---:|---:|---:|---| | BP | 2.1e2 | 3.7e-4 | 0.099 | trustworthy | | DFA | 4.4e8 | 4.2e-9 | 0.047 | walked back: explosion + BP grad at floor | | EP | 3.3e3 | 1.6e-4 | -0.036 | trustworthy (internal control) | EP is the paper's central internal control: same architecture, same dataset, same metric, but EP's training does not produce residual-stream explosion, so the protocol passes for EP even though EP's headline accuracy is also low (~32%). This isolates "the network is in a degenerate measurement regime" from "the method underperforms" — they are different failure modes with different evidence. ## Duck-typed model interface The protocol works on any model exposing: ```python # Forward with hidden states logits, hiddens = model(x, return_hidden=True) # hiddens is a list [h_0, h_1, ..., h_L]; supports either (B, d) or (B, T, d) # (the protocol pools the cls token via h[:, 0, :] in the latter case) # Building blocks for the manual forward (used by diagnostics b and c) model.embed OR model.patch_embed model.blocks # iterable of nn.Modules; each block(h) returns the # *residual branch* f_l(h_l), NOT h_l + f_l(h_l) model.out_ln, model.out_head OR model.norm, model.head ``` `models/residual_mlp.py` and `models/vit_mini.py` in this repo both satisfy this interface as written. ## Thresholds (and how to override them) ```python from protocol import diagnose, DiagnosticReport from protocol.report import DiagnosticThresholds custom = DiagnosticThresholds( g_norm_floor=1e-6, # tighter floor h_norm_explosion_ratio=20.0, # tighter scale check stability_drift_ceiling=0.30, frozen_acc_margin_pp=2.0, ) report = diagnose(..., thresholds=custom) ``` The defaults match the paper: - `g_norm_floor = 1e-7` — well above fp32's smallest subnormal (~1e-38) and well above F.cosine_similarity's eps=1e-8 clamp, but several orders below healthy networks (~1e-5). - `h_norm_explosion_ratio = 50.0` — BP-trained networks are ~1-3× per layer in our measurements; failure modes show ~10⁵-10⁶× from layer 0 to layer L. - `stability_drift_ceiling = 0.30` — calibrated for ~128-sample batches at K=10. BP/EP-trained networks (6 audited values): [-0.04, +0.12]. DFA/SB/CB drift mode (9 audited values): up to +0.99, with 5/9 above the 0.30 cutoff. - `frozen_acc_margin_pp = 2.0` — DFA-on-ResMLP undercuts frozen baseline by 4 pp; the threshold is loose enough to allow noise on healthy networks. ## Pipeline pitfalls (the checklist) In addition to the diagnostics above, we ship `CHECKLIST.md` — six evaluation-pipeline bugs we found in our own dogfood codebase that silently corrupt FA evaluation results. We recommend a final pre-publication pass through the checklist regardless of which diagnostic protocol you use. ## Citation (TBD — points to the paper PDF when uploaded.)