summaryrefslogtreecommitdiff
path: root/experiments/perturbation_correlation_audit.py
blob: 163d3a847ee3a256c2f4055c4c9be2fadcd6ba45 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""
Perturbation correlation rho_l on existing checkpoints. Codex round 19's
recommended alternative metric to per-layer cosine — "a more direct 'is
the local signal useful?' test than cosine alone".

For each checkpoint, compute per-layer rho_l = pearson_correlation(
  predicted_loss_change = <a_l, eps * v>,
  true_loss_change = loss(h_l + eps * v) - loss(h_l)
)
where a_l = e_T @ B_l^T is DFA's local credit signal and v is a random
unit direction. Average over M=32 random directions.

Compares:
  - Vanilla DFA s42 (existing checkpoint, ‖g‖ at floor)
  - Penalized DFA s42 lam=1e-2 30 ep (existing checkpoint, ‖g‖ healthy)
  - BP s42 (existing checkpoint, ‖g‖ healthy)

Pre-registered prediction:
  - Vanilla DFA: deep rho ~0 (we expect random feedback in degenerate
    regime to give noise correlation)
  - Penalized DFA: deep rho > 0 if cos +0.17 reflects real local signal
    (mode 2 partial alleviation should be detectable by both metrics)
  - BP: per-layer ‘credit signal’ for BP is ambiguous — BP doesn't
    have a local credit signal in the DFA sense. Skip BP rho computation.

Run:
    CUDA_VISIBLE_DEVICES=2 python experiments/perturbation_correlation_audit.py
"""
import os
import sys

import numpy as np
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.residual_mlp import ResidualMLP
from metrics.credit_metrics import perturbation_correlation

REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))


def load_eval(n=1024, device="cuda:0"):
    tv = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])
    te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
    loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0)
    xs, ys = [], []
    for x, y in loader:
        xs.append(x.view(x.size(0), -1)); ys.append(y)
        if sum(xb.size(0) for xb in xs) >= n:
            break
    return torch.cat(xs)[:n].to(device), torch.cat(ys)[:n].to(device)


def reconstruct_training_Bs(seed, d_hidden=256, num_blocks=4, num_classes=10, device="cuda:0"):
    torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
    _ = ResidualMLP(3072, d_hidden, num_classes, num_blocks)
    return [torch.randn(d_hidden, num_classes, device=device) / np.sqrt(num_classes)
            for _ in range(num_blocks)]


def make_forward_fn(model, layer_index, x_eval, y_eval):
    """Returns a function that takes h_l (perturbed) and computes per-sample
    cross-entropy loss after running the network from layer_index forward."""
    def fwd(h_l):
        h = h_l
        for i in range(layer_index, model.num_blocks):
            h = h + model.blocks[i](h)
        logits = model.out_head(model.out_ln(h))
        # per-sample loss
        return F.cross_entropy(logits, y_eval, reduction="none")
    return fwd


def measure_rho(model, Bs, x_eval, y_eval, device, eps=1e-3, M=32):
    """For each layer l, compute rho_l using DFA local credit signal."""
    model.eval()
    with torch.no_grad():
        _, hiddens = model(x_eval, return_hidden=True)
    L = model.num_blocks
    out = []
    # Compute the DFA error signal e_T (softmax(logits) - one_hot(y))
    with torch.no_grad():
        logits = model.out_head(model.out_ln(hiddens[-1]))
        e_T = F.softmax(logits, dim=-1).clone()
        e_T[torch.arange(len(y_eval), device=device), y_eval] -= 1
    for l in range(L):
        h_l = hiddens[l].detach().clone()
        a_l = (e_T @ Bs[l].T).detach()
        forward_fn = make_forward_fn(model, l, x_eval, y_eval)
        rho = perturbation_correlation(h_l, a_l, forward_fn, epsilon=eps, M=M)
        out.append({"layer": l, "rho": rho})
    return out


def load_dfa(seed, ckpt_path, device):
    sd = torch.load(ckpt_path, map_location=device, weights_only=False)
    model = ResidualMLP(3072, 256, 10, 4).to(device)
    if isinstance(sd, dict) and "state_dict" in sd:
        model.load_state_dict(sd["state_dict"])
        Bs = [b.to(device) for b in sd["Bs"]] if "Bs" in sd else None
    else:
        model.load_state_dict(sd)
        Bs = None
    if Bs is None:
        Bs = reconstruct_training_Bs(seed, device=device)
    return model, Bs


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")
    x_eval, y_eval = load_eval(n=1024, device=device)

    cases = [
        ("vanilla DFA s42",
         "results/confirmatory/checkpoints_A2/dfa_s42.pt", 42),
        ("penalized DFA s42 lam=1e-2 30ep",
         "results/dfa_pen_short/dfa_pen_lam0.01_s42.pt", 42),
        ("penalized DFA s123 lam=1e-2 30ep",
         "results/dfa_pen_short/dfa_pen_lam0.01_s123.pt", 123),
        ("penalized DFA s456 lam=1e-2 30ep",
         "results/dfa_pen_short/dfa_pen_lam0.01_s456.pt", 456),
        ("vanilla DFA s42 ep1 (meaningful regime)",
         "results/vanilla_dfa_early_ckpts/vanilla_dfa_s42_ep1.pt", 42),
        ("vanilla DFA s123 ep1 (meaningful regime)",
         "results/vanilla_dfa_early_ckpts/vanilla_dfa_s123_ep1.pt", 123),
        ("vanilla DFA s456 ep1 (meaningful regime)",
         "results/vanilla_dfa_early_ckpts/vanilla_dfa_s456_ep1.pt", 456),
    ]

    print("=" * 76)
    print("Perturbation correlation rho_l per layer")
    print("=" * 76)
    print("(epsilon=1e-3, M=32 random unit directions, n=1024 samples)")
    print()

    results = {}
    for label, ckpt, seed in cases:
        path = os.path.join(REPO_ROOT, ckpt)
        if not os.path.exists(path):
            print(f"  SKIPPED ({path} not found)")
            continue
        print(f"=== {label} ===")
        model, Bs = load_dfa(seed, path, device)
        out = measure_rho(model, Bs, x_eval, y_eval, device)
        for entry in out:
            print(f"  l{entry['layer']}: rho = {entry['rho']:+.4f}")
        rhos = [e["rho"] for e in out]
        deep = np.mean(rhos[1:]) if len(rhos) > 1 else float("nan")
        print(f"  layer-mean rho: {np.mean(rhos):+.4f}")
        print(f"  deep-layer mean rho (l1+): {deep:+.4f}")
        print()
        results[label] = {"per_layer": out, "layer_mean": float(np.mean(rhos)), "deep_mean": float(deep)}

    print("=" * 76)
    print("INTERPRETATION")
    print("=" * 76)
    if "vanilla DFA s42" in results and "penalized DFA s42 lam=1e-2 30ep" in results:
        v_deep = results["vanilla DFA s42"]["deep_mean"]
        p_deep = results["penalized DFA s42 lam=1e-2 30ep"]["deep_mean"]
        print(f"  vanilla deep rho:   {v_deep:+.4f}")
        print(f"  penalized deep rho: {p_deep:+.4f}")
        if abs(p_deep) > 0.05 and abs(v_deep) < 0.05:
            print(f"  -> Penalized DFA's local credit signal is locally useful (rho > 0.05),")
            print(f"     vanilla DFA's is not. This triangulates the cos +0.17 finding via")
            print(f"     a different metric (perturbation-based), strengthening mode 2 evidence.")
        elif abs(p_deep) < 0.05 and abs(v_deep) < 0.05:
            print(f"  -> Both vanilla and penalized show ~0 perturbation correlation.")
            print(f"     The cos +0.17 might be capturing direction-of-mean-gradient alignment")
            print(f"     that doesn't translate to per-sample loss usefulness.")


if __name__ == "__main__":
    main()