summaryrefslogtreecommitdiff
path: root/protocol/smoke_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'protocol/smoke_test.py')
-rw-r--r--protocol/smoke_test.py136
1 files changed, 136 insertions, 0 deletions
diff --git a/protocol/smoke_test.py b/protocol/smoke_test.py
new file mode 100644
index 0000000..272a9e5
--- /dev/null
+++ b/protocol/smoke_test.py
@@ -0,0 +1,136 @@
+"""
+Smoke test for the FA Diagnostic Protocol reference implementation.
+
+Loads BP-trained and DFA-trained ResMLP checkpoints from
+`results/confirmatory/checkpoints_A2/` and applies the protocol to each.
+The protocol should:
+ - return verdict="trustworthy" on the BP checkpoint (residual norms bounded,
+ BP grad ~5e-5 well above floor, low cross-batch direction stability),
+ - flag the DFA checkpoint as "needs walk-back" (residual stream exploded
+ by ~10^7×, BP grad at ~5e-10, drift-dominated direction).
+
+Run with:
+ CUDA_VISIBLE_DEVICES=2 python -m protocol.smoke_test
+"""
+import os
+import sys
+
+import torch
+import torchvision
+import torchvision.transforms as transforms
+from torch.utils.data import DataLoader
+
+# Make `protocol` and `models` importable when invoked from repo root
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.residual_mlp import ResidualMLP # noqa: E402
+from protocol import diagnose # noqa: E402
+
+
+REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+CHECKPOINT_DIR = os.path.join(REPO_ROOT, "results/confirmatory/checkpoints_A2")
+EP_CHECKPOINT_DIR = os.path.join(REPO_ROOT, "results/ep_baseline")
+
+
+def load_eval_batches(n_batches: int = 4, batch_size: int = 1024, device="cuda:0"):
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
+ loader = DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=0)
+ batches = []
+ for x, y in loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batches.append((x, y))
+ if len(batches) >= n_batches:
+ break
+ return batches
+
+
+def evaluate(model, loader, device):
+ model.eval()
+ correct = total = 0
+ with torch.no_grad():
+ for x, y in loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ preds = model(x).argmax(-1)
+ correct += (preds == y).sum().item()
+ total += x.size(0)
+ return correct / total
+
+
+def load_model(method: str, seed: int, device):
+ if method == "ep":
+ path = os.path.join(EP_CHECKPOINT_DIR, f"ep_s{seed}.pt")
+ else:
+ path = os.path.join(CHECKPOINT_DIR, f"{method}_s{seed}.pt")
+ ckpt = torch.load(path, map_location=device, weights_only=False)
+ # Try several common checkpoint formats
+ if isinstance(ckpt, dict) and "state_dict" in ckpt:
+ sd = ckpt["state_dict"]
+ elif isinstance(ckpt, dict) and "model" in ckpt:
+ sd = ckpt["model"]
+ elif isinstance(ckpt, dict):
+ sd = ckpt
+ else:
+ sd = ckpt.state_dict()
+ model = ResidualMLP(input_dim=3072, d_hidden=256, num_classes=10, num_blocks=4).to(device)
+ model.load_state_dict(sd)
+ return model
+
+
+def main():
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ print(f"Device: {device}")
+
+ # Cross-batch stability is sensitive to batch size: with smaller batches
+ # the per-sample noise component dominates the batch-mean direction on
+ # healthy networks, giving the cleanest separation from the drift-
+ # dominated failure mode (~0.10 healthy vs ~0.95 drift at N=128).
+ eval_batches = load_eval_batches(n_batches=10, batch_size=128, device=device)
+
+ # Build a single test loader for headline acc
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
+ test_loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0)
+
+ # BP: trustworthy. DFA: walked back. EP: trustworthy (the internal control
+ # — same architecture/metric/dataset, but EP doesn't blow up the residual
+ # stream, so the diagnostic protocol passes for EP even though EP's
+ # accuracy is also low. This is the paper's central comparison.)
+ for method, expected in [
+ ("bp", "trustworthy"),
+ ("dfa", "needs walk-back"),
+ ("ep", "trustworthy"),
+ ]:
+ print()
+ print(f"### {method.upper()} (seed 42)")
+ model = load_model(method, 42, device)
+ acc = evaluate(model, test_loader, device)
+ report = diagnose(
+ model=model,
+ eval_batches=eval_batches,
+ headline_acc=acc,
+ frozen_baseline_acc=None,
+ method_name=method.upper(),
+ notes="4-block d=256 ResMLP, CIFAR-10",
+ )
+ print(report)
+ # Sanity check
+ verdict = report.verdict
+ if expected == "trustworthy" and verdict != "trustworthy":
+ print(f"!! UNEXPECTED: BP should be trustworthy, got '{verdict}'")
+ elif expected == "needs walk-back" and not verdict.startswith("needs walk-back"):
+ print(f"!! UNEXPECTED: DFA should need walk-back, got '{verdict}'")
+ else:
+ print(f"OK: verdict matches expectation ('{expected}')")
+
+
+if __name__ == "__main__":
+ main()