summaryrefslogtreecommitdiff
path: root/protocol
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 22:48:18 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 22:48:18 -0500
commitc2e145e162444b31ac5c66a90daa6bc0a1cda591 (patch)
treead38ef6b7295ecdb3de67e4474d89548e9ca80ba /protocol
parent3a520b203f4f0c75b37b2d5c34d461718729ea02 (diff)
Add random-init sanity check: protocol does not flag untrained networks
3-seed random init ResMLP gives chance accuracy (~10%) but the protocol verdict is 'trustworthy' on all 3 seeds: - residual norms ~8.7 across all layers (no growth, bounded) - BP gradient norms ~8e-3 (healthy, well above 1e-7 floor) - cross-batch stability 0.08-0.18 (in the BP/EP range) This is the answer to the likely reviewer question: 'is your protocol just flagging anything that doesn't perform well?' Answer: no. Random init is at chance and the protocol passes it. The walked-back trained methods are walked back because of the *measurements*, not because of the accuracy. Notable: random init g-norms (8e-3) are actually HIGHER than BP-trained ones (4e-4) — BP training reduces the gradient magnitude as loss decreases. So the protocol distinguishes 3 distinct regimes: (1) untrained healthy, (2) trained-and-still-healthy (BP/EP), (3) trained-into-pathology (DFA/SB/CB).
Diffstat (limited to 'protocol')
-rw-r--r--protocol/examples/random_init_sanity.py96
1 files changed, 96 insertions, 0 deletions
diff --git a/protocol/examples/random_init_sanity.py b/protocol/examples/random_init_sanity.py
new file mode 100644
index 0000000..c4981fd
--- /dev/null
+++ b/protocol/examples/random_init_sanity.py
@@ -0,0 +1,96 @@
+"""
+Sanity check: apply the diagnostic protocol to a *randomly initialized*
+4-block d=256 ResMLP, with no training. The protocol should report
+"trustworthy" — random init is not in the failure regime, even though it
+also has trivial accuracy. The protocol is supposed to flag *trained-into-
+pathology* networks, not weak networks per se.
+
+This is the answer to the likely reviewer question: "is your protocol just
+flagging anything that doesn't perform well?" Answer: no. Random init
+performs at chance (10%) and the protocol passes it. The 3 walked-back
+trained methods are walked back because of the *measurements*, not the
+accuracy.
+
+Run:
+ CUDA_VISIBLE_DEVICES=2 python -m protocol.examples.random_init_sanity
+"""
+import os
+import sys
+
+import torch
+import torchvision
+import torchvision.transforms as transforms
+from torch.utils.data import DataLoader
+
+REPO_ROOT = os.path.dirname(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+)
+sys.path.insert(0, REPO_ROOT)
+
+from models.residual_mlp import ResidualMLP # noqa: E402
+from protocol import diagnose # noqa: E402
+
+
+def get_eval(n_batches, batch_size, device):
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
+ loader = DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=0)
+ batches = []
+ for x, y in loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batches.append((x, y))
+ if len(batches) >= n_batches:
+ break
+ return batches
+
+
+def chance_acc(model, device):
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
+ loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0)
+ model.eval()
+ correct = total = 0
+ with torch.no_grad():
+ for x, y in loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ preds = model(x).argmax(-1)
+ correct += (preds == y).sum().item()
+ total += x.size(0)
+ return correct / total
+
+
+def main():
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ eval_batches = get_eval(n_batches=10, batch_size=128, device=device)
+
+ print("=" * 72)
+ print("RANDOM-INIT SANITY CHECK (4-block d=256 ResMLP, CIFAR-10)")
+ print("=" * 72)
+ for seed in [42, 123, 456]:
+ torch.manual_seed(seed)
+ model = ResidualMLP(3072, 256, 10, 4).to(device)
+ # Skip BatchNorm-style update by not doing forward pass — just use
+ # random init parameters directly.
+ acc = chance_acc(model, device)
+ report = diagnose(
+ model=model,
+ eval_batches=eval_batches,
+ headline_acc=acc,
+ frozen_baseline_acc=None,
+ method_name=f"RANDOM_INIT seed {seed}",
+ notes="untrained, no parameter updates",
+ )
+ print(f"\n=== seed {seed} (chance acc: {acc:.4f}) ===")
+ print(report)
+
+
+if __name__ == "__main__":
+ main()