summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-05-04 19:50:45 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-05-04 19:50:45 -0500
commitb480d0cdc21f944e4adccf6e81cc939b0450c5e9 (patch)
treef0e6afb5b3d448d1d6c35d9622d22d63073ca9a7
Initial submission code: FA evaluation protocol + reproduction scripts
Reference implementation of the three-diagnostic FA evaluation protocol (scale stability, reference validity, depth utility) from the NeurIPS 2026 E&D track paper. Includes models, metrics, and full reproduction pipeline. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
-rw-r--r--README.md81
-rw-r--r--metrics/__init__.py0
-rw-r--r--metrics/credit_metrics.py156
-rw-r--r--models/__init__.py0
-rw-r--r--models/residual_mlp.py75
-rw-r--r--models/small_resnet.py74
-rw-r--r--models/vit_mini.py109
-rw-r--r--protocol/__init__.py0
-rw-r--r--protocol/example_usage.py106
-rw-r--r--protocol/fa_protocol.py215
-rw-r--r--reproduce/__init__.py0
-rw-r--r--reproduce/frozen_baseline.py86
-rw-r--r--reproduce/penalty_sweep.py176
-rwxr-xr-xreproduce/run_all.sh85
-rw-r--r--reproduce/train_methods.py376
-rw-r--r--requirements.txt4
16 files changed, 1543 insertions, 0 deletions
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..3345986
--- /dev/null
+++ b/README.md
@@ -0,0 +1,81 @@
+# What Accuracy and Gradient Cosine Miss: Evaluating Feedback Alignment via Scale Stability, Reference Validity, and Depth Utility
+
+Code for the NeurIPS 2026 Evaluations & Datasets Track submission.
+
+## Structure
+
+```
+submission_code/
+ protocol/ # Reference evaluation artifact (the main deliverable)
+ fa_protocol.py # Three-diagnostic protocol: drop-in evaluator
+ example_usage.py # Minimal example applying protocol to a trained model
+ models/ # Architectures used in the paper
+ residual_mlp.py # Pre-LayerNorm ResMLP (primary audit architecture)
+ vit_mini.py # ViT-Mini (transformer with terminal LN)
+ small_resnet.py # SmallResNet (BatchNorm, no LN)
+ metrics/
+ credit_metrics.py # Cosine similarity, nudging test, perturbation correlation
+ reproduce/ # Scripts to reproduce all paper results
+ train_methods.py # Train BP/FA/DFA on any architecture + compute diagnostics
+ frozen_baseline.py # Frozen-blocks and shallow baselines
+ penalty_sweep.py # Penalty intervention (lambda sweep + fresh-B null)
+ run_all.sh # One-command full reproduction
+ requirements.txt
+```
+
+## Quick start: applying the protocol to your own FA method
+
+```python
+from protocol.fa_protocol import FAProtocol
+
+# Your trained model and a test batch
+model = ... # nn.Module with .blocks, .out_head, .out_ln (or without)
+x_eval, y_eval = ... # (N, d_input), (N,)
+
+protocol = FAProtocol(model, x_eval, y_eval)
+report = protocol.run()
+
+print(report['verdict']) # 'PASS' or 'FAIL(D1+D2)' etc.
+print(report['diagnostics']) # {'D1_scale_growth': ..., 'D2_ref_validity': ..., 'D3_depth_utility': ...}
+```
+
+## Reproducing paper results
+
+```bash
+# Full reproduction (all figures, tables, appendices)
+cd reproduce/
+bash run_all.sh --gpu 0
+
+# Or individual experiments:
+python train_methods.py --arch resmlp --methods bp fa dfa --seeds 42 123 456 --epochs 100
+python frozen_baseline.py --arch resmlp --seeds 42 123 456 --epochs 100
+python penalty_sweep.py --lambdas 0 1e-4 1e-2 --seeds 42 123 456 --epochs 30
+```
+
+## Requirements
+
+- Python >= 3.10
+- PyTorch >= 2.0
+- torchvision
+- numpy
+- scipy (for perturbation correlation)
+
+## Protocol specification
+
+The protocol consists of three diagnostic checks applied to a trained model:
+
+| Diagnostic | What it measures | Threshold | Fires when |
+|---|---|---|---|
+| D1: Scale stability | max per-block residual growth | > 50x | Residual stream explodes |
+| D2: Reference validity | deepest-layer BP gradient norm | < 10 * eps | BP reference collapses below cosine clamp |
+| D3: Depth utility | test accuracy vs frozen-blocks baseline | < 2 pp above | Trained blocks do not outperform random |
+
+**Verdict logic**: A run fails the protocol if either Mode 1 (D1 and D2 both flag) or Depth Utility (D3 flags) is triggered. Passing all three diagnostics does not certify the method as effective; it rules out the specific class of silent failures the audit revealed.
+
+## Training inventory
+
+The paper reports 125 training runs across 5+ architectures, 10 training methods, and 17 experimental settings. Total estimated GPU time: ~12 hours on a single A100/A6000.
+
+## License
+
+MIT
diff --git a/metrics/__init__.py b/metrics/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/metrics/__init__.py
diff --git a/metrics/credit_metrics.py b/metrics/credit_metrics.py
new file mode 100644
index 0000000..516dca2
--- /dev/null
+++ b/metrics/credit_metrics.py
@@ -0,0 +1,156 @@
+"""
+Credit assignment diagnostic metrics:
+1. Exact costate cosine (for toy LQ)
+2. Local perturbation correlation rho_l
+3. Nudging test Delta_l^nudge
+4. Offline BP cosine Gamma_l
+5. Bridge residual R_l
+6. Feature drift M_l
+"""
+import torch
+import torch.nn.functional as F
+import numpy as np
+from scipy.stats import pearsonr
+
+
+def cosine_similarity_batch(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
+ """Compute cosine similarity between a and b along last dim, averaged over batch."""
+ a_flat = a.reshape(a.shape[0], -1)
+ b_flat = b.reshape(b.shape[0], -1)
+ cos = F.cosine_similarity(a_flat, b_flat, dim=-1)
+ return cos.mean().item()
+
+
+def perturbation_correlation(h_l, a_l, forward_fn, epsilon=1e-3, M=32):
+ """
+ Compute local perturbation correlation rho_l.
+
+ Args:
+ h_l: (batch, d) hidden state at layer l
+ a_l: (batch, d) credit signal at layer l
+ forward_fn: callable that takes h_l -> scalar loss (averaged over batch dims handled inside)
+ epsilon: perturbation magnitude
+ M: number of random directions
+
+ Returns:
+ rho: Pearson correlation between predicted and true loss changes
+ """
+ batch_size, d = h_l.shape
+ device = h_l.device
+
+ pred_list = []
+ true_list = []
+
+ base_loss = forward_fn(h_l) # (batch,) or scalar
+
+ for _ in range(M):
+ v = torch.randn(batch_size, d, device=device)
+ v = v / (v.norm(dim=-1, keepdim=True) + 1e-8)
+
+ # Predicted change: <a_l, epsilon * v>
+ delta_pred = (a_l * (epsilon * v)).sum(dim=-1) # (batch,)
+
+ # True change: forward from perturbed h
+ perturbed_loss = forward_fn(h_l + epsilon * v) # (batch,)
+ delta_true = perturbed_loss - base_loss # (batch,)
+
+ pred_list.append(delta_pred.detach().cpu().numpy())
+ true_list.append(delta_true.detach().cpu().numpy())
+
+ pred_arr = np.concatenate(pred_list)
+ true_arr = np.concatenate(true_list)
+
+ if np.std(pred_arr) < 1e-12 or np.std(true_arr) < 1e-12:
+ return 0.0
+
+ rho, _ = pearsonr(pred_arr, true_arr)
+ return float(rho)
+
+
+def nudging_test(h_l, a_l, forward_fn, eta=0.01):
+ """
+ Nudging test: check if moving h_l in -a_l direction decreases loss.
+
+ Args:
+ h_l: (batch, d) hidden state
+ a_l: (batch, d) credit signal
+ forward_fn: callable h -> loss per sample (batch,)
+ eta: step size
+
+ Returns:
+ mean delta_nudge (negative is good)
+ """
+ rms_a = (a_l ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_normed = a_l / rms_a
+ h_nudged = h_l - eta * a_normed
+
+ base_loss = forward_fn(h_l)
+ nudged_loss = forward_fn(h_nudged)
+ delta = (nudged_loss - base_loss).mean().item()
+ return delta
+
+
+def offline_bp_cosine(a_l, bp_grad_l):
+ """
+ Compute offline BP cosine similarity.
+ a_l: (batch, d) credit signal
+ bp_grad_l: (batch, d) true BP gradient at layer l
+ """
+ return cosine_similarity_batch(a_l, bp_grad_l)
+
+
+def bridge_residual(V_phi, V_bar_phi, h_l, t_l, s, h_l_next_noisy_list, t_l_next, lam=0.1):
+ """
+ Compute bridge residual R_l.
+
+ Args:
+ V_phi: value network
+ V_bar_phi: EMA target value network
+ h_l: (batch, d)
+ t_l: (batch,)
+ s: (batch, s_dim)
+ h_l_next_noisy_list: list of K tensors (batch, d), noisy next states
+ t_l_next: (batch,)
+ lam: temperature
+
+ Returns:
+ mean absolute bridge residual
+ """
+ with torch.no_grad():
+ V_current = V_phi(h_l, t_l, s) # (batch,)
+
+ # Compute soft-min target
+ K = len(h_l_next_noisy_list)
+ log_terms = []
+ for h_next in h_l_next_noisy_list:
+ V_next = V_bar_phi(h_next, t_l_next, s) # (batch,)
+ log_terms.append(-V_next / lam)
+
+ log_terms = torch.stack(log_terms, dim=-1) # (batch, K)
+ V_target = -lam * torch.logsumexp(log_terms, dim=-1) + lam * np.log(K)
+
+ residual = (V_current - V_target).abs().mean().item()
+ return residual
+
+
+def feature_drift(model_init_params, model_final_params):
+ """
+ Compute per-layer feature drift M_l = ||W_final - W_init||_F / ||W_init||_F.
+
+ Args:
+ model_init_params: dict of {name: tensor} initial parameters
+ model_final_params: dict of {name: tensor} final parameters
+
+ Returns:
+ dict of {name: drift_ratio}
+ """
+ drifts = {}
+ for name in model_init_params:
+ if name in model_final_params:
+ w_init = model_init_params[name]
+ w_final = model_final_params[name]
+ init_norm = w_init.norm().item()
+ if init_norm > 1e-8:
+ drift = (w_final - w_init).norm().item() / init_norm
+ drifts[name] = drift
+ return drifts
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/models/__init__.py
diff --git a/models/residual_mlp.py b/models/residual_mlp.py
new file mode 100644
index 0000000..6827057
--- /dev/null
+++ b/models/residual_mlp.py
@@ -0,0 +1,75 @@
+"""
+Deep Residual MLP for classification.
+Architecture: Input -> Linear embedding -> L residual blocks -> LayerNorm -> Linear output head.
+Each block: h_{l+1} = h_l + W2 * GELU(W1 * LN(h_l))
+"""
+import torch
+import torch.nn as nn
+
+
+class ResidualBlock(nn.Module):
+ """Single pre-LayerNorm residual MLP block."""
+
+ def __init__(self, d_hidden: int, w2_std: float = 0.01):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.w1 = nn.Linear(d_hidden, d_hidden)
+ self.w2 = nn.Linear(d_hidden, d_hidden)
+ # Small init for residual branch (or larger if used as a non-residual stack)
+ nn.init.normal_(self.w2.weight, std=w2_std)
+ nn.init.zeros_(self.w2.bias)
+
+ def forward(self, h: torch.Tensor) -> torch.Tensor:
+ """Returns the residual F_l(h), NOT h + F_l(h)."""
+ z = self.ln(h)
+ z = self.w1(z)
+ z = torch.nn.functional.gelu(z)
+ z = self.w2(z)
+ return z
+
+
+class ResidualMLP(nn.Module):
+ """Deep residual MLP: embed -> L blocks -> LN -> output head."""
+
+ def __init__(self, input_dim: int, d_hidden: int, num_classes: int, num_blocks: int,
+ residual_add: bool = True, w2_std: float = 0.01):
+ super().__init__()
+ self.embed = nn.Linear(input_dim, d_hidden)
+ self.blocks = nn.ModuleList([ResidualBlock(d_hidden, w2_std=w2_std) for _ in range(num_blocks)])
+ self.out_ln = nn.LayerNorm(d_hidden)
+ self.out_head = nn.Linear(d_hidden, num_classes)
+ self.num_blocks = num_blocks
+ self.d_hidden = d_hidden
+ self.residual_add = residual_add
+
+ def forward(self, x: torch.Tensor, return_hidden: bool = False):
+ """
+ Args:
+ x: (batch, input_dim)
+ return_hidden: if True, also return list of hidden states [h_0, ..., h_L]
+ Returns:
+ logits: (batch, num_classes)
+ hiddens: list of (batch, d_hidden) if return_hidden
+ """
+ h = self.embed(x)
+ hiddens = [h] if return_hidden else None
+
+ for block in self.blocks:
+ f = block(h)
+ h = h + f if self.residual_add else f
+ if return_hidden:
+ hiddens.append(h)
+
+ logits = self.out_head(self.out_ln(h))
+
+ if return_hidden:
+ return logits, hiddens
+ return logits
+
+ def forward_from_layer(self, h: torch.Tensor, start_layer: int):
+ """Run forward from a given layer index to output. Used for perturbation tests."""
+ for i in range(start_layer, self.num_blocks):
+ f = self.blocks[i](h)
+ h = h + f if self.residual_add else f
+ logits = self.out_head(self.out_ln(h))
+ return logits
diff --git a/models/small_resnet.py b/models/small_resnet.py
new file mode 100644
index 0000000..10b122e
--- /dev/null
+++ b/models/small_resnet.py
@@ -0,0 +1,74 @@
+"""
+Small CIFAR-10 ResNet for the FA-evaluation paper. Standard BatchNorm-based
+post-activation residual blocks (no LayerNorm). 4 residual blocks at width 64.
+
+Supports `num_blocks=0` (shallow baseline: just embed → bn → head) and frozen
+blocks via `requires_grad=False` on `.blocks` parameters.
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class BasicBlock(nn.Module):
+ """Standard ResNet BasicBlock with BatchNorm. Pre-activation NOT used; this is
+ the post-activation form: relu(BN(W2 * relu(BN(W1 x)))) + x. d_hidden in == d_hidden out.
+ No stride / downsampling — all blocks operate at the same spatial resolution
+ after the initial stem. This keeps the architecture simple and matches the
+ "4 residual blocks at fixed width" structure of our ResMLP and ViT-Mini comparisons.
+ """
+ def __init__(self, d_hidden):
+ super().__init__()
+ self.conv1 = nn.Conv2d(d_hidden, d_hidden, kernel_size=3, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(d_hidden)
+ self.conv2 = nn.Conv2d(d_hidden, d_hidden, kernel_size=3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(d_hidden)
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.bn2(self.conv2(out))
+ out = F.relu(x + out)
+ return out
+
+
+class SmallResNet(nn.Module):
+ """Small CIFAR-10 ResNet:
+ - 3x3 conv stem (3 → d_hidden) + BN + ReLU
+ - num_blocks BasicBlocks at the same width and resolution
+ - global average pool
+ - linear classification head
+
+ `num_blocks=0` gives the shallow baseline (just stem → pool → head).
+ """
+ def __init__(self, d_hidden=64, num_classes=10, num_blocks=4):
+ super().__init__()
+ self.stem_conv = nn.Conv2d(3, d_hidden, kernel_size=3, padding=1, bias=False)
+ self.stem_bn = nn.BatchNorm2d(d_hidden)
+ self.blocks = nn.ModuleList([BasicBlock(d_hidden) for _ in range(num_blocks)])
+ self.out_head = nn.Linear(d_hidden, num_classes)
+ self.num_blocks = num_blocks
+ self.d_hidden = d_hidden
+
+ def stem(self, x):
+ # x: (B, 3, 32, 32)
+ if x.dim() == 2:
+ x = x.view(x.size(0), 3, 32, 32)
+ h = F.relu(self.stem_bn(self.stem_conv(x)))
+ return h
+
+ def forward(self, x, return_hidden=False):
+ h = self.stem(x) # (B, d, 32, 32)
+ hiddens = [h] if return_hidden else None
+ for block in self.blocks:
+ h = block(h)
+ if return_hidden:
+ hiddens.append(h)
+ h_pool = F.adaptive_avg_pool2d(h, 1).flatten(1) # (B, d)
+ logits = self.out_head(h_pool)
+ if return_hidden:
+ return logits, hiddens
+ return logits
+
+ # Convenience alias for snapshot script compatibility (treats stem as the embed)
+ def embed(self, x):
+ return self.stem(x)
diff --git a/models/vit_mini.py b/models/vit_mini.py
new file mode 100644
index 0000000..af6ba60
--- /dev/null
+++ b/models/vit_mini.py
@@ -0,0 +1,109 @@
+"""
+Minimal Vision Transformer for CIFAR-10. Pre-LN with terminal LayerNorm before
+the classification head — the architecture P4 should target.
+
+Designed to be compatible with the snapshot evolution / DFA training framework.
+Each TransformerBlock is a "layer" for FA-style local credit purposes.
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class TransformerBlock(nn.Module):
+ """Pre-LN transformer block: x = x + attn(LN(x)); x = x + mlp(LN(x))."""
+ def __init__(self, d_model: int, n_heads: int, mlp_ratio: float = 4.0, dropout: float = 0.0):
+ super().__init__()
+ self.ln1 = nn.LayerNorm(d_model)
+ self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
+ self.ln2 = nn.LayerNorm(d_model)
+ mlp_hidden = int(d_model * mlp_ratio)
+ self.mlp = nn.Sequential(
+ nn.Linear(d_model, mlp_hidden),
+ nn.GELU(),
+ nn.Dropout(dropout),
+ nn.Linear(mlp_hidden, d_model),
+ nn.Dropout(dropout),
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # Self-attention sublayer
+ x_norm = self.ln1(x)
+ attn_out, _ = self.attn(x_norm, x_norm, x_norm, need_weights=False)
+ x = x + attn_out
+ # MLP sublayer
+ x = x + self.mlp(self.ln2(x))
+ return x
+
+
+class ViTMini(nn.Module):
+ """Minimal Vision Transformer for CIFAR-10.
+ Patch size 4x4 → 64 patches per image. Plus a learned cls token.
+ Pre-LN with terminal LayerNorm before the head.
+ """
+ def __init__(
+ self,
+ image_size: int = 32,
+ patch_size: int = 4,
+ in_channels: int = 3,
+ num_classes: int = 10,
+ d_model: int = 128,
+ n_heads: int = 4,
+ num_blocks: int = 4,
+ mlp_ratio: float = 4.0,
+ dropout: float = 0.0,
+ ):
+ super().__init__()
+ assert image_size % patch_size == 0
+ n_patches = (image_size // patch_size) ** 2
+ self.n_tokens = n_patches + 1 # +1 for cls token
+
+ # Patch embedding via Conv2d (equivalent to flatten + linear)
+ self.patch_embed = nn.Conv2d(in_channels, d_model, kernel_size=patch_size, stride=patch_size)
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.n_tokens, d_model))
+ nn.init.trunc_normal_(self.cls_token, std=0.02)
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
+
+ self.blocks = nn.ModuleList([
+ TransformerBlock(d_model, n_heads, mlp_ratio, dropout) for _ in range(num_blocks)
+ ])
+ self.out_ln = nn.LayerNorm(d_model) # terminal LN — the P4 trigger
+ self.out_head = nn.Linear(d_model, num_classes)
+
+ self.num_blocks = num_blocks
+ self.d_model = d_model
+ self.d_hidden = d_model # alias for compatibility with snapshot script
+
+ def embed(self, x: torch.Tensor) -> torch.Tensor:
+ """Take a flat-CIFAR input (B, 3072) or image (B, 3, 32, 32) → token sequence (B, 65, d_model)."""
+ if x.dim() == 2: # flat input
+ x = x.view(x.size(0), 3, 32, 32)
+ # x: (B, 3, 32, 32)
+ x = self.patch_embed(x) # (B, d, 8, 8)
+ x = x.flatten(2).transpose(1, 2) # (B, 64, d)
+ cls = self.cls_token.expand(x.size(0), -1, -1)
+ x = torch.cat([cls, x], dim=1) # (B, 65, d)
+ x = x + self.pos_embed
+ return x
+
+ def forward(self, x: torch.Tensor, return_hidden: bool = False):
+ h = self.embed(x) # (B, 65, d)
+ hiddens = [h] if return_hidden else None
+ for block in self.blocks:
+ h = block(h)
+ if return_hidden:
+ hiddens.append(h)
+ # Take cls token, normalize, classify
+ h_cls = self.out_ln(h[:, 0]) # (B, d)
+ logits = self.out_head(h_cls)
+ if return_hidden:
+ return logits, hiddens
+ return logits
+
+ def forward_from_layer(self, h: torch.Tensor, start_layer: int):
+ """Run forward from a given block index. h has shape (B, n_tokens, d)."""
+ for i in range(start_layer, self.num_blocks):
+ h = self.blocks[i](h)
+ h_cls = self.out_ln(h[:, 0])
+ return self.out_head(h_cls)
diff --git a/protocol/__init__.py b/protocol/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/protocol/__init__.py
diff --git a/protocol/example_usage.py b/protocol/example_usage.py
new file mode 100644
index 0000000..2b2c65e
--- /dev/null
+++ b/protocol/example_usage.py
@@ -0,0 +1,106 @@
+"""
+Minimal example: apply the FA evaluation protocol to a DFA-trained ResMLP.
+
+This script trains a model with DFA, then runs the three-diagnostic protocol.
+Expected output: FAIL(D1+D2+D3) — DFA on terminal-LN ResMLP triggers all diagnostics.
+"""
+import sys, os
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+import torchvision
+import torchvision.transforms as transforms
+import numpy as np
+
+from models.residual_mlp import ResidualMLP
+from protocol.fa_protocol import FAProtocol
+
+
+def get_cifar10(batch_size=128):
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tv_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train)
+ te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv)
+ return (torch.utils.data.DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2),
+ torch.utils.data.DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2))
+
+
+def train_dfa(model, train_loader, device, epochs=30):
+ """Minimal DFA training (canonical: no clipping, mean reduction)."""
+ d = model.d_hidden
+ L = model.num_blocks
+ C = 10
+ Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)]
+ block_opts = [optim.AdamW(b.parameters(), lr=1e-3, weight_decay=0.01) for b in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=1e-3, weight_decay=0.01)
+ head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=1e-3, weight_decay=0.01)
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1
+ hL = hiddens[-1].detach()
+ head_opt.zero_grad()
+ F.cross_entropy(model.out_head(model.out_ln(hL)), y).backward()
+ head_opt.step()
+ for l in range(L):
+ a = (e_T @ Bs[l].T).detach()
+ rms = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ f_l = model.blocks[l](hiddens[l].detach())
+ loss = (f_l * (a / rms)).sum(-1).mean()
+ block_opts[l].zero_grad(); loss.backward(); block_opts[l].step()
+ a0 = (e_T @ Bs[0].T).detach()
+ rms0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ h0 = model.embed(x)
+ embed_opt.zero_grad(); (h0 * (a0 / rms0)).sum(-1).mean().backward(); embed_opt.step()
+ if ep % 10 == 0:
+ print(f" DFA ep {ep}/{epochs}", flush=True)
+
+
+def main():
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ seed = 42
+ torch.manual_seed(seed); np.random.seed(seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(seed)
+
+ print("Loading CIFAR-10...")
+ train_loader, test_loader = get_cifar10()
+
+ # Prepare eval buffer
+ xs, ys = [], []
+ for x, y in test_loader:
+ xs.append(x.view(x.size(0), -1)); ys.append(y)
+ if sum(xb.size(0) for xb in xs) >= 128:
+ break
+ x_eval = torch.cat(xs)[:128].to(device)
+ y_eval = torch.cat(ys)[:128].to(device)
+
+ # Train with DFA
+ print("Training DFA (30 epochs)...")
+ model = ResidualMLP(3072, 256, 10, 4).to(device)
+ train_dfa(model, train_loader, device, epochs=30)
+
+ # Run protocol
+ print("\nRunning protocol...")
+ protocol = FAProtocol(model, x_eval, y_eval)
+ report = protocol.run(frozen_baseline_acc=0.349)
+ print(protocol.summary(report))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/protocol/fa_protocol.py b/protocol/fa_protocol.py
new file mode 100644
index 0000000..8d12939
--- /dev/null
+++ b/protocol/fa_protocol.py
@@ -0,0 +1,215 @@
+"""
+Reference implementation of the three-diagnostic FA evaluation protocol.
+
+Usage:
+ from protocol.fa_protocol import FAProtocol
+
+ protocol = FAProtocol(model, x_eval, y_eval)
+ report = protocol.run(frozen_baseline_acc=0.349)
+ print(report['verdict'])
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from typing import Optional
+
+
+class FAProtocol:
+ """Three-diagnostic evaluation protocol for feedback alignment methods.
+
+ Diagnostics:
+ D1 (Scale stability): max per-block residual growth rho = max_l ||h_{l+1}|| / ||h_l||.
+ Flags if rho > threshold (default 50).
+ D2 (Reference validity): BP gradient norm at the deepest hidden state.
+ Flags if ||g_L|| < 10 * eps, where eps is the cosine clamp floor.
+ D3 (Depth utility): test accuracy vs frozen-blocks baseline.
+ Flags if trained acc < frozen_acc + margin (default 2 pp).
+
+ The protocol requires:
+ - A trained model with .blocks (nn.ModuleList) and forward(x, return_hidden=True)
+ - A test batch (x_eval, y_eval)
+ - A frozen-blocks baseline accuracy (must be computed separately)
+ """
+
+ def __init__(
+ self,
+ model: nn.Module,
+ x_eval: torch.Tensor,
+ y_eval: torch.Tensor,
+ d1_threshold: float = 50.0,
+ d2_eps: float = 1e-8,
+ d2_factor: float = 10.0,
+ d3_margin_pp: float = 2.0,
+ ):
+ self.model = model
+ self.x_eval = x_eval
+ self.y_eval = y_eval
+ self.d1_threshold = d1_threshold
+ self.d2_floor = d2_factor * d2_eps
+ self.d3_margin = d3_margin_pp / 100.0
+
+ def _compute_hidden_norms(self, hiddens):
+ """Compute median per-sample L2 norm at each hidden layer."""
+ norms = []
+ for h in hiddens:
+ if h.dim() == 4: # conv: (B, C, H, W) -> pool to (B, C)
+ h_flat = F.adaptive_avg_pool2d(h, 1).flatten(1)
+ elif h.dim() == 3: # transformer: (B, T, D) -> cls token
+ h_flat = h[:, 0]
+ else:
+ h_flat = h
+ norms.append(float(h_flat.norm(dim=-1).median().item()))
+ return norms
+
+ def _compute_bp_grad_norms(self, hiddens):
+ """Compute BP gradient norms at each hidden layer via manual forward."""
+ model = self.model
+ L = len(hiddens) - 1 # number of blocks
+
+ # Rebuild forward from hidden states with grad tracking
+ hs = [hiddens[0].detach().clone().requires_grad_(True)]
+ for i, block in enumerate(model.blocks):
+ if hasattr(block, 'forward'):
+ h_next = block(hs[-1])
+ # Check if block includes residual (output same shape, skip connection)
+ if h_next.shape == hs[-1].shape and not self._block_has_internal_skip(block):
+ h_next = hs[-1] + h_next
+ hs.append(h_next)
+
+ # Forward through head
+ h_final = hs[-1]
+ if h_final.dim() == 4: # conv
+ h_final = F.adaptive_avg_pool2d(h_final, 1).flatten(1)
+ elif h_final.dim() == 3: # transformer cls token
+ h_final = h_final[:, 0]
+ if hasattr(model, 'out_ln'):
+ h_final = model.out_ln(h_final)
+ logits = model.out_head(h_final)
+ loss = F.cross_entropy(logits, self.y_eval)
+ grads = torch.autograd.grad(loss, hs, allow_unused=True)
+
+ norms = []
+ for g in grads:
+ if g is None:
+ norms.append(0.0)
+ continue
+ if g.dim() == 4:
+ g_flat = F.adaptive_avg_pool2d(g, 1).flatten(1)
+ elif g.dim() == 3:
+ g_flat = g[:, 0]
+ else:
+ g_flat = g
+ norms.append(float(g_flat.norm(dim=-1).median().item()))
+ return norms
+
+ @staticmethod
+ def _block_has_internal_skip(block):
+ """Heuristic: check if the block's forward already includes a residual skip."""
+ src = type(block).forward.__qualname__
+ # Blocks that compute x + f(x) internally (e.g., transformer blocks)
+ return False # conservative default; override if needed
+
+ def run(self, frozen_baseline_acc: Optional[float] = None, test_acc: Optional[float] = None):
+ """Run all three diagnostics.
+
+ Args:
+ frozen_baseline_acc: accuracy of the frozen-blocks baseline (required for D3).
+ test_acc: test accuracy of the trained model. If None, computed from x_eval/y_eval.
+
+ Returns:
+ dict with 'diagnostics', 'verdict', and raw values.
+ """
+ self.model.eval()
+
+ # Forward pass to get hidden states
+ with torch.no_grad():
+ logits, hiddens = self.model(self.x_eval, return_hidden=True)
+
+ if test_acc is None:
+ test_acc = float((logits.argmax(-1) == self.y_eval).float().mean().item())
+
+ # D1: Scale stability
+ h_norms = self._compute_hidden_norms(hiddens)
+ growth_ratios = [h_norms[i+1] / max(h_norms[i], 1e-12)
+ for i in range(len(h_norms) - 1)]
+ max_growth = max(growth_ratios) if growth_ratios else 1.0
+ d1_fires = max_growth > self.d1_threshold
+
+ # D2: Reference validity
+ bp_grad_norms = self._compute_bp_grad_norms(hiddens)
+ g_L = bp_grad_norms[-1] if bp_grad_norms else 0.0
+ d2_fires = g_L < self.d2_floor
+
+ # D3: Depth utility
+ if frozen_baseline_acc is not None:
+ margin = test_acc - frozen_baseline_acc
+ d3_fires = margin < self.d3_margin
+ else:
+ margin = None
+ d3_fires = None
+
+ # Verdict
+ mode1 = d1_fires and d2_fires
+ flags = []
+ if d1_fires:
+ flags.append('D1')
+ if d2_fires:
+ flags.append('D2')
+ if d3_fires:
+ flags.append('D3')
+
+ if not flags:
+ verdict = 'PASS'
+ else:
+ verdict = 'FAIL(' + '+'.join(flags) + ')'
+
+ return {
+ 'verdict': verdict,
+ 'test_acc': test_acc,
+ 'diagnostics': {
+ 'D1_scale_growth': {
+ 'max_growth': max_growth,
+ 'per_block_growth': growth_ratios,
+ 'hidden_norms': h_norms,
+ 'threshold': self.d1_threshold,
+ 'fires': d1_fires,
+ },
+ 'D2_ref_validity': {
+ 'g_L': g_L,
+ 'bp_grad_norms': bp_grad_norms,
+ 'floor': self.d2_floor,
+ 'fires': d2_fires,
+ },
+ 'D3_depth_utility': {
+ 'test_acc': test_acc,
+ 'frozen_baseline_acc': frozen_baseline_acc,
+ 'margin': margin,
+ 'margin_threshold': self.d3_margin,
+ 'fires': d3_fires,
+ },
+ },
+ }
+
+ def summary(self, report: dict) -> str:
+ """Human-readable summary of a protocol report."""
+ d = report['diagnostics']
+ lines = [
+ f"Verdict: {report['verdict']}",
+ f"Test accuracy: {report['test_acc']:.4f}",
+ f"D1 Scale stability: max growth = {d['D1_scale_growth']['max_growth']:.1f}x "
+ f"(threshold {d['D1_scale_growth']['threshold']}x) -> "
+ f"{'FIRE' if d['D1_scale_growth']['fires'] else 'pass'}",
+ f"D2 Reference validity: ||g_L|| = {d['D2_ref_validity']['g_L']:.2e} "
+ f"(floor {d['D2_ref_validity']['floor']:.0e}) -> "
+ f"{'FIRE' if d['D2_ref_validity']['fires'] else 'pass'}",
+ ]
+ if d['D3_depth_utility']['fires'] is not None:
+ lines.append(
+ f"D3 Depth utility: margin = {d['D3_depth_utility']['margin']*100:+.1f} pp "
+ f"(threshold {d['D3_depth_utility']['margin_threshold']*100:.0f} pp) -> "
+ f"{'FIRE' if d['D3_depth_utility']['fires'] else 'pass'}"
+ )
+ else:
+ lines.append("D3 Depth utility: not evaluated (no frozen baseline provided)")
+ return '\n'.join(lines)
diff --git a/reproduce/__init__.py b/reproduce/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/reproduce/__init__.py
diff --git a/reproduce/frozen_baseline.py b/reproduce/frozen_baseline.py
new file mode 100644
index 0000000..08368a2
--- /dev/null
+++ b/reproduce/frozen_baseline.py
@@ -0,0 +1,86 @@
+"""
+Frozen-blocks baseline: train only embed/head with blocks frozen at random init.
+
+Usage:
+ python reproduce/frozen_baseline.py --arch resmlp --seeds 42 123 456 --epochs 100
+"""
+import os, sys, json, argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+import torchvision, torchvision.transforms as transforms
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from reproduce.train_methods import get_data, evaluate, make_model
+
+
+def freeze_blocks(model):
+ for p in model.blocks.parameters():
+ p.requires_grad_(False)
+ for m in model.blocks.modules():
+ if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
+ m.eval()
+
+
+def train_frozen(model, train_loader, test_loader, device, epochs, is_conv):
+ opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3, weight_decay=0.01)
+ sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
+ for ep in range(1, epochs + 1):
+ model.train()
+ for m in model.blocks.modules():
+ if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
+ m.eval()
+ for x, y in train_loader:
+ x, y = x.to(device), y.to(device)
+ if not is_conv: x = x.view(x.size(0), -1)
+ loss = F.cross_entropy(model(x), y)
+ opt.zero_grad(); loss.backward(); opt.step()
+ sch.step()
+ if ep % 10 == 0 or ep == epochs:
+ acc = evaluate(model, test_loader, device, is_conv)
+ print(f" [Frozen] ep {ep}: acc={acc:.4f}", flush=True)
+ return evaluate(model, test_loader, device, is_conv)
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--arch', type=str, default='resmlp', choices=['resmlp', 'resmlp_d512_L2', 'vit', 'resnet'])
+ p.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100'])
+ p.add_argument('--seeds', nargs='+', type=int, default=[42, 123, 456])
+ p.add_argument('--epochs', type=int, default=100)
+ p.add_argument('--gpu', type=int, default=0)
+ p.add_argument('--output_dir', type=str, default='results/frozen_baselines')
+ args = p.parse_args()
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ train_loader, test_loader, num_classes = get_data(args.dataset, 128)
+
+ results = {}
+ for seed in args.seeds:
+ print(f"\n--- Frozen baseline seed={seed} ---", flush=True)
+ torch.manual_seed(seed); np.random.seed(seed)
+ if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
+ model, is_conv = make_model(args.arch, num_classes, device)
+ freeze_blocks(model)
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ total = sum(p.numel() for p in model.parameters())
+ print(f" {trainable}/{total} trainable params", flush=True)
+ acc = train_frozen(model, train_loader, test_loader, device, args.epochs, is_conv)
+ results[f's{seed}'] = acc
+ print(f" FINAL: {acc:.4f}", flush=True)
+
+ results['config'] = vars(args)
+ results['mean'] = float(np.mean([results[f's{s}'] for s in args.seeds]))
+ results['std'] = float(np.std([results[f's{s}'] for s in args.seeds], ddof=1))
+ out_path = os.path.join(args.output_dir, f'frozen_{args.arch}_{args.dataset}.json')
+ with open(out_path, 'w') as f:
+ json.dump(results, f, indent=2)
+ print(f"\nSaved: {out_path}")
+ print(f"Frozen baseline: {results['mean']:.4f} +/- {results['std']:.4f}")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/reproduce/penalty_sweep.py b/reproduce/penalty_sweep.py
new file mode 100644
index 0000000..b6b913d
--- /dev/null
+++ b/reproduce/penalty_sweep.py
@@ -0,0 +1,176 @@
+"""
+Penalty intervention sweep: DFA + lambda x {0, 1e-4, 1e-2} with per-epoch trajectory.
+Includes fresh-B null calibration on the lambda=1e-2 checkpoint.
+
+Usage:
+ python reproduce/penalty_sweep.py --seeds 42 123 456 --gpu 0
+"""
+import os, sys, json, argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from reproduce.train_methods import get_data, evaluate, make_model, _pool_hidden, _get_head_logits
+from metrics.credit_metrics import cosine_similarity_batch
+
+
+def train_dfa_trajectory(seed, train_loader, test_loader, device, epochs, lam, num_classes=10):
+ """DFA with per-epoch ||h_L||, ||g_L|| logging."""
+ torch.manual_seed(seed); np.random.seed(seed)
+ if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
+ from models.residual_mlp import ResidualMLP
+ model = ResidualMLP(3072, 256, num_classes, 4).to(device)
+ d, L, C = 256, 4, num_classes
+ Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)]
+ block_opts = [optim.AdamW(b.parameters(), lr=1e-3, weight_decay=0.01) for b in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=1e-3, weight_decay=0.01)
+ head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=1e-3, weight_decay=0.01)
+ all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \
+ [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]
+
+ # Eval buffer
+ xs, ys = [], []
+ for x, y in test_loader:
+ xs.append(x.view(x.size(0), -1)); ys.append(y)
+ if sum(xb.size(0) for xb in xs) >= 128: break
+ x_eval = torch.cat(xs)[:128].to(device)
+ y_eval = torch.cat(ys)[:128].to(device)
+
+ def diagnose():
+ model.eval()
+ with torch.no_grad():
+ _, hi = model(x_eval, return_hidden=True)
+ h_L = hi[-1].norm(dim=-1).median().item()
+ h0 = model.embed(x_eval)
+ hs = [h0.clone().requires_grad_(True)]
+ for b in model.blocks:
+ hs.append(hs[-1] + b(hs[-1]))
+ logits = model.out_head(model.out_ln(hs[-1]))
+ loss = F.cross_entropy(logits, y_eval)
+ grads = torch.autograd.grad(loss, hs)
+ g_L = grads[-1].norm(dim=-1).median().item()
+ acc = (logits.argmax(-1) == y_eval).float().mean().item()
+ model.train()
+ return h_L, g_L, acc
+
+ log = []
+ h, g, a = diagnose()
+ log.append({'epoch': 0, 'h_L': h, 'g_L': g, 'acc': a})
+
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1
+ hL = hiddens[-1].detach()
+ head_opt.zero_grad()
+ F.cross_entropy(model.out_head(model.out_ln(hL)), y).backward()
+ head_opt.step()
+ for l in range(L):
+ a_dfa = (e_T @ Bs[l].T).detach()
+ rms = (a_dfa ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ f_l = model.blocks[l](hiddens[l].detach())
+ local_loss = (f_l * (a_dfa / rms)).sum(-1).mean()
+ if lam > 0:
+ local_loss = local_loss + lam * (f_l ** 2).sum(-1).mean()
+ block_opts[l].zero_grad(); local_loss.backward(); block_opts[l].step()
+ a0 = (e_T @ Bs[0].T).detach()
+ rms0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ h0 = model.embed(x)
+ embed_opt.zero_grad(); (h0 * (a0 / rms0)).sum(-1).mean().backward(); embed_opt.step()
+ for s in all_sch: s.step()
+ h, g, a = diagnose()
+ log.append({'epoch': ep, 'h_L': h, 'g_L': g, 'acc': a})
+ if ep % 10 == 0 or ep == epochs:
+ print(f" [lam={lam}] s={seed} ep {ep}: ||h_L||={h:.3e} ||g_L||={g:.3e} acc={a:.4f}", flush=True)
+
+ return log, model, Bs
+
+
+def fresh_b_null(model, x_eval, y_eval, training_Bs, n_draws=20):
+ """Fresh-B null calibration on a trained checkpoint."""
+ model.eval()
+ d, L, C = 256, 4, len(training_Bs[0][0]) if training_Bs[0].dim() == 2 else 10
+ device = x_eval.device
+
+ def deep_cos_with_Bs(Bs):
+ h0 = model.embed(x_eval)
+ hs = [h0.clone().requires_grad_(True)]
+ for b in model.blocks:
+ hs.append(hs[-1] + b(hs[-1]))
+ logits = model.out_head(model.out_ln(hs[-1]))
+ loss = F.cross_entropy(logits, y_eval)
+ grads = torch.autograd.grad(loss, hs)
+ with torch.no_grad():
+ e_T = logits.softmax(-1)
+ e_T[torch.arange(x_eval.size(0)), y_eval] -= 1
+ cos_layers = []
+ for l in range(L):
+ a = (e_T @ Bs[l].T).detach()
+ cos_layers.append(cosine_similarity_batch(a, grads[l].detach()))
+ return float(np.mean(cos_layers[1:])) # deep = exclude layer 0
+
+ train_cos = deep_cos_with_Bs(training_Bs)
+ fresh_cos = []
+ for _ in range(n_draws):
+ fresh_Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)]
+ fresh_cos.append(deep_cos_with_Bs(fresh_Bs))
+
+ return {
+ 'training_Bs_deep_cos': train_cos,
+ 'fresh_Bs_deep_mean': float(np.mean(fresh_cos)),
+ 'fresh_Bs_deep_std_ddof1': float(np.std(fresh_cos, ddof=1)),
+ 'n_draws': n_draws,
+ }
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--seeds', nargs='+', type=int, default=[42, 123, 456])
+ p.add_argument('--epochs', type=int, default=30)
+ p.add_argument('--lambdas', nargs='+', type=float, default=[0.0, 1e-4, 1e-2])
+ p.add_argument('--gpu', type=int, default=0)
+ p.add_argument('--output_dir', type=str, default='results/penalty_sweep')
+ args = p.parse_args()
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ train_loader, test_loader, _ = get_data('cifar10', 128)
+
+ results = {}
+ for lam in args.lambdas:
+ lam_key = f'lam_{lam}'
+ results[lam_key] = {}
+ for seed in args.seeds:
+ print(f"\n=== lambda={lam}, seed={seed} ===", flush=True)
+ log, model, Bs = train_dfa_trajectory(seed, train_loader, test_loader, device, args.epochs, lam)
+ results[lam_key][str(seed)] = log
+
+ # Fresh-B null on lambda=1e-2, seed=42 only
+ if lam == 1e-2 and seed == 42:
+ xs, ys = [], []
+ for x, y in test_loader:
+ xs.append(x.view(x.size(0), -1)); ys.append(y)
+ if sum(xb.size(0) for xb in xs) >= 128: break
+ x_eval = torch.cat(xs)[:128].to(device)
+ y_eval = torch.cat(ys)[:128].to(device)
+ null = fresh_b_null(model, x_eval, y_eval, Bs)
+ results['fresh_b_null'] = null
+ print(f" Fresh-B: training={null['training_Bs_deep_cos']:+.4f}, "
+ f"fresh={null['fresh_Bs_deep_mean']:+.4f} +/- {null['fresh_Bs_deep_std_ddof1']:.4f}")
+
+ with open(os.path.join(args.output_dir, 'penalty_sweep.json'), 'w') as f:
+ json.dump(results, f, indent=2)
+ print(f"\nSaved: {args.output_dir}/penalty_sweep.json")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/reproduce/run_all.sh b/reproduce/run_all.sh
new file mode 100755
index 0000000..35c3587
--- /dev/null
+++ b/reproduce/run_all.sh
@@ -0,0 +1,85 @@
+#!/bin/bash
+# Full reproduction of all paper results.
+# Usage: bash reproduce/run_all.sh --gpu 0
+# Estimated time: ~12 hours on a single A100/A6000.
+
+GPU=${1:-0}
+export CUDA_VISIBLE_DEVICES=$GPU
+
+echo "============================================================"
+echo "FA Evaluation Protocol — Full Reproduction"
+echo "GPU: $GPU"
+echo "Start: $(date)"
+echo "============================================================"
+
+cd "$(dirname "$0")/.."
+
+# ─── Section 2: Primary audit (ResMLP d=256 L=4, 100ep) ─────────────────
+
+echo ""
+echo "=== Section 2: Primary audit (BP/FA/DFA, 3 seeds, 100ep) ==="
+python reproduce/train_methods.py --arch resmlp --methods bp fa dfa \
+ --seeds 42 123 456 --epochs 100 --gpu 0 --output_dir results/sec2_primary_audit
+
+echo ""
+echo "=== Section 2: Frozen baseline ==="
+python reproduce/frozen_baseline.py --arch resmlp --seeds 42 123 456 \
+ --epochs 100 --gpu 0 --output_dir results/sec2_frozen
+
+# ─── Section 4.1: Cross-architecture (ViT, ResNet) ───────────────────────
+
+echo ""
+echo "=== Section 4.1: ViT-Mini (BP/FA/DFA, 3 seeds, 60ep) ==="
+python reproduce/train_methods.py --arch vit --methods bp fa dfa \
+ --seeds 42 123 456 --epochs 60 --gpu 0 --output_dir results/sec4_vit
+
+echo ""
+echo "=== Section 4.1: ViT-Mini frozen baseline ==="
+python reproduce/frozen_baseline.py --arch vit --seeds 42 123 456 \
+ --epochs 60 --gpu 0 --output_dir results/sec4_vit_frozen
+
+echo ""
+echo "=== Section 4.1: SmallResNet (BP/FA/DFA, 3 seeds, 100ep) ==="
+python reproduce/train_methods.py --arch resnet --methods bp fa dfa \
+ --seeds 42 123 456 --epochs 100 --gpu 0 --output_dir results/sec4_resnet
+
+echo ""
+echo "=== Section 4.1: SmallResNet frozen baseline ==="
+python reproduce/frozen_baseline.py --arch resnet --seeds 42 123 456 \
+ --epochs 100 --gpu 0 --output_dir results/sec4_resnet_frozen
+
+# ─── Section 4.2: Penalty intervention ───────────────────────────────────
+
+echo ""
+echo "=== Section 4.2: DFA penalty sweep (lambda=0, 1e-4, 1e-2, 30ep) ==="
+python reproduce/penalty_sweep.py --seeds 42 123 456 --epochs 30 --gpu 0 \
+ --output_dir results/sec4_penalty
+
+# ─── Section 5.2: Representative setting (d=512 L=2) ────────────────────
+
+echo ""
+echo "=== Section 5.2: d=512 L=2 (BP/FA/DFA, 3 seeds, 100ep) ==="
+python reproduce/train_methods.py --arch resmlp_d512_L2 --methods bp fa dfa \
+ --seeds 1 2 5 --epochs 100 --gpu 0 --output_dir results/sec5_d512_L2
+
+echo ""
+echo "=== Section 5.2: d=512 L=2 frozen baseline ==="
+python reproduce/frozen_baseline.py --arch resmlp_d512_L2 --seeds 1 2 5 \
+ --epochs 100 --gpu 0 --output_dir results/sec5_d512_L2_frozen
+
+# ─── Appendix: CIFAR-100 ────────────────────────────────────────────────
+
+echo ""
+echo "=== Appendix: CIFAR-100 (BP/FA/DFA, 3 seeds, 100ep) ==="
+python reproduce/train_methods.py --arch resmlp --dataset cifar100 --methods bp fa dfa \
+ --seeds 42 123 456 --epochs 100 --gpu 0 --output_dir results/app_cifar100
+
+echo ""
+echo "=== Appendix: CIFAR-100 frozen baseline ==="
+python reproduce/frozen_baseline.py --arch resmlp --dataset cifar100 --seeds 42 123 456 \
+ --epochs 100 --gpu 0 --output_dir results/app_cifar100_frozen
+
+echo ""
+echo "============================================================"
+echo "Full reproduction done: $(date)"
+echo "============================================================"
diff --git a/reproduce/train_methods.py b/reproduce/train_methods.py
new file mode 100644
index 0000000..c430b90
--- /dev/null
+++ b/reproduce/train_methods.py
@@ -0,0 +1,376 @@
+"""
+Train BP/FA/DFA on a specified architecture and compute protocol diagnostics.
+
+Usage:
+ python reproduce/train_methods.py --arch resmlp --methods bp fa dfa \
+ --seeds 42 123 456 --epochs 100 --gpu 0 --output_dir results/main_audit
+
+Architectures: resmlp (d=256 L=4), resmlp_d512_L2, vit, resnet
+"""
+import os, sys, json, argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision, torchvision.transforms as transforms
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from models.residual_mlp import ResidualMLP
+from models.vit_mini import ViTMini
+from models.small_resnet import SmallResNet
+from metrics.credit_metrics import cosine_similarity_batch, nudging_test
+
+
+# ─── Data ────────────────────────────────────────────────────────────────
+
+def get_data(dataset='cifar10', batch_size=128):
+ if dataset == 'cifar10':
+ mean, std = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)
+ Dataset = torchvision.datasets.CIFAR10
+ num_classes = 10
+ else:
+ mean, std = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)
+ Dataset = torchvision.datasets.CIFAR100
+ num_classes = 100
+ tv_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(), transforms.Normalize(mean, std)])
+ tv_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
+ tr = Dataset('./data', True, download=True, transform=tv_train)
+ te = Dataset('./data', False, download=True, transform=tv_test)
+ return (DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2),
+ DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2),
+ num_classes)
+
+
+def evaluate(model, loader, device, is_conv=False):
+ model.eval()
+ c = n = 0
+ with torch.no_grad():
+ for x, y in loader:
+ x, y = x.to(device), y.to(device)
+ if not is_conv:
+ x = x.view(x.size(0), -1)
+ c += (model(x).argmax(-1) == y).sum().item()
+ n += x.size(0)
+ return c / n
+
+
+# ─── Model construction ─────────────────────────────────────────────────
+
+def make_model(arch, num_classes, device):
+ if arch == 'resmlp':
+ return ResidualMLP(3072, 256, num_classes, 4).to(device), False
+ elif arch == 'resmlp_d512_L2':
+ return ResidualMLP(3072, 512, num_classes, 2).to(device), False
+ elif arch == 'vit':
+ return ViTMini(d_model=128, n_heads=4, num_blocks=4, num_classes=num_classes).to(device), True
+ elif arch == 'resnet':
+ return SmallResNet(64, num_classes, 4).to(device), True
+ else:
+ raise ValueError(f"Unknown arch: {arch}")
+
+
+# ─── Training functions ─────────────────────────────────────────────────
+
+def train_bp(model, train_loader, test_loader, device, epochs, is_conv):
+ opt = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
+ sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
+ log = {'train_loss': [], 'train_acc': [], 'test_acc': []}
+ for ep in range(1, epochs + 1):
+ model.train()
+ tl, tc, tn = 0, 0, 0
+ for x, y in train_loader:
+ x, y = x.to(device), y.to(device)
+ if not is_conv: x = x.view(x.size(0), -1)
+ logits = model(x)
+ loss = F.cross_entropy(logits, y)
+ opt.zero_grad(); loss.backward(); opt.step()
+ tl += loss.item() * x.size(0); tc += (logits.argmax(1) == y).sum().item(); tn += x.size(0)
+ sch.step()
+ log['train_loss'].append(tl / tn); log['train_acc'].append(tc / tn)
+ log['test_acc'].append(evaluate(model, test_loader, device, is_conv))
+ if ep % 10 == 0 or ep == epochs:
+ print(f" [BP] ep {ep}: acc={log['test_acc'][-1]:.4f}", flush=True)
+ return log
+
+
+def _get_embed_head_params(model, is_conv):
+ """Get embed and head parameter groups."""
+ if is_conv and hasattr(model, 'stem_conv'):
+ embed_params = list(model.stem_conv.parameters()) + list(model.stem_bn.parameters())
+ head_params = list(model.out_head.parameters())
+ elif hasattr(model, 'patch_embed'): # ViT
+ embed_params = list(model.patch_embed.parameters()) + [model.cls_token, model.pos_embed]
+ head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters())
+ else: # ResMLP
+ embed_params = list(model.embed.parameters())
+ head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters())
+ return embed_params, head_params
+
+
+def _pool_hidden(h):
+ if h.dim() == 4: return F.adaptive_avg_pool2d(h, 1).flatten(1)
+ if h.dim() == 3: return h[:, 0] # cls token
+ return h
+
+
+def _get_head_logits(model, h_pool):
+ if hasattr(model, 'out_ln'):
+ return model.out_head(model.out_ln(h_pool))
+ return model.out_head(h_pool)
+
+
+def _block_residual(model, block, h_l, is_conv):
+ """Compute block residual f_l = block(h_l) - h_l for blocks with internal skip."""
+ out = block(h_l)
+ if is_conv or hasattr(block, 'attn'): # ResNet/ViT blocks include skip internally
+ return out - h_l
+ return out # ResMLP blocks return f_l only
+
+
+def train_dfa(model, train_loader, test_loader, device, epochs, is_conv, num_classes):
+ d = model.d_hidden if hasattr(model, 'd_hidden') else model.d_model
+ L = model.num_blocks
+ C = num_classes
+ Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)]
+ block_opts = [optim.AdamW(b.parameters(), lr=1e-3, weight_decay=0.01) for b in model.blocks]
+ embed_params, head_params = _get_embed_head_params(model, is_conv)
+ embed_opt = optim.AdamW(embed_params, lr=1e-3, weight_decay=0.01)
+ head_opt = optim.AdamW(head_params, lr=1e-3, weight_decay=0.01)
+ all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \
+ [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]
+ log = {'train_loss': [], 'train_acc': [], 'test_acc': []}
+ for ep in range(1, epochs + 1):
+ model.train()
+ tl, tc, tn = 0, 0, 0
+ for x, y in train_loader:
+ x, y = x.to(device), y.to(device)
+ if not is_conv: x = x.view(x.size(0), -1)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1
+ h_pool = _pool_hidden(hiddens[-1].detach())
+ head_opt.zero_grad()
+ F.cross_entropy(_get_head_logits(model, h_pool), y).backward()
+ head_opt.step()
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a = (e_T @ Bs[l].T).detach()
+ rms = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a / rms
+ f_l = _block_residual(model, model.blocks[l], h_l, is_conv)
+ if f_l.dim() > 2:
+ a_b = a_norm.unsqueeze(-1).unsqueeze(-1).expand_as(f_l)
+ local_loss = (f_l * a_b).sum(dim=1).mean()
+ else:
+ local_loss = (f_l * a_norm).sum(-1).mean()
+ block_opts[l].zero_grad(); local_loss.backward(); block_opts[l].step()
+ # Embed
+ a0 = (e_T @ Bs[0].T).detach()
+ rms0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ if is_conv:
+ h0 = model.embed(x) if hasattr(model, 'embed') else model.stem(x)
+ else:
+ h0 = model.embed(x)
+ a0_n = a0 / rms0
+ if h0.dim() > 2:
+ a0_b = a0_n.unsqueeze(-1).unsqueeze(-1).expand_as(h0)
+ embed_loss = (h0 * a0_b).sum(dim=1).mean()
+ else:
+ embed_loss = (h0 * a0_n).sum(-1).mean()
+ embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step()
+ for s in all_sch: s.step()
+ tl += loss_val.item() * batch; tc += (logits.argmax(1) == y).sum().item(); tn += batch
+ log['train_loss'].append(tl / tn); log['train_acc'].append(tc / tn)
+ log['test_acc'].append(evaluate(model, test_loader, device, is_conv))
+ if ep % 10 == 0 or ep == epochs:
+ print(f" [DFA] ep {ep}: acc={log['test_acc'][-1]:.4f}", flush=True)
+ return log, Bs
+
+
+def train_fa(model, train_loader, test_loader, device, epochs, is_conv, num_classes):
+ d = model.d_hidden if hasattr(model, 'd_hidden') else model.d_model
+ L = model.num_blocks
+ Bs = [torch.randn(d, d, device=device) / np.sqrt(d) for _ in range(L)]
+ block_opts = [optim.AdamW(b.parameters(), lr=1e-3, weight_decay=0.01) for b in model.blocks]
+ embed_params, head_params = _get_embed_head_params(model, is_conv)
+ embed_opt = optim.AdamW(embed_params, lr=1e-3, weight_decay=0.01)
+ head_opt = optim.AdamW(head_params, lr=1e-3, weight_decay=0.01)
+ all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \
+ [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]
+ log = {'train_loss': [], 'train_acc': [], 'test_acc': []}
+ for ep in range(1, epochs + 1):
+ model.train()
+ tl, tc, tn = 0, 0, 0
+ for x, y in train_loader:
+ x, y = x.to(device), y.to(device)
+ if not is_conv: x = x.view(x.size(0), -1)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ # Head — grad before step
+ h_pool = _pool_hidden(hiddens[-1].detach()).requires_grad_(True)
+ logits_out = _get_head_logits(model, h_pool)
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad(); loss_out.backward()
+ a_credit = h_pool.grad.detach()
+ head_opt.step()
+ # Top-down blocks
+ for l in range(L - 1, -1, -1):
+ h_l = hiddens[l].detach()
+ rms = (a_credit ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a_credit / rms
+ f_l = _block_residual(model, model.blocks[l], h_l, is_conv)
+ if f_l.dim() > 2:
+ a_b = a_norm.unsqueeze(-1).unsqueeze(-1).expand_as(f_l)
+ local_loss = (f_l * a_b).sum(dim=1).mean()
+ else:
+ local_loss = (f_l * a_norm).sum(-1).mean()
+ block_opts[l].zero_grad(); local_loss.backward(); block_opts[l].step()
+ a_credit = (a_credit @ Bs[l]).detach()
+ # Embed
+ rms0 = (a_credit ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ if is_conv:
+ h0 = model.embed(x) if hasattr(model, 'embed') else model.stem(x)
+ else:
+ h0 = model.embed(x)
+ a0_n = a_credit / rms0
+ if h0.dim() > 2:
+ a0_b = a0_n.unsqueeze(-1).unsqueeze(-1).expand_as(h0)
+ embed_loss = (h0 * a0_b).sum(dim=1).mean()
+ else:
+ embed_loss = (h0 * a0_n).sum(-1).mean()
+ embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step()
+ for s in all_sch: s.step()
+ tl += loss_val.item() * batch; tc += (logits.argmax(1) == y).sum().item(); tn += batch
+ log['train_loss'].append(tl / tn); log['train_acc'].append(tc / tn)
+ log['test_acc'].append(evaluate(model, test_loader, device, is_conv))
+ if ep % 10 == 0 or ep == epochs:
+ print(f" [FA] ep {ep}: acc={log['test_acc'][-1]:.4f}", flush=True)
+ return log, Bs
+
+
+# ─── Diagnostics ─────────────────────────────────────────────────────────
+
+def compute_diagnostics(model, x_eval, y_eval, device, method_name, dfa_Bs=None, fa_Bs=None, is_conv=False):
+ """Compute per-layer cosine, ||g_l||, ||h_l|| and nudging."""
+ model.eval()
+ L = model.num_blocks
+
+ with torch.no_grad():
+ logits, hiddens = model(x_eval, return_hidden=True)
+
+ h_norms = [float(_pool_hidden(h).norm(dim=-1).median().item()) for h in hiddens]
+
+ # BP grads
+ h0 = model.embed(x_eval) if hasattr(model, 'embed') else model.stem(x_eval)
+ hs = [h0.clone().requires_grad_(True)]
+ for block in model.blocks:
+ hs.append(block(hs[-1]))
+ h_final = _pool_hidden(hs[-1])
+ if hasattr(model, 'out_ln'):
+ h_final = model.out_ln(h_final)
+ out_logits = model.out_head(h_final)
+ loss = F.cross_entropy(out_logits, y_eval)
+ grads = torch.autograd.grad(loss, hs)
+ g_norms = [float(_pool_hidden(g).norm(dim=-1).median().item()) for g in grads]
+
+ # Per-layer cosine
+ with torch.no_grad():
+ e_T = out_logits.softmax(-1)
+ e_T[torch.arange(x_eval.size(0)), y_eval] -= 1
+
+ bp_cosine = []
+ if method_name == 'bp':
+ bp_cosine = [1.0] * L
+ elif method_name == 'dfa' and dfa_Bs is not None:
+ for l in range(L):
+ a = (e_T @ dfa_Bs[l].T).detach()
+ g_pool = _pool_hidden(grads[l]).detach()
+ bp_cosine.append(cosine_similarity_batch(a, g_pool))
+ elif method_name == 'fa' and fa_Bs is not None:
+ hL_pool = _pool_hidden(hiddens[-1].detach()).requires_grad_(True)
+ logits_fa = _get_head_logits(model, hL_pool)
+ loss_fa = F.cross_entropy(logits_fa, y_eval)
+ a_credit = torch.autograd.grad(loss_fa, hL_pool)[0].detach()
+ for l in range(L - 1, -1, -1):
+ g_pool = _pool_hidden(grads[l]).detach()
+ bp_cosine.insert(0, cosine_similarity_batch(a_credit, g_pool))
+ a_credit = (a_credit @ fa_Bs[l]).detach()
+
+ model.train()
+ return {
+ 'bp_cosine': bp_cosine,
+ 'bp_grad_norms_per_layer': g_norms,
+ 'hidden_norms_per_layer': h_norms,
+ }
+
+
+# ─── Main ────────────────────────────────────────────────────────────────
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--arch', type=str, default='resmlp', choices=['resmlp', 'resmlp_d512_L2', 'vit', 'resnet'])
+ p.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100'])
+ p.add_argument('--methods', nargs='+', default=['bp', 'fa', 'dfa'])
+ p.add_argument('--seeds', nargs='+', type=int, default=[42, 123, 456])
+ p.add_argument('--epochs', type=int, default=100)
+ p.add_argument('--gpu', type=int, default=0)
+ p.add_argument('--output_dir', type=str, default='results/reproduce')
+ p.add_argument('--penalty_lam', type=float, default=0.0)
+ args = p.parse_args()
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ train_loader, test_loader, num_classes = get_data(args.dataset, 128)
+
+ # Eval buffer
+ xs, ys = [], []
+ for x, y in test_loader:
+ xs.append(x); ys.append(y)
+ if sum(xb.size(0) for xb in xs) >= 128: break
+ x_eval_raw = torch.cat(xs)[:128].to(device)
+ y_eval = torch.cat(ys)[:128].to(device)
+
+ results = {}
+ for seed in args.seeds:
+ print(f"\n{'='*60}\nSeed {seed}\n{'='*60}", flush=True)
+ results[str(seed)] = {}
+
+ for method in args.methods:
+ print(f"\n--- {method.upper()} ---", flush=True)
+ torch.manual_seed(seed); np.random.seed(seed)
+ if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
+ model, is_conv = make_model(args.arch, num_classes, device)
+ x_eval = x_eval_raw if is_conv else x_eval_raw.view(x_eval_raw.size(0), -1)
+
+ if method == 'bp':
+ log = train_bp(model, train_loader, test_loader, device, args.epochs, is_conv)
+ diag = compute_diagnostics(model, x_eval, y_eval, device, 'bp', is_conv=is_conv)
+ results[str(seed)]['bp'] = {'log': log, 'diagnostics': diag}
+ elif method == 'dfa':
+ log, Bs = train_dfa(model, train_loader, test_loader, device, args.epochs, is_conv, num_classes)
+ diag = compute_diagnostics(model, x_eval, y_eval, device, 'dfa', dfa_Bs=Bs, is_conv=is_conv)
+ results[str(seed)]['dfa'] = {'log': log, 'diagnostics': diag}
+ elif method == 'fa':
+ log, Bs = train_fa(model, train_loader, test_loader, device, args.epochs, is_conv, num_classes)
+ diag = compute_diagnostics(model, x_eval, y_eval, device, 'fa', fa_Bs=Bs, is_conv=is_conv)
+ results[str(seed)]['fa'] = {'log': log, 'diagnostics': diag}
+
+ results['config'] = vars(args)
+ out_path = os.path.join(args.output_dir, f'results_{args.dataset}.json')
+ with open(out_path, 'w') as f:
+ json.dump(results, f, indent=2)
+ print(f"\nSaved: {out_path}", flush=True)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..fd142b5
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,4 @@
+torch>=2.0
+torchvision>=0.15
+numpy>=1.24
+scipy>=1.10