summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--protocol/examples/audit_cnn.py197
-rw-r--r--results/protocol_audit/audit_cnn_3seed.json287
2 files changed, 484 insertions, 0 deletions
diff --git a/protocol/examples/audit_cnn.py b/protocol/examples/audit_cnn.py
new file mode 100644
index 0000000..890c1ec
--- /dev/null
+++ b/protocol/examples/audit_cnn.py
@@ -0,0 +1,197 @@
+"""
+Apply the protocol's diagnostic logic to the SmallCNN architecture (3 conv
+blocks + 1 FC + head, BatchNorm, no terminal LayerNorm). The existing
+checkpoints are in `results/cnn_baseline/{method}_s{seed}.pt`.
+
+This is a custom audit script (not via `protocol.diagnose(...)`) because
+the CNN has 4D conv hidden states and no `model.embed` / `model.out_ln`
+attributes that the duck-typed protocol API expects. The diagnostic
+*logic* is identical: per-block growth of flattened ‖h_l‖, BP grad floor
+at the deepest hidden layer, frozen-blocks comparison.
+
+Why this matters: CNN with BatchNorm is a third architecture family
+(neither pre-LN ResMLP nor pre-LN ViT). Both BP and DFA should be
+informative test cases:
+ - BP on CNN: should pass all diagnostics (sanity)
+ - DFA on CNN: open question — BatchNorm normalizes per-feature, so the
+ LN-driven gradient collapse mechanism may or may not apply
+
+Run:
+ CUDA_VISIBLE_DEVICES=2 python -m protocol.examples.audit_cnn
+"""
+import os
+import sys
+import json
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torchvision
+import torchvision.transforms as transforms
+from torch.utils.data import DataLoader
+
+REPO_ROOT = os.path.dirname(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+)
+sys.path.insert(0, REPO_ROOT)
+sys.path.insert(0, os.path.join(REPO_ROOT, "experiments"))
+
+# Import the SmallCNN from the experiments script
+import importlib.util
+_spec = importlib.util.spec_from_file_location(
+ "cnn_baseline_module",
+ os.path.join(REPO_ROOT, "experiments/cnn_baseline.py"),
+)
+_mod = importlib.util.module_from_spec(_spec)
+_spec.loader.exec_module(_mod)
+SmallCNN = _mod.SmallCNN
+
+
+CKPT_DIR = os.path.join(REPO_ROOT, "results/cnn_baseline")
+THRESHOLD_PER_BLOCK = 50.0
+THRESHOLD_GFLOOR = 1e-7
+
+
+def get_eval(n=1024, batch_size=128, device="cuda:0"):
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
+ loader = DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=0)
+ batches = []
+ for x, y in loader:
+ x, y = x.to(device), y.to(device)
+ batches.append((x, y))
+ if sum(b[0].size(0) for b in batches) >= n:
+ break
+ return batches
+
+
+def evaluate(model, device):
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
+ loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0)
+ model.eval()
+ correct = total = 0
+ with torch.no_grad():
+ for x, y in loader:
+ x, y = x.to(device), y.to(device)
+ preds = model(x).argmax(-1)
+ correct += (preds == y).sum().item()
+ total += x.size(0)
+ return correct / total
+
+
+def per_layer_norms_and_grads(model, x, y):
+ """For the CNN, return per-layer flattened ‖h_l‖ medians and ‖g_l‖ medians."""
+ model.eval()
+ with torch.enable_grad():
+ h0 = model.blocks[0](x)
+ h1 = model.blocks[1](h0)
+ h2 = model.blocks[2](h1)
+ h3 = model.blocks[3](h2.flatten(1))
+ logits = model.out_head(h3)
+ hiddens = [h0, h1, h2, h3]
+ loss = F.cross_entropy(logits, y)
+ grads = torch.autograd.grad(loss, hiddens)
+
+ h_norms = []
+ g_norms = []
+ for h, g in zip(hiddens, grads):
+ h_flat = h.reshape(h.shape[0], -1)
+ g_flat = g.reshape(g.shape[0], -1)
+ h_norms.append(h_flat.norm(dim=-1).median().item())
+ g_norms.append(g_flat.norm(dim=-1).median().item())
+ return h_norms, g_norms
+
+
+def max_per_block_growth(h):
+ if len(h) < 2:
+ return 1.0
+ return max(h[i + 1] / max(h[i], 1e-30) for i in range(len(h) - 1))
+
+
+def load_cnn(method, seed, device):
+ path = os.path.join(CKPT_DIR, f"{method}_s{seed}.pt")
+ sd = torch.load(path, map_location=device, weights_only=False)
+ if isinstance(sd, dict) and "model_state" in sd:
+ sd = sd["model_state"]
+ elif isinstance(sd, dict) and "state_dict" in sd:
+ sd = sd["state_dict"]
+ model = SmallCNN().to(device)
+ model.load_state_dict(sd)
+ return model
+
+
+def main():
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ print(f"Device: {device}")
+ eval_batches = get_eval(n=1024, batch_size=128, device=device)
+ x, y = eval_batches[0]
+
+ methods = ["bp", "dfa", "state_bridge", "credit_bridge", "ep"]
+ print()
+ print("=" * 100)
+ print("CNN audit (SmallCNN: 3 conv + BN + 1 FC, NO terminal LN, CIFAR-10)")
+ print("=" * 100)
+ print(f" {'method':<16}{'seed':>6}{'acc':>8}{'h_max/h_min':>14}{'max/block':>14}{'||g_L||':>14} verdict")
+ print(" " + "-" * 100)
+
+ rows = []
+ for seed in [42, 123, 456]:
+ for method in methods:
+ try:
+ model = load_cnn(method, seed, device)
+ except Exception as e:
+ print(f" {method:<16}{seed:>6} SKIPPED ({e})")
+ continue
+ acc = evaluate(model, device)
+ h_norms, g_norms = per_layer_norms_and_grads(model, x, y)
+ max_growth = max_per_block_growth(h_norms)
+ h_ratio = max(h_norms) / max(min(h_norms), 1e-30)
+ g_L = g_norms[-1]
+ flags = []
+ if max_growth > THRESHOLD_PER_BLOCK:
+ flags.append("(a)")
+ if g_L < THRESHOLD_GFLOOR:
+ flags.append("(b)")
+ verdict = "trustworthy" if not flags else f"walk-back: {'+'.join(flags)}"
+ rows.append({
+ "method": method,
+ "seed": seed,
+ "acc": acc,
+ "h_norms": h_norms,
+ "g_norms": g_norms,
+ "max_per_block": max_growth,
+ "verdict": verdict,
+ })
+ print(f" {method:<16}{seed:>6}{acc:>8.4f}{h_ratio:>14.2e}{max_growth:>14.2e}{g_L:>14.2e} {verdict}")
+
+ print()
+ print("=" * 100)
+ print("Per-method 3-seed mean (h_norms across all 4 hidden layers, g across all):")
+ print("=" * 100)
+ for method in methods:
+ method_rows = [r for r in rows if r["method"] == method]
+ if not method_rows:
+ continue
+ accs = np.array([r["acc"] for r in method_rows])
+ h_arrs = np.array([r["h_norms"] for r in method_rows])
+ g_arrs = np.array([r["g_norms"] for r in method_rows])
+ max_g = np.array([r["max_per_block"] for r in method_rows])
+ print(f" {method.upper()}: acc={accs.mean():.4f}±{accs.std():.4f}, "
+ f"h_means={h_arrs.mean(0)}, g_means={g_arrs.mean(0)}, "
+ f"max-per-block={max_g.mean():.2e}")
+
+ out_path = os.path.join(REPO_ROOT, "results/protocol_audit/audit_cnn_3seed.json")
+ with open(out_path, "w") as f:
+ json.dump(rows, f, indent=2)
+ print(f"\nSaved {out_path}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/results/protocol_audit/audit_cnn_3seed.json b/results/protocol_audit/audit_cnn_3seed.json
new file mode 100644
index 0000000..e31d90a
--- /dev/null
+++ b/results/protocol_audit/audit_cnn_3seed.json
@@ -0,0 +1,287 @@
+[
+ {
+ "method": "bp",
+ "seed": 42,
+ "acc": 0.8621,
+ "h_norms": [
+ 92.19713592529297,
+ 36.073795318603516,
+ 31.53321647644043,
+ 37.99840545654297
+ ],
+ "g_norms": [
+ 0.0003265859850216657,
+ 0.0005140144494362175,
+ 0.00032046454725787044,
+ 4.230972990626469e-05
+ ],
+ "max_per_block": 1.2050278944722594,
+ "verdict": "trustworthy"
+ },
+ {
+ "method": "dfa",
+ "seed": 42,
+ "acc": 0.5526,
+ "h_norms": [
+ 250.00730895996094,
+ 312.03765869140625,
+ 338.56951904296875,
+ 72491.6875
+ ],
+ "g_norms": [
+ 0.008991315960884094,
+ 0.004240375477820635,
+ 0.0019398012664169073,
+ 0.0007538454374298453
+ ],
+ "max_per_block": 214.1116769900361,
+ "verdict": "walk-back: (a)"
+ },
+ {
+ "method": "state_bridge",
+ "seed": 42,
+ "acc": 0.632,
+ "h_norms": [
+ 82.06472778320312,
+ 58.023536682128906,
+ 63.79484176635742,
+ 146.3084716796875
+ ],
+ "g_norms": [
+ 0.012001844123005867,
+ 0.0053405677899718285,
+ 0.002992230001837015,
+ 0.0019513164879754186
+ ],
+ "max_per_block": 2.2934216564958096,
+ "verdict": "trustworthy"
+ },
+ {
+ "method": "credit_bridge",
+ "seed": 42,
+ "acc": 0.3357,
+ "h_norms": [
+ 188.0859375,
+ 184.6356658935547,
+ 197.04025268554688,
+ 21333.876953125
+ ],
+ "g_norms": [
+ 0.011781061999499798,
+ 0.009520439431071281,
+ 0.007203294429928064,
+ 0.0029390468262135983
+ ],
+ "max_per_block": 108.27166866848961,
+ "verdict": "walk-back: (a)"
+ },
+ {
+ "method": "ep",
+ "seed": 42,
+ "acc": 0.5033,
+ "h_norms": [
+ 83.98963165283203,
+ 80.78277587890625,
+ 69.89965057373047,
+ 23.533870697021484
+ ],
+ "g_norms": [
+ 0.02113482914865017,
+ 0.01996646821498871,
+ 0.04826148599386215,
+ 0.6656972765922546
+ ],
+ "max_per_block": 0.9618184326943926,
+ "verdict": "trustworthy"
+ },
+ {
+ "method": "bp",
+ "seed": 123,
+ "acc": 0.8683,
+ "h_norms": [
+ 95.01183319091797,
+ 37.34379196166992,
+ 30.323930740356445,
+ 41.04403305053711
+ ],
+ "g_norms": [
+ 0.00020042041433043778,
+ 0.0002769956481643021,
+ 0.0002367425913689658,
+ 2.9566466764663346e-05
+ ],
+ "max_per_block": 1.3535195487013123,
+ "verdict": "trustworthy"
+ },
+ {
+ "method": "dfa",
+ "seed": 123,
+ "acc": 0.5501,
+ "h_norms": [
+ 259.4548034667969,
+ 386.1385498046875,
+ 345.1282958984375,
+ 81147.859375
+ ],
+ "g_norms": [
+ 0.0176572035998106,
+ 0.006263771094381809,
+ 0.004082133527845144,
+ 0.0012416314566507936
+ ],
+ "max_per_block": 235.12375061498798,
+ "verdict": "walk-back: (a)"
+ },
+ {
+ "method": "state_bridge",
+ "seed": 123,
+ "acc": 0.6277,
+ "h_norms": [
+ 78.61927032470703,
+ 61.77223587036133,
+ 62.21416473388672,
+ 168.63580322265625
+ ],
+ "g_norms": [
+ 0.01707093045115471,
+ 0.005183099303394556,
+ 0.0028397536370903254,
+ 0.001994567457586527
+ ],
+ "max_per_block": 2.7105692721902597,
+ "verdict": "trustworthy"
+ },
+ {
+ "method": "credit_bridge",
+ "seed": 123,
+ "acc": 0.3132,
+ "h_norms": [
+ 153.3919219970703,
+ 174.97061157226562,
+ 205.44915771484375,
+ 18492.830078125
+ ],
+ "g_norms": [
+ 0.011560788378119469,
+ 0.00875047780573368,
+ 0.007237838581204414,
+ 0.0031556652393192053
+ ],
+ "max_per_block": 90.01171036092735,
+ "verdict": "walk-back: (a)"
+ },
+ {
+ "method": "ep",
+ "seed": 123,
+ "acc": 0.4897,
+ "h_norms": [
+ 86.96959686279297,
+ 82.86459350585938,
+ 57.40688705444336,
+ 1.998042345046997
+ ],
+ "g_norms": [
+ 0.01942148432135582,
+ 0.0226480383425951,
+ 0.06046672910451889,
+ 1.157748818397522
+ ],
+ "max_per_block": 0.9527995586387525,
+ "verdict": "trustworthy"
+ },
+ {
+ "method": "bp",
+ "seed": 456,
+ "acc": 0.8681,
+ "h_norms": [
+ 96.83692169189453,
+ 37.44154739379883,
+ 31.123756408691406,
+ 42.27854919433594
+ ],
+ "g_norms": [
+ 0.00014156661927700043,
+ 0.00021504472533706576,
+ 0.00015155959408730268,
+ 1.7612564988667145e-05
+ ],
+ "max_per_block": 1.3584012366363825,
+ "verdict": "trustworthy"
+ },
+ {
+ "method": "dfa",
+ "seed": 456,
+ "acc": 0.5954,
+ "h_norms": [
+ 206.5431671142578,
+ 266.57421875,
+ 254.53468322753906,
+ 66974.734375
+ ],
+ "g_norms": [
+ 0.00448678620159626,
+ 0.002167485887184739,
+ 0.0012352537596598268,
+ 0.0004585048300214112
+ ],
+ "max_per_block": 263.12616232000306,
+ "verdict": "walk-back: (a)"
+ },
+ {
+ "method": "state_bridge",
+ "seed": 456,
+ "acc": 0.6396,
+ "h_norms": [
+ 71.60630798339844,
+ 56.15557098388672,
+ 63.141014099121094,
+ 137.78231811523438
+ ],
+ "g_norms": [
+ 0.014506030827760696,
+ 0.005259184632450342,
+ 0.0027562177274376154,
+ 0.001790599781088531
+ ],
+ "max_per_block": 2.1821366045679693,
+ "verdict": "trustworthy"
+ },
+ {
+ "method": "credit_bridge",
+ "seed": 456,
+ "acc": 0.3251,
+ "h_norms": [
+ 169.41067504882812,
+ 151.55250549316406,
+ 177.73605346679688,
+ 16139.1640625
+ ],
+ "g_norms": [
+ 0.009933868423104286,
+ 0.00681547075510025,
+ 0.004268169403076172,
+ 0.0029866439290344715
+ ],
+ "max_per_block": 90.80410950789441,
+ "verdict": "walk-back: (a)"
+ },
+ {
+ "method": "ep",
+ "seed": 456,
+ "acc": 0.5432,
+ "h_norms": [
+ 84.82034301757812,
+ 85.0868911743164,
+ 163.61471557617188,
+ 5375.328125
+ ],
+ "g_norms": [
+ 0.028629517182707787,
+ 0.01810075342655182,
+ 0.0257416944950819,
+ 0.3386228382587433
+ ],
+ "max_per_block": 32.85357375142385,
+ "verdict": "trustworthy"
+ }
+] \ No newline at end of file