""" Minimal worked example showing how a future FA paper author would use the diagnostic protocol on their own model. Trains a fresh tiny ResMLP with DFA on CIFAR-10 for 5 epochs (so the script runs in <2 minutes on CPU), applies the protocol, and prints the verdict. This is the "API tutorial" version of the protocol. Real applications would train for 100 epochs and use a real test set; the structure is identical. Run: python -m protocol.examples.minimal_worked_example """ import os import sys import time import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader REPO_ROOT = os.path.dirname( os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ) sys.path.insert(0, REPO_ROOT) from models.residual_mlp import ResidualMLP # noqa: E402 from protocol import diagnose # noqa: E402 def get_loaders(batch_size=128): tv = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) tr = torchvision.datasets.CIFAR10("./data", train=True, download=True, transform=tv) te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv) return ( DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=0), DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=0), ) def evaluate(model, loader, device): model.eval() correct = total = 0 with torch.no_grad(): for x, y in loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) preds = model(x).argmax(-1) correct += (preds == y).sum().item() total += x.size(0) return correct / total def train_dfa_one_epoch(model, train_loader, Bs, device, lr=1e-3): L = model.num_blocks opts = [optim.AdamW(b.parameters(), lr=lr) for b in model.blocks] embed_opt = optim.AdamW(model.embed.parameters(), lr=lr) head_opt = optim.AdamW( list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=lr ) model.train() for x, y in train_loader: x = x.view(x.size(0), -1).to(device); y = y.to(device) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(-1); e_T[torch.arange(x.size(0)), y] -= 1 head_opt.zero_grad() F.cross_entropy(model.out_head(model.out_ln(hiddens[-1].detach())), y).backward() head_opt.step() for l in range(L): h_l = hiddens[l].detach() a = (e_T @ Bs[l].T).detach() rms = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 f = model.blocks[l](h_l) loss = (f * (a / rms)).sum(-1).mean() opts[l].zero_grad(); loss.backward(); opts[l].step() a0 = (e_T @ Bs[0].T).detach() rms0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 h0 = model.embed(x) embed_opt.zero_grad() (h0 * (a0 / rms0)).sum(-1).mean().backward() embed_opt.step() def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") print() print("Step 1: train a tiny 4-block d=128 ResMLP with DFA on CIFAR-10") print("(5 epochs, just for the example — real runs would use ~100 epochs)") print() train_loader, test_loader = get_loaders(batch_size=128) torch.manual_seed(42); np.random.seed(42) model = ResidualMLP(input_dim=3072, d_hidden=128, num_classes=10, num_blocks=4).to(device) Bs = [torch.randn(128, 10, device=device) / np.sqrt(10) for _ in range(4)] t0 = time.time() for ep in range(1, 6): train_dfa_one_epoch(model, train_loader, Bs, device) acc = evaluate(model, test_loader, device) print(f" epoch {ep}: test_acc = {acc:.4f} ({time.time()-t0:.0f}s elapsed)") print() print("Step 2: build the eval batches the protocol needs (8 batches × 128 samples)") eval_batches = [] tv = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv) sub_loader = DataLoader(te, batch_size=128, shuffle=False, num_workers=0) for x, y in sub_loader: x = x.view(x.size(0), -1).to(device); y = y.to(device) eval_batches.append((x, y)) if len(eval_batches) >= 8: break print() print("Step 3: apply the protocol") print() final_acc = evaluate(model, test_loader, device) report = diagnose( model=model, eval_batches=eval_batches, headline_acc=final_acc, # In a real paper, you'd train an architecture-matched random-blocks # baseline and pass its accuracy here. For this example we use the # 3-seed mean from our paper (4-block d=256 ResMLP DFA-shallow). # The width is different (d=128 vs d=256) but the diagnostic # interpretation is the same. frozen_baseline_acc=0.349, method_name="DFA (5-epoch demo)", notes="4-block d=128 ResMLP, CIFAR-10, seed 42, 5 epochs", ) print(report) print() print("Step 4: interpret") print() if report.verdict == "trustworthy": print(" The protocol gave a trustworthy verdict, meaning the network is") print(" in the meaningful measurement regime. You can report headline") print(" accuracy and Γ alignment confidently.") else: print(" The protocol flagged this run for walk-back. Specifically:") print(f" {report.verdict}") print() print(" In a real paper using this protocol, you would either:") print(" (1) Walk back the headline claim and report the failure mode, OR") print(" (2) Modify the training (e.g., add a residual-stream penalty)") print(" to bring the diagnostics into the healthy regime, then re-") print(" apply the protocol to the modified network.") if __name__ == "__main__": main()