summaryrefslogtreecommitdiff
path: root/protocol
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 23:42:06 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 23:42:06 -0500
commit4cd716757b50a1f4217a3ffdf8ee624c270b7a23 (patch)
tree4754c2a6fc28f6a91f724498e2eaf7598a2f12fa /protocol
parentacc86add44e0cac8701307f936029770edd50891 (diff)
Add CNN third-architecture audit: BN, no terminal LN
5 methods × 3 seeds on the SmallCNN (3 conv + BN + 1 FC + head, no terminal LN) using existing checkpoints in results/cnn_baseline/. Key findings: BP CNN: 0.866 acc, max/block 1.3, trustworthy State Bridge CNN: 0.633 acc, max/block 2.4, trustworthy EP CNN: 0.512 acc, max/block 12, trustworthy DFA CNN: 0.566 acc, max/block 237, walked back via (a) Credit Bridge CNN: 0.325 acc, max/block 96, walked back via (a) CRITICAL: diagnostic (b) ||g_L|| floor NEVER fires on CNN for any method. The deepest BP grad is at ~1e-5 to 6e-1, all well above the 1e-7 floor. This is the cleanest confirmation that terminal LayerNorm is the structural cause of the catastrophic gradient collapse in (b). Without out_ln, the BP grad does NOT collapse to the floor, even on DFA. The scale pathology (a) still appears on DFA and CB, but the gradient collapse pathology (b) is specific to terminal-LN architectures. DFA CNN's accuracy (56.6%) is much higher than DFA ResMLP (30.8%) or DFA ViT (23.7%) — partially because the scale pathology is less catastrophic without the LN-driven gradient cancellation amplifying it. This is the cross-architecture mechanism story made concrete.
Diffstat (limited to 'protocol')
-rw-r--r--protocol/examples/audit_cnn.py197
1 files changed, 197 insertions, 0 deletions
diff --git a/protocol/examples/audit_cnn.py b/protocol/examples/audit_cnn.py
new file mode 100644
index 0000000..890c1ec
--- /dev/null
+++ b/protocol/examples/audit_cnn.py
@@ -0,0 +1,197 @@
+"""
+Apply the protocol's diagnostic logic to the SmallCNN architecture (3 conv
+blocks + 1 FC + head, BatchNorm, no terminal LayerNorm). The existing
+checkpoints are in `results/cnn_baseline/{method}_s{seed}.pt`.
+
+This is a custom audit script (not via `protocol.diagnose(...)`) because
+the CNN has 4D conv hidden states and no `model.embed` / `model.out_ln`
+attributes that the duck-typed protocol API expects. The diagnostic
+*logic* is identical: per-block growth of flattened ‖h_l‖, BP grad floor
+at the deepest hidden layer, frozen-blocks comparison.
+
+Why this matters: CNN with BatchNorm is a third architecture family
+(neither pre-LN ResMLP nor pre-LN ViT). Both BP and DFA should be
+informative test cases:
+ - BP on CNN: should pass all diagnostics (sanity)
+ - DFA on CNN: open question — BatchNorm normalizes per-feature, so the
+ LN-driven gradient collapse mechanism may or may not apply
+
+Run:
+ CUDA_VISIBLE_DEVICES=2 python -m protocol.examples.audit_cnn
+"""
+import os
+import sys
+import json
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torchvision
+import torchvision.transforms as transforms
+from torch.utils.data import DataLoader
+
+REPO_ROOT = os.path.dirname(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+)
+sys.path.insert(0, REPO_ROOT)
+sys.path.insert(0, os.path.join(REPO_ROOT, "experiments"))
+
+# Import the SmallCNN from the experiments script
+import importlib.util
+_spec = importlib.util.spec_from_file_location(
+ "cnn_baseline_module",
+ os.path.join(REPO_ROOT, "experiments/cnn_baseline.py"),
+)
+_mod = importlib.util.module_from_spec(_spec)
+_spec.loader.exec_module(_mod)
+SmallCNN = _mod.SmallCNN
+
+
+CKPT_DIR = os.path.join(REPO_ROOT, "results/cnn_baseline")
+THRESHOLD_PER_BLOCK = 50.0
+THRESHOLD_GFLOOR = 1e-7
+
+
+def get_eval(n=1024, batch_size=128, device="cuda:0"):
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
+ loader = DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=0)
+ batches = []
+ for x, y in loader:
+ x, y = x.to(device), y.to(device)
+ batches.append((x, y))
+ if sum(b[0].size(0) for b in batches) >= n:
+ break
+ return batches
+
+
+def evaluate(model, device):
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
+ loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0)
+ model.eval()
+ correct = total = 0
+ with torch.no_grad():
+ for x, y in loader:
+ x, y = x.to(device), y.to(device)
+ preds = model(x).argmax(-1)
+ correct += (preds == y).sum().item()
+ total += x.size(0)
+ return correct / total
+
+
+def per_layer_norms_and_grads(model, x, y):
+ """For the CNN, return per-layer flattened ‖h_l‖ medians and ‖g_l‖ medians."""
+ model.eval()
+ with torch.enable_grad():
+ h0 = model.blocks[0](x)
+ h1 = model.blocks[1](h0)
+ h2 = model.blocks[2](h1)
+ h3 = model.blocks[3](h2.flatten(1))
+ logits = model.out_head(h3)
+ hiddens = [h0, h1, h2, h3]
+ loss = F.cross_entropy(logits, y)
+ grads = torch.autograd.grad(loss, hiddens)
+
+ h_norms = []
+ g_norms = []
+ for h, g in zip(hiddens, grads):
+ h_flat = h.reshape(h.shape[0], -1)
+ g_flat = g.reshape(g.shape[0], -1)
+ h_norms.append(h_flat.norm(dim=-1).median().item())
+ g_norms.append(g_flat.norm(dim=-1).median().item())
+ return h_norms, g_norms
+
+
+def max_per_block_growth(h):
+ if len(h) < 2:
+ return 1.0
+ return max(h[i + 1] / max(h[i], 1e-30) for i in range(len(h) - 1))
+
+
+def load_cnn(method, seed, device):
+ path = os.path.join(CKPT_DIR, f"{method}_s{seed}.pt")
+ sd = torch.load(path, map_location=device, weights_only=False)
+ if isinstance(sd, dict) and "model_state" in sd:
+ sd = sd["model_state"]
+ elif isinstance(sd, dict) and "state_dict" in sd:
+ sd = sd["state_dict"]
+ model = SmallCNN().to(device)
+ model.load_state_dict(sd)
+ return model
+
+
+def main():
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ print(f"Device: {device}")
+ eval_batches = get_eval(n=1024, batch_size=128, device=device)
+ x, y = eval_batches[0]
+
+ methods = ["bp", "dfa", "state_bridge", "credit_bridge", "ep"]
+ print()
+ print("=" * 100)
+ print("CNN audit (SmallCNN: 3 conv + BN + 1 FC, NO terminal LN, CIFAR-10)")
+ print("=" * 100)
+ print(f" {'method':<16}{'seed':>6}{'acc':>8}{'h_max/h_min':>14}{'max/block':>14}{'||g_L||':>14} verdict")
+ print(" " + "-" * 100)
+
+ rows = []
+ for seed in [42, 123, 456]:
+ for method in methods:
+ try:
+ model = load_cnn(method, seed, device)
+ except Exception as e:
+ print(f" {method:<16}{seed:>6} SKIPPED ({e})")
+ continue
+ acc = evaluate(model, device)
+ h_norms, g_norms = per_layer_norms_and_grads(model, x, y)
+ max_growth = max_per_block_growth(h_norms)
+ h_ratio = max(h_norms) / max(min(h_norms), 1e-30)
+ g_L = g_norms[-1]
+ flags = []
+ if max_growth > THRESHOLD_PER_BLOCK:
+ flags.append("(a)")
+ if g_L < THRESHOLD_GFLOOR:
+ flags.append("(b)")
+ verdict = "trustworthy" if not flags else f"walk-back: {'+'.join(flags)}"
+ rows.append({
+ "method": method,
+ "seed": seed,
+ "acc": acc,
+ "h_norms": h_norms,
+ "g_norms": g_norms,
+ "max_per_block": max_growth,
+ "verdict": verdict,
+ })
+ print(f" {method:<16}{seed:>6}{acc:>8.4f}{h_ratio:>14.2e}{max_growth:>14.2e}{g_L:>14.2e} {verdict}")
+
+ print()
+ print("=" * 100)
+ print("Per-method 3-seed mean (h_norms across all 4 hidden layers, g across all):")
+ print("=" * 100)
+ for method in methods:
+ method_rows = [r for r in rows if r["method"] == method]
+ if not method_rows:
+ continue
+ accs = np.array([r["acc"] for r in method_rows])
+ h_arrs = np.array([r["h_norms"] for r in method_rows])
+ g_arrs = np.array([r["g_norms"] for r in method_rows])
+ max_g = np.array([r["max_per_block"] for r in method_rows])
+ print(f" {method.upper()}: acc={accs.mean():.4f}±{accs.std():.4f}, "
+ f"h_means={h_arrs.mean(0)}, g_means={g_arrs.mean(0)}, "
+ f"max-per-block={max_g.mean():.2e}")
+
+ out_path = os.path.join(REPO_ROOT, "results/protocol_audit/audit_cnn_3seed.json")
+ with open(out_path, "w") as f:
+ json.dump(rows, f, indent=2)
+ print(f"\nSaved {out_path}")
+
+
+if __name__ == "__main__":
+ main()