1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
|
"""
FA Diagnostic Protocol — reference implementation.
A drop-in toolkit for evaluating Feedback Alignment (and related local-credit)
methods on modern residual architectures. Implements the four diagnostics
described in *Beyond Accuracy and Alignment: A Diagnostic Protocol for
Feedback Alignment Evaluation*:
(a) per-layer residual stream norm ||h_l||_2
(b) per-layer BP gradient norm ||g_l||_2
(c) cross-batch direction stability s_l
(d) frozen-blocks control accuracy acc_frozen
The motivation is that headline accuracy and offline Γ (cosine alignment of
local credit to BP gradient) can both silently fail on modern pre-LayerNorm
residual networks: residual-stream explosion drives ||g|| ~10^-10 and turns Γ
into a measurement of cosine to a numerical-floor reference vector. This
protocol catches that regime change and tells you whether your reported
numbers are interpretable.
Quick start:
from protocol import diagnose, DiagnosticReport
report = diagnose(
model=trained_model,
eval_batches=list_of_(x, y)_pairs,
frozen_baseline_acc=accuracy_of_random_blocks_baseline,
method_name="DFA",
notes="4-block d=256 ResMLP, CIFAR-10",
)
print(report)
The report flags any diagnostic that crossed a "degenerate regime" threshold
and emits a verdict ("trustworthy" / "needs walk-back").
"""
from .protocol import (
residual_norms,
bp_grad_norms,
cross_batch_direction_stability,
diagnose,
)
from .report import DiagnosticReport
__all__ = [
"residual_norms",
"bp_grad_norms",
"cross_batch_direction_stability",
"diagnose",
"DiagnosticReport",
]
|