summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-08 02:07:26 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-08 02:07:26 -0500
commita868b29e4c399a3a948e85737e7a632001481969 (patch)
tree48b1e9d527462135aee3658b2603c0b547f7b160 /experiments
parent8bf53ab94ac31c7672d23e2edf0e40c787b157d4 (diff)
Add perturbation correlation audit (round 19's recommended alt metric)
Codex round 19 said: 'use nudging or perturbation correlation on the penalized checkpoints. In the healthy-gradient regime, that is a more direct is-the-local-signal-useful test than cosine alone'. Result on existing checkpoints (eps=1e-3, M=32 random directions, n=1024): vanilla DFA s42: deep rho +0.002 penalized DFA s42 lam=1e-2 30ep: deep rho +0.094 penalized DFA s123 lam=1e-2 30ep: deep rho +0.073 penalized DFA s456 lam=1e-2 30ep: deep rho +0.072 penalized 3-seed mean: deep rho +0.080 ± 0.011 This INDEPENDENTLY TRIANGULATES the cos +0.17 finding via a different metric: - vanilla deep cos ~0 matches vanilla deep rho ~0 - penalized deep cos +0.155 matches penalized deep rho +0.080 The two metrics measure different things: - cos = directional alignment with BP grad - rho = correlation between predicted and true loss change under random perturbation Both show the same pattern: penalty creates partial usefulness from essentially zero. This is the 6th independent validation of the mode 2 'penalty creates partial alignment' framing. Crucially, rho doesn't use F.cosine_similarity (no eps clamp), and it measures sample-level loss change correlation rather than direction match — so it rules out 'cos is capturing some directional artifact unrelated to local usefulness'.
Diffstat (limited to 'experiments')
-rw-r--r--experiments/perturbation_correlation_audit.py175
1 files changed, 175 insertions, 0 deletions
diff --git a/experiments/perturbation_correlation_audit.py b/experiments/perturbation_correlation_audit.py
new file mode 100644
index 0000000..cba84ea
--- /dev/null
+++ b/experiments/perturbation_correlation_audit.py
@@ -0,0 +1,175 @@
+"""
+Perturbation correlation rho_l on existing checkpoints. Codex round 19's
+recommended alternative metric to per-layer cosine — "a more direct 'is
+the local signal useful?' test than cosine alone".
+
+For each checkpoint, compute per-layer rho_l = pearson_correlation(
+ predicted_loss_change = <a_l, eps * v>,
+ true_loss_change = loss(h_l + eps * v) - loss(h_l)
+)
+where a_l = e_T @ B_l^T is DFA's local credit signal and v is a random
+unit direction. Average over M=32 random directions.
+
+Compares:
+ - Vanilla DFA s42 (existing checkpoint, ‖g‖ at floor)
+ - Penalized DFA s42 lam=1e-2 30 ep (existing checkpoint, ‖g‖ healthy)
+ - BP s42 (existing checkpoint, ‖g‖ healthy)
+
+Pre-registered prediction:
+ - Vanilla DFA: deep rho ~0 (we expect random feedback in degenerate
+ regime to give noise correlation)
+ - Penalized DFA: deep rho > 0 if cos +0.17 reflects real local signal
+ (mode 2 partial alleviation should be detectable by both metrics)
+ - BP: per-layer ‘credit signal’ for BP is ambiguous — BP doesn't
+ have a local credit signal in the DFA sense. Skip BP rho computation.
+
+Run:
+ CUDA_VISIBLE_DEVICES=2 python experiments/perturbation_correlation_audit.py
+"""
+import os
+import sys
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torchvision
+import torchvision.transforms as transforms
+from torch.utils.data import DataLoader
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from models.residual_mlp import ResidualMLP
+from metrics.credit_metrics import perturbation_correlation
+
+REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+
+
+def load_eval(n=1024, device="cuda:0"):
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
+ loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0)
+ xs, ys = [], []
+ for x, y in loader:
+ xs.append(x.view(x.size(0), -1)); ys.append(y)
+ if sum(xb.size(0) for xb in xs) >= n:
+ break
+ return torch.cat(xs)[:n].to(device), torch.cat(ys)[:n].to(device)
+
+
+def reconstruct_training_Bs(seed, d_hidden=256, num_blocks=4, num_classes=10, device="cuda:0"):
+ torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
+ _ = ResidualMLP(3072, d_hidden, num_classes, num_blocks)
+ return [torch.randn(d_hidden, num_classes, device=device) / np.sqrt(num_classes)
+ for _ in range(num_blocks)]
+
+
+def make_forward_fn(model, layer_index, x_eval, y_eval):
+ """Returns a function that takes h_l (perturbed) and computes per-sample
+ cross-entropy loss after running the network from layer_index forward."""
+ def fwd(h_l):
+ h = h_l
+ for i in range(layer_index, model.num_blocks):
+ h = h + model.blocks[i](h)
+ logits = model.out_head(model.out_ln(h))
+ # per-sample loss
+ return F.cross_entropy(logits, y_eval, reduction="none")
+ return fwd
+
+
+def measure_rho(model, Bs, x_eval, y_eval, device, eps=1e-3, M=32):
+ """For each layer l, compute rho_l using DFA local credit signal."""
+ model.eval()
+ with torch.no_grad():
+ _, hiddens = model(x_eval, return_hidden=True)
+ L = model.num_blocks
+ out = []
+ # Compute the DFA error signal e_T (softmax(logits) - one_hot(y))
+ with torch.no_grad():
+ logits = model.out_head(model.out_ln(hiddens[-1]))
+ e_T = F.softmax(logits, dim=-1).clone()
+ e_T[torch.arange(len(y_eval), device=device), y_eval] -= 1
+ for l in range(L):
+ h_l = hiddens[l].detach().clone()
+ a_l = (e_T @ Bs[l].T).detach()
+ forward_fn = make_forward_fn(model, l, x_eval, y_eval)
+ rho = perturbation_correlation(h_l, a_l, forward_fn, epsilon=eps, M=M)
+ out.append({"layer": l, "rho": rho})
+ return out
+
+
+def load_dfa(seed, ckpt_path, device):
+ sd = torch.load(ckpt_path, map_location=device, weights_only=False)
+ model = ResidualMLP(3072, 256, 10, 4).to(device)
+ if isinstance(sd, dict) and "state_dict" in sd:
+ model.load_state_dict(sd["state_dict"])
+ Bs = [b.to(device) for b in sd["Bs"]] if "Bs" in sd else None
+ else:
+ model.load_state_dict(sd)
+ Bs = None
+ if Bs is None:
+ Bs = reconstruct_training_Bs(seed, device=device)
+ return model, Bs
+
+
+def main():
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ print(f"Device: {device}")
+ x_eval, y_eval = load_eval(n=1024, device=device)
+
+ cases = [
+ ("vanilla DFA s42",
+ "results/confirmatory/checkpoints_A2/dfa_s42.pt", 42),
+ ("penalized DFA s42 lam=1e-2 30ep",
+ "results/dfa_pen_short/dfa_pen_lam0.01_s42.pt", 42),
+ ("penalized DFA s123 lam=1e-2 30ep",
+ "results/dfa_pen_short/dfa_pen_lam0.01_s123.pt", 123),
+ ("penalized DFA s456 lam=1e-2 30ep",
+ "results/dfa_pen_short/dfa_pen_lam0.01_s456.pt", 456),
+ ]
+
+ print("=" * 76)
+ print("Perturbation correlation rho_l per layer")
+ print("=" * 76)
+ print("(epsilon=1e-3, M=32 random unit directions, n=1024 samples)")
+ print()
+
+ results = {}
+ for label, ckpt, seed in cases:
+ path = os.path.join(REPO_ROOT, ckpt)
+ if not os.path.exists(path):
+ print(f" SKIPPED ({path} not found)")
+ continue
+ print(f"=== {label} ===")
+ model, Bs = load_dfa(seed, path, device)
+ out = measure_rho(model, Bs, x_eval, y_eval, device)
+ for entry in out:
+ print(f" l{entry['layer']}: rho = {entry['rho']:+.4f}")
+ rhos = [e["rho"] for e in out]
+ deep = np.mean(rhos[1:]) if len(rhos) > 1 else float("nan")
+ print(f" layer-mean rho: {np.mean(rhos):+.4f}")
+ print(f" deep-layer mean rho (l1+): {deep:+.4f}")
+ print()
+ results[label] = {"per_layer": out, "layer_mean": float(np.mean(rhos)), "deep_mean": float(deep)}
+
+ print("=" * 76)
+ print("INTERPRETATION")
+ print("=" * 76)
+ if "vanilla DFA s42" in results and "penalized DFA s42 lam=1e-2 30ep" in results:
+ v_deep = results["vanilla DFA s42"]["deep_mean"]
+ p_deep = results["penalized DFA s42 lam=1e-2 30ep"]["deep_mean"]
+ print(f" vanilla deep rho: {v_deep:+.4f}")
+ print(f" penalized deep rho: {p_deep:+.4f}")
+ if abs(p_deep) > 0.05 and abs(v_deep) < 0.05:
+ print(f" -> Penalized DFA's local credit signal is locally useful (rho > 0.05),")
+ print(f" vanilla DFA's is not. This triangulates the cos +0.17 finding via")
+ print(f" a different metric (perturbation-based), strengthening mode 2 evidence.")
+ elif abs(p_deep) < 0.05 and abs(v_deep) < 0.05:
+ print(f" -> Both vanilla and penalized show ~0 perturbation correlation.")
+ print(f" The cos +0.17 might be capturing direction-of-mean-gradient alignment")
+ print(f" that doesn't translate to per-sample loss usefulness.")
+
+
+if __name__ == "__main__":
+ main()