summaryrefslogtreecommitdiff
path: root/protocol
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 23:33:49 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 23:33:49 -0500
commitacc86add44e0cac8701307f936029770edd50891 (patch)
tree59ddcebd72975307e4a58dd174af4dbc4d99a304 /protocol
parentd5185a3cc692fe96c93bbc5d7b286b7080ba7458 (diff)
Add minimal worked example: end-to-end protocol usage tutorial
5-epoch DFA training on CIFAR-10 + apply protocol + interpret verdict. Self-contained, runs on CPU in <2 minutes. Demonstrates the API a future paper author would use: 1. train your model (any FA-style method) 2. build eval_batches from your test loader 3. call diagnose(model, eval_batches, headline_acc, frozen_baseline_acc) 4. read report.verdict; walk back if 'needs walk-back' Not run during this session to avoid GPU contention with the in-flight direction-quality and ViT/ResNet experiments.
Diffstat (limited to 'protocol')
-rw-r--r--protocol/examples/minimal_worked_example.py164
1 files changed, 164 insertions, 0 deletions
diff --git a/protocol/examples/minimal_worked_example.py b/protocol/examples/minimal_worked_example.py
new file mode 100644
index 0000000..6b8ec2f
--- /dev/null
+++ b/protocol/examples/minimal_worked_example.py
@@ -0,0 +1,164 @@
+"""
+Minimal worked example showing how a future FA paper author would use the
+diagnostic protocol on their own model. Trains a fresh tiny ResMLP with DFA
+on CIFAR-10 for 5 epochs (so the script runs in <2 minutes on CPU), applies
+the protocol, and prints the verdict.
+
+This is the "API tutorial" version of the protocol. Real applications would
+train for 100 epochs and use a real test set; the structure is identical.
+
+Run:
+ python -m protocol.examples.minimal_worked_example
+"""
+import os
+import sys
+import time
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+import torchvision
+import torchvision.transforms as transforms
+from torch.utils.data import DataLoader
+
+REPO_ROOT = os.path.dirname(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+)
+sys.path.insert(0, REPO_ROOT)
+
+from models.residual_mlp import ResidualMLP # noqa: E402
+from protocol import diagnose # noqa: E402
+
+
+def get_loaders(batch_size=128):
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tr = torchvision.datasets.CIFAR10("./data", train=True, download=True, transform=tv)
+ te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
+ return (
+ DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=0),
+ DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=0),
+ )
+
+
+def evaluate(model, loader, device):
+ model.eval()
+ correct = total = 0
+ with torch.no_grad():
+ for x, y in loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ preds = model(x).argmax(-1)
+ correct += (preds == y).sum().item()
+ total += x.size(0)
+ return correct / total
+
+
+def train_dfa_one_epoch(model, train_loader, Bs, device, lr=1e-3):
+ L = model.num_blocks
+ opts = [optim.AdamW(b.parameters(), lr=lr) for b in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=lr)
+ head_opt = optim.AdamW(
+ list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=lr
+ )
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(-1); e_T[torch.arange(x.size(0)), y] -= 1
+ head_opt.zero_grad()
+ F.cross_entropy(model.out_head(model.out_ln(hiddens[-1].detach())), y).backward()
+ head_opt.step()
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a = (e_T @ Bs[l].T).detach()
+ rms = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ f = model.blocks[l](h_l)
+ loss = (f * (a / rms)).sum(-1).mean()
+ opts[l].zero_grad(); loss.backward(); opts[l].step()
+ a0 = (e_T @ Bs[0].T).detach()
+ rms0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ h0 = model.embed(x)
+ embed_opt.zero_grad()
+ (h0 * (a0 / rms0)).sum(-1).mean().backward()
+ embed_opt.step()
+
+
+def main():
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ print(f"Device: {device}")
+ print()
+ print("Step 1: train a tiny 4-block d=128 ResMLP with DFA on CIFAR-10")
+ print("(5 epochs, just for the example — real runs would use ~100 epochs)")
+ print()
+
+ train_loader, test_loader = get_loaders(batch_size=128)
+ torch.manual_seed(42); np.random.seed(42)
+ model = ResidualMLP(input_dim=3072, d_hidden=128, num_classes=10, num_blocks=4).to(device)
+ Bs = [torch.randn(128, 10, device=device) / np.sqrt(10) for _ in range(4)]
+
+ t0 = time.time()
+ for ep in range(1, 6):
+ train_dfa_one_epoch(model, train_loader, Bs, device)
+ acc = evaluate(model, test_loader, device)
+ print(f" epoch {ep}: test_acc = {acc:.4f} ({time.time()-t0:.0f}s elapsed)")
+
+ print()
+ print("Step 2: build the eval batches the protocol needs (8 batches × 128 samples)")
+ eval_batches = []
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
+ sub_loader = DataLoader(te, batch_size=128, shuffle=False, num_workers=0)
+ for x, y in sub_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device)
+ eval_batches.append((x, y))
+ if len(eval_batches) >= 8:
+ break
+
+ print()
+ print("Step 3: apply the protocol")
+ print()
+ final_acc = evaluate(model, test_loader, device)
+ report = diagnose(
+ model=model,
+ eval_batches=eval_batches,
+ headline_acc=final_acc,
+ # In a real paper, you'd train an architecture-matched random-blocks
+ # baseline and pass its accuracy here. For this example we use the
+ # 3-seed mean from our paper (4-block d=256 ResMLP DFA-shallow).
+ # The width is different (d=128 vs d=256) but the diagnostic
+ # interpretation is the same.
+ frozen_baseline_acc=0.349,
+ method_name="DFA (5-epoch demo)",
+ notes="4-block d=128 ResMLP, CIFAR-10, seed 42, 5 epochs",
+ )
+ print(report)
+
+ print()
+ print("Step 4: interpret")
+ print()
+ if report.verdict == "trustworthy":
+ print(" The protocol gave a trustworthy verdict, meaning the network is")
+ print(" in the meaningful measurement regime. You can report headline")
+ print(" accuracy and Γ alignment confidently.")
+ else:
+ print(" The protocol flagged this run for walk-back. Specifically:")
+ print(f" {report.verdict}")
+ print()
+ print(" In a real paper using this protocol, you would either:")
+ print(" (1) Walk back the headline claim and report the failure mode, OR")
+ print(" (2) Modify the training (e.g., add a residual-stream penalty)")
+ print(" to bring the diagnostics into the healthy regime, then re-")
+ print(" apply the protocol to the modified network.")
+
+
+if __name__ == "__main__":
+ main()