""" Apply the protocol's diagnostic logic to the SmallCNN architecture (3 conv blocks + 1 FC + head, BatchNorm, no terminal LayerNorm). The existing checkpoints are in `results/cnn_baseline/{method}_s{seed}.pt`. This is a custom audit script (not via `protocol.diagnose(...)`) because the CNN has 4D conv hidden states and no `model.embed` / `model.out_ln` attributes that the duck-typed protocol API expects. The diagnostic *logic* is identical: per-block growth of flattened ‖h_l‖, BP grad floor at the deepest hidden layer, frozen-blocks comparison. Why this matters: CNN with BatchNorm is a third architecture family (neither pre-LN ResMLP nor pre-LN ViT). Both BP and DFA should be informative test cases: - BP on CNN: should pass all diagnostics (sanity) - DFA on CNN: open question — BatchNorm normalizes per-feature, so the LN-driven gradient collapse mechanism may or may not apply Run: CUDA_VISIBLE_DEVICES=2 python -m protocol.examples.audit_cnn """ import os import sys import json import numpy as np import torch import torch.nn.functional as F import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader REPO_ROOT = os.path.dirname( os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ) sys.path.insert(0, REPO_ROOT) sys.path.insert(0, os.path.join(REPO_ROOT, "experiments")) # Import the SmallCNN from the experiments script import importlib.util _spec = importlib.util.spec_from_file_location( "cnn_baseline_module", os.path.join(REPO_ROOT, "experiments/cnn_baseline.py"), ) _mod = importlib.util.module_from_spec(_spec) _spec.loader.exec_module(_mod) SmallCNN = _mod.SmallCNN CKPT_DIR = os.path.join(REPO_ROOT, "results/cnn_baseline") THRESHOLD_PER_BLOCK = 50.0 THRESHOLD_GFLOOR = 1e-7 def get_eval(n=1024, batch_size=128, device="cuda:0"): tv = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv) loader = DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=0) batches = [] for x, y in loader: x, y = x.to(device), y.to(device) batches.append((x, y)) if sum(b[0].size(0) for b in batches) >= n: break return batches def evaluate(model, device): tv = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv) loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0) model.eval() correct = total = 0 with torch.no_grad(): for x, y in loader: x, y = x.to(device), y.to(device) preds = model(x).argmax(-1) correct += (preds == y).sum().item() total += x.size(0) return correct / total def per_layer_norms_and_grads(model, x, y): """For the CNN, return per-layer flattened ‖h_l‖ medians and ‖g_l‖ medians.""" model.eval() with torch.enable_grad(): h0 = model.blocks[0](x) h1 = model.blocks[1](h0) h2 = model.blocks[2](h1) h3 = model.blocks[3](h2.flatten(1)) logits = model.out_head(h3) hiddens = [h0, h1, h2, h3] loss = F.cross_entropy(logits, y) grads = torch.autograd.grad(loss, hiddens) h_norms = [] g_norms = [] for h, g in zip(hiddens, grads): h_flat = h.reshape(h.shape[0], -1) g_flat = g.reshape(g.shape[0], -1) h_norms.append(h_flat.norm(dim=-1).median().item()) g_norms.append(g_flat.norm(dim=-1).median().item()) return h_norms, g_norms def max_per_block_growth(h): if len(h) < 2: return 1.0 return max(h[i + 1] / max(h[i], 1e-30) for i in range(len(h) - 1)) def load_cnn(method, seed, device): path = os.path.join(CKPT_DIR, f"{method}_s{seed}.pt") sd = torch.load(path, map_location=device, weights_only=False) if isinstance(sd, dict) and "model_state" in sd: sd = sd["model_state"] elif isinstance(sd, dict) and "state_dict" in sd: sd = sd["state_dict"] model = SmallCNN().to(device) model.load_state_dict(sd) return model def main(): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") eval_batches = get_eval(n=1024, batch_size=128, device=device) x, y = eval_batches[0] methods = ["bp", "dfa", "state_bridge", "credit_bridge", "ep"] print() print("=" * 100) print("CNN audit (SmallCNN: 3 conv + BN + 1 FC, NO terminal LN, CIFAR-10)") print("=" * 100) print(f" {'method':<16}{'seed':>6}{'acc':>8}{'h_max/h_min':>14}{'max/block':>14}{'||g_L||':>14} verdict") print(" " + "-" * 100) rows = [] for seed in [42, 123, 456]: for method in methods: try: model = load_cnn(method, seed, device) except Exception as e: print(f" {method:<16}{seed:>6} SKIPPED ({e})") continue acc = evaluate(model, device) h_norms, g_norms = per_layer_norms_and_grads(model, x, y) max_growth = max_per_block_growth(h_norms) h_ratio = max(h_norms) / max(min(h_norms), 1e-30) g_L = g_norms[-1] flags = [] if max_growth > THRESHOLD_PER_BLOCK: flags.append("(a)") if g_L < THRESHOLD_GFLOOR: flags.append("(b)") verdict = "trustworthy" if not flags else f"walk-back: {'+'.join(flags)}" rows.append({ "method": method, "seed": seed, "acc": acc, "h_norms": h_norms, "g_norms": g_norms, "max_per_block": max_growth, "verdict": verdict, }) print(f" {method:<16}{seed:>6}{acc:>8.4f}{h_ratio:>14.2e}{max_growth:>14.2e}{g_L:>14.2e} {verdict}") print() print("=" * 100) print("Per-method 3-seed mean (h_norms across all 4 hidden layers, g across all):") print("=" * 100) for method in methods: method_rows = [r for r in rows if r["method"] == method] if not method_rows: continue accs = np.array([r["acc"] for r in method_rows]) h_arrs = np.array([r["h_norms"] for r in method_rows]) g_arrs = np.array([r["g_norms"] for r in method_rows]) max_g = np.array([r["max_per_block"] for r in method_rows]) print(f" {method.upper()}: acc={accs.mean():.4f}±{accs.std():.4f}, " f"h_means={h_arrs.mean(0)}, g_means={g_arrs.mean(0)}, " f"max-per-block={max_g.mean():.2e}") out_path = os.path.join(REPO_ROOT, "results/protocol_audit/audit_cnn_3seed.json") with open(out_path, "w") as f: json.dump(rows, f, indent=2) print(f"\nSaved {out_path}") if __name__ == "__main__": main()