diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 22:20:48 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 22:20:48 -0500 |
| commit | 7b64702ad970c16171142665365e16a8e1737190 (patch) | |
| tree | 30df9313fc195aa037f05c0598fdc0df8b32bbcc /protocol/__init__.py | |
| parent | 2dc8e7efb5f2ff827fbd97be87f0127aa5ab2757 (diff) | |
Add FA diagnostic protocol reference implementation
Codex round 15 #1 priority for the E&D-track paper:
- protocol/protocol.py: 4 diagnostics (residual norms, BP grad norms,
cross-batch direction stability, and a frozen-baseline comparator)
- protocol/report.py: DiagnosticReport with per-diagnostic verdicts and
pretty-printer
- protocol/smoke_test.py: validates BP/DFA/EP checkpoints produce the
expected verdicts (BP/EP trustworthy; DFA walked back via residual
explosion + BP grad at floor)
- protocol/README.md: usage, audit cases, threshold rationale
- protocol/CHECKLIST.md: 6 evaluation pipeline pitfalls (norm(-1),
cosine_similarity eps clamp, fp16 underflow, Bs reproducibility,
aggregation, layer-0 dominance)
- protocol/REPORTING_TEMPLATE.md: per-method fillable form for FA papers
Diffstat (limited to 'protocol/__init__.py')
| -rw-r--r-- | protocol/__init__.py | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/protocol/__init__.py b/protocol/__init__.py new file mode 100644 index 0000000..770bdda --- /dev/null +++ b/protocol/__init__.py @@ -0,0 +1,52 @@ +""" +FA Diagnostic Protocol — reference implementation. + +A drop-in toolkit for evaluating Feedback Alignment (and related local-credit) +methods on modern residual architectures. Implements the four diagnostics +described in *Beyond Accuracy and Alignment: A Diagnostic Protocol for +Feedback Alignment Evaluation*: + + (a) per-layer residual stream norm ||h_l||_2 + (b) per-layer BP gradient norm ||g_l||_2 + (c) cross-batch direction stability s_l + (d) frozen-blocks control accuracy acc_frozen + +The motivation is that headline accuracy and offline Γ (cosine alignment of +local credit to BP gradient) can both silently fail on modern pre-LayerNorm +residual networks: residual-stream explosion drives ||g|| ~10^-10 and turns Γ +into a measurement of cosine to a numerical-floor reference vector. This +protocol catches that regime change and tells you whether your reported +numbers are interpretable. + +Quick start: + + from protocol import diagnose, DiagnosticReport + + report = diagnose( + model=trained_model, + eval_batches=list_of_(x, y)_pairs, + frozen_baseline_acc=accuracy_of_random_blocks_baseline, + method_name="DFA", + notes="4-block d=256 ResMLP, CIFAR-10", + ) + print(report) + +The report flags any diagnostic that crossed a "degenerate regime" threshold +and emits a verdict ("trustworthy" / "needs walk-back"). +""" + +from .protocol import ( + residual_norms, + bp_grad_norms, + cross_batch_direction_stability, + diagnose, +) +from .report import DiagnosticReport + +__all__ = [ + "residual_norms", + "bp_grad_norms", + "cross_batch_direction_stability", + "diagnose", + "DiagnosticReport", +] |
