diff options
Diffstat (limited to 'protocol/fa_protocol.py')
| -rw-r--r-- | protocol/fa_protocol.py | 215 |
1 files changed, 215 insertions, 0 deletions
diff --git a/protocol/fa_protocol.py b/protocol/fa_protocol.py new file mode 100644 index 0000000..8d12939 --- /dev/null +++ b/protocol/fa_protocol.py @@ -0,0 +1,215 @@ +""" +Reference implementation of the three-diagnostic FA evaluation protocol. + +Usage: + from protocol.fa_protocol import FAProtocol + + protocol = FAProtocol(model, x_eval, y_eval) + report = protocol.run(frozen_baseline_acc=0.349) + print(report['verdict']) +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from typing import Optional + + +class FAProtocol: + """Three-diagnostic evaluation protocol for feedback alignment methods. + + Diagnostics: + D1 (Scale stability): max per-block residual growth rho = max_l ||h_{l+1}|| / ||h_l||. + Flags if rho > threshold (default 50). + D2 (Reference validity): BP gradient norm at the deepest hidden state. + Flags if ||g_L|| < 10 * eps, where eps is the cosine clamp floor. + D3 (Depth utility): test accuracy vs frozen-blocks baseline. + Flags if trained acc < frozen_acc + margin (default 2 pp). + + The protocol requires: + - A trained model with .blocks (nn.ModuleList) and forward(x, return_hidden=True) + - A test batch (x_eval, y_eval) + - A frozen-blocks baseline accuracy (must be computed separately) + """ + + def __init__( + self, + model: nn.Module, + x_eval: torch.Tensor, + y_eval: torch.Tensor, + d1_threshold: float = 50.0, + d2_eps: float = 1e-8, + d2_factor: float = 10.0, + d3_margin_pp: float = 2.0, + ): + self.model = model + self.x_eval = x_eval + self.y_eval = y_eval + self.d1_threshold = d1_threshold + self.d2_floor = d2_factor * d2_eps + self.d3_margin = d3_margin_pp / 100.0 + + def _compute_hidden_norms(self, hiddens): + """Compute median per-sample L2 norm at each hidden layer.""" + norms = [] + for h in hiddens: + if h.dim() == 4: # conv: (B, C, H, W) -> pool to (B, C) + h_flat = F.adaptive_avg_pool2d(h, 1).flatten(1) + elif h.dim() == 3: # transformer: (B, T, D) -> cls token + h_flat = h[:, 0] + else: + h_flat = h + norms.append(float(h_flat.norm(dim=-1).median().item())) + return norms + + def _compute_bp_grad_norms(self, hiddens): + """Compute BP gradient norms at each hidden layer via manual forward.""" + model = self.model + L = len(hiddens) - 1 # number of blocks + + # Rebuild forward from hidden states with grad tracking + hs = [hiddens[0].detach().clone().requires_grad_(True)] + for i, block in enumerate(model.blocks): + if hasattr(block, 'forward'): + h_next = block(hs[-1]) + # Check if block includes residual (output same shape, skip connection) + if h_next.shape == hs[-1].shape and not self._block_has_internal_skip(block): + h_next = hs[-1] + h_next + hs.append(h_next) + + # Forward through head + h_final = hs[-1] + if h_final.dim() == 4: # conv + h_final = F.adaptive_avg_pool2d(h_final, 1).flatten(1) + elif h_final.dim() == 3: # transformer cls token + h_final = h_final[:, 0] + if hasattr(model, 'out_ln'): + h_final = model.out_ln(h_final) + logits = model.out_head(h_final) + loss = F.cross_entropy(logits, self.y_eval) + grads = torch.autograd.grad(loss, hs, allow_unused=True) + + norms = [] + for g in grads: + if g is None: + norms.append(0.0) + continue + if g.dim() == 4: + g_flat = F.adaptive_avg_pool2d(g, 1).flatten(1) + elif g.dim() == 3: + g_flat = g[:, 0] + else: + g_flat = g + norms.append(float(g_flat.norm(dim=-1).median().item())) + return norms + + @staticmethod + def _block_has_internal_skip(block): + """Heuristic: check if the block's forward already includes a residual skip.""" + src = type(block).forward.__qualname__ + # Blocks that compute x + f(x) internally (e.g., transformer blocks) + return False # conservative default; override if needed + + def run(self, frozen_baseline_acc: Optional[float] = None, test_acc: Optional[float] = None): + """Run all three diagnostics. + + Args: + frozen_baseline_acc: accuracy of the frozen-blocks baseline (required for D3). + test_acc: test accuracy of the trained model. If None, computed from x_eval/y_eval. + + Returns: + dict with 'diagnostics', 'verdict', and raw values. + """ + self.model.eval() + + # Forward pass to get hidden states + with torch.no_grad(): + logits, hiddens = self.model(self.x_eval, return_hidden=True) + + if test_acc is None: + test_acc = float((logits.argmax(-1) == self.y_eval).float().mean().item()) + + # D1: Scale stability + h_norms = self._compute_hidden_norms(hiddens) + growth_ratios = [h_norms[i+1] / max(h_norms[i], 1e-12) + for i in range(len(h_norms) - 1)] + max_growth = max(growth_ratios) if growth_ratios else 1.0 + d1_fires = max_growth > self.d1_threshold + + # D2: Reference validity + bp_grad_norms = self._compute_bp_grad_norms(hiddens) + g_L = bp_grad_norms[-1] if bp_grad_norms else 0.0 + d2_fires = g_L < self.d2_floor + + # D3: Depth utility + if frozen_baseline_acc is not None: + margin = test_acc - frozen_baseline_acc + d3_fires = margin < self.d3_margin + else: + margin = None + d3_fires = None + + # Verdict + mode1 = d1_fires and d2_fires + flags = [] + if d1_fires: + flags.append('D1') + if d2_fires: + flags.append('D2') + if d3_fires: + flags.append('D3') + + if not flags: + verdict = 'PASS' + else: + verdict = 'FAIL(' + '+'.join(flags) + ')' + + return { + 'verdict': verdict, + 'test_acc': test_acc, + 'diagnostics': { + 'D1_scale_growth': { + 'max_growth': max_growth, + 'per_block_growth': growth_ratios, + 'hidden_norms': h_norms, + 'threshold': self.d1_threshold, + 'fires': d1_fires, + }, + 'D2_ref_validity': { + 'g_L': g_L, + 'bp_grad_norms': bp_grad_norms, + 'floor': self.d2_floor, + 'fires': d2_fires, + }, + 'D3_depth_utility': { + 'test_acc': test_acc, + 'frozen_baseline_acc': frozen_baseline_acc, + 'margin': margin, + 'margin_threshold': self.d3_margin, + 'fires': d3_fires, + }, + }, + } + + def summary(self, report: dict) -> str: + """Human-readable summary of a protocol report.""" + d = report['diagnostics'] + lines = [ + f"Verdict: {report['verdict']}", + f"Test accuracy: {report['test_acc']:.4f}", + f"D1 Scale stability: max growth = {d['D1_scale_growth']['max_growth']:.1f}x " + f"(threshold {d['D1_scale_growth']['threshold']}x) -> " + f"{'FIRE' if d['D1_scale_growth']['fires'] else 'pass'}", + f"D2 Reference validity: ||g_L|| = {d['D2_ref_validity']['g_L']:.2e} " + f"(floor {d['D2_ref_validity']['floor']:.0e}) -> " + f"{'FIRE' if d['D2_ref_validity']['fires'] else 'pass'}", + ] + if d['D3_depth_utility']['fires'] is not None: + lines.append( + f"D3 Depth utility: margin = {d['D3_depth_utility']['margin']*100:+.1f} pp " + f"(threshold {d['D3_depth_utility']['margin_threshold']*100:.0f} pp) -> " + f"{'FIRE' if d['D3_depth_utility']['fires'] else 'pass'}" + ) + else: + lines.append("D3 Depth utility: not evaluated (no frozen baseline provided)") + return '\n'.join(lines) |
