1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
|
# What Accuracy and Gradient Cosine Miss: Evaluating Feedback Alignment via Scale Stability, Reference Validity, and Depth Utility
Code for the NeurIPS 2026 Evaluations & Datasets Track submission.
## Structure
```
submission_code/
protocol/ # Reference evaluation artifact (the main deliverable)
fa_protocol.py # Three-diagnostic protocol: drop-in evaluator
example_usage.py # Minimal example applying protocol to a trained model
models/ # Architectures used in the paper
residual_mlp.py # Pre-LayerNorm ResMLP (primary audit architecture)
vit_mini.py # ViT-Mini (transformer with terminal LN)
small_resnet.py # SmallResNet (BatchNorm, no LN)
metrics/
credit_metrics.py # Cosine similarity, nudging test, perturbation correlation
reproduce/ # Scripts to reproduce all paper results
train_methods.py # Train BP/FA/DFA on any architecture + compute diagnostics
frozen_baseline.py # Frozen-blocks and shallow baselines
penalty_sweep.py # Penalty intervention (lambda sweep + fresh-B null)
run_all.sh # One-command full reproduction
requirements.txt
```
## Quick start: applying the protocol to your own FA method
```python
from protocol.fa_protocol import FAProtocol
# Your trained model and a test batch
model = ... # nn.Module with .blocks, .out_head, .out_ln (or without)
x_eval, y_eval = ... # (N, d_input), (N,)
protocol = FAProtocol(model, x_eval, y_eval)
report = protocol.run()
print(report['verdict']) # 'PASS' or 'FAIL(D1+D2)' etc.
print(report['diagnostics']) # {'D1_scale_growth': ..., 'D2_ref_validity': ..., 'D3_depth_utility': ...}
```
## Reproducing paper results
```bash
# Full reproduction (all figures, tables, appendices)
cd reproduce/
bash run_all.sh --gpu 0
# Or individual experiments:
python train_methods.py --arch resmlp --methods bp fa dfa --seeds 42 123 456 --epochs 100
python frozen_baseline.py --arch resmlp --seeds 42 123 456 --epochs 100
python penalty_sweep.py --lambdas 0 1e-4 1e-2 --seeds 42 123 456 --epochs 30
```
## Requirements
- Python >= 3.10
- PyTorch >= 2.0
- torchvision
- numpy
- scipy (for perturbation correlation)
## Protocol specification
The protocol consists of three diagnostic checks applied to a trained model:
| Diagnostic | What it measures | Threshold | Fires when |
|---|---|---|---|
| D1: Scale stability | max per-block residual growth | > 50x | Residual stream explodes |
| D2: Reference validity | deepest-layer BP gradient norm | < 10 * eps | BP reference collapses below cosine clamp |
| D3: Depth utility | test accuracy vs frozen-blocks baseline | < 2 pp above | Trained blocks do not outperform random |
**Verdict logic**: A run fails the protocol if either Mode 1 (D1 and D2 both flag) or Depth Utility (D3 flags) is triggered. Passing all three diagnostics does not certify the method as effective; it rules out the specific class of silent failures the audit revealed.
## Training inventory
The paper reports 125 training runs across 5+ architectures, 10 training methods, and 17 experimental settings. Total estimated GPU time: ~12 hours on a single A100/A6000.
## License
MIT
|