diff options
Diffstat (limited to 'protocol/__init__.py')
| -rw-r--r-- | protocol/__init__.py | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/protocol/__init__.py b/protocol/__init__.py new file mode 100644 index 0000000..770bdda --- /dev/null +++ b/protocol/__init__.py @@ -0,0 +1,52 @@ +""" +FA Diagnostic Protocol — reference implementation. + +A drop-in toolkit for evaluating Feedback Alignment (and related local-credit) +methods on modern residual architectures. Implements the four diagnostics +described in *Beyond Accuracy and Alignment: A Diagnostic Protocol for +Feedback Alignment Evaluation*: + + (a) per-layer residual stream norm ||h_l||_2 + (b) per-layer BP gradient norm ||g_l||_2 + (c) cross-batch direction stability s_l + (d) frozen-blocks control accuracy acc_frozen + +The motivation is that headline accuracy and offline Γ (cosine alignment of +local credit to BP gradient) can both silently fail on modern pre-LayerNorm +residual networks: residual-stream explosion drives ||g|| ~10^-10 and turns Γ +into a measurement of cosine to a numerical-floor reference vector. This +protocol catches that regime change and tells you whether your reported +numbers are interpretable. + +Quick start: + + from protocol import diagnose, DiagnosticReport + + report = diagnose( + model=trained_model, + eval_batches=list_of_(x, y)_pairs, + frozen_baseline_acc=accuracy_of_random_blocks_baseline, + method_name="DFA", + notes="4-block d=256 ResMLP, CIFAR-10", + ) + print(report) + +The report flags any diagnostic that crossed a "degenerate regime" threshold +and emits a verdict ("trustworthy" / "needs walk-back"). +""" + +from .protocol import ( + residual_norms, + bp_grad_norms, + cross_batch_direction_stability, + diagnose, +) +from .report import DiagnosticReport + +__all__ = [ + "residual_norms", + "bp_grad_norms", + "cross_batch_direction_stability", + "diagnose", + "DiagnosticReport", +] |
