""" Smoke test for the FA Diagnostic Protocol reference implementation. Loads BP-trained and DFA-trained ResMLP checkpoints from `results/confirmatory/checkpoints_A2/` and applies the protocol to each. The protocol should: - return verdict="trustworthy" on the BP checkpoint (residual norms bounded, BP grad ~5e-5 well above floor, low cross-batch direction stability), - flag the DFA checkpoint as "needs walk-back" (residual stream exploded by ~10^7×, BP grad at ~5e-10, drift-dominated direction). Run with: CUDA_VISIBLE_DEVICES=2 python -m protocol.smoke_test """ import os import sys import torch import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader # Make `protocol` and `models` importable when invoked from repo root sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.residual_mlp import ResidualMLP # noqa: E402 from protocol import diagnose # noqa: E402 REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) CHECKPOINT_DIR = os.path.join(REPO_ROOT, "results/confirmatory/checkpoints_A2") EP_CHECKPOINT_DIR = os.path.join(REPO_ROOT, "results/ep_baseline") def load_eval_batches(n_batches: int = 4, batch_size: int = 1024, device="cuda:0"): tv = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv) loader = DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=0) batches = [] for x, y in loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) batches.append((x, y)) if len(batches) >= n_batches: break return batches def evaluate(model, loader, device): model.eval() correct = total = 0 with torch.no_grad(): for x, y in loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) preds = model(x).argmax(-1) correct += (preds == y).sum().item() total += x.size(0) return correct / total def load_model(method: str, seed: int, device): if method == "ep": path = os.path.join(EP_CHECKPOINT_DIR, f"ep_s{seed}.pt") else: path = os.path.join(CHECKPOINT_DIR, f"{method}_s{seed}.pt") ckpt = torch.load(path, map_location=device, weights_only=False) # Try several common checkpoint formats if isinstance(ckpt, dict) and "state_dict" in ckpt: sd = ckpt["state_dict"] elif isinstance(ckpt, dict) and "model" in ckpt: sd = ckpt["model"] elif isinstance(ckpt, dict): sd = ckpt else: sd = ckpt.state_dict() model = ResidualMLP(input_dim=3072, d_hidden=256, num_classes=10, num_blocks=4).to(device) model.load_state_dict(sd) return model def main(): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") # Cross-batch stability is sensitive to batch size: with smaller batches # the per-sample noise component dominates the batch-mean direction on # healthy networks, giving the cleanest separation from the drift- # dominated failure mode (~0.10 healthy vs ~0.95 drift at N=128). eval_batches = load_eval_batches(n_batches=10, batch_size=128, device=device) # Build a single test loader for headline acc tv = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv) test_loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0) # BP: trustworthy. DFA: walked back. EP: trustworthy (the internal control # — same architecture/metric/dataset, but EP doesn't blow up the residual # stream, so the diagnostic protocol passes for EP even though EP's # accuracy is also low. This is the paper's central comparison.) for method, expected in [ ("bp", "trustworthy"), ("dfa", "needs walk-back"), ("ep", "trustworthy"), ]: print() print(f"### {method.upper()} (seed 42)") model = load_model(method, 42, device) acc = evaluate(model, test_loader, device) report = diagnose( model=model, eval_batches=eval_batches, headline_acc=acc, frozen_baseline_acc=None, method_name=method.upper(), notes="4-block d=256 ResMLP, CIFAR-10", ) print(report) # Sanity check verdict = report.verdict if expected == "trustworthy" and verdict != "trustworthy": print(f"!! UNEXPECTED: BP should be trustworthy, got '{verdict}'") elif expected == "needs walk-back" and not verdict.startswith("needs walk-back"): print(f"!! UNEXPECTED: DFA should need walk-back, got '{verdict}'") else: print(f"OK: verdict matches expectation ('{expected}')") if __name__ == "__main__": main()