""" Perturbation correlation rho_l on existing checkpoints. Codex round 19's recommended alternative metric to per-layer cosine — "a more direct 'is the local signal useful?' test than cosine alone". For each checkpoint, compute per-layer rho_l = pearson_correlation( predicted_loss_change = , true_loss_change = loss(h_l + eps * v) - loss(h_l) ) where a_l = e_T @ B_l^T is DFA's local credit signal and v is a random unit direction. Average over M=32 random directions. Compares: - Vanilla DFA s42 (existing checkpoint, ‖g‖ at floor) - Penalized DFA s42 lam=1e-2 30 ep (existing checkpoint, ‖g‖ healthy) - BP s42 (existing checkpoint, ‖g‖ healthy) Pre-registered prediction: - Vanilla DFA: deep rho ~0 (we expect random feedback in degenerate regime to give noise correlation) - Penalized DFA: deep rho > 0 if cos +0.17 reflects real local signal (mode 2 partial alleviation should be detectable by both metrics) - BP: per-layer ‘credit signal’ for BP is ambiguous — BP doesn't have a local credit signal in the DFA sense. Skip BP rho computation. Run: CUDA_VISIBLE_DEVICES=2 python experiments/perturbation_correlation_audit.py """ import os import sys import numpy as np import torch import torch.nn.functional as F import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.residual_mlp import ResidualMLP from metrics.credit_metrics import perturbation_correlation REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) def load_eval(n=1024, device="cuda:0"): tv = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv) loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0) xs, ys = [], [] for x, y in loader: xs.append(x.view(x.size(0), -1)); ys.append(y) if sum(xb.size(0) for xb in xs) >= n: break return torch.cat(xs)[:n].to(device), torch.cat(ys)[:n].to(device) def reconstruct_training_Bs(seed, d_hidden=256, num_blocks=4, num_classes=10, device="cuda:0"): torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) _ = ResidualMLP(3072, d_hidden, num_classes, num_blocks) return [torch.randn(d_hidden, num_classes, device=device) / np.sqrt(num_classes) for _ in range(num_blocks)] def make_forward_fn(model, layer_index, x_eval, y_eval): """Returns a function that takes h_l (perturbed) and computes per-sample cross-entropy loss after running the network from layer_index forward.""" def fwd(h_l): h = h_l for i in range(layer_index, model.num_blocks): h = h + model.blocks[i](h) logits = model.out_head(model.out_ln(h)) # per-sample loss return F.cross_entropy(logits, y_eval, reduction="none") return fwd def measure_rho(model, Bs, x_eval, y_eval, device, eps=1e-3, M=32): """For each layer l, compute rho_l using DFA local credit signal.""" model.eval() with torch.no_grad(): _, hiddens = model(x_eval, return_hidden=True) L = model.num_blocks out = [] # Compute the DFA error signal e_T (softmax(logits) - one_hot(y)) with torch.no_grad(): logits = model.out_head(model.out_ln(hiddens[-1])) e_T = F.softmax(logits, dim=-1).clone() e_T[torch.arange(len(y_eval), device=device), y_eval] -= 1 for l in range(L): h_l = hiddens[l].detach().clone() a_l = (e_T @ Bs[l].T).detach() forward_fn = make_forward_fn(model, l, x_eval, y_eval) rho = perturbation_correlation(h_l, a_l, forward_fn, epsilon=eps, M=M) out.append({"layer": l, "rho": rho}) return out def load_dfa(seed, ckpt_path, device): sd = torch.load(ckpt_path, map_location=device, weights_only=False) model = ResidualMLP(3072, 256, 10, 4).to(device) if isinstance(sd, dict) and "state_dict" in sd: model.load_state_dict(sd["state_dict"]) Bs = [b.to(device) for b in sd["Bs"]] if "Bs" in sd else None else: model.load_state_dict(sd) Bs = None if Bs is None: Bs = reconstruct_training_Bs(seed, device=device) return model, Bs def main(): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") x_eval, y_eval = load_eval(n=1024, device=device) cases = [ ("vanilla DFA s42", "results/confirmatory/checkpoints_A2/dfa_s42.pt", 42), ("penalized DFA s42 lam=1e-2 30ep", "results/dfa_pen_short/dfa_pen_lam0.01_s42.pt", 42), ("penalized DFA s123 lam=1e-2 30ep", "results/dfa_pen_short/dfa_pen_lam0.01_s123.pt", 123), ("penalized DFA s456 lam=1e-2 30ep", "results/dfa_pen_short/dfa_pen_lam0.01_s456.pt", 456), ("vanilla DFA s42 ep1 (meaningful regime)", "results/vanilla_dfa_early_ckpts/vanilla_dfa_s42_ep1.pt", 42), ("vanilla DFA s123 ep1 (meaningful regime)", "results/vanilla_dfa_early_ckpts/vanilla_dfa_s123_ep1.pt", 123), ("vanilla DFA s456 ep1 (meaningful regime)", "results/vanilla_dfa_early_ckpts/vanilla_dfa_s456_ep1.pt", 456), ] print("=" * 76) print("Perturbation correlation rho_l per layer") print("=" * 76) print("(epsilon=1e-3, M=32 random unit directions, n=1024 samples)") print() results = {} for label, ckpt, seed in cases: path = os.path.join(REPO_ROOT, ckpt) if not os.path.exists(path): print(f" SKIPPED ({path} not found)") continue print(f"=== {label} ===") model, Bs = load_dfa(seed, path, device) out = measure_rho(model, Bs, x_eval, y_eval, device) for entry in out: print(f" l{entry['layer']}: rho = {entry['rho']:+.4f}") rhos = [e["rho"] for e in out] deep = np.mean(rhos[1:]) if len(rhos) > 1 else float("nan") print(f" layer-mean rho: {np.mean(rhos):+.4f}") print(f" deep-layer mean rho (l1+): {deep:+.4f}") print() results[label] = {"per_layer": out, "layer_mean": float(np.mean(rhos)), "deep_mean": float(deep)} print("=" * 76) print("INTERPRETATION") print("=" * 76) if "vanilla DFA s42" in results and "penalized DFA s42 lam=1e-2 30ep" in results: v_deep = results["vanilla DFA s42"]["deep_mean"] p_deep = results["penalized DFA s42 lam=1e-2 30ep"]["deep_mean"] print(f" vanilla deep rho: {v_deep:+.4f}") print(f" penalized deep rho: {p_deep:+.4f}") if abs(p_deep) > 0.05 and abs(v_deep) < 0.05: print(f" -> Penalized DFA's local credit signal is locally useful (rho > 0.05),") print(f" vanilla DFA's is not. This triangulates the cos +0.17 finding via") print(f" a different metric (perturbation-based), strengthening mode 2 evidence.") elif abs(p_deep) < 0.05 and abs(v_deep) < 0.05: print(f" -> Both vanilla and penalized show ~0 perturbation correlation.") print(f" The cos +0.17 might be capturing direction-of-mean-gradient alignment") print(f" that doesn't translate to per-sample loss usefulness.") if __name__ == "__main__": main()