summaryrefslogtreecommitdiff
path: root/experiments/measure_direction_quality_existing_ckpt.py
blob: d150eb378a48edb131d7f376d46cbe09ba1a54a4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""
Fast direction-quality measurement using EXISTING checkpoints. No training.

Loads the vanilla DFA s42 checkpoint from
`results/confirmatory/checkpoints_A2/dfa_s42.pt` and computes the per-layer
cosine between DFA's local credit signal `a_l = e_T @ B_l^T` and the BP
gradient `g_l = ∂L/∂h_l`. This is the "Γ on the degenerate reference"
measurement — what the field-standard FA evaluation reports.

The catch: the trained-time random feedback Bs were not saved in the
existing checkpoint. We reconstruct them by replaying the training-time
RNG sequence (`torch.manual_seed(seed); ResidualMLP(...); randn(d, C)`),
matching what the original DFA trainer did.

For the "scale-fixed" comparison case, we'll need a penalized DFA checkpoint
(which `experiments/dfa_direction_quality_test.py` is currently saving in
the background). Once that lands, this script can be reused with --ckpt
pointing to the penalized one.

Run:
    CUDA_VISIBLE_DEVICES=2 python experiments/measure_direction_quality_existing_ckpt.py
"""
import os
import sys
import argparse

import numpy as np
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from models.residual_mlp import ResidualMLP


def load_eval(n=2048, device="cuda:0"):
    tv = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])
    te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
    loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0)
    xs, ys = [], []
    for x, y in loader:
        xs.append(x.view(x.size(0), -1)); ys.append(y)
        if sum(xb.size(0) for xb in xs) >= n:
            break
    return torch.cat(xs)[:n].to(device), torch.cat(ys)[:n].to(device)


def reconstruct_training_Bs(seed, d_hidden=256, num_blocks=4, num_classes=10, device="cuda:0"):
    """Replay the training-time RNG sequence to recover the Bs."""
    torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
    _ = ResidualMLP(3072, d_hidden, num_classes, num_blocks)  # consume model init RNG
    Bs = [torch.randn(d_hidden, num_classes, device=device) / np.sqrt(num_classes) for _ in range(num_blocks)]
    return Bs


def per_layer_bp_grads(model, x, y):
    with torch.enable_grad():
        h = model.embed(x)
        hiddens = [h]
        for block in model.blocks:
            h = h + block(h)
            hiddens.append(h)
        logits = model.out_head(model.out_ln(h))
        loss = F.cross_entropy(logits, y)
        grads = torch.autograd.grad(loss, hiddens)
    return list(grads), logits.detach()


def cosine_no_clamp(a, b):
    eps = 1e-30
    an = a.norm(dim=-1, keepdim=True).clamp_min(eps)
    bn = b.norm(dim=-1, keepdim=True).clamp_min(eps)
    return ((a / an) * (b / bn)).sum(dim=-1)


def measure(model, Bs, x, y):
    L = model.num_blocks
    grads, logits = per_layer_bp_grads(model, x, y)
    e_T = F.softmax(logits, dim=-1).clone()
    e_T[torch.arange(len(y), device=y.device), y] -= 1
    out = []
    for l in range(L + 1):
        b_idx = min(l, L - 1)
        a_l = (e_T @ Bs[b_idx].T).detach()
        g_l = grads[l].detach()
        cos = cosine_no_clamp(a_l, g_l)
        out.append({
            "layer": l,
            "cos_mean": float(cos.mean().item()),
            "cos_std": float(cos.std().item()),
            "g_l_norm_median": float(g_l.norm(dim=-1).median().item()),
            "a_l_norm_median": float(a_l.norm(dim=-1).median().item()),
        })
    return out


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--ckpt", type=str,
                   default="results/confirmatory/checkpoints_A2/dfa_s42.pt")
    p.add_argument("--label", type=str, default="vanilla DFA")
    args = p.parse_args()

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    x, y = load_eval(n=2048, device=device)
    sd = torch.load(args.ckpt, map_location=device, weights_only=False)
    if isinstance(sd, dict) and "state_dict" in sd:
        # Direction-quality script saves dict with state_dict + Bs
        model = ResidualMLP(3072, 256, 10, 4).to(device)
        model.load_state_dict(sd["state_dict"])
        Bs = [b.to(device) for b in sd["Bs"]] if "Bs" in sd else None
    else:
        model = ResidualMLP(3072, 256, 10, 4).to(device)
        model.load_state_dict(sd)
        Bs = None

    if Bs is None:
        print(f"  Reconstructing training Bs from RNG seed {args.seed}...")
        Bs = reconstruct_training_Bs(args.seed, device=device)

    print(f"\n=== {args.label} (seed {args.seed}) ===")
    print(f"  ckpt: {args.ckpt}")
    out = measure(model, Bs, x, y)
    print(f"  per-layer DFA-credit vs BP-grad cosine:")
    for entry in out:
        print(f"    l{entry['layer']}: cos_mean={entry['cos_mean']:+.4f} "
              f"(±{entry['cos_std']:.4f})  ‖g‖={entry['g_l_norm_median']:.2e}  "
              f"‖a‖={entry['a_l_norm_median']:.2e}")
    mean_cos = np.mean([e["cos_mean"] for e in out])
    print(f"  layer-mean cos: {mean_cos:+.4f}")


if __name__ == "__main__":
    main()