diff options
| -rw-r--r-- | protocol/examples/audit_cnn.py | 197 | ||||
| -rw-r--r-- | results/protocol_audit/audit_cnn_3seed.json | 287 |
2 files changed, 484 insertions, 0 deletions
diff --git a/protocol/examples/audit_cnn.py b/protocol/examples/audit_cnn.py new file mode 100644 index 0000000..890c1ec --- /dev/null +++ b/protocol/examples/audit_cnn.py @@ -0,0 +1,197 @@ +""" +Apply the protocol's diagnostic logic to the SmallCNN architecture (3 conv +blocks + 1 FC + head, BatchNorm, no terminal LayerNorm). The existing +checkpoints are in `results/cnn_baseline/{method}_s{seed}.pt`. + +This is a custom audit script (not via `protocol.diagnose(...)`) because +the CNN has 4D conv hidden states and no `model.embed` / `model.out_ln` +attributes that the duck-typed protocol API expects. The diagnostic +*logic* is identical: per-block growth of flattened ‖h_l‖, BP grad floor +at the deepest hidden layer, frozen-blocks comparison. + +Why this matters: CNN with BatchNorm is a third architecture family +(neither pre-LN ResMLP nor pre-LN ViT). Both BP and DFA should be +informative test cases: + - BP on CNN: should pass all diagnostics (sanity) + - DFA on CNN: open question — BatchNorm normalizes per-feature, so the + LN-driven gradient collapse mechanism may or may not apply + +Run: + CUDA_VISIBLE_DEVICES=2 python -m protocol.examples.audit_cnn +""" +import os +import sys +import json + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision +import torchvision.transforms as transforms +from torch.utils.data import DataLoader + +REPO_ROOT = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) +sys.path.insert(0, REPO_ROOT) +sys.path.insert(0, os.path.join(REPO_ROOT, "experiments")) + +# Import the SmallCNN from the experiments script +import importlib.util +_spec = importlib.util.spec_from_file_location( + "cnn_baseline_module", + os.path.join(REPO_ROOT, "experiments/cnn_baseline.py"), +) +_mod = importlib.util.module_from_spec(_spec) +_spec.loader.exec_module(_mod) +SmallCNN = _mod.SmallCNN + + +CKPT_DIR = os.path.join(REPO_ROOT, "results/cnn_baseline") +THRESHOLD_PER_BLOCK = 50.0 +THRESHOLD_GFLOOR = 1e-7 + + +def get_eval(n=1024, batch_size=128, device="cuda:0"): + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv) + loader = DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=0) + batches = [] + for x, y in loader: + x, y = x.to(device), y.to(device) + batches.append((x, y)) + if sum(b[0].size(0) for b in batches) >= n: + break + return batches + + +def evaluate(model, device): + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv) + loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0) + model.eval() + correct = total = 0 + with torch.no_grad(): + for x, y in loader: + x, y = x.to(device), y.to(device) + preds = model(x).argmax(-1) + correct += (preds == y).sum().item() + total += x.size(0) + return correct / total + + +def per_layer_norms_and_grads(model, x, y): + """For the CNN, return per-layer flattened ‖h_l‖ medians and ‖g_l‖ medians.""" + model.eval() + with torch.enable_grad(): + h0 = model.blocks[0](x) + h1 = model.blocks[1](h0) + h2 = model.blocks[2](h1) + h3 = model.blocks[3](h2.flatten(1)) + logits = model.out_head(h3) + hiddens = [h0, h1, h2, h3] + loss = F.cross_entropy(logits, y) + grads = torch.autograd.grad(loss, hiddens) + + h_norms = [] + g_norms = [] + for h, g in zip(hiddens, grads): + h_flat = h.reshape(h.shape[0], -1) + g_flat = g.reshape(g.shape[0], -1) + h_norms.append(h_flat.norm(dim=-1).median().item()) + g_norms.append(g_flat.norm(dim=-1).median().item()) + return h_norms, g_norms + + +def max_per_block_growth(h): + if len(h) < 2: + return 1.0 + return max(h[i + 1] / max(h[i], 1e-30) for i in range(len(h) - 1)) + + +def load_cnn(method, seed, device): + path = os.path.join(CKPT_DIR, f"{method}_s{seed}.pt") + sd = torch.load(path, map_location=device, weights_only=False) + if isinstance(sd, dict) and "model_state" in sd: + sd = sd["model_state"] + elif isinstance(sd, dict) and "state_dict" in sd: + sd = sd["state_dict"] + model = SmallCNN().to(device) + model.load_state_dict(sd) + return model + + +def main(): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + print(f"Device: {device}") + eval_batches = get_eval(n=1024, batch_size=128, device=device) + x, y = eval_batches[0] + + methods = ["bp", "dfa", "state_bridge", "credit_bridge", "ep"] + print() + print("=" * 100) + print("CNN audit (SmallCNN: 3 conv + BN + 1 FC, NO terminal LN, CIFAR-10)") + print("=" * 100) + print(f" {'method':<16}{'seed':>6}{'acc':>8}{'h_max/h_min':>14}{'max/block':>14}{'||g_L||':>14} verdict") + print(" " + "-" * 100) + + rows = [] + for seed in [42, 123, 456]: + for method in methods: + try: + model = load_cnn(method, seed, device) + except Exception as e: + print(f" {method:<16}{seed:>6} SKIPPED ({e})") + continue + acc = evaluate(model, device) + h_norms, g_norms = per_layer_norms_and_grads(model, x, y) + max_growth = max_per_block_growth(h_norms) + h_ratio = max(h_norms) / max(min(h_norms), 1e-30) + g_L = g_norms[-1] + flags = [] + if max_growth > THRESHOLD_PER_BLOCK: + flags.append("(a)") + if g_L < THRESHOLD_GFLOOR: + flags.append("(b)") + verdict = "trustworthy" if not flags else f"walk-back: {'+'.join(flags)}" + rows.append({ + "method": method, + "seed": seed, + "acc": acc, + "h_norms": h_norms, + "g_norms": g_norms, + "max_per_block": max_growth, + "verdict": verdict, + }) + print(f" {method:<16}{seed:>6}{acc:>8.4f}{h_ratio:>14.2e}{max_growth:>14.2e}{g_L:>14.2e} {verdict}") + + print() + print("=" * 100) + print("Per-method 3-seed mean (h_norms across all 4 hidden layers, g across all):") + print("=" * 100) + for method in methods: + method_rows = [r for r in rows if r["method"] == method] + if not method_rows: + continue + accs = np.array([r["acc"] for r in method_rows]) + h_arrs = np.array([r["h_norms"] for r in method_rows]) + g_arrs = np.array([r["g_norms"] for r in method_rows]) + max_g = np.array([r["max_per_block"] for r in method_rows]) + print(f" {method.upper()}: acc={accs.mean():.4f}±{accs.std():.4f}, " + f"h_means={h_arrs.mean(0)}, g_means={g_arrs.mean(0)}, " + f"max-per-block={max_g.mean():.2e}") + + out_path = os.path.join(REPO_ROOT, "results/protocol_audit/audit_cnn_3seed.json") + with open(out_path, "w") as f: + json.dump(rows, f, indent=2) + print(f"\nSaved {out_path}") + + +if __name__ == "__main__": + main() diff --git a/results/protocol_audit/audit_cnn_3seed.json b/results/protocol_audit/audit_cnn_3seed.json new file mode 100644 index 0000000..e31d90a --- /dev/null +++ b/results/protocol_audit/audit_cnn_3seed.json @@ -0,0 +1,287 @@ +[ + { + "method": "bp", + "seed": 42, + "acc": 0.8621, + "h_norms": [ + 92.19713592529297, + 36.073795318603516, + 31.53321647644043, + 37.99840545654297 + ], + "g_norms": [ + 0.0003265859850216657, + 0.0005140144494362175, + 0.00032046454725787044, + 4.230972990626469e-05 + ], + "max_per_block": 1.2050278944722594, + "verdict": "trustworthy" + }, + { + "method": "dfa", + "seed": 42, + "acc": 0.5526, + "h_norms": [ + 250.00730895996094, + 312.03765869140625, + 338.56951904296875, + 72491.6875 + ], + "g_norms": [ + 0.008991315960884094, + 0.004240375477820635, + 0.0019398012664169073, + 0.0007538454374298453 + ], + "max_per_block": 214.1116769900361, + "verdict": "walk-back: (a)" + }, + { + "method": "state_bridge", + "seed": 42, + "acc": 0.632, + "h_norms": [ + 82.06472778320312, + 58.023536682128906, + 63.79484176635742, + 146.3084716796875 + ], + "g_norms": [ + 0.012001844123005867, + 0.0053405677899718285, + 0.002992230001837015, + 0.0019513164879754186 + ], + "max_per_block": 2.2934216564958096, + "verdict": "trustworthy" + }, + { + "method": "credit_bridge", + "seed": 42, + "acc": 0.3357, + "h_norms": [ + 188.0859375, + 184.6356658935547, + 197.04025268554688, + 21333.876953125 + ], + "g_norms": [ + 0.011781061999499798, + 0.009520439431071281, + 0.007203294429928064, + 0.0029390468262135983 + ], + "max_per_block": 108.27166866848961, + "verdict": "walk-back: (a)" + }, + { + "method": "ep", + "seed": 42, + "acc": 0.5033, + "h_norms": [ + 83.98963165283203, + 80.78277587890625, + 69.89965057373047, + 23.533870697021484 + ], + "g_norms": [ + 0.02113482914865017, + 0.01996646821498871, + 0.04826148599386215, + 0.6656972765922546 + ], + "max_per_block": 0.9618184326943926, + "verdict": "trustworthy" + }, + { + "method": "bp", + "seed": 123, + "acc": 0.8683, + "h_norms": [ + 95.01183319091797, + 37.34379196166992, + 30.323930740356445, + 41.04403305053711 + ], + "g_norms": [ + 0.00020042041433043778, + 0.0002769956481643021, + 0.0002367425913689658, + 2.9566466764663346e-05 + ], + "max_per_block": 1.3535195487013123, + "verdict": "trustworthy" + }, + { + "method": "dfa", + "seed": 123, + "acc": 0.5501, + "h_norms": [ + 259.4548034667969, + 386.1385498046875, + 345.1282958984375, + 81147.859375 + ], + "g_norms": [ + 0.0176572035998106, + 0.006263771094381809, + 0.004082133527845144, + 0.0012416314566507936 + ], + "max_per_block": 235.12375061498798, + "verdict": "walk-back: (a)" + }, + { + "method": "state_bridge", + "seed": 123, + "acc": 0.6277, + "h_norms": [ + 78.61927032470703, + 61.77223587036133, + 62.21416473388672, + 168.63580322265625 + ], + "g_norms": [ + 0.01707093045115471, + 0.005183099303394556, + 0.0028397536370903254, + 0.001994567457586527 + ], + "max_per_block": 2.7105692721902597, + "verdict": "trustworthy" + }, + { + "method": "credit_bridge", + "seed": 123, + "acc": 0.3132, + "h_norms": [ + 153.3919219970703, + 174.97061157226562, + 205.44915771484375, + 18492.830078125 + ], + "g_norms": [ + 0.011560788378119469, + 0.00875047780573368, + 0.007237838581204414, + 0.0031556652393192053 + ], + "max_per_block": 90.01171036092735, + "verdict": "walk-back: (a)" + }, + { + "method": "ep", + "seed": 123, + "acc": 0.4897, + "h_norms": [ + 86.96959686279297, + 82.86459350585938, + 57.40688705444336, + 1.998042345046997 + ], + "g_norms": [ + 0.01942148432135582, + 0.0226480383425951, + 0.06046672910451889, + 1.157748818397522 + ], + "max_per_block": 0.9527995586387525, + "verdict": "trustworthy" + }, + { + "method": "bp", + "seed": 456, + "acc": 0.8681, + "h_norms": [ + 96.83692169189453, + 37.44154739379883, + 31.123756408691406, + 42.27854919433594 + ], + "g_norms": [ + 0.00014156661927700043, + 0.00021504472533706576, + 0.00015155959408730268, + 1.7612564988667145e-05 + ], + "max_per_block": 1.3584012366363825, + "verdict": "trustworthy" + }, + { + "method": "dfa", + "seed": 456, + "acc": 0.5954, + "h_norms": [ + 206.5431671142578, + 266.57421875, + 254.53468322753906, + 66974.734375 + ], + "g_norms": [ + 0.00448678620159626, + 0.002167485887184739, + 0.0012352537596598268, + 0.0004585048300214112 + ], + "max_per_block": 263.12616232000306, + "verdict": "walk-back: (a)" + }, + { + "method": "state_bridge", + "seed": 456, + "acc": 0.6396, + "h_norms": [ + 71.60630798339844, + 56.15557098388672, + 63.141014099121094, + 137.78231811523438 + ], + "g_norms": [ + 0.014506030827760696, + 0.005259184632450342, + 0.0027562177274376154, + 0.001790599781088531 + ], + "max_per_block": 2.1821366045679693, + "verdict": "trustworthy" + }, + { + "method": "credit_bridge", + "seed": 456, + "acc": 0.3251, + "h_norms": [ + 169.41067504882812, + 151.55250549316406, + 177.73605346679688, + 16139.1640625 + ], + "g_norms": [ + 0.009933868423104286, + 0.00681547075510025, + 0.004268169403076172, + 0.0029866439290344715 + ], + "max_per_block": 90.80410950789441, + "verdict": "walk-back: (a)" + }, + { + "method": "ep", + "seed": 456, + "acc": 0.5432, + "h_norms": [ + 84.82034301757812, + 85.0868911743164, + 163.61471557617188, + 5375.328125 + ], + "g_norms": [ + 0.028629517182707787, + 0.01810075342655182, + 0.0257416944950819, + 0.3386228382587433 + ], + "max_per_block": 32.85357375142385, + "verdict": "trustworthy" + } +]
\ No newline at end of file |
