diff options
Diffstat (limited to 'protocol')
| -rw-r--r-- | protocol/__init__.py | 0 | ||||
| -rw-r--r-- | protocol/example_usage.py | 106 | ||||
| -rw-r--r-- | protocol/fa_protocol.py | 215 |
3 files changed, 321 insertions, 0 deletions
diff --git a/protocol/__init__.py b/protocol/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/protocol/__init__.py diff --git a/protocol/example_usage.py b/protocol/example_usage.py new file mode 100644 index 0000000..2b2c65e --- /dev/null +++ b/protocol/example_usage.py @@ -0,0 +1,106 @@ +""" +Minimal example: apply the FA evaluation protocol to a DFA-trained ResMLP. + +This script trains a model with DFA, then runs the three-diagnostic protocol. +Expected output: FAIL(D1+D2+D3) — DFA on terminal-LN ResMLP triggers all diagnostics. +""" +import sys, os +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torchvision +import torchvision.transforms as transforms +import numpy as np + +from models.residual_mlp import ResidualMLP +from protocol.fa_protocol import FAProtocol + + +def get_cifar10(batch_size=128): + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tv_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) + te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) + return (torch.utils.data.DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), + torch.utils.data.DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2)) + + +def train_dfa(model, train_loader, device, epochs=30): + """Minimal DFA training (canonical: no clipping, mean reduction).""" + d = model.d_hidden + L = model.num_blocks + C = 10 + Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] + block_opts = [optim.AdamW(b.parameters(), lr=1e-3, weight_decay=0.01) for b in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=1e-3, weight_decay=0.01) + head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=1e-3, weight_decay=0.01) + for ep in range(1, epochs + 1): + model.train() + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device); y = y.to(device) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1 + hL = hiddens[-1].detach() + head_opt.zero_grad() + F.cross_entropy(model.out_head(model.out_ln(hL)), y).backward() + head_opt.step() + for l in range(L): + a = (e_T @ Bs[l].T).detach() + rms = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + f_l = model.blocks[l](hiddens[l].detach()) + loss = (f_l * (a / rms)).sum(-1).mean() + block_opts[l].zero_grad(); loss.backward(); block_opts[l].step() + a0 = (e_T @ Bs[0].T).detach() + rms0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + h0 = model.embed(x) + embed_opt.zero_grad(); (h0 * (a0 / rms0)).sum(-1).mean().backward(); embed_opt.step() + if ep % 10 == 0: + print(f" DFA ep {ep}/{epochs}", flush=True) + + +def main(): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + seed = 42 + torch.manual_seed(seed); np.random.seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + print("Loading CIFAR-10...") + train_loader, test_loader = get_cifar10() + + # Prepare eval buffer + xs, ys = [], [] + for x, y in test_loader: + xs.append(x.view(x.size(0), -1)); ys.append(y) + if sum(xb.size(0) for xb in xs) >= 128: + break + x_eval = torch.cat(xs)[:128].to(device) + y_eval = torch.cat(ys)[:128].to(device) + + # Train with DFA + print("Training DFA (30 epochs)...") + model = ResidualMLP(3072, 256, 10, 4).to(device) + train_dfa(model, train_loader, device, epochs=30) + + # Run protocol + print("\nRunning protocol...") + protocol = FAProtocol(model, x_eval, y_eval) + report = protocol.run(frozen_baseline_acc=0.349) + print(protocol.summary(report)) + + +if __name__ == '__main__': + main() diff --git a/protocol/fa_protocol.py b/protocol/fa_protocol.py new file mode 100644 index 0000000..8d12939 --- /dev/null +++ b/protocol/fa_protocol.py @@ -0,0 +1,215 @@ +""" +Reference implementation of the three-diagnostic FA evaluation protocol. + +Usage: + from protocol.fa_protocol import FAProtocol + + protocol = FAProtocol(model, x_eval, y_eval) + report = protocol.run(frozen_baseline_acc=0.349) + print(report['verdict']) +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from typing import Optional + + +class FAProtocol: + """Three-diagnostic evaluation protocol for feedback alignment methods. + + Diagnostics: + D1 (Scale stability): max per-block residual growth rho = max_l ||h_{l+1}|| / ||h_l||. + Flags if rho > threshold (default 50). + D2 (Reference validity): BP gradient norm at the deepest hidden state. + Flags if ||g_L|| < 10 * eps, where eps is the cosine clamp floor. + D3 (Depth utility): test accuracy vs frozen-blocks baseline. + Flags if trained acc < frozen_acc + margin (default 2 pp). + + The protocol requires: + - A trained model with .blocks (nn.ModuleList) and forward(x, return_hidden=True) + - A test batch (x_eval, y_eval) + - A frozen-blocks baseline accuracy (must be computed separately) + """ + + def __init__( + self, + model: nn.Module, + x_eval: torch.Tensor, + y_eval: torch.Tensor, + d1_threshold: float = 50.0, + d2_eps: float = 1e-8, + d2_factor: float = 10.0, + d3_margin_pp: float = 2.0, + ): + self.model = model + self.x_eval = x_eval + self.y_eval = y_eval + self.d1_threshold = d1_threshold + self.d2_floor = d2_factor * d2_eps + self.d3_margin = d3_margin_pp / 100.0 + + def _compute_hidden_norms(self, hiddens): + """Compute median per-sample L2 norm at each hidden layer.""" + norms = [] + for h in hiddens: + if h.dim() == 4: # conv: (B, C, H, W) -> pool to (B, C) + h_flat = F.adaptive_avg_pool2d(h, 1).flatten(1) + elif h.dim() == 3: # transformer: (B, T, D) -> cls token + h_flat = h[:, 0] + else: + h_flat = h + norms.append(float(h_flat.norm(dim=-1).median().item())) + return norms + + def _compute_bp_grad_norms(self, hiddens): + """Compute BP gradient norms at each hidden layer via manual forward.""" + model = self.model + L = len(hiddens) - 1 # number of blocks + + # Rebuild forward from hidden states with grad tracking + hs = [hiddens[0].detach().clone().requires_grad_(True)] + for i, block in enumerate(model.blocks): + if hasattr(block, 'forward'): + h_next = block(hs[-1]) + # Check if block includes residual (output same shape, skip connection) + if h_next.shape == hs[-1].shape and not self._block_has_internal_skip(block): + h_next = hs[-1] + h_next + hs.append(h_next) + + # Forward through head + h_final = hs[-1] + if h_final.dim() == 4: # conv + h_final = F.adaptive_avg_pool2d(h_final, 1).flatten(1) + elif h_final.dim() == 3: # transformer cls token + h_final = h_final[:, 0] + if hasattr(model, 'out_ln'): + h_final = model.out_ln(h_final) + logits = model.out_head(h_final) + loss = F.cross_entropy(logits, self.y_eval) + grads = torch.autograd.grad(loss, hs, allow_unused=True) + + norms = [] + for g in grads: + if g is None: + norms.append(0.0) + continue + if g.dim() == 4: + g_flat = F.adaptive_avg_pool2d(g, 1).flatten(1) + elif g.dim() == 3: + g_flat = g[:, 0] + else: + g_flat = g + norms.append(float(g_flat.norm(dim=-1).median().item())) + return norms + + @staticmethod + def _block_has_internal_skip(block): + """Heuristic: check if the block's forward already includes a residual skip.""" + src = type(block).forward.__qualname__ + # Blocks that compute x + f(x) internally (e.g., transformer blocks) + return False # conservative default; override if needed + + def run(self, frozen_baseline_acc: Optional[float] = None, test_acc: Optional[float] = None): + """Run all three diagnostics. + + Args: + frozen_baseline_acc: accuracy of the frozen-blocks baseline (required for D3). + test_acc: test accuracy of the trained model. If None, computed from x_eval/y_eval. + + Returns: + dict with 'diagnostics', 'verdict', and raw values. + """ + self.model.eval() + + # Forward pass to get hidden states + with torch.no_grad(): + logits, hiddens = self.model(self.x_eval, return_hidden=True) + + if test_acc is None: + test_acc = float((logits.argmax(-1) == self.y_eval).float().mean().item()) + + # D1: Scale stability + h_norms = self._compute_hidden_norms(hiddens) + growth_ratios = [h_norms[i+1] / max(h_norms[i], 1e-12) + for i in range(len(h_norms) - 1)] + max_growth = max(growth_ratios) if growth_ratios else 1.0 + d1_fires = max_growth > self.d1_threshold + + # D2: Reference validity + bp_grad_norms = self._compute_bp_grad_norms(hiddens) + g_L = bp_grad_norms[-1] if bp_grad_norms else 0.0 + d2_fires = g_L < self.d2_floor + + # D3: Depth utility + if frozen_baseline_acc is not None: + margin = test_acc - frozen_baseline_acc + d3_fires = margin < self.d3_margin + else: + margin = None + d3_fires = None + + # Verdict + mode1 = d1_fires and d2_fires + flags = [] + if d1_fires: + flags.append('D1') + if d2_fires: + flags.append('D2') + if d3_fires: + flags.append('D3') + + if not flags: + verdict = 'PASS' + else: + verdict = 'FAIL(' + '+'.join(flags) + ')' + + return { + 'verdict': verdict, + 'test_acc': test_acc, + 'diagnostics': { + 'D1_scale_growth': { + 'max_growth': max_growth, + 'per_block_growth': growth_ratios, + 'hidden_norms': h_norms, + 'threshold': self.d1_threshold, + 'fires': d1_fires, + }, + 'D2_ref_validity': { + 'g_L': g_L, + 'bp_grad_norms': bp_grad_norms, + 'floor': self.d2_floor, + 'fires': d2_fires, + }, + 'D3_depth_utility': { + 'test_acc': test_acc, + 'frozen_baseline_acc': frozen_baseline_acc, + 'margin': margin, + 'margin_threshold': self.d3_margin, + 'fires': d3_fires, + }, + }, + } + + def summary(self, report: dict) -> str: + """Human-readable summary of a protocol report.""" + d = report['diagnostics'] + lines = [ + f"Verdict: {report['verdict']}", + f"Test accuracy: {report['test_acc']:.4f}", + f"D1 Scale stability: max growth = {d['D1_scale_growth']['max_growth']:.1f}x " + f"(threshold {d['D1_scale_growth']['threshold']}x) -> " + f"{'FIRE' if d['D1_scale_growth']['fires'] else 'pass'}", + f"D2 Reference validity: ||g_L|| = {d['D2_ref_validity']['g_L']:.2e} " + f"(floor {d['D2_ref_validity']['floor']:.0e}) -> " + f"{'FIRE' if d['D2_ref_validity']['fires'] else 'pass'}", + ] + if d['D3_depth_utility']['fires'] is not None: + lines.append( + f"D3 Depth utility: margin = {d['D3_depth_utility']['margin']*100:+.1f} pp " + f"(threshold {d['D3_depth_utility']['margin_threshold']*100:.0f} pp) -> " + f"{'FIRE' if d['D3_depth_utility']['fires'] else 'pass'}" + ) + else: + lines.append("D3 Depth utility: not evaluated (no frozen baseline provided)") + return '\n'.join(lines) |
