summaryrefslogtreecommitdiff
path: root/protocol/README.md
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 22:20:48 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 22:20:48 -0500
commit7b64702ad970c16171142665365e16a8e1737190 (patch)
tree30df9313fc195aa037f05c0598fdc0df8b32bbcc /protocol/README.md
parent2dc8e7efb5f2ff827fbd97be87f0127aa5ab2757 (diff)
Add FA diagnostic protocol reference implementation
Codex round 15 #1 priority for the E&D-track paper: - protocol/protocol.py: 4 diagnostics (residual norms, BP grad norms, cross-batch direction stability, and a frozen-baseline comparator) - protocol/report.py: DiagnosticReport with per-diagnostic verdicts and pretty-printer - protocol/smoke_test.py: validates BP/DFA/EP checkpoints produce the expected verdicts (BP/EP trustworthy; DFA walked back via residual explosion + BP grad at floor) - protocol/README.md: usage, audit cases, threshold rationale - protocol/CHECKLIST.md: 6 evaluation pipeline pitfalls (norm(-1), cosine_similarity eps clamp, fp16 underflow, Bs reproducibility, aggregation, layer-0 dominance) - protocol/REPORTING_TEMPLATE.md: per-method fillable form for FA papers
Diffstat (limited to 'protocol/README.md')
-rw-r--r--protocol/README.md150
1 files changed, 150 insertions, 0 deletions
diff --git a/protocol/README.md b/protocol/README.md
new file mode 100644
index 0000000..770d582
--- /dev/null
+++ b/protocol/README.md
@@ -0,0 +1,150 @@
+# FA Diagnostic Protocol — reference implementation
+
+A drop-in toolkit for evaluating Feedback Alignment (FA) and other local-credit
+methods on modern residual architectures. Implements the four diagnostics
+described in *Beyond Accuracy and Alignment: A Diagnostic Protocol for
+Feedback Alignment Evaluation*.
+
+## What this fixes
+
+Standard FA evaluation reports two numbers:
+
+1. **Headline accuracy** of the trained network.
+2. **Γ** — cosine alignment between the local credit signal and the
+ true BP gradient at hidden layers.
+
+Both can silently fail on modern pre-LayerNorm residual networks:
+
+- **Headline accuracy** fails because deep blocks may not be contributing at
+ all — the embedding + terminal-LN + head pathway alone can match (or
+ exceed) the trainable-blocks accuracy. We have measured DFA on a 4-block
+ d=256 ResMLP at 30.8% test acc, *below* the same architecture's
+ frozen-random-blocks baseline at 34.9%. The headline number reads as
+ "DFA learns" but the deep blocks are *actively destroying* value.
+- **Γ** fails because residual-stream explosion drives the BP gradient at
+ hidden layers to ~10⁻¹⁰ via terminal-LayerNorm scaling. Γ then computes
+ the cosine of the local credit signal against a numerical-floor reference
+ vector. The reported value is mathematically well-defined but does not
+ measure credit-assignment quality.
+
+This protocol catches both failures and tells you whether your reported
+numbers are interpretable.
+
+## The four diagnostics
+
+| | Diagnostic | What it catches |
+|---|---|---|
+| (a) | Per-layer residual stream norm ‖h_l‖₂ | Scale pathology — if ‖h_L‖ has grown ≫50× over depth, the network is in a degenerate regime. |
+| (b) | Per-layer BP gradient norm ‖g_l‖₂ | Reference-vector floor — if ‖g_L‖ < 1e-7, Γ is being measured against the numerical floor and is not interpretable. |
+| (c) | Cross-batch direction stability of g | Drift dominance — if the BP gradient direction is sample-invariant (cosine ≈ 1 across disjoint batches), the reference is a global drift artifact, not per-sample credit. |
+| (d) | Frozen-blocks baseline accuracy | Useful blocks — if the trainable-blocks variant fails to clear an architecture-matched frozen-random-blocks baseline by ≥2 pp, the deep blocks are not contributing. |
+
+## Quick start
+
+```python
+from protocol import diagnose
+
+report = diagnose(
+ model=trained_model, # any PyTorch module exposing
+ # `model(x, return_hidden=True)`,
+ # `model.embed`/`model.patch_embed`,
+ # `model.blocks`, and
+ # `model.out_ln`/`model.out_head`
+ # (or `model.norm`/`model.head`)
+ eval_batches=eval_batches, # ≥8 disjoint (x,y) batches of ~128
+ headline_acc=test_acc, # the number you would normally publish
+ frozen_baseline_acc=frozen_acc, # acc of an architecture-matched
+ # frozen-random-blocks control
+ # (set None to skip diagnostic d)
+ method_name="DFA",
+ notes="4-block d=256 ResMLP, CIFAR-10",
+)
+print(report)
+# emits a per-diagnostic table and either
+# "VERDICT: trustworthy"
+# or
+# "VERDICT: needs walk-back: <list of failed diagnostics>"
+```
+
+## Reproducing the audit cases
+
+We ship a smoke test that loads BP, DFA, and EP checkpoints from
+`results/confirmatory/checkpoints_A2/` and `results/ep_baseline/`, runs the
+protocol on each, and checks that the verdicts match expectation:
+
+```bash
+CUDA_VISIBLE_DEVICES=2 python -m protocol.smoke_test
+```
+
+Expected verdicts (4-block d=256 ResMLP, CIFAR-10, seed 42):
+
+| method | ‖h₄‖ | ‖g₄‖ | stability | verdict |
+|---|---:|---:|---:|---|
+| BP | 2.1e2 | 3.7e-4 | 0.099 | trustworthy |
+| DFA | 4.4e8 | 4.2e-9 | 0.047 | walked back: explosion + BP grad at floor |
+| EP | 3.3e3 | 1.6e-4 | -0.036 | trustworthy (internal control) |
+
+EP is the paper's central internal control: same architecture, same dataset,
+same metric, but EP's training does not produce residual-stream explosion,
+so the protocol passes for EP even though EP's headline accuracy is also low
+(~32%). This isolates "the network is in a degenerate measurement regime"
+from "the method underperforms" — they are different failure modes with
+different evidence.
+
+## Duck-typed model interface
+
+The protocol works on any model exposing:
+
+```python
+# Forward with hidden states
+logits, hiddens = model(x, return_hidden=True)
+# hiddens is a list [h_0, h_1, ..., h_L]; supports either (B, d) or (B, T, d)
+# (the protocol pools the cls token via h[:, 0, :] in the latter case)
+
+# Building blocks for the manual forward (used by diagnostics b and c)
+model.embed OR model.patch_embed
+model.blocks # iterable of nn.Modules; each block(h) returns the
+ # *residual branch* f_l(h_l), NOT h_l + f_l(h_l)
+model.out_ln, model.out_head OR model.norm, model.head
+```
+
+`models/residual_mlp.py` and `models/vit_mini.py` in this repo both satisfy
+this interface as written.
+
+## Thresholds (and how to override them)
+
+```python
+from protocol import diagnose, DiagnosticReport
+from protocol.report import DiagnosticThresholds
+
+custom = DiagnosticThresholds(
+ g_norm_floor=1e-6, # tighter floor
+ h_norm_explosion_ratio=20.0, # tighter scale check
+ stability_drift_ceiling=0.30,
+ frozen_acc_margin_pp=2.0,
+)
+report = diagnose(..., thresholds=custom)
+```
+
+The defaults match the paper:
+
+- `g_norm_floor = 1e-7` — well above fp32's smallest subnormal (~1e-38) and
+ well above F.cosine_similarity's eps=1e-8 clamp, but several orders below
+ healthy networks (~1e-5).
+- `h_norm_explosion_ratio = 50.0` — BP-trained networks are ~1-3× per layer
+ in our measurements; failure modes show ~10⁵-10⁶× from layer 0 to layer L.
+- `stability_drift_ceiling = 0.30` — calibrated for ~128-sample batches.
+ BP/EP-trained networks: 0.05-0.18. DFA/SB drift mode: 0.43-0.99.
+- `frozen_acc_margin_pp = 2.0` — DFA-on-ResMLP undercuts frozen baseline by
+ 4 pp; the threshold is loose enough to allow noise on healthy networks.
+
+## Pipeline pitfalls (the checklist)
+
+In addition to the diagnostics above, we ship `CHECKLIST.md` — six
+evaluation-pipeline bugs we found in our own dogfood codebase that silently
+corrupt FA evaluation results. We recommend a final pre-publication pass
+through the checklist regardless of which diagnostic protocol you use.
+
+## Citation
+
+(TBD — points to the paper PDF when uploaded.)