summaryrefslogtreecommitdiff
path: root/protocol/protocol.py
diff options
context:
space:
mode:
Diffstat (limited to 'protocol/protocol.py')
-rw-r--r--protocol/protocol.py290
1 files changed, 290 insertions, 0 deletions
diff --git a/protocol/protocol.py b/protocol/protocol.py
new file mode 100644
index 0000000..4298756
--- /dev/null
+++ b/protocol/protocol.py
@@ -0,0 +1,290 @@
+"""
+Core diagnostics for the FA evaluation protocol.
+
+The four diagnostics:
+
+ (a) residual_norms — per-layer ||h_l||_2 on a fixed eval batch
+ (b) bp_grad_norms — per-layer ||∂CE/∂h_l||_2 on a fixed eval batch
+ (c) cross_batch_direction_stability — cosine of normalized BP-grad direction
+ across disjoint minibatches; high values
+ indicate the reference vector is dominated by
+ a sample-invariant global drift, NOT per-sample
+ credit
+ (d) frozen-blocks accuracy — caller-supplied; the protocol just compares it
+ to the headline accuracy of the trained network
+
+The protocol expects a model exposing a `return_hidden=True` mode that yields
+the per-layer residual stream as a list `[h_0, h_1, ..., h_L]`, plus the
+standard `out_ln` + `out_head` terminal stack. ResidualMLP and ViT-Mini in
+this repo both satisfy that contract; see the docstring of `diagnose` for the
+duck-typed interface.
+
+All metrics use `tensor.norm(dim=-1)` (L2 along the feature dim) — NEVER
+`tensor.norm(-1)` (which is L_{-1} of the entire tensor and is the bug behind
+several walk-backs in our project).
+"""
+from __future__ import annotations
+
+from typing import List, Sequence, Tuple, Optional
+
+import torch
+import torch.nn.functional as F
+
+from .report import DiagnosticReport, DiagnosticThresholds
+
+
+# --------------------------------------------------------------------------- #
+# Diagnostic (a): residual stream norm
+# --------------------------------------------------------------------------- #
+
+def residual_norms(model, x: torch.Tensor) -> List[float]:
+ """Per-layer median ||h_l||_2 on the eval batch.
+
+ Args:
+ model: must support `model(x, return_hidden=True) -> (logits, hiddens)`
+ where hiddens is `[h_0, ..., h_L]` (each (B, ..., d_hidden)).
+ x: eval batch (already on the correct device, already in the input
+ shape the model wants).
+
+ Returns:
+ list of L+1 floats — the median (over batch) of ||h_l||_2 per layer.
+ """
+ model.eval()
+ with torch.no_grad():
+ _, hiddens = model(x, return_hidden=True)
+ out: List[float] = []
+ for h in hiddens:
+ # If h has token dim (B, T, d), pool tokens to a single vector per
+ # sample by taking the cls token (index 0) — that's what the head
+ # actually consumes in our ViT.
+ if h.dim() == 3:
+ h = h[:, 0, :]
+ # NOTE: dim=-1 is mandatory; `.norm(-1)` would compute L_{-1} of the
+ # whole tensor and is wrong (see project_norm_minus1_footgun).
+ out.append(h.norm(dim=-1).median().item())
+ return out
+
+
+# --------------------------------------------------------------------------- #
+# Diagnostic (b): BP gradient norm at hidden layers
+# --------------------------------------------------------------------------- #
+
+def bp_grad_norms(
+ model,
+ x: torch.Tensor,
+ y: torch.Tensor,
+ loss_fn=F.cross_entropy,
+) -> List[float]:
+ """Per-layer median ||∂L/∂h_l||_2 on the eval batch.
+
+ The BP gradient at hidden layers is the *reference vector* that offline
+ alignment metrics (Γ) compare against. When this is at the numerical
+ floor (~1e-10) on modern pre-LN residual nets, Γ measures alignment to a
+ noise vector and is no longer interpretable as "credit assignment
+ quality".
+
+ Args:
+ model: same contract as for `residual_norms`.
+ x: eval batch.
+ y: labels for the eval batch.
+ loss_fn: scalar loss given (logits, y); cross_entropy by default.
+
+ Returns:
+ list of L+1 floats — the median (over batch) of ||g_l||_2 per layer,
+ where g_l = ∂loss/∂h_l.
+ """
+ model.eval()
+ # Forward through the canonical path (embed -> blocks with explicit
+ # residual addition -> terminal LN -> head), keeping every intermediate
+ # h_l as a non-leaf in the autograd graph. `torch.autograd.grad` accepts
+ # non-leaf tensors as `inputs` and computes ∂L/∂h_l via the chain rule.
+ # NOTE: do not clone+detach the hidden states — that severs the graph.
+ with torch.enable_grad():
+ h = _embed(model, x)
+ hiddens: List[torch.Tensor] = [h]
+ for block in _iter_blocks(model):
+ f = block(h)
+ h = h + f
+ hiddens.append(h)
+ logits = _terminal(model, h)
+ loss = loss_fn(logits, y)
+ grads = torch.autograd.grad(loss, hiddens)
+ out: List[float] = []
+ for g in grads:
+ if g.dim() == 3:
+ g = g[:, 0, :]
+ out.append(g.norm(dim=-1).median().item())
+ return out
+
+
+def _embed(model, x):
+ if hasattr(model, "embed"):
+ return model.embed(x)
+ if hasattr(model, "patch_embed"):
+ # ViT-style
+ return model.patch_embed(x)
+ raise AttributeError(
+ "model must expose `.embed(x)` or `.patch_embed(x)` for the protocol"
+ )
+
+
+def _iter_blocks(model):
+ if hasattr(model, "blocks"):
+ return list(model.blocks)
+ raise AttributeError("model must expose `.blocks` (a Module list)")
+
+
+def _terminal(model, h):
+ """Apply terminal LN (if any) + classification head to hidden state."""
+ if hasattr(model, "out_ln") and hasattr(model, "out_head"):
+ return model.out_head(model.out_ln(h))
+ if hasattr(model, "norm") and hasattr(model, "head"):
+ # ViT-style naming
+ h = model.norm(h)
+ if h.dim() == 3:
+ h = h[:, 0, :]
+ return model.head(h)
+ raise AttributeError(
+ "model must expose `(out_ln, out_head)` or `(norm, head)` for the protocol"
+ )
+
+
+# --------------------------------------------------------------------------- #
+# Diagnostic (c): cross-batch direction stability
+# --------------------------------------------------------------------------- #
+
+def cross_batch_direction_stability(
+ model,
+ eval_batches: Sequence[Tuple[torch.Tensor, torch.Tensor]],
+ layer_index: int,
+ loss_fn=F.cross_entropy,
+) -> float:
+ """Mean pairwise cosine of the BP-grad *direction* at layer `layer_index`
+ across disjoint minibatches.
+
+ A value near 1.0 means: the BP gradient direction at that hidden layer is
+ the same regardless of which samples are in the batch. That is the
+ fingerprint of a **sample-invariant global drift** — i.e. the reference
+ vector against which Γ is computed is NOT per-sample credit, it is a
+ constant artifact of the trained network's geometry. On healthy
+ BP-trained or EP-trained networks this value is small (~0.05-0.18); on
+ DFA/SB/CB pre-LN ResMLPs we see ~0.5-0.99.
+
+ Args:
+ model: same contract as the other diagnostics.
+ eval_batches: list of (x, y) pairs, each a *disjoint* minibatch on the
+ same device. Use ≥8 batches of ~128 samples each. The metric is
+ calibrated for this batch size: smaller batches let per-sample
+ noise dominate the batch-mean direction on healthy networks
+ (~0.10), giving the cleanest separation from drift-dominated
+ failure modes (~0.95). Larger batches average out per-sample
+ noise via CLT and shrink the gap.
+ layer_index: which layer's BP grad direction to compare.
+ loss_fn: scalar loss; cross_entropy by default.
+
+ Returns:
+ mean pairwise cosine in [-1, 1].
+ """
+ if len(eval_batches) < 2:
+ raise ValueError("need ≥2 eval batches to measure stability")
+ directions: List[torch.Tensor] = []
+ for x, y in eval_batches:
+ with torch.enable_grad():
+ h = _embed(model, x)
+ hiddens: List[torch.Tensor] = [h]
+ for block in _iter_blocks(model):
+ f = block(h)
+ h = h + f
+ hiddens.append(h)
+ logits = _terminal(model, h)
+ loss = loss_fn(logits, y)
+ grads = torch.autograd.grad(loss, hiddens)
+ g = grads[layer_index]
+ if g.dim() == 3:
+ g = g[:, 0, :]
+ # Reduce to a single direction per batch by averaging across samples,
+ # then normalizing. This is the natural "direction the gradient is
+ # pointing on this batch" signal.
+ gbar = g.mean(dim=0)
+ n = gbar.norm()
+ if n < 1e-30:
+ directions.append(torch.zeros_like(gbar))
+ else:
+ directions.append(gbar / n)
+ # Mean pairwise cosine
+ cos_vals: List[float] = []
+ for i in range(len(directions)):
+ for j in range(i + 1, len(directions)):
+ cos_vals.append(
+ float(torch.dot(directions[i], directions[j]).item())
+ )
+ return sum(cos_vals) / len(cos_vals)
+
+
+# --------------------------------------------------------------------------- #
+# Diagnostic (d) is just a comparison: caller supplies frozen_baseline_acc
+# --------------------------------------------------------------------------- #
+
+
+# --------------------------------------------------------------------------- #
+# Top-level convenience: run all four and produce a verdict
+# --------------------------------------------------------------------------- #
+
+def diagnose(
+ model,
+ eval_batches: Sequence[Tuple[torch.Tensor, torch.Tensor]],
+ headline_acc: float,
+ frozen_baseline_acc: Optional[float] = None,
+ method_name: str = "method",
+ notes: str = "",
+ layer_for_stability: Optional[int] = None,
+ thresholds: Optional[DiagnosticThresholds] = None,
+) -> DiagnosticReport:
+ """Run the full protocol on `model` and return a DiagnosticReport.
+
+ Args:
+ model: trained network. Must satisfy the duck-typed interface
+ described at the top of `protocol.py`.
+ eval_batches: ≥4 disjoint minibatches `(x, y)`, on the model's device.
+ headline_acc: the accuracy you would otherwise have reported as the
+ method's result on this network.
+ frozen_baseline_acc: accuracy of an architecture-matched baseline
+ with the deep blocks **frozen at random initialization** and only
+ embed/LN/head trained. If None, the (d)-diagnostic is skipped.
+ See `protocol/examples/frozen_baseline.py` for a reference impl.
+ method_name: free-form label, only used in the printed report.
+ notes: free-form architecture/dataset description for the report.
+ layer_for_stability: which hidden layer index to use for the cross-
+ batch stability check. Defaults to the middle layer.
+ thresholds: degeneracy thresholds; defaults to those used in our
+ paper (see DiagnosticThresholds for details).
+
+ Returns:
+ A DiagnosticReport summarizing each diagnostic and its verdict.
+ """
+ if thresholds is None:
+ thresholds = DiagnosticThresholds()
+ if not eval_batches:
+ raise ValueError("eval_batches must not be empty")
+ x0, y0 = eval_batches[0]
+
+ h_norms = residual_norms(model, x0)
+ g_norms = bp_grad_norms(model, x0, y0)
+
+ if layer_for_stability is None:
+ layer_for_stability = max(1, len(h_norms) // 2)
+ stability = cross_batch_direction_stability(
+ model, eval_batches, layer_index=layer_for_stability
+ )
+
+ return DiagnosticReport(
+ method_name=method_name,
+ notes=notes,
+ residual_norms=h_norms,
+ bp_grad_norms=g_norms,
+ stability_layer=layer_for_stability,
+ cross_batch_stability=stability,
+ headline_acc=headline_acc,
+ frozen_baseline_acc=frozen_baseline_acc,
+ thresholds=thresholds,
+ )