From 7b64702ad970c16171142665365e16a8e1737190 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Tue, 7 Apr 2026 22:20:48 -0500 Subject: Add FA diagnostic protocol reference implementation Codex round 15 #1 priority for the E&D-track paper: - protocol/protocol.py: 4 diagnostics (residual norms, BP grad norms, cross-batch direction stability, and a frozen-baseline comparator) - protocol/report.py: DiagnosticReport with per-diagnostic verdicts and pretty-printer - protocol/smoke_test.py: validates BP/DFA/EP checkpoints produce the expected verdicts (BP/EP trustworthy; DFA walked back via residual explosion + BP grad at floor) - protocol/README.md: usage, audit cases, threshold rationale - protocol/CHECKLIST.md: 6 evaluation pipeline pitfalls (norm(-1), cosine_similarity eps clamp, fp16 underflow, Bs reproducibility, aggregation, layer-0 dominance) - protocol/REPORTING_TEMPLATE.md: per-method fillable form for FA papers --- protocol/__init__.py | 52 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 protocol/__init__.py (limited to 'protocol/__init__.py') diff --git a/protocol/__init__.py b/protocol/__init__.py new file mode 100644 index 0000000..770bdda --- /dev/null +++ b/protocol/__init__.py @@ -0,0 +1,52 @@ +""" +FA Diagnostic Protocol — reference implementation. + +A drop-in toolkit for evaluating Feedback Alignment (and related local-credit) +methods on modern residual architectures. Implements the four diagnostics +described in *Beyond Accuracy and Alignment: A Diagnostic Protocol for +Feedback Alignment Evaluation*: + + (a) per-layer residual stream norm ||h_l||_2 + (b) per-layer BP gradient norm ||g_l||_2 + (c) cross-batch direction stability s_l + (d) frozen-blocks control accuracy acc_frozen + +The motivation is that headline accuracy and offline Γ (cosine alignment of +local credit to BP gradient) can both silently fail on modern pre-LayerNorm +residual networks: residual-stream explosion drives ||g|| ~10^-10 and turns Γ +into a measurement of cosine to a numerical-floor reference vector. This +protocol catches that regime change and tells you whether your reported +numbers are interpretable. + +Quick start: + + from protocol import diagnose, DiagnosticReport + + report = diagnose( + model=trained_model, + eval_batches=list_of_(x, y)_pairs, + frozen_baseline_acc=accuracy_of_random_blocks_baseline, + method_name="DFA", + notes="4-block d=256 ResMLP, CIFAR-10", + ) + print(report) + +The report flags any diagnostic that crossed a "degenerate regime" threshold +and emits a verdict ("trustworthy" / "needs walk-back"). +""" + +from .protocol import ( + residual_norms, + bp_grad_norms, + cross_batch_direction_stability, + diagnose, +) +from .report import DiagnosticReport + +__all__ = [ + "residual_norms", + "bp_grad_norms", + "cross_batch_direction_stability", + "diagnose", + "DiagnosticReport", +] -- cgit v1.2.3