""" Null calibration of the +0.17 deep-layer cosine on penalized DFA. Codex round 19 critical control: same penalized checkpoint, but compute the cosine with FRESH random Bs (not the training-time Bs). If +0.17 was real signal that the network adapted to its training-time Bs, fresh Bs should give cosine ≈ 0. If +0.17 was an artifact of how the cosine is computed (e.g., a property of the penalized network independent of the Bs), fresh Bs should also give ~+0.17. Run: CUDA_VISIBLE_DEVICES=2 python experiments/null_calibration_penalized_cos.py """ import os import sys import argparse import numpy as np import torch import torch.nn.functional as F import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.residual_mlp import ResidualMLP def load_eval(n=2048, device="cuda:0"): tv = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv) loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0) xs, ys = [], [] for x, y in loader: xs.append(x.view(x.size(0), -1)); ys.append(y) if sum(xb.size(0) for xb in xs) >= n: break return torch.cat(xs)[:n].to(device), torch.cat(ys)[:n].to(device) def per_layer_bp_grads(model, x, y): with torch.enable_grad(): h = model.embed(x) hiddens = [h] for block in model.blocks: h = h + block(h) hiddens.append(h) logits = model.out_head(model.out_ln(h)) loss = F.cross_entropy(logits, y) grads = torch.autograd.grad(loss, hiddens) return list(grads), logits.detach() def cosine_no_clamp(a, b): eps = 1e-30 an = a.norm(dim=-1, keepdim=True).clamp_min(eps) bn = b.norm(dim=-1, keepdim=True).clamp_min(eps) return ((a / an) * (b / bn)).sum(dim=-1) def measure_with_Bs(model, Bs, x, y): L = model.num_blocks grads, logits = per_layer_bp_grads(model, x, y) e_T = F.softmax(logits, dim=-1).clone() e_T[torch.arange(len(y), device=y.device), y] -= 1 out = [] for l in range(L + 1): b_idx = min(l, L - 1) a_l = (e_T @ Bs[b_idx].T).detach() g_l = grads[l].detach() cos = cosine_no_clamp(a_l, g_l) out.append({"layer": l, "cos_mean": float(cos.mean().item())}) return out def main(): p = argparse.ArgumentParser() p.add_argument("--ckpt", type=str, default="results/dfa_pen_short/dfa_pen_lam0.01_s42.pt") p.add_argument("--n_fresh", type=int, default=20, help="number of fresh Bs draws") args = p.parse_args() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(f"Loading {args.ckpt}") sd = torch.load(args.ckpt, map_location=device, weights_only=False) model = ResidualMLP(3072, 256, 10, 4).to(device) model.load_state_dict(sd["state_dict"]) Bs_train = [b.to(device) for b in sd["Bs"]] print(f"Test acc: {sd.get('test_acc', 'unknown')}") print() x, y = load_eval(n=2048, device=device) print("=" * 72) print("REFERENCE: training-time Bs") print("=" * 72) out_train = measure_with_Bs(model, Bs_train, x, y) for entry in out_train: print(f" l{entry['layer']}: cos_mean={entry['cos_mean']:+.4f}") train_mean = np.mean([e['cos_mean'] for e in out_train]) train_deep = np.mean([e['cos_mean'] for e in out_train[1:]]) print(f" layer-mean: {train_mean:+.4f}") print(f" deep-layer mean (l1-l4): {train_deep:+.4f}") print() print("=" * 72) print(f"NULL CALIBRATION: {args.n_fresh} fresh random Bs draws") print("=" * 72) fresh_results = [] for k in range(args.n_fresh): torch.manual_seed(10000 + k) Bs_fresh = [torch.randn(256, 10, device=device) / np.sqrt(10) for _ in range(4)] out_fresh = measure_with_Bs(model, Bs_fresh, x, y) fresh_results.append(out_fresh) deep_mean = np.mean([e['cos_mean'] for e in out_fresh[1:]]) per_layer_str = ", ".join(f"{e['cos_mean']:+.4f}" for e in out_fresh) print(f" fresh #{k}: per-layer = [{per_layer_str}], deep mean {deep_mean:+.4f}") # Aggregate arr = np.array([[e['cos_mean'] for e in r] for r in fresh_results]) # (n_fresh, n_layers) print() print(f" Across {args.n_fresh} fresh Bs draws (mean ± std per layer):") for l in range(arr.shape[1]): print(f" l{l}: {arr[:,l].mean():+.4f} ± {arr[:,l].std():.4f}") fresh_deep_mean = arr[:, 1:].mean() fresh_deep_std = arr[:, 1:].std() print(f" fresh-Bs deep-layer mean: {fresh_deep_mean:+.4f} ± {fresh_deep_std:.4f}") print() print("=" * 72) print("INTERPRETATION") print("=" * 72) print(f" Training-Bs deep cos: {train_deep:+.4f}") print(f" Fresh-Bs deep cos: {fresh_deep_mean:+.4f}") print() if abs(fresh_deep_mean) < 0.05: print(f" Fresh Bs give ~0 cosine (|{fresh_deep_mean:.4f}| < 0.05)") print(f" → The +{train_deep:.4f} on training Bs is REAL signal that the network") print(f" adapted to its specific Bs during training.") elif abs(fresh_deep_mean) > 0.10: print(f" Fresh Bs give SIMILAR cosine to training Bs (|{fresh_deep_mean:.4f}| > 0.10)") print(f" → The +{train_deep:.4f} is NOT specifically about the training Bs.") print(f" It could be a property of the BP grad direction itself in the") print(f" penalized regime — i.e. the BP grad and ANY random direction give") print(f" a similar partial alignment. This would weaken the 'partial credit") print(f" quality' interpretation.") else: print(f" Fresh Bs give intermediate cosine ({fresh_deep_mean:.4f})") print(f" → Mixed: the training Bs are partially specific, partially generic.") if __name__ == "__main__": main()