""" Core diagnostics for the FA evaluation protocol. The four diagnostics: (a) residual_norms — per-layer ||h_l||_2 on a fixed eval batch (b) bp_grad_norms — per-layer ||∂CE/∂h_l||_2 on a fixed eval batch (c) cross_batch_direction_stability — cosine of normalized BP-grad direction across disjoint minibatches; high values indicate the reference vector is dominated by a sample-invariant global drift, NOT per-sample credit (d) frozen-blocks accuracy — caller-supplied; the protocol just compares it to the headline accuracy of the trained network The protocol expects a model exposing a `return_hidden=True` mode that yields the per-layer residual stream as a list `[h_0, h_1, ..., h_L]`, plus the standard `out_ln` + `out_head` terminal stack. ResidualMLP and ViT-Mini in this repo both satisfy that contract; see the docstring of `diagnose` for the duck-typed interface. All metrics use `tensor.norm(dim=-1)` (L2 along the feature dim) — NEVER `tensor.norm(-1)` (which is L_{-1} of the entire tensor and is the bug behind several walk-backs in our project). """ from __future__ import annotations from typing import List, Sequence, Tuple, Optional import torch import torch.nn.functional as F from .report import DiagnosticReport, DiagnosticThresholds # --------------------------------------------------------------------------- # # Diagnostic (a): residual stream norm # --------------------------------------------------------------------------- # def residual_norms(model, x: torch.Tensor) -> List[float]: """Per-layer median ||h_l||_2 on the eval batch. Args: model: must support `model(x, return_hidden=True) -> (logits, hiddens)` where hiddens is `[h_0, ..., h_L]` (each (B, ..., d_hidden)). x: eval batch (already on the correct device, already in the input shape the model wants). Returns: list of L+1 floats — the median (over batch) of ||h_l||_2 per layer. """ model.eval() with torch.no_grad(): _, hiddens = model(x, return_hidden=True) out: List[float] = [] for h in hiddens: # If h has token dim (B, T, d), pool tokens to a single vector per # sample by taking the cls token (index 0) — that's what the head # actually consumes in our ViT. if h.dim() == 3: h = h[:, 0, :] # NOTE: dim=-1 is mandatory; `.norm(-1)` would compute L_{-1} of the # whole tensor and is wrong (see project_norm_minus1_footgun). out.append(h.norm(dim=-1).median().item()) return out # --------------------------------------------------------------------------- # # Diagnostic (b): BP gradient norm at hidden layers # --------------------------------------------------------------------------- # def bp_grad_norms( model, x: torch.Tensor, y: torch.Tensor, loss_fn=F.cross_entropy, ) -> List[float]: """Per-layer median ||∂L/∂h_l||_2 on the eval batch. The BP gradient at hidden layers is the *reference vector* that offline alignment metrics (Γ) compare against. When this is at the numerical floor (~1e-10) on modern pre-LN residual nets, Γ measures alignment to a noise vector and is no longer interpretable as "credit assignment quality". Args: model: same contract as for `residual_norms`. x: eval batch. y: labels for the eval batch. loss_fn: scalar loss given (logits, y); cross_entropy by default. Returns: list of L+1 floats — the median (over batch) of ||g_l||_2 per layer, where g_l = ∂loss/∂h_l. """ model.eval() # Forward through the canonical path (embed -> blocks with explicit # residual addition -> terminal LN -> head), keeping every intermediate # h_l as a non-leaf in the autograd graph. `torch.autograd.grad` accepts # non-leaf tensors as `inputs` and computes ∂L/∂h_l via the chain rule. # NOTE: do not clone+detach the hidden states — that severs the graph. with torch.enable_grad(): h = _embed(model, x) hiddens: List[torch.Tensor] = [h] for block in _iter_blocks(model): f = block(h) h = h + f hiddens.append(h) logits = _terminal(model, h) loss = loss_fn(logits, y) grads = torch.autograd.grad(loss, hiddens) out: List[float] = [] for g in grads: if g.dim() == 3: g = g[:, 0, :] out.append(g.norm(dim=-1).median().item()) return out def _embed(model, x): if hasattr(model, "embed"): return model.embed(x) if hasattr(model, "patch_embed"): # ViT-style return model.patch_embed(x) raise AttributeError( "model must expose `.embed(x)` or `.patch_embed(x)` for the protocol" ) def _iter_blocks(model): if hasattr(model, "blocks"): return list(model.blocks) raise AttributeError("model must expose `.blocks` (a Module list)") def _terminal(model, h): """Apply terminal LN (if any) + classification head to hidden state.""" if hasattr(model, "out_ln") and hasattr(model, "out_head"): return model.out_head(model.out_ln(h)) if hasattr(model, "norm") and hasattr(model, "head"): # ViT-style naming h = model.norm(h) if h.dim() == 3: h = h[:, 0, :] return model.head(h) raise AttributeError( "model must expose `(out_ln, out_head)` or `(norm, head)` for the protocol" ) # --------------------------------------------------------------------------- # # Diagnostic (c): cross-batch direction stability # --------------------------------------------------------------------------- # def cross_batch_direction_stability( model, eval_batches: Sequence[Tuple[torch.Tensor, torch.Tensor]], layer_index: int, loss_fn=F.cross_entropy, ) -> float: """Mean pairwise cosine of the BP-grad *direction* at layer `layer_index` across disjoint minibatches. A value near 1.0 means: the BP gradient direction at that hidden layer is the same regardless of which samples are in the batch. That is the fingerprint of a **sample-invariant global drift** — i.e. the reference vector against which Γ is computed is NOT per-sample credit, it is a constant artifact of the trained network's geometry. In our 5-method CIFAR-10 audit at K=10 batches of 128 samples, healthy BP/EP networks cluster near zero with all six values in [-0.04, +0.12]; drift-dominated DFA/SB/CB pre-LN ResMLPs reach high tails up to +0.99 (5/9 of the nine degen values exceed the 0.30 default cutoff). See results/protocol_audit/audit_table_s42_s123_s456.json for the per-method per-seed values. Args: model: same contract as the other diagnostics. eval_batches: list of (x, y) pairs, each a *disjoint* minibatch on the same device. Use ≥8 batches of ~128 samples each. The metric is calibrated for this batch size: smaller batches let per-sample noise dominate the batch-mean direction on healthy networks (~0.10), giving the cleanest separation from drift-dominated failure modes (~0.95). Larger batches average out per-sample noise via CLT and shrink the gap. layer_index: which layer's BP grad direction to compare. loss_fn: scalar loss; cross_entropy by default. Returns: mean pairwise cosine in [-1, 1]. """ if len(eval_batches) < 2: raise ValueError("need ≥2 eval batches to measure stability") directions: List[torch.Tensor] = [] for x, y in eval_batches: with torch.enable_grad(): h = _embed(model, x) hiddens: List[torch.Tensor] = [h] for block in _iter_blocks(model): f = block(h) h = h + f hiddens.append(h) logits = _terminal(model, h) loss = loss_fn(logits, y) grads = torch.autograd.grad(loss, hiddens) g = grads[layer_index] if g.dim() == 3: g = g[:, 0, :] # Reduce to a single direction per batch by averaging across samples, # then normalizing. This is the natural "direction the gradient is # pointing on this batch" signal. gbar = g.mean(dim=0) n = gbar.norm() if n < 1e-30: directions.append(torch.zeros_like(gbar)) else: directions.append(gbar / n) # Mean pairwise cosine cos_vals: List[float] = [] for i in range(len(directions)): for j in range(i + 1, len(directions)): cos_vals.append( float(torch.dot(directions[i], directions[j]).item()) ) return sum(cos_vals) / len(cos_vals) # --------------------------------------------------------------------------- # # Diagnostic (d) is just a comparison: caller supplies frozen_baseline_acc # --------------------------------------------------------------------------- # # --------------------------------------------------------------------------- # # Top-level convenience: run all four and produce a verdict # --------------------------------------------------------------------------- # def diagnose( model, eval_batches: Sequence[Tuple[torch.Tensor, torch.Tensor]], headline_acc: float, frozen_baseline_acc: Optional[float] = None, method_name: str = "method", notes: str = "", layer_for_stability: Optional[int] = None, thresholds: Optional[DiagnosticThresholds] = None, ) -> DiagnosticReport: """Run the full protocol on `model` and return a DiagnosticReport. Args: model: trained network. Must satisfy the duck-typed interface described at the top of `protocol.py`. eval_batches: ≥4 disjoint minibatches `(x, y)`, on the model's device. headline_acc: the accuracy you would otherwise have reported as the method's result on this network. frozen_baseline_acc: accuracy of an architecture-matched baseline with the deep blocks **frozen at random initialization** and only embed/LN/head trained. If None, the (d)-diagnostic is skipped. See `protocol/examples/frozen_baseline.py` for a reference impl. method_name: free-form label, only used in the printed report. notes: free-form architecture/dataset description for the report. layer_for_stability: which hidden layer index to use for the cross- batch stability check. Defaults to the middle layer. thresholds: degeneracy thresholds; defaults to those used in our paper (see DiagnosticThresholds for details). Returns: A DiagnosticReport summarizing each diagnostic and its verdict. """ if thresholds is None: thresholds = DiagnosticThresholds() if not eval_batches: raise ValueError("eval_batches must not be empty") x0, y0 = eval_batches[0] h_norms = residual_norms(model, x0) g_norms = bp_grad_norms(model, x0, y0) if layer_for_stability is None: layer_for_stability = max(1, len(h_norms) // 2) stability = cross_batch_direction_stability( model, eval_batches, layer_index=layer_for_stability ) return DiagnosticReport( method_name=method_name, notes=notes, residual_norms=h_norms, bp_grad_norms=g_norms, stability_layer=layer_for_stability, cross_batch_stability=stability, headline_acc=headline_acc, frozen_baseline_acc=frozen_baseline_acc, thresholds=thresholds, )