summaryrefslogtreecommitdiff
path: root/protocol/__init__.py
blob: 770bddaa8f8e5b2dc8eaf3b8be5a0b66d4363b7f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
"""
FA Diagnostic Protocol — reference implementation.

A drop-in toolkit for evaluating Feedback Alignment (and related local-credit)
methods on modern residual architectures. Implements the four diagnostics
described in *Beyond Accuracy and Alignment: A Diagnostic Protocol for
Feedback Alignment Evaluation*:

    (a) per-layer residual stream norm   ||h_l||_2
    (b) per-layer BP gradient norm       ||g_l||_2
    (c) cross-batch direction stability  s_l
    (d) frozen-blocks control accuracy   acc_frozen

The motivation is that headline accuracy and offline Γ (cosine alignment of
local credit to BP gradient) can both silently fail on modern pre-LayerNorm
residual networks: residual-stream explosion drives ||g|| ~10^-10 and turns Γ
into a measurement of cosine to a numerical-floor reference vector. This
protocol catches that regime change and tells you whether your reported
numbers are interpretable.

Quick start:

    from protocol import diagnose, DiagnosticReport

    report = diagnose(
        model=trained_model,
        eval_batches=list_of_(x, y)_pairs,
        frozen_baseline_acc=accuracy_of_random_blocks_baseline,
        method_name="DFA",
        notes="4-block d=256 ResMLP, CIFAR-10",
    )
    print(report)

The report flags any diagnostic that crossed a "degenerate regime" threshold
and emits a verdict ("trustworthy" / "needs walk-back").
"""

from .protocol import (
    residual_norms,
    bp_grad_norms,
    cross_batch_direction_stability,
    diagnose,
)
from .report import DiagnosticReport

__all__ = [
    "residual_norms",
    "bp_grad_norms",
    "cross_batch_direction_stability",
    "diagnose",
    "DiagnosticReport",
]