summaryrefslogtreecommitdiff
path: root/protocol/report.py
blob: 00640eb99893199b53dae790eb9701f38415ecde (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
"""
DiagnosticReport: structured output of the FA evaluation protocol.

Holds the per-layer numbers from the four diagnostics and emits a verdict
for each one based on the published thresholds. The verdict is intentionally
binary ("trustworthy" / "needs walk-back"); fine-grained reading is the
caller's job.
"""
from __future__ import annotations

from dataclasses import dataclass, field
from typing import List, Optional


@dataclass
class DiagnosticThresholds:
    """Degeneracy thresholds (defaults match the paper).

    g_norm_floor: BP gradient norms below this are considered to be at the
        numerical floor (Γ measured against this is not interpretable as
        credit alignment). Default 1e-7 — well above fp32 floor (~1e-38) and
        well above F.cosine_similarity's eps=1e-8 clamp, but several orders
        below healthy networks (~1e-5).
    h_norm_explosion_ratio: residual stream norm growth (relative to layer 0)
        above this is considered "exploded". Default 50× — BP-trained
        networks are ~1-3× per layer; failure modes show ~10^5-10^6×.
    stability_drift_ceiling: cross-batch direction cosine above this is
        considered to be drift-dominated (reference vector is sample-
        invariant). Default 0.30 — BP-trained / EP-trained networks are
        below 0.20, failure modes are 0.5-0.99.
    frozen_acc_margin_pp: minimum acc gain (in percentage points) of the
        trainable-blocks variant over the frozen-random-blocks baseline for
        the deep blocks to be considered "actually contributing". Default
        2.0 pp.
    """

    g_norm_floor: float = 1e-7
    # Per-block residual growth ratio threshold. The diagnostic is
    # `max_l(||h_{l+1}|| / ||h_l||)` — the largest single-block residual
    # amplification. We avoided `max(||h||) / ||h_0||` because it false-
    # positives on ViT-style architectures where the cls token at layer 0
    # is anomalously small after patch_embed.
    #
    # Calibration on observed data:
    #   - BP-trained, late training: <5× per block (steady state)
    #   - BP ViT, early training (epoch 1-5): 13-25× max (cls token still
    #     resolving from its small init magnitude)
    #   - DFA-trained ResMLP / ViT: 100-4000× max per block
    # Threshold 50 sits cleanly between healthy-early-training (max 25) and
    # failure-regime (min 100), with margin on both sides.
    h_norm_explosion_ratio: float = 50.0
    stability_drift_ceiling: float = 0.30
    frozen_acc_margin_pp: float = 2.0


@dataclass
class DiagnosticReport:
    """Result of running the protocol on one trained network."""

    method_name: str
    notes: str
    residual_norms: List[float]
    bp_grad_norms: List[float]
    stability_layer: int
    cross_batch_stability: float
    headline_acc: float
    frozen_baseline_acc: Optional[float]
    thresholds: DiagnosticThresholds = field(default_factory=DiagnosticThresholds)

    # ------------------------------------------------------------------ #
    # Per-diagnostic verdicts
    # ------------------------------------------------------------------ #

    @property
    def max_per_block_growth(self) -> float:
        """max_l (||h_{l+1}|| / ||h_l||) — the largest residual-stream
        amplification by any single block. Healthy BP/EP networks have all
        per-block growth < 5×; pathological networks (DFA/SB/CB on pre-LN
        residuals) have at least one block with growth > 100×."""
        if len(self.residual_norms) < 2:
            return 1.0
        ratios = []
        for i in range(len(self.residual_norms) - 1):
            denom = max(self.residual_norms[i], 1e-30)
            ratios.append(self.residual_norms[i + 1] / denom)
        return max(ratios)

    @property
    def residual_stream_exploded(self) -> bool:
        return self.max_per_block_growth > self.thresholds.h_norm_explosion_ratio

    @property
    def bp_grad_at_floor(self) -> bool:
        if not self.bp_grad_norms:
            return False
        # Check the *deepest* hidden layer's BP grad — that's where Γ is
        # typically reported and where LN-driven collapse hits hardest.
        return self.bp_grad_norms[-1] < self.thresholds.g_norm_floor

    @property
    def reference_drift_dominated(self) -> bool:
        return self.cross_batch_stability > self.thresholds.stability_drift_ceiling

    @property
    def frozen_baseline_undercut(self) -> Optional[bool]:
        """True if the trainable-blocks acc fails to clear the frozen baseline
        by `frozen_acc_margin_pp`. None if no frozen baseline supplied.
        """
        if self.frozen_baseline_acc is None:
            return None
        margin_pp = (self.headline_acc - self.frozen_baseline_acc) * 100
        return margin_pp < self.thresholds.frozen_acc_margin_pp

    # ------------------------------------------------------------------ #
    # Aggregate verdict
    # ------------------------------------------------------------------ #

    @property
    def verdict(self) -> str:
        flags = [
            ("residual stream exploded", self.residual_stream_exploded),
            ("BP grad at numerical floor", self.bp_grad_at_floor),
            ("BP grad direction is drift-dominated", self.reference_drift_dominated),
        ]
        if self.frozen_baseline_undercut is True:
            flags.append(("deep blocks fail to beat frozen-random baseline", True))
        flagged = [name for name, val in flags if val]
        if not flagged:
            return "trustworthy"
        return "needs walk-back: " + "; ".join(flagged)

    # ------------------------------------------------------------------ #
    # Pretty-print
    # ------------------------------------------------------------------ #

    def __str__(self) -> str:
        L = len(self.residual_norms)
        lines: List[str] = []
        lines.append("=" * 72)
        lines.append(f"FA Diagnostic Protocol Report — method: {self.method_name}")
        if self.notes:
            lines.append(f"Notes: {self.notes}")
        lines.append("=" * 72)

        # (a) Residual stream norms
        lines.append("(a) Residual stream norms ||h_l||_2 (median over batch):")
        for l in range(L):
            lines.append(f"      h_{l}:  {self.residual_norms[l]:.3e}")
        if self.residual_stream_exploded:
            lines.append(
                f"    FLAG: max per-block growth ‖h_{{l+1}}‖/‖h_l‖ = "
                f"{self.max_per_block_growth:.2e} "
                f"> threshold {self.thresholds.h_norm_explosion_ratio}× — "
                "residual stream exploded."
            )

        # (b) BP grad norms
        lines.append("")
        lines.append("(b) BP gradient norms ||g_l||_2 (median over batch):")
        for l in range(L):
            lines.append(f"      g_{l}:  {self.bp_grad_norms[l]:.3e}")
        if self.bp_grad_at_floor:
            lines.append(
                f"    FLAG: deepest ||g_L|| = {self.bp_grad_norms[-1]:.2e} "
                f"< floor {self.thresholds.g_norm_floor:.0e} — Γ measured "
                "against numerical floor; not interpretable as credit alignment."
            )

        # (c) Cross-batch direction stability
        lines.append("")
        lines.append(
            f"(c) Cross-batch direction stability at layer {self.stability_layer}: "
            f"{self.cross_batch_stability:.3f}"
        )
        if self.reference_drift_dominated:
            lines.append(
                f"    FLAG: stability {self.cross_batch_stability:.3f} > "
                f"ceiling {self.thresholds.stability_drift_ceiling} — "
                "reference vector is sample-invariant drift, not per-sample credit."
            )

        # (d) Frozen baseline comparison
        lines.append("")
        if self.frozen_baseline_acc is None:
            lines.append("(d) Frozen-blocks baseline: NOT PROVIDED — diagnostic skipped.")
        else:
            margin_pp = (self.headline_acc - self.frozen_baseline_acc) * 100
            lines.append(
                f"(d) Headline acc: {self.headline_acc:.4f}, "
                f"frozen-blocks baseline: {self.frozen_baseline_acc:.4f}, "
                f"margin: {margin_pp:+.2f} pp"
            )
            if self.frozen_baseline_undercut:
                lines.append(
                    f"    FLAG: margin {margin_pp:+.2f} pp < "
                    f"required {self.thresholds.frozen_acc_margin_pp} pp — "
                    "deep blocks are not contributing over a random untrained baseline."
                )

        # Verdict
        lines.append("")
        lines.append(f"VERDICT: {self.verdict}")
        lines.append("=" * 72)
        return "\n".join(lines)

    def to_dict(self) -> dict:
        return {
            "method_name": self.method_name,
            "notes": self.notes,
            "residual_norms": list(self.residual_norms),
            "bp_grad_norms": list(self.bp_grad_norms),
            "stability_layer": self.stability_layer,
            "cross_batch_stability": self.cross_batch_stability,
            "headline_acc": self.headline_acc,
            "frozen_baseline_acc": self.frozen_baseline_acc,
            "verdict": self.verdict,
            "thresholds": {
                "g_norm_floor": self.thresholds.g_norm_floor,
                "h_norm_explosion_ratio": self.thresholds.h_norm_explosion_ratio,
                "stability_drift_ceiling": self.thresholds.stability_drift_ceiling,
                "frozen_acc_margin_pp": self.thresholds.frozen_acc_margin_pp,
            },
        }