summaryrefslogtreecommitdiff
path: root/protocol
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-05-04 19:50:45 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-05-04 19:50:45 -0500
commitb480d0cdc21f944e4adccf6e81cc939b0450c5e9 (patch)
treef0e6afb5b3d448d1d6c35d9622d22d63073ca9a7 /protocol
Initial submission code: FA evaluation protocol + reproduction scripts
Reference implementation of the three-diagnostic FA evaluation protocol (scale stability, reference validity, depth utility) from the NeurIPS 2026 E&D track paper. Includes models, metrics, and full reproduction pipeline. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'protocol')
-rw-r--r--protocol/__init__.py0
-rw-r--r--protocol/example_usage.py106
-rw-r--r--protocol/fa_protocol.py215
3 files changed, 321 insertions, 0 deletions
diff --git a/protocol/__init__.py b/protocol/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/protocol/__init__.py
diff --git a/protocol/example_usage.py b/protocol/example_usage.py
new file mode 100644
index 0000000..2b2c65e
--- /dev/null
+++ b/protocol/example_usage.py
@@ -0,0 +1,106 @@
+"""
+Minimal example: apply the FA evaluation protocol to a DFA-trained ResMLP.
+
+This script trains a model with DFA, then runs the three-diagnostic protocol.
+Expected output: FAIL(D1+D2+D3) — DFA on terminal-LN ResMLP triggers all diagnostics.
+"""
+import sys, os
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+import torchvision
+import torchvision.transforms as transforms
+import numpy as np
+
+from models.residual_mlp import ResidualMLP
+from protocol.fa_protocol import FAProtocol
+
+
+def get_cifar10(batch_size=128):
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tv_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train)
+ te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv)
+ return (torch.utils.data.DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2),
+ torch.utils.data.DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2))
+
+
+def train_dfa(model, train_loader, device, epochs=30):
+ """Minimal DFA training (canonical: no clipping, mean reduction)."""
+ d = model.d_hidden
+ L = model.num_blocks
+ C = 10
+ Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)]
+ block_opts = [optim.AdamW(b.parameters(), lr=1e-3, weight_decay=0.01) for b in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=1e-3, weight_decay=0.01)
+ head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=1e-3, weight_decay=0.01)
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1
+ hL = hiddens[-1].detach()
+ head_opt.zero_grad()
+ F.cross_entropy(model.out_head(model.out_ln(hL)), y).backward()
+ head_opt.step()
+ for l in range(L):
+ a = (e_T @ Bs[l].T).detach()
+ rms = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ f_l = model.blocks[l](hiddens[l].detach())
+ loss = (f_l * (a / rms)).sum(-1).mean()
+ block_opts[l].zero_grad(); loss.backward(); block_opts[l].step()
+ a0 = (e_T @ Bs[0].T).detach()
+ rms0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ h0 = model.embed(x)
+ embed_opt.zero_grad(); (h0 * (a0 / rms0)).sum(-1).mean().backward(); embed_opt.step()
+ if ep % 10 == 0:
+ print(f" DFA ep {ep}/{epochs}", flush=True)
+
+
+def main():
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ seed = 42
+ torch.manual_seed(seed); np.random.seed(seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(seed)
+
+ print("Loading CIFAR-10...")
+ train_loader, test_loader = get_cifar10()
+
+ # Prepare eval buffer
+ xs, ys = [], []
+ for x, y in test_loader:
+ xs.append(x.view(x.size(0), -1)); ys.append(y)
+ if sum(xb.size(0) for xb in xs) >= 128:
+ break
+ x_eval = torch.cat(xs)[:128].to(device)
+ y_eval = torch.cat(ys)[:128].to(device)
+
+ # Train with DFA
+ print("Training DFA (30 epochs)...")
+ model = ResidualMLP(3072, 256, 10, 4).to(device)
+ train_dfa(model, train_loader, device, epochs=30)
+
+ # Run protocol
+ print("\nRunning protocol...")
+ protocol = FAProtocol(model, x_eval, y_eval)
+ report = protocol.run(frozen_baseline_acc=0.349)
+ print(protocol.summary(report))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/protocol/fa_protocol.py b/protocol/fa_protocol.py
new file mode 100644
index 0000000..8d12939
--- /dev/null
+++ b/protocol/fa_protocol.py
@@ -0,0 +1,215 @@
+"""
+Reference implementation of the three-diagnostic FA evaluation protocol.
+
+Usage:
+ from protocol.fa_protocol import FAProtocol
+
+ protocol = FAProtocol(model, x_eval, y_eval)
+ report = protocol.run(frozen_baseline_acc=0.349)
+ print(report['verdict'])
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from typing import Optional
+
+
+class FAProtocol:
+ """Three-diagnostic evaluation protocol for feedback alignment methods.
+
+ Diagnostics:
+ D1 (Scale stability): max per-block residual growth rho = max_l ||h_{l+1}|| / ||h_l||.
+ Flags if rho > threshold (default 50).
+ D2 (Reference validity): BP gradient norm at the deepest hidden state.
+ Flags if ||g_L|| < 10 * eps, where eps is the cosine clamp floor.
+ D3 (Depth utility): test accuracy vs frozen-blocks baseline.
+ Flags if trained acc < frozen_acc + margin (default 2 pp).
+
+ The protocol requires:
+ - A trained model with .blocks (nn.ModuleList) and forward(x, return_hidden=True)
+ - A test batch (x_eval, y_eval)
+ - A frozen-blocks baseline accuracy (must be computed separately)
+ """
+
+ def __init__(
+ self,
+ model: nn.Module,
+ x_eval: torch.Tensor,
+ y_eval: torch.Tensor,
+ d1_threshold: float = 50.0,
+ d2_eps: float = 1e-8,
+ d2_factor: float = 10.0,
+ d3_margin_pp: float = 2.0,
+ ):
+ self.model = model
+ self.x_eval = x_eval
+ self.y_eval = y_eval
+ self.d1_threshold = d1_threshold
+ self.d2_floor = d2_factor * d2_eps
+ self.d3_margin = d3_margin_pp / 100.0
+
+ def _compute_hidden_norms(self, hiddens):
+ """Compute median per-sample L2 norm at each hidden layer."""
+ norms = []
+ for h in hiddens:
+ if h.dim() == 4: # conv: (B, C, H, W) -> pool to (B, C)
+ h_flat = F.adaptive_avg_pool2d(h, 1).flatten(1)
+ elif h.dim() == 3: # transformer: (B, T, D) -> cls token
+ h_flat = h[:, 0]
+ else:
+ h_flat = h
+ norms.append(float(h_flat.norm(dim=-1).median().item()))
+ return norms
+
+ def _compute_bp_grad_norms(self, hiddens):
+ """Compute BP gradient norms at each hidden layer via manual forward."""
+ model = self.model
+ L = len(hiddens) - 1 # number of blocks
+
+ # Rebuild forward from hidden states with grad tracking
+ hs = [hiddens[0].detach().clone().requires_grad_(True)]
+ for i, block in enumerate(model.blocks):
+ if hasattr(block, 'forward'):
+ h_next = block(hs[-1])
+ # Check if block includes residual (output same shape, skip connection)
+ if h_next.shape == hs[-1].shape and not self._block_has_internal_skip(block):
+ h_next = hs[-1] + h_next
+ hs.append(h_next)
+
+ # Forward through head
+ h_final = hs[-1]
+ if h_final.dim() == 4: # conv
+ h_final = F.adaptive_avg_pool2d(h_final, 1).flatten(1)
+ elif h_final.dim() == 3: # transformer cls token
+ h_final = h_final[:, 0]
+ if hasattr(model, 'out_ln'):
+ h_final = model.out_ln(h_final)
+ logits = model.out_head(h_final)
+ loss = F.cross_entropy(logits, self.y_eval)
+ grads = torch.autograd.grad(loss, hs, allow_unused=True)
+
+ norms = []
+ for g in grads:
+ if g is None:
+ norms.append(0.0)
+ continue
+ if g.dim() == 4:
+ g_flat = F.adaptive_avg_pool2d(g, 1).flatten(1)
+ elif g.dim() == 3:
+ g_flat = g[:, 0]
+ else:
+ g_flat = g
+ norms.append(float(g_flat.norm(dim=-1).median().item()))
+ return norms
+
+ @staticmethod
+ def _block_has_internal_skip(block):
+ """Heuristic: check if the block's forward already includes a residual skip."""
+ src = type(block).forward.__qualname__
+ # Blocks that compute x + f(x) internally (e.g., transformer blocks)
+ return False # conservative default; override if needed
+
+ def run(self, frozen_baseline_acc: Optional[float] = None, test_acc: Optional[float] = None):
+ """Run all three diagnostics.
+
+ Args:
+ frozen_baseline_acc: accuracy of the frozen-blocks baseline (required for D3).
+ test_acc: test accuracy of the trained model. If None, computed from x_eval/y_eval.
+
+ Returns:
+ dict with 'diagnostics', 'verdict', and raw values.
+ """
+ self.model.eval()
+
+ # Forward pass to get hidden states
+ with torch.no_grad():
+ logits, hiddens = self.model(self.x_eval, return_hidden=True)
+
+ if test_acc is None:
+ test_acc = float((logits.argmax(-1) == self.y_eval).float().mean().item())
+
+ # D1: Scale stability
+ h_norms = self._compute_hidden_norms(hiddens)
+ growth_ratios = [h_norms[i+1] / max(h_norms[i], 1e-12)
+ for i in range(len(h_norms) - 1)]
+ max_growth = max(growth_ratios) if growth_ratios else 1.0
+ d1_fires = max_growth > self.d1_threshold
+
+ # D2: Reference validity
+ bp_grad_norms = self._compute_bp_grad_norms(hiddens)
+ g_L = bp_grad_norms[-1] if bp_grad_norms else 0.0
+ d2_fires = g_L < self.d2_floor
+
+ # D3: Depth utility
+ if frozen_baseline_acc is not None:
+ margin = test_acc - frozen_baseline_acc
+ d3_fires = margin < self.d3_margin
+ else:
+ margin = None
+ d3_fires = None
+
+ # Verdict
+ mode1 = d1_fires and d2_fires
+ flags = []
+ if d1_fires:
+ flags.append('D1')
+ if d2_fires:
+ flags.append('D2')
+ if d3_fires:
+ flags.append('D3')
+
+ if not flags:
+ verdict = 'PASS'
+ else:
+ verdict = 'FAIL(' + '+'.join(flags) + ')'
+
+ return {
+ 'verdict': verdict,
+ 'test_acc': test_acc,
+ 'diagnostics': {
+ 'D1_scale_growth': {
+ 'max_growth': max_growth,
+ 'per_block_growth': growth_ratios,
+ 'hidden_norms': h_norms,
+ 'threshold': self.d1_threshold,
+ 'fires': d1_fires,
+ },
+ 'D2_ref_validity': {
+ 'g_L': g_L,
+ 'bp_grad_norms': bp_grad_norms,
+ 'floor': self.d2_floor,
+ 'fires': d2_fires,
+ },
+ 'D3_depth_utility': {
+ 'test_acc': test_acc,
+ 'frozen_baseline_acc': frozen_baseline_acc,
+ 'margin': margin,
+ 'margin_threshold': self.d3_margin,
+ 'fires': d3_fires,
+ },
+ },
+ }
+
+ def summary(self, report: dict) -> str:
+ """Human-readable summary of a protocol report."""
+ d = report['diagnostics']
+ lines = [
+ f"Verdict: {report['verdict']}",
+ f"Test accuracy: {report['test_acc']:.4f}",
+ f"D1 Scale stability: max growth = {d['D1_scale_growth']['max_growth']:.1f}x "
+ f"(threshold {d['D1_scale_growth']['threshold']}x) -> "
+ f"{'FIRE' if d['D1_scale_growth']['fires'] else 'pass'}",
+ f"D2 Reference validity: ||g_L|| = {d['D2_ref_validity']['g_L']:.2e} "
+ f"(floor {d['D2_ref_validity']['floor']:.0e}) -> "
+ f"{'FIRE' if d['D2_ref_validity']['fires'] else 'pass'}",
+ ]
+ if d['D3_depth_utility']['fires'] is not None:
+ lines.append(
+ f"D3 Depth utility: margin = {d['D3_depth_utility']['margin']*100:+.1f} pp "
+ f"(threshold {d['D3_depth_utility']['margin_threshold']*100:.0f} pp) -> "
+ f"{'FIRE' if d['D3_depth_utility']['fires'] else 'pass'}"
+ )
+ else:
+ lines.append("D3 Depth utility: not evaluated (no frozen baseline provided)")
+ return '\n'.join(lines)