summaryrefslogtreecommitdiff
path: root/protocol/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'protocol/__init__.py')
-rw-r--r--protocol/__init__.py52
1 files changed, 52 insertions, 0 deletions
diff --git a/protocol/__init__.py b/protocol/__init__.py
new file mode 100644
index 0000000..770bdda
--- /dev/null
+++ b/protocol/__init__.py
@@ -0,0 +1,52 @@
+"""
+FA Diagnostic Protocol — reference implementation.
+
+A drop-in toolkit for evaluating Feedback Alignment (and related local-credit)
+methods on modern residual architectures. Implements the four diagnostics
+described in *Beyond Accuracy and Alignment: A Diagnostic Protocol for
+Feedback Alignment Evaluation*:
+
+ (a) per-layer residual stream norm ||h_l||_2
+ (b) per-layer BP gradient norm ||g_l||_2
+ (c) cross-batch direction stability s_l
+ (d) frozen-blocks control accuracy acc_frozen
+
+The motivation is that headline accuracy and offline Γ (cosine alignment of
+local credit to BP gradient) can both silently fail on modern pre-LayerNorm
+residual networks: residual-stream explosion drives ||g|| ~10^-10 and turns Γ
+into a measurement of cosine to a numerical-floor reference vector. This
+protocol catches that regime change and tells you whether your reported
+numbers are interpretable.
+
+Quick start:
+
+ from protocol import diagnose, DiagnosticReport
+
+ report = diagnose(
+ model=trained_model,
+ eval_batches=list_of_(x, y)_pairs,
+ frozen_baseline_acc=accuracy_of_random_blocks_baseline,
+ method_name="DFA",
+ notes="4-block d=256 ResMLP, CIFAR-10",
+ )
+ print(report)
+
+The report flags any diagnostic that crossed a "degenerate regime" threshold
+and emits a verdict ("trustworthy" / "needs walk-back").
+"""
+
+from .protocol import (
+ residual_norms,
+ bp_grad_norms,
+ cross_batch_direction_stability,
+ diagnose,
+)
+from .report import DiagnosticReport
+
+__all__ = [
+ "residual_norms",
+ "bp_grad_norms",
+ "cross_batch_direction_stability",
+ "diagnose",
+ "DiagnosticReport",
+]