""" Reference implementation of the three-diagnostic FA evaluation protocol. Usage: from protocol.fa_protocol import FAProtocol protocol = FAProtocol(model, x_eval, y_eval) report = protocol.run(frozen_baseline_acc=0.349) print(report['verdict']) """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from typing import Optional class FAProtocol: """Three-diagnostic evaluation protocol for feedback alignment methods. Diagnostics: D1 (Scale stability): max per-block residual growth rho = max_l ||h_{l+1}|| / ||h_l||. Flags if rho > threshold (default 50). D2 (Reference validity): BP gradient norm at the deepest hidden state. Flags if ||g_L|| < 10 * eps, where eps is the cosine clamp floor. D3 (Depth utility): test accuracy vs frozen-blocks baseline. Flags if trained acc < frozen_acc + margin (default 2 pp). The protocol requires: - A trained model with .blocks (nn.ModuleList) and forward(x, return_hidden=True) - A test batch (x_eval, y_eval) - A frozen-blocks baseline accuracy (must be computed separately) """ def __init__( self, model: nn.Module, x_eval: torch.Tensor, y_eval: torch.Tensor, d1_threshold: float = 50.0, d2_eps: float = 1e-8, d2_factor: float = 10.0, d3_margin_pp: float = 2.0, ): self.model = model self.x_eval = x_eval self.y_eval = y_eval self.d1_threshold = d1_threshold self.d2_floor = d2_factor * d2_eps self.d3_margin = d3_margin_pp / 100.0 def _compute_hidden_norms(self, hiddens): """Compute median per-sample L2 norm at each hidden layer.""" norms = [] for h in hiddens: if h.dim() == 4: # conv: (B, C, H, W) -> pool to (B, C) h_flat = F.adaptive_avg_pool2d(h, 1).flatten(1) elif h.dim() == 3: # transformer: (B, T, D) -> cls token h_flat = h[:, 0] else: h_flat = h norms.append(float(h_flat.norm(dim=-1).median().item())) return norms def _compute_bp_grad_norms(self, hiddens): """Compute BP gradient norms at each hidden layer via manual forward.""" model = self.model L = len(hiddens) - 1 # number of blocks # Rebuild forward from hidden states with grad tracking hs = [hiddens[0].detach().clone().requires_grad_(True)] for i, block in enumerate(model.blocks): if hasattr(block, 'forward'): h_next = block(hs[-1]) # Check if block includes residual (output same shape, skip connection) if h_next.shape == hs[-1].shape and not self._block_has_internal_skip(block): h_next = hs[-1] + h_next hs.append(h_next) # Forward through head h_final = hs[-1] if h_final.dim() == 4: # conv h_final = F.adaptive_avg_pool2d(h_final, 1).flatten(1) elif h_final.dim() == 3: # transformer cls token h_final = h_final[:, 0] if hasattr(model, 'out_ln'): h_final = model.out_ln(h_final) logits = model.out_head(h_final) loss = F.cross_entropy(logits, self.y_eval) grads = torch.autograd.grad(loss, hs, allow_unused=True) norms = [] for g in grads: if g is None: norms.append(0.0) continue if g.dim() == 4: g_flat = F.adaptive_avg_pool2d(g, 1).flatten(1) elif g.dim() == 3: g_flat = g[:, 0] else: g_flat = g norms.append(float(g_flat.norm(dim=-1).median().item())) return norms @staticmethod def _block_has_internal_skip(block): """Heuristic: check if the block's forward already includes a residual skip.""" src = type(block).forward.__qualname__ # Blocks that compute x + f(x) internally (e.g., transformer blocks) return False # conservative default; override if needed def run(self, frozen_baseline_acc: Optional[float] = None, test_acc: Optional[float] = None): """Run all three diagnostics. Args: frozen_baseline_acc: accuracy of the frozen-blocks baseline (required for D3). test_acc: test accuracy of the trained model. If None, computed from x_eval/y_eval. Returns: dict with 'diagnostics', 'verdict', and raw values. """ self.model.eval() # Forward pass to get hidden states with torch.no_grad(): logits, hiddens = self.model(self.x_eval, return_hidden=True) if test_acc is None: test_acc = float((logits.argmax(-1) == self.y_eval).float().mean().item()) # D1: Scale stability h_norms = self._compute_hidden_norms(hiddens) growth_ratios = [h_norms[i+1] / max(h_norms[i], 1e-12) for i in range(len(h_norms) - 1)] max_growth = max(growth_ratios) if growth_ratios else 1.0 d1_fires = max_growth > self.d1_threshold # D2: Reference validity bp_grad_norms = self._compute_bp_grad_norms(hiddens) g_L = bp_grad_norms[-1] if bp_grad_norms else 0.0 d2_fires = g_L < self.d2_floor # D3: Depth utility if frozen_baseline_acc is not None: margin = test_acc - frozen_baseline_acc d3_fires = margin < self.d3_margin else: margin = None d3_fires = None # Verdict mode1 = d1_fires and d2_fires flags = [] if d1_fires: flags.append('D1') if d2_fires: flags.append('D2') if d3_fires: flags.append('D3') if not flags: verdict = 'PASS' else: verdict = 'FAIL(' + '+'.join(flags) + ')' return { 'verdict': verdict, 'test_acc': test_acc, 'diagnostics': { 'D1_scale_growth': { 'max_growth': max_growth, 'per_block_growth': growth_ratios, 'hidden_norms': h_norms, 'threshold': self.d1_threshold, 'fires': d1_fires, }, 'D2_ref_validity': { 'g_L': g_L, 'bp_grad_norms': bp_grad_norms, 'floor': self.d2_floor, 'fires': d2_fires, }, 'D3_depth_utility': { 'test_acc': test_acc, 'frozen_baseline_acc': frozen_baseline_acc, 'margin': margin, 'margin_threshold': self.d3_margin, 'fires': d3_fires, }, }, } def summary(self, report: dict) -> str: """Human-readable summary of a protocol report.""" d = report['diagnostics'] lines = [ f"Verdict: {report['verdict']}", f"Test accuracy: {report['test_acc']:.4f}", f"D1 Scale stability: max growth = {d['D1_scale_growth']['max_growth']:.1f}x " f"(threshold {d['D1_scale_growth']['threshold']}x) -> " f"{'FIRE' if d['D1_scale_growth']['fires'] else 'pass'}", f"D2 Reference validity: ||g_L|| = {d['D2_ref_validity']['g_L']:.2e} " f"(floor {d['D2_ref_validity']['floor']:.0e}) -> " f"{'FIRE' if d['D2_ref_validity']['fires'] else 'pass'}", ] if d['D3_depth_utility']['fires'] is not None: lines.append( f"D3 Depth utility: margin = {d['D3_depth_utility']['margin']*100:+.1f} pp " f"(threshold {d['D3_depth_utility']['margin_threshold']*100:.0f} pp) -> " f"{'FIRE' if d['D3_depth_utility']['fires'] else 'pass'}" ) else: lines.append("D3 Depth utility: not evaluated (no frozen baseline provided)") return '\n'.join(lines)