summaryrefslogtreecommitdiff
path: root/README.md
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-05-04 19:50:45 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-05-04 19:50:45 -0500
commitb480d0cdc21f944e4adccf6e81cc939b0450c5e9 (patch)
treef0e6afb5b3d448d1d6c35d9622d22d63073ca9a7 /README.md
Initial submission code: FA evaluation protocol + reproduction scripts
Reference implementation of the three-diagnostic FA evaluation protocol (scale stability, reference validity, depth utility) from the NeurIPS 2026 E&D track paper. Includes models, metrics, and full reproduction pipeline. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'README.md')
-rw-r--r--README.md81
1 files changed, 81 insertions, 0 deletions
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..3345986
--- /dev/null
+++ b/README.md
@@ -0,0 +1,81 @@
+# What Accuracy and Gradient Cosine Miss: Evaluating Feedback Alignment via Scale Stability, Reference Validity, and Depth Utility
+
+Code for the NeurIPS 2026 Evaluations & Datasets Track submission.
+
+## Structure
+
+```
+submission_code/
+ protocol/ # Reference evaluation artifact (the main deliverable)
+ fa_protocol.py # Three-diagnostic protocol: drop-in evaluator
+ example_usage.py # Minimal example applying protocol to a trained model
+ models/ # Architectures used in the paper
+ residual_mlp.py # Pre-LayerNorm ResMLP (primary audit architecture)
+ vit_mini.py # ViT-Mini (transformer with terminal LN)
+ small_resnet.py # SmallResNet (BatchNorm, no LN)
+ metrics/
+ credit_metrics.py # Cosine similarity, nudging test, perturbation correlation
+ reproduce/ # Scripts to reproduce all paper results
+ train_methods.py # Train BP/FA/DFA on any architecture + compute diagnostics
+ frozen_baseline.py # Frozen-blocks and shallow baselines
+ penalty_sweep.py # Penalty intervention (lambda sweep + fresh-B null)
+ run_all.sh # One-command full reproduction
+ requirements.txt
+```
+
+## Quick start: applying the protocol to your own FA method
+
+```python
+from protocol.fa_protocol import FAProtocol
+
+# Your trained model and a test batch
+model = ... # nn.Module with .blocks, .out_head, .out_ln (or without)
+x_eval, y_eval = ... # (N, d_input), (N,)
+
+protocol = FAProtocol(model, x_eval, y_eval)
+report = protocol.run()
+
+print(report['verdict']) # 'PASS' or 'FAIL(D1+D2)' etc.
+print(report['diagnostics']) # {'D1_scale_growth': ..., 'D2_ref_validity': ..., 'D3_depth_utility': ...}
+```
+
+## Reproducing paper results
+
+```bash
+# Full reproduction (all figures, tables, appendices)
+cd reproduce/
+bash run_all.sh --gpu 0
+
+# Or individual experiments:
+python train_methods.py --arch resmlp --methods bp fa dfa --seeds 42 123 456 --epochs 100
+python frozen_baseline.py --arch resmlp --seeds 42 123 456 --epochs 100
+python penalty_sweep.py --lambdas 0 1e-4 1e-2 --seeds 42 123 456 --epochs 30
+```
+
+## Requirements
+
+- Python >= 3.10
+- PyTorch >= 2.0
+- torchvision
+- numpy
+- scipy (for perturbation correlation)
+
+## Protocol specification
+
+The protocol consists of three diagnostic checks applied to a trained model:
+
+| Diagnostic | What it measures | Threshold | Fires when |
+|---|---|---|---|
+| D1: Scale stability | max per-block residual growth | > 50x | Residual stream explodes |
+| D2: Reference validity | deepest-layer BP gradient norm | < 10 * eps | BP reference collapses below cosine clamp |
+| D3: Depth utility | test accuracy vs frozen-blocks baseline | < 2 pp above | Trained blocks do not outperform random |
+
+**Verdict logic**: A run fails the protocol if either Mode 1 (D1 and D2 both flag) or Depth Utility (D3 flags) is triggered. Passing all three diagnostics does not certify the method as effective; it rules out the specific class of silent failures the audit revealed.
+
+## Training inventory
+
+The paper reports 125 training runs across 5+ architectures, 10 training methods, and 17 experimental settings. Total estimated GPU time: ~12 hours on a single A100/A6000.
+
+## License
+
+MIT