summaryrefslogtreecommitdiff
path: root/README.md
blob: 334598636b13ceadd4b7ef869fc7761eb2dcf2ca (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# What Accuracy and Gradient Cosine Miss: Evaluating Feedback Alignment via Scale Stability, Reference Validity, and Depth Utility

Code for the NeurIPS 2026 Evaluations & Datasets Track submission.

## Structure

```
submission_code/
  protocol/              # Reference evaluation artifact (the main deliverable)
    fa_protocol.py        # Three-diagnostic protocol: drop-in evaluator
    example_usage.py      # Minimal example applying protocol to a trained model
  models/                # Architectures used in the paper
    residual_mlp.py       # Pre-LayerNorm ResMLP (primary audit architecture)
    vit_mini.py           # ViT-Mini (transformer with terminal LN)
    small_resnet.py       # SmallResNet (BatchNorm, no LN)
  metrics/
    credit_metrics.py     # Cosine similarity, nudging test, perturbation correlation
  reproduce/             # Scripts to reproduce all paper results
    train_methods.py      # Train BP/FA/DFA on any architecture + compute diagnostics
    frozen_baseline.py    # Frozen-blocks and shallow baselines
    penalty_sweep.py      # Penalty intervention (lambda sweep + fresh-B null)
    run_all.sh            # One-command full reproduction
  requirements.txt
```

## Quick start: applying the protocol to your own FA method

```python
from protocol.fa_protocol import FAProtocol

# Your trained model and a test batch
model = ...          # nn.Module with .blocks, .out_head, .out_ln (or without)
x_eval, y_eval = ... # (N, d_input), (N,)

protocol = FAProtocol(model, x_eval, y_eval)
report = protocol.run()

print(report['verdict'])        # 'PASS' or 'FAIL(D1+D2)' etc.
print(report['diagnostics'])    # {'D1_scale_growth': ..., 'D2_ref_validity': ..., 'D3_depth_utility': ...}
```

## Reproducing paper results

```bash
# Full reproduction (all figures, tables, appendices)
cd reproduce/
bash run_all.sh --gpu 0

# Or individual experiments:
python train_methods.py --arch resmlp --methods bp fa dfa --seeds 42 123 456 --epochs 100
python frozen_baseline.py --arch resmlp --seeds 42 123 456 --epochs 100
python penalty_sweep.py --lambdas 0 1e-4 1e-2 --seeds 42 123 456 --epochs 30
```

## Requirements

- Python >= 3.10
- PyTorch >= 2.0
- torchvision
- numpy
- scipy (for perturbation correlation)

## Protocol specification

The protocol consists of three diagnostic checks applied to a trained model:

| Diagnostic | What it measures | Threshold | Fires when |
|---|---|---|---|
| D1: Scale stability | max per-block residual growth | > 50x | Residual stream explodes |
| D2: Reference validity | deepest-layer BP gradient norm | < 10 * eps | BP reference collapses below cosine clamp |
| D3: Depth utility | test accuracy vs frozen-blocks baseline | < 2 pp above | Trained blocks do not outperform random |

**Verdict logic**: A run fails the protocol if either Mode 1 (D1 and D2 both flag) or Depth Utility (D3 flags) is triggered. Passing all three diagnostics does not certify the method as effective; it rules out the specific class of silent failures the audit revealed.

## Training inventory

The paper reports 125 training runs across 5+ architectures, 10 training methods, and 17 experimental settings. Total estimated GPU time: ~12 hours on a single A100/A6000.

## License

MIT