summaryrefslogtreecommitdiff
path: root/protocol/__init__.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 22:20:48 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 22:20:48 -0500
commit7b64702ad970c16171142665365e16a8e1737190 (patch)
tree30df9313fc195aa037f05c0598fdc0df8b32bbcc /protocol/__init__.py
parent2dc8e7efb5f2ff827fbd97be87f0127aa5ab2757 (diff)
Add FA diagnostic protocol reference implementation
Codex round 15 #1 priority for the E&D-track paper: - protocol/protocol.py: 4 diagnostics (residual norms, BP grad norms, cross-batch direction stability, and a frozen-baseline comparator) - protocol/report.py: DiagnosticReport with per-diagnostic verdicts and pretty-printer - protocol/smoke_test.py: validates BP/DFA/EP checkpoints produce the expected verdicts (BP/EP trustworthy; DFA walked back via residual explosion + BP grad at floor) - protocol/README.md: usage, audit cases, threshold rationale - protocol/CHECKLIST.md: 6 evaluation pipeline pitfalls (norm(-1), cosine_similarity eps clamp, fp16 underflow, Bs reproducibility, aggregation, layer-0 dominance) - protocol/REPORTING_TEMPLATE.md: per-method fillable form for FA papers
Diffstat (limited to 'protocol/__init__.py')
-rw-r--r--protocol/__init__.py52
1 files changed, 52 insertions, 0 deletions
diff --git a/protocol/__init__.py b/protocol/__init__.py
new file mode 100644
index 0000000..770bdda
--- /dev/null
+++ b/protocol/__init__.py
@@ -0,0 +1,52 @@
+"""
+FA Diagnostic Protocol — reference implementation.
+
+A drop-in toolkit for evaluating Feedback Alignment (and related local-credit)
+methods on modern residual architectures. Implements the four diagnostics
+described in *Beyond Accuracy and Alignment: A Diagnostic Protocol for
+Feedback Alignment Evaluation*:
+
+ (a) per-layer residual stream norm ||h_l||_2
+ (b) per-layer BP gradient norm ||g_l||_2
+ (c) cross-batch direction stability s_l
+ (d) frozen-blocks control accuracy acc_frozen
+
+The motivation is that headline accuracy and offline Γ (cosine alignment of
+local credit to BP gradient) can both silently fail on modern pre-LayerNorm
+residual networks: residual-stream explosion drives ||g|| ~10^-10 and turns Γ
+into a measurement of cosine to a numerical-floor reference vector. This
+protocol catches that regime change and tells you whether your reported
+numbers are interpretable.
+
+Quick start:
+
+ from protocol import diagnose, DiagnosticReport
+
+ report = diagnose(
+ model=trained_model,
+ eval_batches=list_of_(x, y)_pairs,
+ frozen_baseline_acc=accuracy_of_random_blocks_baseline,
+ method_name="DFA",
+ notes="4-block d=256 ResMLP, CIFAR-10",
+ )
+ print(report)
+
+The report flags any diagnostic that crossed a "degenerate regime" threshold
+and emits a verdict ("trustworthy" / "needs walk-back").
+"""
+
+from .protocol import (
+ residual_norms,
+ bp_grad_norms,
+ cross_batch_direction_stability,
+ diagnose,
+)
+from .report import DiagnosticReport
+
+__all__ = [
+ "residual_norms",
+ "bp_grad_norms",
+ "cross_batch_direction_stability",
+ "diagnose",
+ "DiagnosticReport",
+]