summaryrefslogtreecommitdiff
path: root/ep_run/probe_geometry.py
blob: 971e908d384b1786ce969817274856c9e69ab4f9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
"""Geometric probes for trained models: CKA(model_a, model_b) and effective rank per layer.

Usage:
  python3 probe_geometry.py \
      --ckpts bp:runs_local/probe_bp/ckpt.pt bpfree:runs_local/probe_bpfree/ckpt.pt \
      --data_dir data/tinystories --batch_size 32 --n_batches 4 \
      --out probes/probe_results.json

Outputs JSON with per-layer effective rank, and CKA matrix between every pair of named ckpts.
"""
import argparse
import json
import pickle
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F

from model_local import LocalGPTConfig
from train_local_ce import LocalCETransformer


def load_model(ckpt_path, device):
    blob = torch.load(ckpt_path, map_location=device, weights_only=False)
    cfg_dict = blob["config"]
    args = blob["args"]
    cfg = LocalGPTConfig(**{k: v for k, v in cfg_dict.items()
                            if k in LocalGPTConfig.__dataclass_fields__})
    model = LocalCETransformer(cfg, translator_rank=args.get("translator_rank", 0)).to(device)
    model.load_state_dict(blob["model_state"], strict=False)
    model.eval()
    return model, cfg, args


def get_fixed_batches(data_dir, block_size, batch_size, n_batches, device, seed=12345):
    fn = data_dir / "val.bin"
    data = np.memmap(fn, dtype=np.uint16, mode="r")
    g = torch.Generator().manual_seed(seed)
    batches = []
    for _ in range(n_batches):
        ix = torch.randint(len(data) - block_size - 1, (batch_size,), generator=g)
        x = torch.stack([torch.from_numpy(data[i : i + block_size].astype(np.int64)) for i in ix])
        batches.append(x.to(device, non_blocking=True))
    return batches


@torch.no_grad()
def collect_activations(model, batches):
    """Run model on each batch, return list of per-layer activations stacked across batches.
    Returns: list of length L+1, each element is tensor of shape (N_total_tokens, d_model)
    """
    per_layer = None
    for X in batches:
        acts = model.forward_activations(X)  # list of (B, T, d), length L+1
        if per_layer is None:
            per_layer = [[] for _ in range(len(acts))]
        for l, a in enumerate(acts):
            per_layer[l].append(a.reshape(-1, a.size(-1)).float().cpu())
    return [torch.cat(parts, dim=0) for parts in per_layer]  # list of (N, d)


def linear_cka(X, Y, center=True):
    """Linear CKA between (N, d_x) and (N, d_y) matrices.
    CKA(X,Y) = ||Y^T X||_F^2 / (||X^T X||_F * ||Y^T Y||_F)
    """
    if center:
        X = X - X.mean(dim=0, keepdim=True)
        Y = Y - Y.mean(dim=0, keepdim=True)
    XtY = X.T @ Y  # (d_x, d_y)
    num = (XtY ** 2).sum().item()
    XtX = X.T @ X
    YtY = Y.T @ Y
    den = ((XtX ** 2).sum().sqrt() * (YtY ** 2).sum().sqrt()).item()
    return num / max(den, 1e-12)


def effective_rank(X, eps=1e-12):
    """Effective rank = exp(entropy of normalized eigenvalues)."""
    X_centered = X - X.mean(dim=0, keepdim=True)
    # Use eigendecomp of covariance for stability
    cov = (X_centered.T @ X_centered) / max(X_centered.size(0) - 1, 1)
    eigvals = torch.linalg.eigvalsh(cov).clamp_min(0.0)
    s = eigvals.sum().item()
    if s < eps:
        return 0.0
    p = eigvals / s
    p = p[p > eps]
    H = -(p * p.log()).sum().item()
    return float(np.exp(H))


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--ckpts", type=str, nargs="+", required=True,
                   help="space-separated NAME:PATH pairs, e.g. bp:runs/bp/ckpt.pt bpfree:runs/bpfree/ckpt.pt")
    p.add_argument("--data_dir", type=str, default="data/tinystories")
    p.add_argument("--batch_size", type=int, default=32)
    p.add_argument("--n_batches", type=int, default=4)
    p.add_argument("--block_size", type=int, default=512)
    p.add_argument("--out", type=str, default="probes/probe_results.json")
    args = p.parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"
    data_dir = Path(args.data_dir)
    out_path = Path(args.out)
    out_path.parent.mkdir(parents=True, exist_ok=True)

    # Parse ckpt names
    name_paths = [s.split(":", 1) for s in args.ckpts]
    print(f"Probing {len(name_paths)} ckpts on {data_dir}, {args.n_batches} batches × {args.batch_size}")

    # Use block_size from first ckpt's training config
    first_blob = torch.load(name_paths[0][1], map_location="cpu", weights_only=False)
    block_size = first_blob["config"].get("block_size", args.block_size)

    batches = get_fixed_batches(data_dir, block_size, args.batch_size, args.n_batches, device)

    # Collect activations per ckpt
    all_acts = {}  # name → list of (N, d)
    eff_ranks = {}  # name → list of float per layer
    for name, ckpt_path in name_paths:
        print(f"  loading {name} from {ckpt_path}")
        model, cfg, train_args = load_model(ckpt_path, device)
        acts = collect_activations(model, batches)
        all_acts[name] = acts
        eff_ranks[name] = [effective_rank(a) for a in acts]
        print(f"    layers: {len(acts)}, d_model: {acts[0].size(1)}, N: {acts[0].size(0)}")
        print(f"    eff_rank per layer: {[f'{r:.1f}' for r in eff_ranks[name]]}")
        del model
        torch.cuda.empty_cache()

    # CKA matrices: for each pair, compute (L+1) × (L+1) CKA matrix
    cka_matrices = {}  # "name_a:name_b" → list of lists
    names = list(all_acts.keys())
    for i in range(len(names)):
        for j in range(i, len(names)):
            a, b = names[i], names[j]
            La, Lb = len(all_acts[a]), len(all_acts[b])
            mat = [[0.0] * Lb for _ in range(La)]
            for la in range(La):
                for lb in range(Lb):
                    mat[la][lb] = linear_cka(all_acts[a][la], all_acts[b][lb])
            cka_matrices[f"{a}::{b}"] = mat
            # Show diagonal (corresponding layer pairs) if same depth
            if La == Lb:
                diag = [mat[k][k] for k in range(La)]
                print(f"  CKA({a},{b}) diag: {[f'{v:.3f}' for v in diag]}")

    results = {
        "ckpts": dict(name_paths),
        "data_dir": str(data_dir),
        "n_total_tokens": all_acts[names[0]][0].size(0),
        "effective_rank": eff_ranks,
        "cka": cka_matrices,
    }
    out_path.write_text(json.dumps(results, indent=2))
    print(f"\nWrote {out_path}")


if __name__ == "__main__":
    main()