1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
|
"""
Core diagnostics for the FA evaluation protocol.
The four diagnostics:
(a) residual_norms — per-layer ||h_l||_2 on a fixed eval batch
(b) bp_grad_norms — per-layer ||∂CE/∂h_l||_2 on a fixed eval batch
(c) cross_batch_direction_stability — cosine of normalized BP-grad direction
across disjoint minibatches; high values
indicate the reference vector is dominated by
a sample-invariant global drift, NOT per-sample
credit
(d) frozen-blocks accuracy — caller-supplied; the protocol just compares it
to the headline accuracy of the trained network
The protocol expects a model exposing a `return_hidden=True` mode that yields
the per-layer residual stream as a list `[h_0, h_1, ..., h_L]`, plus the
standard `out_ln` + `out_head` terminal stack. ResidualMLP and ViT-Mini in
this repo both satisfy that contract; see the docstring of `diagnose` for the
duck-typed interface.
All metrics use `tensor.norm(dim=-1)` (L2 along the feature dim) — NEVER
`tensor.norm(-1)` (which is L_{-1} of the entire tensor and is the bug behind
several walk-backs in our project).
"""
from __future__ import annotations
from typing import List, Sequence, Tuple, Optional
import torch
import torch.nn.functional as F
from .report import DiagnosticReport, DiagnosticThresholds
# --------------------------------------------------------------------------- #
# Diagnostic (a): residual stream norm
# --------------------------------------------------------------------------- #
def residual_norms(model, x: torch.Tensor) -> List[float]:
"""Per-layer median ||h_l||_2 on the eval batch.
Args:
model: must support `model(x, return_hidden=True) -> (logits, hiddens)`
where hiddens is `[h_0, ..., h_L]` (each (B, ..., d_hidden)).
x: eval batch (already on the correct device, already in the input
shape the model wants).
Returns:
list of L+1 floats — the median (over batch) of ||h_l||_2 per layer.
"""
model.eval()
with torch.no_grad():
_, hiddens = model(x, return_hidden=True)
out: List[float] = []
for h in hiddens:
# If h has token dim (B, T, d), pool tokens to a single vector per
# sample by taking the cls token (index 0) — that's what the head
# actually consumes in our ViT.
if h.dim() == 3:
h = h[:, 0, :]
# NOTE: dim=-1 is mandatory; `.norm(-1)` would compute L_{-1} of the
# whole tensor and is wrong (see project_norm_minus1_footgun).
out.append(h.norm(dim=-1).median().item())
return out
# --------------------------------------------------------------------------- #
# Diagnostic (b): BP gradient norm at hidden layers
# --------------------------------------------------------------------------- #
def bp_grad_norms(
model,
x: torch.Tensor,
y: torch.Tensor,
loss_fn=F.cross_entropy,
) -> List[float]:
"""Per-layer median ||∂L/∂h_l||_2 on the eval batch.
The BP gradient at hidden layers is the *reference vector* that offline
alignment metrics (Γ) compare against. When this is at the numerical
floor (~1e-10) on modern pre-LN residual nets, Γ measures alignment to a
noise vector and is no longer interpretable as "credit assignment
quality".
Args:
model: same contract as for `residual_norms`.
x: eval batch.
y: labels for the eval batch.
loss_fn: scalar loss given (logits, y); cross_entropy by default.
Returns:
list of L+1 floats — the median (over batch) of ||g_l||_2 per layer,
where g_l = ∂loss/∂h_l.
"""
model.eval()
# Forward through the canonical path (embed -> blocks with explicit
# residual addition -> terminal LN -> head), keeping every intermediate
# h_l as a non-leaf in the autograd graph. `torch.autograd.grad` accepts
# non-leaf tensors as `inputs` and computes ∂L/∂h_l via the chain rule.
# NOTE: do not clone+detach the hidden states — that severs the graph.
with torch.enable_grad():
h = _embed(model, x)
hiddens: List[torch.Tensor] = [h]
for block in _iter_blocks(model):
f = block(h)
h = h + f
hiddens.append(h)
logits = _terminal(model, h)
loss = loss_fn(logits, y)
grads = torch.autograd.grad(loss, hiddens)
out: List[float] = []
for g in grads:
if g.dim() == 3:
g = g[:, 0, :]
out.append(g.norm(dim=-1).median().item())
return out
def _embed(model, x):
if hasattr(model, "embed"):
return model.embed(x)
if hasattr(model, "patch_embed"):
# ViT-style
return model.patch_embed(x)
raise AttributeError(
"model must expose `.embed(x)` or `.patch_embed(x)` for the protocol"
)
def _iter_blocks(model):
if hasattr(model, "blocks"):
return list(model.blocks)
raise AttributeError("model must expose `.blocks` (a Module list)")
def _terminal(model, h):
"""Apply terminal LN (if any) + classification head to hidden state."""
if hasattr(model, "out_ln") and hasattr(model, "out_head"):
return model.out_head(model.out_ln(h))
if hasattr(model, "norm") and hasattr(model, "head"):
# ViT-style naming
h = model.norm(h)
if h.dim() == 3:
h = h[:, 0, :]
return model.head(h)
raise AttributeError(
"model must expose `(out_ln, out_head)` or `(norm, head)` for the protocol"
)
# --------------------------------------------------------------------------- #
# Diagnostic (c): cross-batch direction stability
# --------------------------------------------------------------------------- #
def cross_batch_direction_stability(
model,
eval_batches: Sequence[Tuple[torch.Tensor, torch.Tensor]],
layer_index: int,
loss_fn=F.cross_entropy,
) -> float:
"""Mean pairwise cosine of the BP-grad *direction* at layer `layer_index`
across disjoint minibatches.
A value near 1.0 means: the BP gradient direction at that hidden layer is
the same regardless of which samples are in the batch. That is the
fingerprint of a **sample-invariant global drift** — i.e. the reference
vector against which Γ is computed is NOT per-sample credit, it is a
constant artifact of the trained network's geometry. On healthy
BP-trained or EP-trained networks this value is small (~0.05-0.18); on
DFA/SB/CB pre-LN ResMLPs we see ~0.5-0.99.
Args:
model: same contract as the other diagnostics.
eval_batches: list of (x, y) pairs, each a *disjoint* minibatch on the
same device. Use ≥8 batches of ~128 samples each. The metric is
calibrated for this batch size: smaller batches let per-sample
noise dominate the batch-mean direction on healthy networks
(~0.10), giving the cleanest separation from drift-dominated
failure modes (~0.95). Larger batches average out per-sample
noise via CLT and shrink the gap.
layer_index: which layer's BP grad direction to compare.
loss_fn: scalar loss; cross_entropy by default.
Returns:
mean pairwise cosine in [-1, 1].
"""
if len(eval_batches) < 2:
raise ValueError("need ≥2 eval batches to measure stability")
directions: List[torch.Tensor] = []
for x, y in eval_batches:
with torch.enable_grad():
h = _embed(model, x)
hiddens: List[torch.Tensor] = [h]
for block in _iter_blocks(model):
f = block(h)
h = h + f
hiddens.append(h)
logits = _terminal(model, h)
loss = loss_fn(logits, y)
grads = torch.autograd.grad(loss, hiddens)
g = grads[layer_index]
if g.dim() == 3:
g = g[:, 0, :]
# Reduce to a single direction per batch by averaging across samples,
# then normalizing. This is the natural "direction the gradient is
# pointing on this batch" signal.
gbar = g.mean(dim=0)
n = gbar.norm()
if n < 1e-30:
directions.append(torch.zeros_like(gbar))
else:
directions.append(gbar / n)
# Mean pairwise cosine
cos_vals: List[float] = []
for i in range(len(directions)):
for j in range(i + 1, len(directions)):
cos_vals.append(
float(torch.dot(directions[i], directions[j]).item())
)
return sum(cos_vals) / len(cos_vals)
# --------------------------------------------------------------------------- #
# Diagnostic (d) is just a comparison: caller supplies frozen_baseline_acc
# --------------------------------------------------------------------------- #
# --------------------------------------------------------------------------- #
# Top-level convenience: run all four and produce a verdict
# --------------------------------------------------------------------------- #
def diagnose(
model,
eval_batches: Sequence[Tuple[torch.Tensor, torch.Tensor]],
headline_acc: float,
frozen_baseline_acc: Optional[float] = None,
method_name: str = "method",
notes: str = "",
layer_for_stability: Optional[int] = None,
thresholds: Optional[DiagnosticThresholds] = None,
) -> DiagnosticReport:
"""Run the full protocol on `model` and return a DiagnosticReport.
Args:
model: trained network. Must satisfy the duck-typed interface
described at the top of `protocol.py`.
eval_batches: ≥4 disjoint minibatches `(x, y)`, on the model's device.
headline_acc: the accuracy you would otherwise have reported as the
method's result on this network.
frozen_baseline_acc: accuracy of an architecture-matched baseline
with the deep blocks **frozen at random initialization** and only
embed/LN/head trained. If None, the (d)-diagnostic is skipped.
See `protocol/examples/frozen_baseline.py` for a reference impl.
method_name: free-form label, only used in the printed report.
notes: free-form architecture/dataset description for the report.
layer_for_stability: which hidden layer index to use for the cross-
batch stability check. Defaults to the middle layer.
thresholds: degeneracy thresholds; defaults to those used in our
paper (see DiagnosticThresholds for details).
Returns:
A DiagnosticReport summarizing each diagnostic and its verdict.
"""
if thresholds is None:
thresholds = DiagnosticThresholds()
if not eval_batches:
raise ValueError("eval_batches must not be empty")
x0, y0 = eval_batches[0]
h_norms = residual_norms(model, x0)
g_norms = bp_grad_norms(model, x0, y0)
if layer_for_stability is None:
layer_for_stability = max(1, len(h_norms) // 2)
stability = cross_batch_direction_stability(
model, eval_batches, layer_index=layer_for_stability
)
return DiagnosticReport(
method_name=method_name,
notes=notes,
residual_norms=h_norms,
bp_grad_norms=g_norms,
stability_layer=layer_for_stability,
cross_batch_stability=stability,
headline_acc=headline_acc,
frozen_baseline_acc=frozen_baseline_acc,
thresholds=thresholds,
)
|