From b480d0cdc21f944e4adccf6e81cc939b0450c5e9 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Mon, 4 May 2026 19:50:45 -0500 Subject: Initial submission code: FA evaluation protocol + reproduction scripts Reference implementation of the three-diagnostic FA evaluation protocol (scale stability, reference validity, depth utility) from the NeurIPS 2026 E&D track paper. Includes models, metrics, and full reproduction pipeline. Co-Authored-By: Claude Opus 4.6 (1M context) --- README.md | 81 ++++++++++ metrics/__init__.py | 0 metrics/credit_metrics.py | 156 ++++++++++++++++++ models/__init__.py | 0 models/residual_mlp.py | 75 +++++++++ models/small_resnet.py | 74 +++++++++ models/vit_mini.py | 109 +++++++++++++ protocol/__init__.py | 0 protocol/example_usage.py | 106 ++++++++++++ protocol/fa_protocol.py | 215 +++++++++++++++++++++++++ reproduce/__init__.py | 0 reproduce/frozen_baseline.py | 86 ++++++++++ reproduce/penalty_sweep.py | 176 ++++++++++++++++++++ reproduce/run_all.sh | 85 ++++++++++ reproduce/train_methods.py | 376 +++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 4 + 16 files changed, 1543 insertions(+) create mode 100644 README.md create mode 100644 metrics/__init__.py create mode 100644 metrics/credit_metrics.py create mode 100644 models/__init__.py create mode 100644 models/residual_mlp.py create mode 100644 models/small_resnet.py create mode 100644 models/vit_mini.py create mode 100644 protocol/__init__.py create mode 100644 protocol/example_usage.py create mode 100644 protocol/fa_protocol.py create mode 100644 reproduce/__init__.py create mode 100644 reproduce/frozen_baseline.py create mode 100644 reproduce/penalty_sweep.py create mode 100755 reproduce/run_all.sh create mode 100644 reproduce/train_methods.py create mode 100644 requirements.txt diff --git a/README.md b/README.md new file mode 100644 index 0000000..3345986 --- /dev/null +++ b/README.md @@ -0,0 +1,81 @@ +# What Accuracy and Gradient Cosine Miss: Evaluating Feedback Alignment via Scale Stability, Reference Validity, and Depth Utility + +Code for the NeurIPS 2026 Evaluations & Datasets Track submission. + +## Structure + +``` +submission_code/ + protocol/ # Reference evaluation artifact (the main deliverable) + fa_protocol.py # Three-diagnostic protocol: drop-in evaluator + example_usage.py # Minimal example applying protocol to a trained model + models/ # Architectures used in the paper + residual_mlp.py # Pre-LayerNorm ResMLP (primary audit architecture) + vit_mini.py # ViT-Mini (transformer with terminal LN) + small_resnet.py # SmallResNet (BatchNorm, no LN) + metrics/ + credit_metrics.py # Cosine similarity, nudging test, perturbation correlation + reproduce/ # Scripts to reproduce all paper results + train_methods.py # Train BP/FA/DFA on any architecture + compute diagnostics + frozen_baseline.py # Frozen-blocks and shallow baselines + penalty_sweep.py # Penalty intervention (lambda sweep + fresh-B null) + run_all.sh # One-command full reproduction + requirements.txt +``` + +## Quick start: applying the protocol to your own FA method + +```python +from protocol.fa_protocol import FAProtocol + +# Your trained model and a test batch +model = ... # nn.Module with .blocks, .out_head, .out_ln (or without) +x_eval, y_eval = ... # (N, d_input), (N,) + +protocol = FAProtocol(model, x_eval, y_eval) +report = protocol.run() + +print(report['verdict']) # 'PASS' or 'FAIL(D1+D2)' etc. +print(report['diagnostics']) # {'D1_scale_growth': ..., 'D2_ref_validity': ..., 'D3_depth_utility': ...} +``` + +## Reproducing paper results + +```bash +# Full reproduction (all figures, tables, appendices) +cd reproduce/ +bash run_all.sh --gpu 0 + +# Or individual experiments: +python train_methods.py --arch resmlp --methods bp fa dfa --seeds 42 123 456 --epochs 100 +python frozen_baseline.py --arch resmlp --seeds 42 123 456 --epochs 100 +python penalty_sweep.py --lambdas 0 1e-4 1e-2 --seeds 42 123 456 --epochs 30 +``` + +## Requirements + +- Python >= 3.10 +- PyTorch >= 2.0 +- torchvision +- numpy +- scipy (for perturbation correlation) + +## Protocol specification + +The protocol consists of three diagnostic checks applied to a trained model: + +| Diagnostic | What it measures | Threshold | Fires when | +|---|---|---|---| +| D1: Scale stability | max per-block residual growth | > 50x | Residual stream explodes | +| D2: Reference validity | deepest-layer BP gradient norm | < 10 * eps | BP reference collapses below cosine clamp | +| D3: Depth utility | test accuracy vs frozen-blocks baseline | < 2 pp above | Trained blocks do not outperform random | + +**Verdict logic**: A run fails the protocol if either Mode 1 (D1 and D2 both flag) or Depth Utility (D3 flags) is triggered. Passing all three diagnostics does not certify the method as effective; it rules out the specific class of silent failures the audit revealed. + +## Training inventory + +The paper reports 125 training runs across 5+ architectures, 10 training methods, and 17 experimental settings. Total estimated GPU time: ~12 hours on a single A100/A6000. + +## License + +MIT diff --git a/metrics/__init__.py b/metrics/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/metrics/credit_metrics.py b/metrics/credit_metrics.py new file mode 100644 index 0000000..516dca2 --- /dev/null +++ b/metrics/credit_metrics.py @@ -0,0 +1,156 @@ +""" +Credit assignment diagnostic metrics: +1. Exact costate cosine (for toy LQ) +2. Local perturbation correlation rho_l +3. Nudging test Delta_l^nudge +4. Offline BP cosine Gamma_l +5. Bridge residual R_l +6. Feature drift M_l +""" +import torch +import torch.nn.functional as F +import numpy as np +from scipy.stats import pearsonr + + +def cosine_similarity_batch(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """Compute cosine similarity between a and b along last dim, averaged over batch.""" + a_flat = a.reshape(a.shape[0], -1) + b_flat = b.reshape(b.shape[0], -1) + cos = F.cosine_similarity(a_flat, b_flat, dim=-1) + return cos.mean().item() + + +def perturbation_correlation(h_l, a_l, forward_fn, epsilon=1e-3, M=32): + """ + Compute local perturbation correlation rho_l. + + Args: + h_l: (batch, d) hidden state at layer l + a_l: (batch, d) credit signal at layer l + forward_fn: callable that takes h_l -> scalar loss (averaged over batch dims handled inside) + epsilon: perturbation magnitude + M: number of random directions + + Returns: + rho: Pearson correlation between predicted and true loss changes + """ + batch_size, d = h_l.shape + device = h_l.device + + pred_list = [] + true_list = [] + + base_loss = forward_fn(h_l) # (batch,) or scalar + + for _ in range(M): + v = torch.randn(batch_size, d, device=device) + v = v / (v.norm(dim=-1, keepdim=True) + 1e-8) + + # Predicted change: + delta_pred = (a_l * (epsilon * v)).sum(dim=-1) # (batch,) + + # True change: forward from perturbed h + perturbed_loss = forward_fn(h_l + epsilon * v) # (batch,) + delta_true = perturbed_loss - base_loss # (batch,) + + pred_list.append(delta_pred.detach().cpu().numpy()) + true_list.append(delta_true.detach().cpu().numpy()) + + pred_arr = np.concatenate(pred_list) + true_arr = np.concatenate(true_list) + + if np.std(pred_arr) < 1e-12 or np.std(true_arr) < 1e-12: + return 0.0 + + rho, _ = pearsonr(pred_arr, true_arr) + return float(rho) + + +def nudging_test(h_l, a_l, forward_fn, eta=0.01): + """ + Nudging test: check if moving h_l in -a_l direction decreases loss. + + Args: + h_l: (batch, d) hidden state + a_l: (batch, d) credit signal + forward_fn: callable h -> loss per sample (batch,) + eta: step size + + Returns: + mean delta_nudge (negative is good) + """ + rms_a = (a_l ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_normed = a_l / rms_a + h_nudged = h_l - eta * a_normed + + base_loss = forward_fn(h_l) + nudged_loss = forward_fn(h_nudged) + delta = (nudged_loss - base_loss).mean().item() + return delta + + +def offline_bp_cosine(a_l, bp_grad_l): + """ + Compute offline BP cosine similarity. + a_l: (batch, d) credit signal + bp_grad_l: (batch, d) true BP gradient at layer l + """ + return cosine_similarity_batch(a_l, bp_grad_l) + + +def bridge_residual(V_phi, V_bar_phi, h_l, t_l, s, h_l_next_noisy_list, t_l_next, lam=0.1): + """ + Compute bridge residual R_l. + + Args: + V_phi: value network + V_bar_phi: EMA target value network + h_l: (batch, d) + t_l: (batch,) + s: (batch, s_dim) + h_l_next_noisy_list: list of K tensors (batch, d), noisy next states + t_l_next: (batch,) + lam: temperature + + Returns: + mean absolute bridge residual + """ + with torch.no_grad(): + V_current = V_phi(h_l, t_l, s) # (batch,) + + # Compute soft-min target + K = len(h_l_next_noisy_list) + log_terms = [] + for h_next in h_l_next_noisy_list: + V_next = V_bar_phi(h_next, t_l_next, s) # (batch,) + log_terms.append(-V_next / lam) + + log_terms = torch.stack(log_terms, dim=-1) # (batch, K) + V_target = -lam * torch.logsumexp(log_terms, dim=-1) + lam * np.log(K) + + residual = (V_current - V_target).abs().mean().item() + return residual + + +def feature_drift(model_init_params, model_final_params): + """ + Compute per-layer feature drift M_l = ||W_final - W_init||_F / ||W_init||_F. + + Args: + model_init_params: dict of {name: tensor} initial parameters + model_final_params: dict of {name: tensor} final parameters + + Returns: + dict of {name: drift_ratio} + """ + drifts = {} + for name in model_init_params: + if name in model_final_params: + w_init = model_init_params[name] + w_final = model_final_params[name] + init_norm = w_init.norm().item() + if init_norm > 1e-8: + drift = (w_final - w_init).norm().item() / init_norm + drifts[name] = drift + return drifts diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/residual_mlp.py b/models/residual_mlp.py new file mode 100644 index 0000000..6827057 --- /dev/null +++ b/models/residual_mlp.py @@ -0,0 +1,75 @@ +""" +Deep Residual MLP for classification. +Architecture: Input -> Linear embedding -> L residual blocks -> LayerNorm -> Linear output head. +Each block: h_{l+1} = h_l + W2 * GELU(W1 * LN(h_l)) +""" +import torch +import torch.nn as nn + + +class ResidualBlock(nn.Module): + """Single pre-LayerNorm residual MLP block.""" + + def __init__(self, d_hidden: int, w2_std: float = 0.01): + super().__init__() + self.ln = nn.LayerNorm(d_hidden) + self.w1 = nn.Linear(d_hidden, d_hidden) + self.w2 = nn.Linear(d_hidden, d_hidden) + # Small init for residual branch (or larger if used as a non-residual stack) + nn.init.normal_(self.w2.weight, std=w2_std) + nn.init.zeros_(self.w2.bias) + + def forward(self, h: torch.Tensor) -> torch.Tensor: + """Returns the residual F_l(h), NOT h + F_l(h).""" + z = self.ln(h) + z = self.w1(z) + z = torch.nn.functional.gelu(z) + z = self.w2(z) + return z + + +class ResidualMLP(nn.Module): + """Deep residual MLP: embed -> L blocks -> LN -> output head.""" + + def __init__(self, input_dim: int, d_hidden: int, num_classes: int, num_blocks: int, + residual_add: bool = True, w2_std: float = 0.01): + super().__init__() + self.embed = nn.Linear(input_dim, d_hidden) + self.blocks = nn.ModuleList([ResidualBlock(d_hidden, w2_std=w2_std) for _ in range(num_blocks)]) + self.out_ln = nn.LayerNorm(d_hidden) + self.out_head = nn.Linear(d_hidden, num_classes) + self.num_blocks = num_blocks + self.d_hidden = d_hidden + self.residual_add = residual_add + + def forward(self, x: torch.Tensor, return_hidden: bool = False): + """ + Args: + x: (batch, input_dim) + return_hidden: if True, also return list of hidden states [h_0, ..., h_L] + Returns: + logits: (batch, num_classes) + hiddens: list of (batch, d_hidden) if return_hidden + """ + h = self.embed(x) + hiddens = [h] if return_hidden else None + + for block in self.blocks: + f = block(h) + h = h + f if self.residual_add else f + if return_hidden: + hiddens.append(h) + + logits = self.out_head(self.out_ln(h)) + + if return_hidden: + return logits, hiddens + return logits + + def forward_from_layer(self, h: torch.Tensor, start_layer: int): + """Run forward from a given layer index to output. Used for perturbation tests.""" + for i in range(start_layer, self.num_blocks): + f = self.blocks[i](h) + h = h + f if self.residual_add else f + logits = self.out_head(self.out_ln(h)) + return logits diff --git a/models/small_resnet.py b/models/small_resnet.py new file mode 100644 index 0000000..10b122e --- /dev/null +++ b/models/small_resnet.py @@ -0,0 +1,74 @@ +""" +Small CIFAR-10 ResNet for the FA-evaluation paper. Standard BatchNorm-based +post-activation residual blocks (no LayerNorm). 4 residual blocks at width 64. + +Supports `num_blocks=0` (shallow baseline: just embed → bn → head) and frozen +blocks via `requires_grad=False` on `.blocks` parameters. +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BasicBlock(nn.Module): + """Standard ResNet BasicBlock with BatchNorm. Pre-activation NOT used; this is + the post-activation form: relu(BN(W2 * relu(BN(W1 x)))) + x. d_hidden in == d_hidden out. + No stride / downsampling — all blocks operate at the same spatial resolution + after the initial stem. This keeps the architecture simple and matches the + "4 residual blocks at fixed width" structure of our ResMLP and ViT-Mini comparisons. + """ + def __init__(self, d_hidden): + super().__init__() + self.conv1 = nn.Conv2d(d_hidden, d_hidden, kernel_size=3, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(d_hidden) + self.conv2 = nn.Conv2d(d_hidden, d_hidden, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(d_hidden) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out = F.relu(x + out) + return out + + +class SmallResNet(nn.Module): + """Small CIFAR-10 ResNet: + - 3x3 conv stem (3 → d_hidden) + BN + ReLU + - num_blocks BasicBlocks at the same width and resolution + - global average pool + - linear classification head + + `num_blocks=0` gives the shallow baseline (just stem → pool → head). + """ + def __init__(self, d_hidden=64, num_classes=10, num_blocks=4): + super().__init__() + self.stem_conv = nn.Conv2d(3, d_hidden, kernel_size=3, padding=1, bias=False) + self.stem_bn = nn.BatchNorm2d(d_hidden) + self.blocks = nn.ModuleList([BasicBlock(d_hidden) for _ in range(num_blocks)]) + self.out_head = nn.Linear(d_hidden, num_classes) + self.num_blocks = num_blocks + self.d_hidden = d_hidden + + def stem(self, x): + # x: (B, 3, 32, 32) + if x.dim() == 2: + x = x.view(x.size(0), 3, 32, 32) + h = F.relu(self.stem_bn(self.stem_conv(x))) + return h + + def forward(self, x, return_hidden=False): + h = self.stem(x) # (B, d, 32, 32) + hiddens = [h] if return_hidden else None + for block in self.blocks: + h = block(h) + if return_hidden: + hiddens.append(h) + h_pool = F.adaptive_avg_pool2d(h, 1).flatten(1) # (B, d) + logits = self.out_head(h_pool) + if return_hidden: + return logits, hiddens + return logits + + # Convenience alias for snapshot script compatibility (treats stem as the embed) + def embed(self, x): + return self.stem(x) diff --git a/models/vit_mini.py b/models/vit_mini.py new file mode 100644 index 0000000..af6ba60 --- /dev/null +++ b/models/vit_mini.py @@ -0,0 +1,109 @@ +""" +Minimal Vision Transformer for CIFAR-10. Pre-LN with terminal LayerNorm before +the classification head — the architecture P4 should target. + +Designed to be compatible with the snapshot evolution / DFA training framework. +Each TransformerBlock is a "layer" for FA-style local credit purposes. +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class TransformerBlock(nn.Module): + """Pre-LN transformer block: x = x + attn(LN(x)); x = x + mlp(LN(x)).""" + def __init__(self, d_model: int, n_heads: int, mlp_ratio: float = 4.0, dropout: float = 0.0): + super().__init__() + self.ln1 = nn.LayerNorm(d_model) + self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True) + self.ln2 = nn.LayerNorm(d_model) + mlp_hidden = int(d_model * mlp_ratio) + self.mlp = nn.Sequential( + nn.Linear(d_model, mlp_hidden), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(mlp_hidden, d_model), + nn.Dropout(dropout), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Self-attention sublayer + x_norm = self.ln1(x) + attn_out, _ = self.attn(x_norm, x_norm, x_norm, need_weights=False) + x = x + attn_out + # MLP sublayer + x = x + self.mlp(self.ln2(x)) + return x + + +class ViTMini(nn.Module): + """Minimal Vision Transformer for CIFAR-10. + Patch size 4x4 → 64 patches per image. Plus a learned cls token. + Pre-LN with terminal LayerNorm before the head. + """ + def __init__( + self, + image_size: int = 32, + patch_size: int = 4, + in_channels: int = 3, + num_classes: int = 10, + d_model: int = 128, + n_heads: int = 4, + num_blocks: int = 4, + mlp_ratio: float = 4.0, + dropout: float = 0.0, + ): + super().__init__() + assert image_size % patch_size == 0 + n_patches = (image_size // patch_size) ** 2 + self.n_tokens = n_patches + 1 # +1 for cls token + + # Patch embedding via Conv2d (equivalent to flatten + linear) + self.patch_embed = nn.Conv2d(in_channels, d_model, kernel_size=patch_size, stride=patch_size) + self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) + self.pos_embed = nn.Parameter(torch.zeros(1, self.n_tokens, d_model)) + nn.init.trunc_normal_(self.cls_token, std=0.02) + nn.init.trunc_normal_(self.pos_embed, std=0.02) + + self.blocks = nn.ModuleList([ + TransformerBlock(d_model, n_heads, mlp_ratio, dropout) for _ in range(num_blocks) + ]) + self.out_ln = nn.LayerNorm(d_model) # terminal LN — the P4 trigger + self.out_head = nn.Linear(d_model, num_classes) + + self.num_blocks = num_blocks + self.d_model = d_model + self.d_hidden = d_model # alias for compatibility with snapshot script + + def embed(self, x: torch.Tensor) -> torch.Tensor: + """Take a flat-CIFAR input (B, 3072) or image (B, 3, 32, 32) → token sequence (B, 65, d_model).""" + if x.dim() == 2: # flat input + x = x.view(x.size(0), 3, 32, 32) + # x: (B, 3, 32, 32) + x = self.patch_embed(x) # (B, d, 8, 8) + x = x.flatten(2).transpose(1, 2) # (B, 64, d) + cls = self.cls_token.expand(x.size(0), -1, -1) + x = torch.cat([cls, x], dim=1) # (B, 65, d) + x = x + self.pos_embed + return x + + def forward(self, x: torch.Tensor, return_hidden: bool = False): + h = self.embed(x) # (B, 65, d) + hiddens = [h] if return_hidden else None + for block in self.blocks: + h = block(h) + if return_hidden: + hiddens.append(h) + # Take cls token, normalize, classify + h_cls = self.out_ln(h[:, 0]) # (B, d) + logits = self.out_head(h_cls) + if return_hidden: + return logits, hiddens + return logits + + def forward_from_layer(self, h: torch.Tensor, start_layer: int): + """Run forward from a given block index. h has shape (B, n_tokens, d).""" + for i in range(start_layer, self.num_blocks): + h = self.blocks[i](h) + h_cls = self.out_ln(h[:, 0]) + return self.out_head(h_cls) diff --git a/protocol/__init__.py b/protocol/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/protocol/example_usage.py b/protocol/example_usage.py new file mode 100644 index 0000000..2b2c65e --- /dev/null +++ b/protocol/example_usage.py @@ -0,0 +1,106 @@ +""" +Minimal example: apply the FA evaluation protocol to a DFA-trained ResMLP. + +This script trains a model with DFA, then runs the three-diagnostic protocol. +Expected output: FAIL(D1+D2+D3) — DFA on terminal-LN ResMLP triggers all diagnostics. +""" +import sys, os +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torchvision +import torchvision.transforms as transforms +import numpy as np + +from models.residual_mlp import ResidualMLP +from protocol.fa_protocol import FAProtocol + + +def get_cifar10(batch_size=128): + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tv_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) + te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) + return (torch.utils.data.DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), + torch.utils.data.DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2)) + + +def train_dfa(model, train_loader, device, epochs=30): + """Minimal DFA training (canonical: no clipping, mean reduction).""" + d = model.d_hidden + L = model.num_blocks + C = 10 + Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] + block_opts = [optim.AdamW(b.parameters(), lr=1e-3, weight_decay=0.01) for b in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=1e-3, weight_decay=0.01) + head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=1e-3, weight_decay=0.01) + for ep in range(1, epochs + 1): + model.train() + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device); y = y.to(device) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1 + hL = hiddens[-1].detach() + head_opt.zero_grad() + F.cross_entropy(model.out_head(model.out_ln(hL)), y).backward() + head_opt.step() + for l in range(L): + a = (e_T @ Bs[l].T).detach() + rms = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + f_l = model.blocks[l](hiddens[l].detach()) + loss = (f_l * (a / rms)).sum(-1).mean() + block_opts[l].zero_grad(); loss.backward(); block_opts[l].step() + a0 = (e_T @ Bs[0].T).detach() + rms0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + h0 = model.embed(x) + embed_opt.zero_grad(); (h0 * (a0 / rms0)).sum(-1).mean().backward(); embed_opt.step() + if ep % 10 == 0: + print(f" DFA ep {ep}/{epochs}", flush=True) + + +def main(): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + seed = 42 + torch.manual_seed(seed); np.random.seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + print("Loading CIFAR-10...") + train_loader, test_loader = get_cifar10() + + # Prepare eval buffer + xs, ys = [], [] + for x, y in test_loader: + xs.append(x.view(x.size(0), -1)); ys.append(y) + if sum(xb.size(0) for xb in xs) >= 128: + break + x_eval = torch.cat(xs)[:128].to(device) + y_eval = torch.cat(ys)[:128].to(device) + + # Train with DFA + print("Training DFA (30 epochs)...") + model = ResidualMLP(3072, 256, 10, 4).to(device) + train_dfa(model, train_loader, device, epochs=30) + + # Run protocol + print("\nRunning protocol...") + protocol = FAProtocol(model, x_eval, y_eval) + report = protocol.run(frozen_baseline_acc=0.349) + print(protocol.summary(report)) + + +if __name__ == '__main__': + main() diff --git a/protocol/fa_protocol.py b/protocol/fa_protocol.py new file mode 100644 index 0000000..8d12939 --- /dev/null +++ b/protocol/fa_protocol.py @@ -0,0 +1,215 @@ +""" +Reference implementation of the three-diagnostic FA evaluation protocol. + +Usage: + from protocol.fa_protocol import FAProtocol + + protocol = FAProtocol(model, x_eval, y_eval) + report = protocol.run(frozen_baseline_acc=0.349) + print(report['verdict']) +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from typing import Optional + + +class FAProtocol: + """Three-diagnostic evaluation protocol for feedback alignment methods. + + Diagnostics: + D1 (Scale stability): max per-block residual growth rho = max_l ||h_{l+1}|| / ||h_l||. + Flags if rho > threshold (default 50). + D2 (Reference validity): BP gradient norm at the deepest hidden state. + Flags if ||g_L|| < 10 * eps, where eps is the cosine clamp floor. + D3 (Depth utility): test accuracy vs frozen-blocks baseline. + Flags if trained acc < frozen_acc + margin (default 2 pp). + + The protocol requires: + - A trained model with .blocks (nn.ModuleList) and forward(x, return_hidden=True) + - A test batch (x_eval, y_eval) + - A frozen-blocks baseline accuracy (must be computed separately) + """ + + def __init__( + self, + model: nn.Module, + x_eval: torch.Tensor, + y_eval: torch.Tensor, + d1_threshold: float = 50.0, + d2_eps: float = 1e-8, + d2_factor: float = 10.0, + d3_margin_pp: float = 2.0, + ): + self.model = model + self.x_eval = x_eval + self.y_eval = y_eval + self.d1_threshold = d1_threshold + self.d2_floor = d2_factor * d2_eps + self.d3_margin = d3_margin_pp / 100.0 + + def _compute_hidden_norms(self, hiddens): + """Compute median per-sample L2 norm at each hidden layer.""" + norms = [] + for h in hiddens: + if h.dim() == 4: # conv: (B, C, H, W) -> pool to (B, C) + h_flat = F.adaptive_avg_pool2d(h, 1).flatten(1) + elif h.dim() == 3: # transformer: (B, T, D) -> cls token + h_flat = h[:, 0] + else: + h_flat = h + norms.append(float(h_flat.norm(dim=-1).median().item())) + return norms + + def _compute_bp_grad_norms(self, hiddens): + """Compute BP gradient norms at each hidden layer via manual forward.""" + model = self.model + L = len(hiddens) - 1 # number of blocks + + # Rebuild forward from hidden states with grad tracking + hs = [hiddens[0].detach().clone().requires_grad_(True)] + for i, block in enumerate(model.blocks): + if hasattr(block, 'forward'): + h_next = block(hs[-1]) + # Check if block includes residual (output same shape, skip connection) + if h_next.shape == hs[-1].shape and not self._block_has_internal_skip(block): + h_next = hs[-1] + h_next + hs.append(h_next) + + # Forward through head + h_final = hs[-1] + if h_final.dim() == 4: # conv + h_final = F.adaptive_avg_pool2d(h_final, 1).flatten(1) + elif h_final.dim() == 3: # transformer cls token + h_final = h_final[:, 0] + if hasattr(model, 'out_ln'): + h_final = model.out_ln(h_final) + logits = model.out_head(h_final) + loss = F.cross_entropy(logits, self.y_eval) + grads = torch.autograd.grad(loss, hs, allow_unused=True) + + norms = [] + for g in grads: + if g is None: + norms.append(0.0) + continue + if g.dim() == 4: + g_flat = F.adaptive_avg_pool2d(g, 1).flatten(1) + elif g.dim() == 3: + g_flat = g[:, 0] + else: + g_flat = g + norms.append(float(g_flat.norm(dim=-1).median().item())) + return norms + + @staticmethod + def _block_has_internal_skip(block): + """Heuristic: check if the block's forward already includes a residual skip.""" + src = type(block).forward.__qualname__ + # Blocks that compute x + f(x) internally (e.g., transformer blocks) + return False # conservative default; override if needed + + def run(self, frozen_baseline_acc: Optional[float] = None, test_acc: Optional[float] = None): + """Run all three diagnostics. + + Args: + frozen_baseline_acc: accuracy of the frozen-blocks baseline (required for D3). + test_acc: test accuracy of the trained model. If None, computed from x_eval/y_eval. + + Returns: + dict with 'diagnostics', 'verdict', and raw values. + """ + self.model.eval() + + # Forward pass to get hidden states + with torch.no_grad(): + logits, hiddens = self.model(self.x_eval, return_hidden=True) + + if test_acc is None: + test_acc = float((logits.argmax(-1) == self.y_eval).float().mean().item()) + + # D1: Scale stability + h_norms = self._compute_hidden_norms(hiddens) + growth_ratios = [h_norms[i+1] / max(h_norms[i], 1e-12) + for i in range(len(h_norms) - 1)] + max_growth = max(growth_ratios) if growth_ratios else 1.0 + d1_fires = max_growth > self.d1_threshold + + # D2: Reference validity + bp_grad_norms = self._compute_bp_grad_norms(hiddens) + g_L = bp_grad_norms[-1] if bp_grad_norms else 0.0 + d2_fires = g_L < self.d2_floor + + # D3: Depth utility + if frozen_baseline_acc is not None: + margin = test_acc - frozen_baseline_acc + d3_fires = margin < self.d3_margin + else: + margin = None + d3_fires = None + + # Verdict + mode1 = d1_fires and d2_fires + flags = [] + if d1_fires: + flags.append('D1') + if d2_fires: + flags.append('D2') + if d3_fires: + flags.append('D3') + + if not flags: + verdict = 'PASS' + else: + verdict = 'FAIL(' + '+'.join(flags) + ')' + + return { + 'verdict': verdict, + 'test_acc': test_acc, + 'diagnostics': { + 'D1_scale_growth': { + 'max_growth': max_growth, + 'per_block_growth': growth_ratios, + 'hidden_norms': h_norms, + 'threshold': self.d1_threshold, + 'fires': d1_fires, + }, + 'D2_ref_validity': { + 'g_L': g_L, + 'bp_grad_norms': bp_grad_norms, + 'floor': self.d2_floor, + 'fires': d2_fires, + }, + 'D3_depth_utility': { + 'test_acc': test_acc, + 'frozen_baseline_acc': frozen_baseline_acc, + 'margin': margin, + 'margin_threshold': self.d3_margin, + 'fires': d3_fires, + }, + }, + } + + def summary(self, report: dict) -> str: + """Human-readable summary of a protocol report.""" + d = report['diagnostics'] + lines = [ + f"Verdict: {report['verdict']}", + f"Test accuracy: {report['test_acc']:.4f}", + f"D1 Scale stability: max growth = {d['D1_scale_growth']['max_growth']:.1f}x " + f"(threshold {d['D1_scale_growth']['threshold']}x) -> " + f"{'FIRE' if d['D1_scale_growth']['fires'] else 'pass'}", + f"D2 Reference validity: ||g_L|| = {d['D2_ref_validity']['g_L']:.2e} " + f"(floor {d['D2_ref_validity']['floor']:.0e}) -> " + f"{'FIRE' if d['D2_ref_validity']['fires'] else 'pass'}", + ] + if d['D3_depth_utility']['fires'] is not None: + lines.append( + f"D3 Depth utility: margin = {d['D3_depth_utility']['margin']*100:+.1f} pp " + f"(threshold {d['D3_depth_utility']['margin_threshold']*100:.0f} pp) -> " + f"{'FIRE' if d['D3_depth_utility']['fires'] else 'pass'}" + ) + else: + lines.append("D3 Depth utility: not evaluated (no frozen baseline provided)") + return '\n'.join(lines) diff --git a/reproduce/__init__.py b/reproduce/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/reproduce/frozen_baseline.py b/reproduce/frozen_baseline.py new file mode 100644 index 0000000..08368a2 --- /dev/null +++ b/reproduce/frozen_baseline.py @@ -0,0 +1,86 @@ +""" +Frozen-blocks baseline: train only embed/head with blocks frozen at random init. + +Usage: + python reproduce/frozen_baseline.py --arch resmlp --seeds 42 123 456 --epochs 100 +""" +import os, sys, json, argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torchvision, torchvision.transforms as transforms + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from reproduce.train_methods import get_data, evaluate, make_model + + +def freeze_blocks(model): + for p in model.blocks.parameters(): + p.requires_grad_(False) + for m in model.blocks.modules(): + if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): + m.eval() + + +def train_frozen(model, train_loader, test_loader, device, epochs, is_conv): + opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3, weight_decay=0.01) + sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) + for ep in range(1, epochs + 1): + model.train() + for m in model.blocks.modules(): + if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): + m.eval() + for x, y in train_loader: + x, y = x.to(device), y.to(device) + if not is_conv: x = x.view(x.size(0), -1) + loss = F.cross_entropy(model(x), y) + opt.zero_grad(); loss.backward(); opt.step() + sch.step() + if ep % 10 == 0 or ep == epochs: + acc = evaluate(model, test_loader, device, is_conv) + print(f" [Frozen] ep {ep}: acc={acc:.4f}", flush=True) + return evaluate(model, test_loader, device, is_conv) + + +def main(): + p = argparse.ArgumentParser() + p.add_argument('--arch', type=str, default='resmlp', choices=['resmlp', 'resmlp_d512_L2', 'vit', 'resnet']) + p.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100']) + p.add_argument('--seeds', nargs='+', type=int, default=[42, 123, 456]) + p.add_argument('--epochs', type=int, default=100) + p.add_argument('--gpu', type=int, default=0) + p.add_argument('--output_dir', type=str, default='results/frozen_baselines') + args = p.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + train_loader, test_loader, num_classes = get_data(args.dataset, 128) + + results = {} + for seed in args.seeds: + print(f"\n--- Frozen baseline seed={seed} ---", flush=True) + torch.manual_seed(seed); np.random.seed(seed) + if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) + model, is_conv = make_model(args.arch, num_classes, device) + freeze_blocks(model) + trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + total = sum(p.numel() for p in model.parameters()) + print(f" {trainable}/{total} trainable params", flush=True) + acc = train_frozen(model, train_loader, test_loader, device, args.epochs, is_conv) + results[f's{seed}'] = acc + print(f" FINAL: {acc:.4f}", flush=True) + + results['config'] = vars(args) + results['mean'] = float(np.mean([results[f's{s}'] for s in args.seeds])) + results['std'] = float(np.std([results[f's{s}'] for s in args.seeds], ddof=1)) + out_path = os.path.join(args.output_dir, f'frozen_{args.arch}_{args.dataset}.json') + with open(out_path, 'w') as f: + json.dump(results, f, indent=2) + print(f"\nSaved: {out_path}") + print(f"Frozen baseline: {results['mean']:.4f} +/- {results['std']:.4f}") + + +if __name__ == '__main__': + main() diff --git a/reproduce/penalty_sweep.py b/reproduce/penalty_sweep.py new file mode 100644 index 0000000..b6b913d --- /dev/null +++ b/reproduce/penalty_sweep.py @@ -0,0 +1,176 @@ +""" +Penalty intervention sweep: DFA + lambda x {0, 1e-4, 1e-2} with per-epoch trajectory. +Includes fresh-B null calibration on the lambda=1e-2 checkpoint. + +Usage: + python reproduce/penalty_sweep.py --seeds 42 123 456 --gpu 0 +""" +import os, sys, json, argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from reproduce.train_methods import get_data, evaluate, make_model, _pool_hidden, _get_head_logits +from metrics.credit_metrics import cosine_similarity_batch + + +def train_dfa_trajectory(seed, train_loader, test_loader, device, epochs, lam, num_classes=10): + """DFA with per-epoch ||h_L||, ||g_L|| logging.""" + torch.manual_seed(seed); np.random.seed(seed) + if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) + from models.residual_mlp import ResidualMLP + model = ResidualMLP(3072, 256, num_classes, 4).to(device) + d, L, C = 256, 4, num_classes + Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] + block_opts = [optim.AdamW(b.parameters(), lr=1e-3, weight_decay=0.01) for b in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=1e-3, weight_decay=0.01) + head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=1e-3, weight_decay=0.01) + all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)] + + # Eval buffer + xs, ys = [], [] + for x, y in test_loader: + xs.append(x.view(x.size(0), -1)); ys.append(y) + if sum(xb.size(0) for xb in xs) >= 128: break + x_eval = torch.cat(xs)[:128].to(device) + y_eval = torch.cat(ys)[:128].to(device) + + def diagnose(): + model.eval() + with torch.no_grad(): + _, hi = model(x_eval, return_hidden=True) + h_L = hi[-1].norm(dim=-1).median().item() + h0 = model.embed(x_eval) + hs = [h0.clone().requires_grad_(True)] + for b in model.blocks: + hs.append(hs[-1] + b(hs[-1])) + logits = model.out_head(model.out_ln(hs[-1])) + loss = F.cross_entropy(logits, y_eval) + grads = torch.autograd.grad(loss, hs) + g_L = grads[-1].norm(dim=-1).median().item() + acc = (logits.argmax(-1) == y_eval).float().mean().item() + model.train() + return h_L, g_L, acc + + log = [] + h, g, a = diagnose() + log.append({'epoch': 0, 'h_L': h, 'g_L': g, 'acc': a}) + + for ep in range(1, epochs + 1): + model.train() + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device); y = y.to(device) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1 + hL = hiddens[-1].detach() + head_opt.zero_grad() + F.cross_entropy(model.out_head(model.out_ln(hL)), y).backward() + head_opt.step() + for l in range(L): + a_dfa = (e_T @ Bs[l].T).detach() + rms = (a_dfa ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + f_l = model.blocks[l](hiddens[l].detach()) + local_loss = (f_l * (a_dfa / rms)).sum(-1).mean() + if lam > 0: + local_loss = local_loss + lam * (f_l ** 2).sum(-1).mean() + block_opts[l].zero_grad(); local_loss.backward(); block_opts[l].step() + a0 = (e_T @ Bs[0].T).detach() + rms0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + h0 = model.embed(x) + embed_opt.zero_grad(); (h0 * (a0 / rms0)).sum(-1).mean().backward(); embed_opt.step() + for s in all_sch: s.step() + h, g, a = diagnose() + log.append({'epoch': ep, 'h_L': h, 'g_L': g, 'acc': a}) + if ep % 10 == 0 or ep == epochs: + print(f" [lam={lam}] s={seed} ep {ep}: ||h_L||={h:.3e} ||g_L||={g:.3e} acc={a:.4f}", flush=True) + + return log, model, Bs + + +def fresh_b_null(model, x_eval, y_eval, training_Bs, n_draws=20): + """Fresh-B null calibration on a trained checkpoint.""" + model.eval() + d, L, C = 256, 4, len(training_Bs[0][0]) if training_Bs[0].dim() == 2 else 10 + device = x_eval.device + + def deep_cos_with_Bs(Bs): + h0 = model.embed(x_eval) + hs = [h0.clone().requires_grad_(True)] + for b in model.blocks: + hs.append(hs[-1] + b(hs[-1])) + logits = model.out_head(model.out_ln(hs[-1])) + loss = F.cross_entropy(logits, y_eval) + grads = torch.autograd.grad(loss, hs) + with torch.no_grad(): + e_T = logits.softmax(-1) + e_T[torch.arange(x_eval.size(0)), y_eval] -= 1 + cos_layers = [] + for l in range(L): + a = (e_T @ Bs[l].T).detach() + cos_layers.append(cosine_similarity_batch(a, grads[l].detach())) + return float(np.mean(cos_layers[1:])) # deep = exclude layer 0 + + train_cos = deep_cos_with_Bs(training_Bs) + fresh_cos = [] + for _ in range(n_draws): + fresh_Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] + fresh_cos.append(deep_cos_with_Bs(fresh_Bs)) + + return { + 'training_Bs_deep_cos': train_cos, + 'fresh_Bs_deep_mean': float(np.mean(fresh_cos)), + 'fresh_Bs_deep_std_ddof1': float(np.std(fresh_cos, ddof=1)), + 'n_draws': n_draws, + } + + +def main(): + p = argparse.ArgumentParser() + p.add_argument('--seeds', nargs='+', type=int, default=[42, 123, 456]) + p.add_argument('--epochs', type=int, default=30) + p.add_argument('--lambdas', nargs='+', type=float, default=[0.0, 1e-4, 1e-2]) + p.add_argument('--gpu', type=int, default=0) + p.add_argument('--output_dir', type=str, default='results/penalty_sweep') + args = p.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + train_loader, test_loader, _ = get_data('cifar10', 128) + + results = {} + for lam in args.lambdas: + lam_key = f'lam_{lam}' + results[lam_key] = {} + for seed in args.seeds: + print(f"\n=== lambda={lam}, seed={seed} ===", flush=True) + log, model, Bs = train_dfa_trajectory(seed, train_loader, test_loader, device, args.epochs, lam) + results[lam_key][str(seed)] = log + + # Fresh-B null on lambda=1e-2, seed=42 only + if lam == 1e-2 and seed == 42: + xs, ys = [], [] + for x, y in test_loader: + xs.append(x.view(x.size(0), -1)); ys.append(y) + if sum(xb.size(0) for xb in xs) >= 128: break + x_eval = torch.cat(xs)[:128].to(device) + y_eval = torch.cat(ys)[:128].to(device) + null = fresh_b_null(model, x_eval, y_eval, Bs) + results['fresh_b_null'] = null + print(f" Fresh-B: training={null['training_Bs_deep_cos']:+.4f}, " + f"fresh={null['fresh_Bs_deep_mean']:+.4f} +/- {null['fresh_Bs_deep_std_ddof1']:.4f}") + + with open(os.path.join(args.output_dir, 'penalty_sweep.json'), 'w') as f: + json.dump(results, f, indent=2) + print(f"\nSaved: {args.output_dir}/penalty_sweep.json") + + +if __name__ == '__main__': + main() diff --git a/reproduce/run_all.sh b/reproduce/run_all.sh new file mode 100755 index 0000000..35c3587 --- /dev/null +++ b/reproduce/run_all.sh @@ -0,0 +1,85 @@ +#!/bin/bash +# Full reproduction of all paper results. +# Usage: bash reproduce/run_all.sh --gpu 0 +# Estimated time: ~12 hours on a single A100/A6000. + +GPU=${1:-0} +export CUDA_VISIBLE_DEVICES=$GPU + +echo "============================================================" +echo "FA Evaluation Protocol — Full Reproduction" +echo "GPU: $GPU" +echo "Start: $(date)" +echo "============================================================" + +cd "$(dirname "$0")/.." + +# ─── Section 2: Primary audit (ResMLP d=256 L=4, 100ep) ───────────────── + +echo "" +echo "=== Section 2: Primary audit (BP/FA/DFA, 3 seeds, 100ep) ===" +python reproduce/train_methods.py --arch resmlp --methods bp fa dfa \ + --seeds 42 123 456 --epochs 100 --gpu 0 --output_dir results/sec2_primary_audit + +echo "" +echo "=== Section 2: Frozen baseline ===" +python reproduce/frozen_baseline.py --arch resmlp --seeds 42 123 456 \ + --epochs 100 --gpu 0 --output_dir results/sec2_frozen + +# ─── Section 4.1: Cross-architecture (ViT, ResNet) ─────────────────────── + +echo "" +echo "=== Section 4.1: ViT-Mini (BP/FA/DFA, 3 seeds, 60ep) ===" +python reproduce/train_methods.py --arch vit --methods bp fa dfa \ + --seeds 42 123 456 --epochs 60 --gpu 0 --output_dir results/sec4_vit + +echo "" +echo "=== Section 4.1: ViT-Mini frozen baseline ===" +python reproduce/frozen_baseline.py --arch vit --seeds 42 123 456 \ + --epochs 60 --gpu 0 --output_dir results/sec4_vit_frozen + +echo "" +echo "=== Section 4.1: SmallResNet (BP/FA/DFA, 3 seeds, 100ep) ===" +python reproduce/train_methods.py --arch resnet --methods bp fa dfa \ + --seeds 42 123 456 --epochs 100 --gpu 0 --output_dir results/sec4_resnet + +echo "" +echo "=== Section 4.1: SmallResNet frozen baseline ===" +python reproduce/frozen_baseline.py --arch resnet --seeds 42 123 456 \ + --epochs 100 --gpu 0 --output_dir results/sec4_resnet_frozen + +# ─── Section 4.2: Penalty intervention ─────────────────────────────────── + +echo "" +echo "=== Section 4.2: DFA penalty sweep (lambda=0, 1e-4, 1e-2, 30ep) ===" +python reproduce/penalty_sweep.py --seeds 42 123 456 --epochs 30 --gpu 0 \ + --output_dir results/sec4_penalty + +# ─── Section 5.2: Representative setting (d=512 L=2) ──────────────────── + +echo "" +echo "=== Section 5.2: d=512 L=2 (BP/FA/DFA, 3 seeds, 100ep) ===" +python reproduce/train_methods.py --arch resmlp_d512_L2 --methods bp fa dfa \ + --seeds 1 2 5 --epochs 100 --gpu 0 --output_dir results/sec5_d512_L2 + +echo "" +echo "=== Section 5.2: d=512 L=2 frozen baseline ===" +python reproduce/frozen_baseline.py --arch resmlp_d512_L2 --seeds 1 2 5 \ + --epochs 100 --gpu 0 --output_dir results/sec5_d512_L2_frozen + +# ─── Appendix: CIFAR-100 ──────────────────────────────────────────────── + +echo "" +echo "=== Appendix: CIFAR-100 (BP/FA/DFA, 3 seeds, 100ep) ===" +python reproduce/train_methods.py --arch resmlp --dataset cifar100 --methods bp fa dfa \ + --seeds 42 123 456 --epochs 100 --gpu 0 --output_dir results/app_cifar100 + +echo "" +echo "=== Appendix: CIFAR-100 frozen baseline ===" +python reproduce/frozen_baseline.py --arch resmlp --dataset cifar100 --seeds 42 123 456 \ + --epochs 100 --gpu 0 --output_dir results/app_cifar100_frozen + +echo "" +echo "============================================================" +echo "Full reproduction done: $(date)" +echo "============================================================" diff --git a/reproduce/train_methods.py b/reproduce/train_methods.py new file mode 100644 index 0000000..c430b90 --- /dev/null +++ b/reproduce/train_methods.py @@ -0,0 +1,376 @@ +""" +Train BP/FA/DFA on a specified architecture and compute protocol diagnostics. + +Usage: + python reproduce/train_methods.py --arch resmlp --methods bp fa dfa \ + --seeds 42 123 456 --epochs 100 --gpu 0 --output_dir results/main_audit + +Architectures: resmlp (d=256 L=4), resmlp_d512_L2, vit, resnet +""" +import os, sys, json, argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision, torchvision.transforms as transforms + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from models.residual_mlp import ResidualMLP +from models.vit_mini import ViTMini +from models.small_resnet import SmallResNet +from metrics.credit_metrics import cosine_similarity_batch, nudging_test + + +# ─── Data ──────────────────────────────────────────────────────────────── + +def get_data(dataset='cifar10', batch_size=128): + if dataset == 'cifar10': + mean, std = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616) + Dataset = torchvision.datasets.CIFAR10 + num_classes = 10 + else: + mean, std = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761) + Dataset = torchvision.datasets.CIFAR100 + num_classes = 100 + tv_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), + transforms.ToTensor(), transforms.Normalize(mean, std)]) + tv_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) + tr = Dataset('./data', True, download=True, transform=tv_train) + te = Dataset('./data', False, download=True, transform=tv_test) + return (DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), + DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2), + num_classes) + + +def evaluate(model, loader, device, is_conv=False): + model.eval() + c = n = 0 + with torch.no_grad(): + for x, y in loader: + x, y = x.to(device), y.to(device) + if not is_conv: + x = x.view(x.size(0), -1) + c += (model(x).argmax(-1) == y).sum().item() + n += x.size(0) + return c / n + + +# ─── Model construction ───────────────────────────────────────────────── + +def make_model(arch, num_classes, device): + if arch == 'resmlp': + return ResidualMLP(3072, 256, num_classes, 4).to(device), False + elif arch == 'resmlp_d512_L2': + return ResidualMLP(3072, 512, num_classes, 2).to(device), False + elif arch == 'vit': + return ViTMini(d_model=128, n_heads=4, num_blocks=4, num_classes=num_classes).to(device), True + elif arch == 'resnet': + return SmallResNet(64, num_classes, 4).to(device), True + else: + raise ValueError(f"Unknown arch: {arch}") + + +# ─── Training functions ───────────────────────────────────────────────── + +def train_bp(model, train_loader, test_loader, device, epochs, is_conv): + opt = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01) + sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) + log = {'train_loss': [], 'train_acc': [], 'test_acc': []} + for ep in range(1, epochs + 1): + model.train() + tl, tc, tn = 0, 0, 0 + for x, y in train_loader: + x, y = x.to(device), y.to(device) + if not is_conv: x = x.view(x.size(0), -1) + logits = model(x) + loss = F.cross_entropy(logits, y) + opt.zero_grad(); loss.backward(); opt.step() + tl += loss.item() * x.size(0); tc += (logits.argmax(1) == y).sum().item(); tn += x.size(0) + sch.step() + log['train_loss'].append(tl / tn); log['train_acc'].append(tc / tn) + log['test_acc'].append(evaluate(model, test_loader, device, is_conv)) + if ep % 10 == 0 or ep == epochs: + print(f" [BP] ep {ep}: acc={log['test_acc'][-1]:.4f}", flush=True) + return log + + +def _get_embed_head_params(model, is_conv): + """Get embed and head parameter groups.""" + if is_conv and hasattr(model, 'stem_conv'): + embed_params = list(model.stem_conv.parameters()) + list(model.stem_bn.parameters()) + head_params = list(model.out_head.parameters()) + elif hasattr(model, 'patch_embed'): # ViT + embed_params = list(model.patch_embed.parameters()) + [model.cls_token, model.pos_embed] + head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters()) + else: # ResMLP + embed_params = list(model.embed.parameters()) + head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters()) + return embed_params, head_params + + +def _pool_hidden(h): + if h.dim() == 4: return F.adaptive_avg_pool2d(h, 1).flatten(1) + if h.dim() == 3: return h[:, 0] # cls token + return h + + +def _get_head_logits(model, h_pool): + if hasattr(model, 'out_ln'): + return model.out_head(model.out_ln(h_pool)) + return model.out_head(h_pool) + + +def _block_residual(model, block, h_l, is_conv): + """Compute block residual f_l = block(h_l) - h_l for blocks with internal skip.""" + out = block(h_l) + if is_conv or hasattr(block, 'attn'): # ResNet/ViT blocks include skip internally + return out - h_l + return out # ResMLP blocks return f_l only + + +def train_dfa(model, train_loader, test_loader, device, epochs, is_conv, num_classes): + d = model.d_hidden if hasattr(model, 'd_hidden') else model.d_model + L = model.num_blocks + C = num_classes + Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] + block_opts = [optim.AdamW(b.parameters(), lr=1e-3, weight_decay=0.01) for b in model.blocks] + embed_params, head_params = _get_embed_head_params(model, is_conv) + embed_opt = optim.AdamW(embed_params, lr=1e-3, weight_decay=0.01) + head_opt = optim.AdamW(head_params, lr=1e-3, weight_decay=0.01) + all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)] + log = {'train_loss': [], 'train_acc': [], 'test_acc': []} + for ep in range(1, epochs + 1): + model.train() + tl, tc, tn = 0, 0, 0 + for x, y in train_loader: + x, y = x.to(device), y.to(device) + if not is_conv: x = x.view(x.size(0), -1) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1 + h_pool = _pool_hidden(hiddens[-1].detach()) + head_opt.zero_grad() + F.cross_entropy(_get_head_logits(model, h_pool), y).backward() + head_opt.step() + for l in range(L): + h_l = hiddens[l].detach() + a = (e_T @ Bs[l].T).detach() + rms = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + a_norm = a / rms + f_l = _block_residual(model, model.blocks[l], h_l, is_conv) + if f_l.dim() > 2: + a_b = a_norm.unsqueeze(-1).unsqueeze(-1).expand_as(f_l) + local_loss = (f_l * a_b).sum(dim=1).mean() + else: + local_loss = (f_l * a_norm).sum(-1).mean() + block_opts[l].zero_grad(); local_loss.backward(); block_opts[l].step() + # Embed + a0 = (e_T @ Bs[0].T).detach() + rms0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + if is_conv: + h0 = model.embed(x) if hasattr(model, 'embed') else model.stem(x) + else: + h0 = model.embed(x) + a0_n = a0 / rms0 + if h0.dim() > 2: + a0_b = a0_n.unsqueeze(-1).unsqueeze(-1).expand_as(h0) + embed_loss = (h0 * a0_b).sum(dim=1).mean() + else: + embed_loss = (h0 * a0_n).sum(-1).mean() + embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step() + for s in all_sch: s.step() + tl += loss_val.item() * batch; tc += (logits.argmax(1) == y).sum().item(); tn += batch + log['train_loss'].append(tl / tn); log['train_acc'].append(tc / tn) + log['test_acc'].append(evaluate(model, test_loader, device, is_conv)) + if ep % 10 == 0 or ep == epochs: + print(f" [DFA] ep {ep}: acc={log['test_acc'][-1]:.4f}", flush=True) + return log, Bs + + +def train_fa(model, train_loader, test_loader, device, epochs, is_conv, num_classes): + d = model.d_hidden if hasattr(model, 'd_hidden') else model.d_model + L = model.num_blocks + Bs = [torch.randn(d, d, device=device) / np.sqrt(d) for _ in range(L)] + block_opts = [optim.AdamW(b.parameters(), lr=1e-3, weight_decay=0.01) for b in model.blocks] + embed_params, head_params = _get_embed_head_params(model, is_conv) + embed_opt = optim.AdamW(embed_params, lr=1e-3, weight_decay=0.01) + head_opt = optim.AdamW(head_params, lr=1e-3, weight_decay=0.01) + all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)] + log = {'train_loss': [], 'train_acc': [], 'test_acc': []} + for ep in range(1, epochs + 1): + model.train() + tl, tc, tn = 0, 0, 0 + for x, y in train_loader: + x, y = x.to(device), y.to(device) + if not is_conv: x = x.view(x.size(0), -1) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + # Head — grad before step + h_pool = _pool_hidden(hiddens[-1].detach()).requires_grad_(True) + logits_out = _get_head_logits(model, h_pool) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad(); loss_out.backward() + a_credit = h_pool.grad.detach() + head_opt.step() + # Top-down blocks + for l in range(L - 1, -1, -1): + h_l = hiddens[l].detach() + rms = (a_credit ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + a_norm = a_credit / rms + f_l = _block_residual(model, model.blocks[l], h_l, is_conv) + if f_l.dim() > 2: + a_b = a_norm.unsqueeze(-1).unsqueeze(-1).expand_as(f_l) + local_loss = (f_l * a_b).sum(dim=1).mean() + else: + local_loss = (f_l * a_norm).sum(-1).mean() + block_opts[l].zero_grad(); local_loss.backward(); block_opts[l].step() + a_credit = (a_credit @ Bs[l]).detach() + # Embed + rms0 = (a_credit ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + if is_conv: + h0 = model.embed(x) if hasattr(model, 'embed') else model.stem(x) + else: + h0 = model.embed(x) + a0_n = a_credit / rms0 + if h0.dim() > 2: + a0_b = a0_n.unsqueeze(-1).unsqueeze(-1).expand_as(h0) + embed_loss = (h0 * a0_b).sum(dim=1).mean() + else: + embed_loss = (h0 * a0_n).sum(-1).mean() + embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step() + for s in all_sch: s.step() + tl += loss_val.item() * batch; tc += (logits.argmax(1) == y).sum().item(); tn += batch + log['train_loss'].append(tl / tn); log['train_acc'].append(tc / tn) + log['test_acc'].append(evaluate(model, test_loader, device, is_conv)) + if ep % 10 == 0 or ep == epochs: + print(f" [FA] ep {ep}: acc={log['test_acc'][-1]:.4f}", flush=True) + return log, Bs + + +# ─── Diagnostics ───────────────────────────────────────────────────────── + +def compute_diagnostics(model, x_eval, y_eval, device, method_name, dfa_Bs=None, fa_Bs=None, is_conv=False): + """Compute per-layer cosine, ||g_l||, ||h_l|| and nudging.""" + model.eval() + L = model.num_blocks + + with torch.no_grad(): + logits, hiddens = model(x_eval, return_hidden=True) + + h_norms = [float(_pool_hidden(h).norm(dim=-1).median().item()) for h in hiddens] + + # BP grads + h0 = model.embed(x_eval) if hasattr(model, 'embed') else model.stem(x_eval) + hs = [h0.clone().requires_grad_(True)] + for block in model.blocks: + hs.append(block(hs[-1])) + h_final = _pool_hidden(hs[-1]) + if hasattr(model, 'out_ln'): + h_final = model.out_ln(h_final) + out_logits = model.out_head(h_final) + loss = F.cross_entropy(out_logits, y_eval) + grads = torch.autograd.grad(loss, hs) + g_norms = [float(_pool_hidden(g).norm(dim=-1).median().item()) for g in grads] + + # Per-layer cosine + with torch.no_grad(): + e_T = out_logits.softmax(-1) + e_T[torch.arange(x_eval.size(0)), y_eval] -= 1 + + bp_cosine = [] + if method_name == 'bp': + bp_cosine = [1.0] * L + elif method_name == 'dfa' and dfa_Bs is not None: + for l in range(L): + a = (e_T @ dfa_Bs[l].T).detach() + g_pool = _pool_hidden(grads[l]).detach() + bp_cosine.append(cosine_similarity_batch(a, g_pool)) + elif method_name == 'fa' and fa_Bs is not None: + hL_pool = _pool_hidden(hiddens[-1].detach()).requires_grad_(True) + logits_fa = _get_head_logits(model, hL_pool) + loss_fa = F.cross_entropy(logits_fa, y_eval) + a_credit = torch.autograd.grad(loss_fa, hL_pool)[0].detach() + for l in range(L - 1, -1, -1): + g_pool = _pool_hidden(grads[l]).detach() + bp_cosine.insert(0, cosine_similarity_batch(a_credit, g_pool)) + a_credit = (a_credit @ fa_Bs[l]).detach() + + model.train() + return { + 'bp_cosine': bp_cosine, + 'bp_grad_norms_per_layer': g_norms, + 'hidden_norms_per_layer': h_norms, + } + + +# ─── Main ──────────────────────────────────────────────────────────────── + +def main(): + p = argparse.ArgumentParser() + p.add_argument('--arch', type=str, default='resmlp', choices=['resmlp', 'resmlp_d512_L2', 'vit', 'resnet']) + p.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100']) + p.add_argument('--methods', nargs='+', default=['bp', 'fa', 'dfa']) + p.add_argument('--seeds', nargs='+', type=int, default=[42, 123, 456]) + p.add_argument('--epochs', type=int, default=100) + p.add_argument('--gpu', type=int, default=0) + p.add_argument('--output_dir', type=str, default='results/reproduce') + p.add_argument('--penalty_lam', type=float, default=0.0) + args = p.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + train_loader, test_loader, num_classes = get_data(args.dataset, 128) + + # Eval buffer + xs, ys = [], [] + for x, y in test_loader: + xs.append(x); ys.append(y) + if sum(xb.size(0) for xb in xs) >= 128: break + x_eval_raw = torch.cat(xs)[:128].to(device) + y_eval = torch.cat(ys)[:128].to(device) + + results = {} + for seed in args.seeds: + print(f"\n{'='*60}\nSeed {seed}\n{'='*60}", flush=True) + results[str(seed)] = {} + + for method in args.methods: + print(f"\n--- {method.upper()} ---", flush=True) + torch.manual_seed(seed); np.random.seed(seed) + if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) + model, is_conv = make_model(args.arch, num_classes, device) + x_eval = x_eval_raw if is_conv else x_eval_raw.view(x_eval_raw.size(0), -1) + + if method == 'bp': + log = train_bp(model, train_loader, test_loader, device, args.epochs, is_conv) + diag = compute_diagnostics(model, x_eval, y_eval, device, 'bp', is_conv=is_conv) + results[str(seed)]['bp'] = {'log': log, 'diagnostics': diag} + elif method == 'dfa': + log, Bs = train_dfa(model, train_loader, test_loader, device, args.epochs, is_conv, num_classes) + diag = compute_diagnostics(model, x_eval, y_eval, device, 'dfa', dfa_Bs=Bs, is_conv=is_conv) + results[str(seed)]['dfa'] = {'log': log, 'diagnostics': diag} + elif method == 'fa': + log, Bs = train_fa(model, train_loader, test_loader, device, args.epochs, is_conv, num_classes) + diag = compute_diagnostics(model, x_eval, y_eval, device, 'fa', fa_Bs=Bs, is_conv=is_conv) + results[str(seed)]['fa'] = {'log': log, 'diagnostics': diag} + + results['config'] = vars(args) + out_path = os.path.join(args.output_dir, f'results_{args.dataset}.json') + with open(out_path, 'w') as f: + json.dump(results, f, indent=2) + print(f"\nSaved: {out_path}", flush=True) + + +if __name__ == '__main__': + main() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..fd142b5 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +torch>=2.0 +torchvision>=0.15 +numpy>=1.24 +scipy>=1.10 -- cgit v1.2.3