From b480d0cdc21f944e4adccf6e81cc939b0450c5e9 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Mon, 4 May 2026 19:50:45 -0500 Subject: Initial submission code: FA evaluation protocol + reproduction scripts Reference implementation of the three-diagnostic FA evaluation protocol (scale stability, reference validity, depth utility) from the NeurIPS 2026 E&D track paper. Includes models, metrics, and full reproduction pipeline. Co-Authored-By: Claude Opus 4.6 (1M context) --- README.md | 81 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 README.md (limited to 'README.md') diff --git a/README.md b/README.md new file mode 100644 index 0000000..3345986 --- /dev/null +++ b/README.md @@ -0,0 +1,81 @@ +# What Accuracy and Gradient Cosine Miss: Evaluating Feedback Alignment via Scale Stability, Reference Validity, and Depth Utility + +Code for the NeurIPS 2026 Evaluations & Datasets Track submission. + +## Structure + +``` +submission_code/ + protocol/ # Reference evaluation artifact (the main deliverable) + fa_protocol.py # Three-diagnostic protocol: drop-in evaluator + example_usage.py # Minimal example applying protocol to a trained model + models/ # Architectures used in the paper + residual_mlp.py # Pre-LayerNorm ResMLP (primary audit architecture) + vit_mini.py # ViT-Mini (transformer with terminal LN) + small_resnet.py # SmallResNet (BatchNorm, no LN) + metrics/ + credit_metrics.py # Cosine similarity, nudging test, perturbation correlation + reproduce/ # Scripts to reproduce all paper results + train_methods.py # Train BP/FA/DFA on any architecture + compute diagnostics + frozen_baseline.py # Frozen-blocks and shallow baselines + penalty_sweep.py # Penalty intervention (lambda sweep + fresh-B null) + run_all.sh # One-command full reproduction + requirements.txt +``` + +## Quick start: applying the protocol to your own FA method + +```python +from protocol.fa_protocol import FAProtocol + +# Your trained model and a test batch +model = ... # nn.Module with .blocks, .out_head, .out_ln (or without) +x_eval, y_eval = ... # (N, d_input), (N,) + +protocol = FAProtocol(model, x_eval, y_eval) +report = protocol.run() + +print(report['verdict']) # 'PASS' or 'FAIL(D1+D2)' etc. +print(report['diagnostics']) # {'D1_scale_growth': ..., 'D2_ref_validity': ..., 'D3_depth_utility': ...} +``` + +## Reproducing paper results + +```bash +# Full reproduction (all figures, tables, appendices) +cd reproduce/ +bash run_all.sh --gpu 0 + +# Or individual experiments: +python train_methods.py --arch resmlp --methods bp fa dfa --seeds 42 123 456 --epochs 100 +python frozen_baseline.py --arch resmlp --seeds 42 123 456 --epochs 100 +python penalty_sweep.py --lambdas 0 1e-4 1e-2 --seeds 42 123 456 --epochs 30 +``` + +## Requirements + +- Python >= 3.10 +- PyTorch >= 2.0 +- torchvision +- numpy +- scipy (for perturbation correlation) + +## Protocol specification + +The protocol consists of three diagnostic checks applied to a trained model: + +| Diagnostic | What it measures | Threshold | Fires when | +|---|---|---|---| +| D1: Scale stability | max per-block residual growth | > 50x | Residual stream explodes | +| D2: Reference validity | deepest-layer BP gradient norm | < 10 * eps | BP reference collapses below cosine clamp | +| D3: Depth utility | test accuracy vs frozen-blocks baseline | < 2 pp above | Trained blocks do not outperform random | + +**Verdict logic**: A run fails the protocol if either Mode 1 (D1 and D2 both flag) or Depth Utility (D3 flags) is triggered. Passing all three diagnostics does not certify the method as effective; it rules out the specific class of silent failures the audit revealed. + +## Training inventory + +The paper reports 125 training runs across 5+ architectures, 10 training methods, and 17 experimental settings. Total estimated GPU time: ~12 hours on a single A100/A6000. + +## License + +MIT -- cgit v1.2.3