diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-08 02:07:26 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-08 02:07:26 -0500 |
| commit | a868b29e4c399a3a948e85737e7a632001481969 (patch) | |
| tree | 48b1e9d527462135aee3658b2603c0b547f7b160 /experiments | |
| parent | 8bf53ab94ac31c7672d23e2edf0e40c787b157d4 (diff) | |
Add perturbation correlation audit (round 19's recommended alt metric)
Codex round 19 said: 'use nudging or perturbation correlation on the
penalized checkpoints. In the healthy-gradient regime, that is a more
direct is-the-local-signal-useful test than cosine alone'.
Result on existing checkpoints (eps=1e-3, M=32 random directions, n=1024):
vanilla DFA s42: deep rho +0.002
penalized DFA s42 lam=1e-2 30ep: deep rho +0.094
penalized DFA s123 lam=1e-2 30ep: deep rho +0.073
penalized DFA s456 lam=1e-2 30ep: deep rho +0.072
penalized 3-seed mean: deep rho +0.080 ± 0.011
This INDEPENDENTLY TRIANGULATES the cos +0.17 finding via a different
metric:
- vanilla deep cos ~0 matches vanilla deep rho ~0
- penalized deep cos +0.155 matches penalized deep rho +0.080
The two metrics measure different things:
- cos = directional alignment with BP grad
- rho = correlation between predicted and true loss change under
random perturbation
Both show the same pattern: penalty creates partial usefulness from
essentially zero. This is the 6th independent validation of the mode 2
'penalty creates partial alignment' framing.
Crucially, rho doesn't use F.cosine_similarity (no eps clamp), and it
measures sample-level loss change correlation rather than direction
match — so it rules out 'cos is capturing some directional artifact
unrelated to local usefulness'.
Diffstat (limited to 'experiments')
| -rw-r--r-- | experiments/perturbation_correlation_audit.py | 175 |
1 files changed, 175 insertions, 0 deletions
diff --git a/experiments/perturbation_correlation_audit.py b/experiments/perturbation_correlation_audit.py new file mode 100644 index 0000000..cba84ea --- /dev/null +++ b/experiments/perturbation_correlation_audit.py @@ -0,0 +1,175 @@ +""" +Perturbation correlation rho_l on existing checkpoints. Codex round 19's +recommended alternative metric to per-layer cosine — "a more direct 'is +the local signal useful?' test than cosine alone". + +For each checkpoint, compute per-layer rho_l = pearson_correlation( + predicted_loss_change = <a_l, eps * v>, + true_loss_change = loss(h_l + eps * v) - loss(h_l) +) +where a_l = e_T @ B_l^T is DFA's local credit signal and v is a random +unit direction. Average over M=32 random directions. + +Compares: + - Vanilla DFA s42 (existing checkpoint, ‖g‖ at floor) + - Penalized DFA s42 lam=1e-2 30 ep (existing checkpoint, ‖g‖ healthy) + - BP s42 (existing checkpoint, ‖g‖ healthy) + +Pre-registered prediction: + - Vanilla DFA: deep rho ~0 (we expect random feedback in degenerate + regime to give noise correlation) + - Penalized DFA: deep rho > 0 if cos +0.17 reflects real local signal + (mode 2 partial alleviation should be detectable by both metrics) + - BP: per-layer ‘credit signal’ for BP is ambiguous — BP doesn't + have a local credit signal in the DFA sense. Skip BP rho computation. + +Run: + CUDA_VISIBLE_DEVICES=2 python experiments/perturbation_correlation_audit.py +""" +import os +import sys + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision +import torchvision.transforms as transforms +from torch.utils.data import DataLoader + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from models.residual_mlp import ResidualMLP +from metrics.credit_metrics import perturbation_correlation + +REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +def load_eval(n=1024, device="cuda:0"): + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv) + loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0) + xs, ys = [], [] + for x, y in loader: + xs.append(x.view(x.size(0), -1)); ys.append(y) + if sum(xb.size(0) for xb in xs) >= n: + break + return torch.cat(xs)[:n].to(device), torch.cat(ys)[:n].to(device) + + +def reconstruct_training_Bs(seed, d_hidden=256, num_blocks=4, num_classes=10, device="cuda:0"): + torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) + _ = ResidualMLP(3072, d_hidden, num_classes, num_blocks) + return [torch.randn(d_hidden, num_classes, device=device) / np.sqrt(num_classes) + for _ in range(num_blocks)] + + +def make_forward_fn(model, layer_index, x_eval, y_eval): + """Returns a function that takes h_l (perturbed) and computes per-sample + cross-entropy loss after running the network from layer_index forward.""" + def fwd(h_l): + h = h_l + for i in range(layer_index, model.num_blocks): + h = h + model.blocks[i](h) + logits = model.out_head(model.out_ln(h)) + # per-sample loss + return F.cross_entropy(logits, y_eval, reduction="none") + return fwd + + +def measure_rho(model, Bs, x_eval, y_eval, device, eps=1e-3, M=32): + """For each layer l, compute rho_l using DFA local credit signal.""" + model.eval() + with torch.no_grad(): + _, hiddens = model(x_eval, return_hidden=True) + L = model.num_blocks + out = [] + # Compute the DFA error signal e_T (softmax(logits) - one_hot(y)) + with torch.no_grad(): + logits = model.out_head(model.out_ln(hiddens[-1])) + e_T = F.softmax(logits, dim=-1).clone() + e_T[torch.arange(len(y_eval), device=device), y_eval] -= 1 + for l in range(L): + h_l = hiddens[l].detach().clone() + a_l = (e_T @ Bs[l].T).detach() + forward_fn = make_forward_fn(model, l, x_eval, y_eval) + rho = perturbation_correlation(h_l, a_l, forward_fn, epsilon=eps, M=M) + out.append({"layer": l, "rho": rho}) + return out + + +def load_dfa(seed, ckpt_path, device): + sd = torch.load(ckpt_path, map_location=device, weights_only=False) + model = ResidualMLP(3072, 256, 10, 4).to(device) + if isinstance(sd, dict) and "state_dict" in sd: + model.load_state_dict(sd["state_dict"]) + Bs = [b.to(device) for b in sd["Bs"]] if "Bs" in sd else None + else: + model.load_state_dict(sd) + Bs = None + if Bs is None: + Bs = reconstruct_training_Bs(seed, device=device) + return model, Bs + + +def main(): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + print(f"Device: {device}") + x_eval, y_eval = load_eval(n=1024, device=device) + + cases = [ + ("vanilla DFA s42", + "results/confirmatory/checkpoints_A2/dfa_s42.pt", 42), + ("penalized DFA s42 lam=1e-2 30ep", + "results/dfa_pen_short/dfa_pen_lam0.01_s42.pt", 42), + ("penalized DFA s123 lam=1e-2 30ep", + "results/dfa_pen_short/dfa_pen_lam0.01_s123.pt", 123), + ("penalized DFA s456 lam=1e-2 30ep", + "results/dfa_pen_short/dfa_pen_lam0.01_s456.pt", 456), + ] + + print("=" * 76) + print("Perturbation correlation rho_l per layer") + print("=" * 76) + print("(epsilon=1e-3, M=32 random unit directions, n=1024 samples)") + print() + + results = {} + for label, ckpt, seed in cases: + path = os.path.join(REPO_ROOT, ckpt) + if not os.path.exists(path): + print(f" SKIPPED ({path} not found)") + continue + print(f"=== {label} ===") + model, Bs = load_dfa(seed, path, device) + out = measure_rho(model, Bs, x_eval, y_eval, device) + for entry in out: + print(f" l{entry['layer']}: rho = {entry['rho']:+.4f}") + rhos = [e["rho"] for e in out] + deep = np.mean(rhos[1:]) if len(rhos) > 1 else float("nan") + print(f" layer-mean rho: {np.mean(rhos):+.4f}") + print(f" deep-layer mean rho (l1+): {deep:+.4f}") + print() + results[label] = {"per_layer": out, "layer_mean": float(np.mean(rhos)), "deep_mean": float(deep)} + + print("=" * 76) + print("INTERPRETATION") + print("=" * 76) + if "vanilla DFA s42" in results and "penalized DFA s42 lam=1e-2 30ep" in results: + v_deep = results["vanilla DFA s42"]["deep_mean"] + p_deep = results["penalized DFA s42 lam=1e-2 30ep"]["deep_mean"] + print(f" vanilla deep rho: {v_deep:+.4f}") + print(f" penalized deep rho: {p_deep:+.4f}") + if abs(p_deep) > 0.05 and abs(v_deep) < 0.05: + print(f" -> Penalized DFA's local credit signal is locally useful (rho > 0.05),") + print(f" vanilla DFA's is not. This triangulates the cos +0.17 finding via") + print(f" a different metric (perturbation-based), strengthening mode 2 evidence.") + elif abs(p_deep) < 0.05 and abs(v_deep) < 0.05: + print(f" -> Both vanilla and penalized show ~0 perturbation correlation.") + print(f" The cos +0.17 might be capturing direction-of-mean-gradient alignment") + print(f" that doesn't translate to per-sample loss usefulness.") + + +if __name__ == "__main__": + main() |
