From 7b64702ad970c16171142665365e16a8e1737190 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Tue, 7 Apr 2026 22:20:48 -0500 Subject: Add FA diagnostic protocol reference implementation Codex round 15 #1 priority for the E&D-track paper: - protocol/protocol.py: 4 diagnostics (residual norms, BP grad norms, cross-batch direction stability, and a frozen-baseline comparator) - protocol/report.py: DiagnosticReport with per-diagnostic verdicts and pretty-printer - protocol/smoke_test.py: validates BP/DFA/EP checkpoints produce the expected verdicts (BP/EP trustworthy; DFA walked back via residual explosion + BP grad at floor) - protocol/README.md: usage, audit cases, threshold rationale - protocol/CHECKLIST.md: 6 evaluation pipeline pitfalls (norm(-1), cosine_similarity eps clamp, fp16 underflow, Bs reproducibility, aggregation, layer-0 dominance) - protocol/REPORTING_TEMPLATE.md: per-method fillable form for FA papers --- protocol/smoke_test.py | 136 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 136 insertions(+) create mode 100644 protocol/smoke_test.py (limited to 'protocol/smoke_test.py') diff --git a/protocol/smoke_test.py b/protocol/smoke_test.py new file mode 100644 index 0000000..272a9e5 --- /dev/null +++ b/protocol/smoke_test.py @@ -0,0 +1,136 @@ +""" +Smoke test for the FA Diagnostic Protocol reference implementation. + +Loads BP-trained and DFA-trained ResMLP checkpoints from +`results/confirmatory/checkpoints_A2/` and applies the protocol to each. +The protocol should: + - return verdict="trustworthy" on the BP checkpoint (residual norms bounded, + BP grad ~5e-5 well above floor, low cross-batch direction stability), + - flag the DFA checkpoint as "needs walk-back" (residual stream exploded + by ~10^7×, BP grad at ~5e-10, drift-dominated direction). + +Run with: + CUDA_VISIBLE_DEVICES=2 python -m protocol.smoke_test +""" +import os +import sys + +import torch +import torchvision +import torchvision.transforms as transforms +from torch.utils.data import DataLoader + +# Make `protocol` and `models` importable when invoked from repo root +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.residual_mlp import ResidualMLP # noqa: E402 +from protocol import diagnose # noqa: E402 + + +REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +CHECKPOINT_DIR = os.path.join(REPO_ROOT, "results/confirmatory/checkpoints_A2") +EP_CHECKPOINT_DIR = os.path.join(REPO_ROOT, "results/ep_baseline") + + +def load_eval_batches(n_batches: int = 4, batch_size: int = 1024, device="cuda:0"): + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv) + loader = DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=0) + batches = [] + for x, y in loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + batches.append((x, y)) + if len(batches) >= n_batches: + break + return batches + + +def evaluate(model, loader, device): + model.eval() + correct = total = 0 + with torch.no_grad(): + for x, y in loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + preds = model(x).argmax(-1) + correct += (preds == y).sum().item() + total += x.size(0) + return correct / total + + +def load_model(method: str, seed: int, device): + if method == "ep": + path = os.path.join(EP_CHECKPOINT_DIR, f"ep_s{seed}.pt") + else: + path = os.path.join(CHECKPOINT_DIR, f"{method}_s{seed}.pt") + ckpt = torch.load(path, map_location=device, weights_only=False) + # Try several common checkpoint formats + if isinstance(ckpt, dict) and "state_dict" in ckpt: + sd = ckpt["state_dict"] + elif isinstance(ckpt, dict) and "model" in ckpt: + sd = ckpt["model"] + elif isinstance(ckpt, dict): + sd = ckpt + else: + sd = ckpt.state_dict() + model = ResidualMLP(input_dim=3072, d_hidden=256, num_classes=10, num_blocks=4).to(device) + model.load_state_dict(sd) + return model + + +def main(): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + print(f"Device: {device}") + + # Cross-batch stability is sensitive to batch size: with smaller batches + # the per-sample noise component dominates the batch-mean direction on + # healthy networks, giving the cleanest separation from the drift- + # dominated failure mode (~0.10 healthy vs ~0.95 drift at N=128). + eval_batches = load_eval_batches(n_batches=10, batch_size=128, device=device) + + # Build a single test loader for headline acc + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv) + test_loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0) + + # BP: trustworthy. DFA: walked back. EP: trustworthy (the internal control + # — same architecture/metric/dataset, but EP doesn't blow up the residual + # stream, so the diagnostic protocol passes for EP even though EP's + # accuracy is also low. This is the paper's central comparison.) + for method, expected in [ + ("bp", "trustworthy"), + ("dfa", "needs walk-back"), + ("ep", "trustworthy"), + ]: + print() + print(f"### {method.upper()} (seed 42)") + model = load_model(method, 42, device) + acc = evaluate(model, test_loader, device) + report = diagnose( + model=model, + eval_batches=eval_batches, + headline_acc=acc, + frozen_baseline_acc=None, + method_name=method.upper(), + notes="4-block d=256 ResMLP, CIFAR-10", + ) + print(report) + # Sanity check + verdict = report.verdict + if expected == "trustworthy" and verdict != "trustworthy": + print(f"!! UNEXPECTED: BP should be trustworthy, got '{verdict}'") + elif expected == "needs walk-back" and not verdict.startswith("needs walk-back"): + print(f"!! UNEXPECTED: DFA should need walk-back, got '{verdict}'") + else: + print(f"OK: verdict matches expectation ('{expected}')") + + +if __name__ == "__main__": + main() -- cgit v1.2.3