diff options
Diffstat (limited to 'protocol/protocol.py')
| -rw-r--r-- | protocol/protocol.py | 290 |
1 files changed, 290 insertions, 0 deletions
diff --git a/protocol/protocol.py b/protocol/protocol.py new file mode 100644 index 0000000..4298756 --- /dev/null +++ b/protocol/protocol.py @@ -0,0 +1,290 @@ +""" +Core diagnostics for the FA evaluation protocol. + +The four diagnostics: + + (a) residual_norms — per-layer ||h_l||_2 on a fixed eval batch + (b) bp_grad_norms — per-layer ||∂CE/∂h_l||_2 on a fixed eval batch + (c) cross_batch_direction_stability — cosine of normalized BP-grad direction + across disjoint minibatches; high values + indicate the reference vector is dominated by + a sample-invariant global drift, NOT per-sample + credit + (d) frozen-blocks accuracy — caller-supplied; the protocol just compares it + to the headline accuracy of the trained network + +The protocol expects a model exposing a `return_hidden=True` mode that yields +the per-layer residual stream as a list `[h_0, h_1, ..., h_L]`, plus the +standard `out_ln` + `out_head` terminal stack. ResidualMLP and ViT-Mini in +this repo both satisfy that contract; see the docstring of `diagnose` for the +duck-typed interface. + +All metrics use `tensor.norm(dim=-1)` (L2 along the feature dim) — NEVER +`tensor.norm(-1)` (which is L_{-1} of the entire tensor and is the bug behind +several walk-backs in our project). +""" +from __future__ import annotations + +from typing import List, Sequence, Tuple, Optional + +import torch +import torch.nn.functional as F + +from .report import DiagnosticReport, DiagnosticThresholds + + +# --------------------------------------------------------------------------- # +# Diagnostic (a): residual stream norm +# --------------------------------------------------------------------------- # + +def residual_norms(model, x: torch.Tensor) -> List[float]: + """Per-layer median ||h_l||_2 on the eval batch. + + Args: + model: must support `model(x, return_hidden=True) -> (logits, hiddens)` + where hiddens is `[h_0, ..., h_L]` (each (B, ..., d_hidden)). + x: eval batch (already on the correct device, already in the input + shape the model wants). + + Returns: + list of L+1 floats — the median (over batch) of ||h_l||_2 per layer. + """ + model.eval() + with torch.no_grad(): + _, hiddens = model(x, return_hidden=True) + out: List[float] = [] + for h in hiddens: + # If h has token dim (B, T, d), pool tokens to a single vector per + # sample by taking the cls token (index 0) — that's what the head + # actually consumes in our ViT. + if h.dim() == 3: + h = h[:, 0, :] + # NOTE: dim=-1 is mandatory; `.norm(-1)` would compute L_{-1} of the + # whole tensor and is wrong (see project_norm_minus1_footgun). + out.append(h.norm(dim=-1).median().item()) + return out + + +# --------------------------------------------------------------------------- # +# Diagnostic (b): BP gradient norm at hidden layers +# --------------------------------------------------------------------------- # + +def bp_grad_norms( + model, + x: torch.Tensor, + y: torch.Tensor, + loss_fn=F.cross_entropy, +) -> List[float]: + """Per-layer median ||∂L/∂h_l||_2 on the eval batch. + + The BP gradient at hidden layers is the *reference vector* that offline + alignment metrics (Γ) compare against. When this is at the numerical + floor (~1e-10) on modern pre-LN residual nets, Γ measures alignment to a + noise vector and is no longer interpretable as "credit assignment + quality". + + Args: + model: same contract as for `residual_norms`. + x: eval batch. + y: labels for the eval batch. + loss_fn: scalar loss given (logits, y); cross_entropy by default. + + Returns: + list of L+1 floats — the median (over batch) of ||g_l||_2 per layer, + where g_l = ∂loss/∂h_l. + """ + model.eval() + # Forward through the canonical path (embed -> blocks with explicit + # residual addition -> terminal LN -> head), keeping every intermediate + # h_l as a non-leaf in the autograd graph. `torch.autograd.grad` accepts + # non-leaf tensors as `inputs` and computes ∂L/∂h_l via the chain rule. + # NOTE: do not clone+detach the hidden states — that severs the graph. + with torch.enable_grad(): + h = _embed(model, x) + hiddens: List[torch.Tensor] = [h] + for block in _iter_blocks(model): + f = block(h) + h = h + f + hiddens.append(h) + logits = _terminal(model, h) + loss = loss_fn(logits, y) + grads = torch.autograd.grad(loss, hiddens) + out: List[float] = [] + for g in grads: + if g.dim() == 3: + g = g[:, 0, :] + out.append(g.norm(dim=-1).median().item()) + return out + + +def _embed(model, x): + if hasattr(model, "embed"): + return model.embed(x) + if hasattr(model, "patch_embed"): + # ViT-style + return model.patch_embed(x) + raise AttributeError( + "model must expose `.embed(x)` or `.patch_embed(x)` for the protocol" + ) + + +def _iter_blocks(model): + if hasattr(model, "blocks"): + return list(model.blocks) + raise AttributeError("model must expose `.blocks` (a Module list)") + + +def _terminal(model, h): + """Apply terminal LN (if any) + classification head to hidden state.""" + if hasattr(model, "out_ln") and hasattr(model, "out_head"): + return model.out_head(model.out_ln(h)) + if hasattr(model, "norm") and hasattr(model, "head"): + # ViT-style naming + h = model.norm(h) + if h.dim() == 3: + h = h[:, 0, :] + return model.head(h) + raise AttributeError( + "model must expose `(out_ln, out_head)` or `(norm, head)` for the protocol" + ) + + +# --------------------------------------------------------------------------- # +# Diagnostic (c): cross-batch direction stability +# --------------------------------------------------------------------------- # + +def cross_batch_direction_stability( + model, + eval_batches: Sequence[Tuple[torch.Tensor, torch.Tensor]], + layer_index: int, + loss_fn=F.cross_entropy, +) -> float: + """Mean pairwise cosine of the BP-grad *direction* at layer `layer_index` + across disjoint minibatches. + + A value near 1.0 means: the BP gradient direction at that hidden layer is + the same regardless of which samples are in the batch. That is the + fingerprint of a **sample-invariant global drift** — i.e. the reference + vector against which Γ is computed is NOT per-sample credit, it is a + constant artifact of the trained network's geometry. On healthy + BP-trained or EP-trained networks this value is small (~0.05-0.18); on + DFA/SB/CB pre-LN ResMLPs we see ~0.5-0.99. + + Args: + model: same contract as the other diagnostics. + eval_batches: list of (x, y) pairs, each a *disjoint* minibatch on the + same device. Use ≥8 batches of ~128 samples each. The metric is + calibrated for this batch size: smaller batches let per-sample + noise dominate the batch-mean direction on healthy networks + (~0.10), giving the cleanest separation from drift-dominated + failure modes (~0.95). Larger batches average out per-sample + noise via CLT and shrink the gap. + layer_index: which layer's BP grad direction to compare. + loss_fn: scalar loss; cross_entropy by default. + + Returns: + mean pairwise cosine in [-1, 1]. + """ + if len(eval_batches) < 2: + raise ValueError("need ≥2 eval batches to measure stability") + directions: List[torch.Tensor] = [] + for x, y in eval_batches: + with torch.enable_grad(): + h = _embed(model, x) + hiddens: List[torch.Tensor] = [h] + for block in _iter_blocks(model): + f = block(h) + h = h + f + hiddens.append(h) + logits = _terminal(model, h) + loss = loss_fn(logits, y) + grads = torch.autograd.grad(loss, hiddens) + g = grads[layer_index] + if g.dim() == 3: + g = g[:, 0, :] + # Reduce to a single direction per batch by averaging across samples, + # then normalizing. This is the natural "direction the gradient is + # pointing on this batch" signal. + gbar = g.mean(dim=0) + n = gbar.norm() + if n < 1e-30: + directions.append(torch.zeros_like(gbar)) + else: + directions.append(gbar / n) + # Mean pairwise cosine + cos_vals: List[float] = [] + for i in range(len(directions)): + for j in range(i + 1, len(directions)): + cos_vals.append( + float(torch.dot(directions[i], directions[j]).item()) + ) + return sum(cos_vals) / len(cos_vals) + + +# --------------------------------------------------------------------------- # +# Diagnostic (d) is just a comparison: caller supplies frozen_baseline_acc +# --------------------------------------------------------------------------- # + + +# --------------------------------------------------------------------------- # +# Top-level convenience: run all four and produce a verdict +# --------------------------------------------------------------------------- # + +def diagnose( + model, + eval_batches: Sequence[Tuple[torch.Tensor, torch.Tensor]], + headline_acc: float, + frozen_baseline_acc: Optional[float] = None, + method_name: str = "method", + notes: str = "", + layer_for_stability: Optional[int] = None, + thresholds: Optional[DiagnosticThresholds] = None, +) -> DiagnosticReport: + """Run the full protocol on `model` and return a DiagnosticReport. + + Args: + model: trained network. Must satisfy the duck-typed interface + described at the top of `protocol.py`. + eval_batches: ≥4 disjoint minibatches `(x, y)`, on the model's device. + headline_acc: the accuracy you would otherwise have reported as the + method's result on this network. + frozen_baseline_acc: accuracy of an architecture-matched baseline + with the deep blocks **frozen at random initialization** and only + embed/LN/head trained. If None, the (d)-diagnostic is skipped. + See `protocol/examples/frozen_baseline.py` for a reference impl. + method_name: free-form label, only used in the printed report. + notes: free-form architecture/dataset description for the report. + layer_for_stability: which hidden layer index to use for the cross- + batch stability check. Defaults to the middle layer. + thresholds: degeneracy thresholds; defaults to those used in our + paper (see DiagnosticThresholds for details). + + Returns: + A DiagnosticReport summarizing each diagnostic and its verdict. + """ + if thresholds is None: + thresholds = DiagnosticThresholds() + if not eval_batches: + raise ValueError("eval_batches must not be empty") + x0, y0 = eval_batches[0] + + h_norms = residual_norms(model, x0) + g_norms = bp_grad_norms(model, x0, y0) + + if layer_for_stability is None: + layer_for_stability = max(1, len(h_norms) // 2) + stability = cross_batch_direction_stability( + model, eval_batches, layer_index=layer_for_stability + ) + + return DiagnosticReport( + method_name=method_name, + notes=notes, + residual_norms=h_norms, + bp_grad_norms=g_norms, + stability_layer=layer_for_stability, + cross_batch_stability=stability, + headline_acc=headline_acc, + frozen_baseline_acc=frozen_baseline_acc, + thresholds=thresholds, + ) |
