""" Sanity check: apply the diagnostic protocol to a *randomly initialized* 4-block d=256 ResMLP, with no training. The protocol should report "trustworthy" — random init is not in the failure regime, even though it also has trivial accuracy. The protocol is supposed to flag *trained-into- pathology* networks, not weak networks per se. This is the answer to the likely reviewer question: "is your protocol just flagging anything that doesn't perform well?" Answer: no. Random init performs at chance (10%) and the protocol passes it. The 3 walked-back trained methods are walked back because of the *measurements*, not the accuracy. Run: CUDA_VISIBLE_DEVICES=2 python -m protocol.examples.random_init_sanity """ import os import sys import torch import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader REPO_ROOT = os.path.dirname( os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ) sys.path.insert(0, REPO_ROOT) from models.residual_mlp import ResidualMLP # noqa: E402 from protocol import diagnose # noqa: E402 def get_eval(n_batches, batch_size, device): tv = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv) loader = DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=0) batches = [] for x, y in loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) batches.append((x, y)) if len(batches) >= n_batches: break return batches def chance_acc(model, device): tv = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv) loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0) model.eval() correct = total = 0 with torch.no_grad(): for x, y in loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) preds = model(x).argmax(-1) correct += (preds == y).sum().item() total += x.size(0) return correct / total def main(): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") eval_batches = get_eval(n_batches=10, batch_size=128, device=device) print("=" * 72) print("RANDOM-INIT SANITY CHECK (4-block d=256 ResMLP, CIFAR-10)") print("=" * 72) for seed in [42, 123, 456]: torch.manual_seed(seed) model = ResidualMLP(3072, 256, 10, 4).to(device) # Skip BatchNorm-style update by not doing forward pass — just use # random init parameters directly. acc = chance_acc(model, device) report = diagnose( model=model, eval_batches=eval_batches, headline_acc=acc, frozen_baseline_acc=None, method_name=f"RANDOM_INIT seed {seed}", notes="untrained, no parameter updates", ) print(f"\n=== seed {seed} (chance acc: {acc:.4f}) ===") print(report) if __name__ == "__main__": main()