diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 22:48:18 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 22:48:18 -0500 |
| commit | c2e145e162444b31ac5c66a90daa6bc0a1cda591 (patch) | |
| tree | ad38ef6b7295ecdb3de67e4474d89548e9ca80ba /protocol | |
| parent | 3a520b203f4f0c75b37b2d5c34d461718729ea02 (diff) | |
Add random-init sanity check: protocol does not flag untrained networks
3-seed random init ResMLP gives chance accuracy (~10%) but the protocol
verdict is 'trustworthy' on all 3 seeds:
- residual norms ~8.7 across all layers (no growth, bounded)
- BP gradient norms ~8e-3 (healthy, well above 1e-7 floor)
- cross-batch stability 0.08-0.18 (in the BP/EP range)
This is the answer to the likely reviewer question: 'is your protocol just
flagging anything that doesn't perform well?' Answer: no. Random init is
at chance and the protocol passes it. The walked-back trained methods are
walked back because of the *measurements*, not because of the accuracy.
Notable: random init g-norms (8e-3) are actually HIGHER than BP-trained
ones (4e-4) — BP training reduces the gradient magnitude as loss decreases.
So the protocol distinguishes 3 distinct regimes: (1) untrained healthy,
(2) trained-and-still-healthy (BP/EP), (3) trained-into-pathology (DFA/SB/CB).
Diffstat (limited to 'protocol')
| -rw-r--r-- | protocol/examples/random_init_sanity.py | 96 |
1 files changed, 96 insertions, 0 deletions
diff --git a/protocol/examples/random_init_sanity.py b/protocol/examples/random_init_sanity.py new file mode 100644 index 0000000..c4981fd --- /dev/null +++ b/protocol/examples/random_init_sanity.py @@ -0,0 +1,96 @@ +""" +Sanity check: apply the diagnostic protocol to a *randomly initialized* +4-block d=256 ResMLP, with no training. The protocol should report +"trustworthy" — random init is not in the failure regime, even though it +also has trivial accuracy. The protocol is supposed to flag *trained-into- +pathology* networks, not weak networks per se. + +This is the answer to the likely reviewer question: "is your protocol just +flagging anything that doesn't perform well?" Answer: no. Random init +performs at chance (10%) and the protocol passes it. The 3 walked-back +trained methods are walked back because of the *measurements*, not the +accuracy. + +Run: + CUDA_VISIBLE_DEVICES=2 python -m protocol.examples.random_init_sanity +""" +import os +import sys + +import torch +import torchvision +import torchvision.transforms as transforms +from torch.utils.data import DataLoader + +REPO_ROOT = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) +sys.path.insert(0, REPO_ROOT) + +from models.residual_mlp import ResidualMLP # noqa: E402 +from protocol import diagnose # noqa: E402 + + +def get_eval(n_batches, batch_size, device): + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv) + loader = DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=0) + batches = [] + for x, y in loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + batches.append((x, y)) + if len(batches) >= n_batches: + break + return batches + + +def chance_acc(model, device): + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv) + loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0) + model.eval() + correct = total = 0 + with torch.no_grad(): + for x, y in loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + preds = model(x).argmax(-1) + correct += (preds == y).sum().item() + total += x.size(0) + return correct / total + + +def main(): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + eval_batches = get_eval(n_batches=10, batch_size=128, device=device) + + print("=" * 72) + print("RANDOM-INIT SANITY CHECK (4-block d=256 ResMLP, CIFAR-10)") + print("=" * 72) + for seed in [42, 123, 456]: + torch.manual_seed(seed) + model = ResidualMLP(3072, 256, 10, 4).to(device) + # Skip BatchNorm-style update by not doing forward pass — just use + # random init parameters directly. + acc = chance_acc(model, device) + report = diagnose( + model=model, + eval_batches=eval_batches, + headline_acc=acc, + frozen_baseline_acc=None, + method_name=f"RANDOM_INIT seed {seed}", + notes="untrained, no parameter updates", + ) + print(f"\n=== seed {seed} (chance acc: {acc:.4f}) ===") + print(report) + + +if __name__ == "__main__": + main() |
