diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 22:20:48 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 22:20:48 -0500 |
| commit | 7b64702ad970c16171142665365e16a8e1737190 (patch) | |
| tree | 30df9313fc195aa037f05c0598fdc0df8b32bbcc /protocol/README.md | |
| parent | 2dc8e7efb5f2ff827fbd97be87f0127aa5ab2757 (diff) | |
Add FA diagnostic protocol reference implementation
Codex round 15 #1 priority for the E&D-track paper:
- protocol/protocol.py: 4 diagnostics (residual norms, BP grad norms,
cross-batch direction stability, and a frozen-baseline comparator)
- protocol/report.py: DiagnosticReport with per-diagnostic verdicts and
pretty-printer
- protocol/smoke_test.py: validates BP/DFA/EP checkpoints produce the
expected verdicts (BP/EP trustworthy; DFA walked back via residual
explosion + BP grad at floor)
- protocol/README.md: usage, audit cases, threshold rationale
- protocol/CHECKLIST.md: 6 evaluation pipeline pitfalls (norm(-1),
cosine_similarity eps clamp, fp16 underflow, Bs reproducibility,
aggregation, layer-0 dominance)
- protocol/REPORTING_TEMPLATE.md: per-method fillable form for FA papers
Diffstat (limited to 'protocol/README.md')
| -rw-r--r-- | protocol/README.md | 150 |
1 files changed, 150 insertions, 0 deletions
diff --git a/protocol/README.md b/protocol/README.md new file mode 100644 index 0000000..770d582 --- /dev/null +++ b/protocol/README.md @@ -0,0 +1,150 @@ +# FA Diagnostic Protocol — reference implementation + +A drop-in toolkit for evaluating Feedback Alignment (FA) and other local-credit +methods on modern residual architectures. Implements the four diagnostics +described in *Beyond Accuracy and Alignment: A Diagnostic Protocol for +Feedback Alignment Evaluation*. + +## What this fixes + +Standard FA evaluation reports two numbers: + +1. **Headline accuracy** of the trained network. +2. **Γ** — cosine alignment between the local credit signal and the + true BP gradient at hidden layers. + +Both can silently fail on modern pre-LayerNorm residual networks: + +- **Headline accuracy** fails because deep blocks may not be contributing at + all — the embedding + terminal-LN + head pathway alone can match (or + exceed) the trainable-blocks accuracy. We have measured DFA on a 4-block + d=256 ResMLP at 30.8% test acc, *below* the same architecture's + frozen-random-blocks baseline at 34.9%. The headline number reads as + "DFA learns" but the deep blocks are *actively destroying* value. +- **Γ** fails because residual-stream explosion drives the BP gradient at + hidden layers to ~10⁻¹⁰ via terminal-LayerNorm scaling. Γ then computes + the cosine of the local credit signal against a numerical-floor reference + vector. The reported value is mathematically well-defined but does not + measure credit-assignment quality. + +This protocol catches both failures and tells you whether your reported +numbers are interpretable. + +## The four diagnostics + +| | Diagnostic | What it catches | +|---|---|---| +| (a) | Per-layer residual stream norm ‖h_l‖₂ | Scale pathology — if ‖h_L‖ has grown ≫50× over depth, the network is in a degenerate regime. | +| (b) | Per-layer BP gradient norm ‖g_l‖₂ | Reference-vector floor — if ‖g_L‖ < 1e-7, Γ is being measured against the numerical floor and is not interpretable. | +| (c) | Cross-batch direction stability of g | Drift dominance — if the BP gradient direction is sample-invariant (cosine ≈ 1 across disjoint batches), the reference is a global drift artifact, not per-sample credit. | +| (d) | Frozen-blocks baseline accuracy | Useful blocks — if the trainable-blocks variant fails to clear an architecture-matched frozen-random-blocks baseline by ≥2 pp, the deep blocks are not contributing. | + +## Quick start + +```python +from protocol import diagnose + +report = diagnose( + model=trained_model, # any PyTorch module exposing + # `model(x, return_hidden=True)`, + # `model.embed`/`model.patch_embed`, + # `model.blocks`, and + # `model.out_ln`/`model.out_head` + # (or `model.norm`/`model.head`) + eval_batches=eval_batches, # ≥8 disjoint (x,y) batches of ~128 + headline_acc=test_acc, # the number you would normally publish + frozen_baseline_acc=frozen_acc, # acc of an architecture-matched + # frozen-random-blocks control + # (set None to skip diagnostic d) + method_name="DFA", + notes="4-block d=256 ResMLP, CIFAR-10", +) +print(report) +# emits a per-diagnostic table and either +# "VERDICT: trustworthy" +# or +# "VERDICT: needs walk-back: <list of failed diagnostics>" +``` + +## Reproducing the audit cases + +We ship a smoke test that loads BP, DFA, and EP checkpoints from +`results/confirmatory/checkpoints_A2/` and `results/ep_baseline/`, runs the +protocol on each, and checks that the verdicts match expectation: + +```bash +CUDA_VISIBLE_DEVICES=2 python -m protocol.smoke_test +``` + +Expected verdicts (4-block d=256 ResMLP, CIFAR-10, seed 42): + +| method | ‖h₄‖ | ‖g₄‖ | stability | verdict | +|---|---:|---:|---:|---| +| BP | 2.1e2 | 3.7e-4 | 0.099 | trustworthy | +| DFA | 4.4e8 | 4.2e-9 | 0.047 | walked back: explosion + BP grad at floor | +| EP | 3.3e3 | 1.6e-4 | -0.036 | trustworthy (internal control) | + +EP is the paper's central internal control: same architecture, same dataset, +same metric, but EP's training does not produce residual-stream explosion, +so the protocol passes for EP even though EP's headline accuracy is also low +(~32%). This isolates "the network is in a degenerate measurement regime" +from "the method underperforms" — they are different failure modes with +different evidence. + +## Duck-typed model interface + +The protocol works on any model exposing: + +```python +# Forward with hidden states +logits, hiddens = model(x, return_hidden=True) +# hiddens is a list [h_0, h_1, ..., h_L]; supports either (B, d) or (B, T, d) +# (the protocol pools the cls token via h[:, 0, :] in the latter case) + +# Building blocks for the manual forward (used by diagnostics b and c) +model.embed OR model.patch_embed +model.blocks # iterable of nn.Modules; each block(h) returns the + # *residual branch* f_l(h_l), NOT h_l + f_l(h_l) +model.out_ln, model.out_head OR model.norm, model.head +``` + +`models/residual_mlp.py` and `models/vit_mini.py` in this repo both satisfy +this interface as written. + +## Thresholds (and how to override them) + +```python +from protocol import diagnose, DiagnosticReport +from protocol.report import DiagnosticThresholds + +custom = DiagnosticThresholds( + g_norm_floor=1e-6, # tighter floor + h_norm_explosion_ratio=20.0, # tighter scale check + stability_drift_ceiling=0.30, + frozen_acc_margin_pp=2.0, +) +report = diagnose(..., thresholds=custom) +``` + +The defaults match the paper: + +- `g_norm_floor = 1e-7` — well above fp32's smallest subnormal (~1e-38) and + well above F.cosine_similarity's eps=1e-8 clamp, but several orders below + healthy networks (~1e-5). +- `h_norm_explosion_ratio = 50.0` — BP-trained networks are ~1-3× per layer + in our measurements; failure modes show ~10⁵-10⁶× from layer 0 to layer L. +- `stability_drift_ceiling = 0.30` — calibrated for ~128-sample batches. + BP/EP-trained networks: 0.05-0.18. DFA/SB drift mode: 0.43-0.99. +- `frozen_acc_margin_pp = 2.0` — DFA-on-ResMLP undercuts frozen baseline by + 4 pp; the threshold is loose enough to allow noise on healthy networks. + +## Pipeline pitfalls (the checklist) + +In addition to the diagnostics above, we ship `CHECKLIST.md` — six +evaluation-pipeline bugs we found in our own dogfood codebase that silently +corrupt FA evaluation results. We recommend a final pre-publication pass +through the checklist regardless of which diagnostic protocol you use. + +## Citation + +(TBD — points to the paper PDF when uploaded.) |
