diff options
Diffstat (limited to 'ep_run/probe_geometry.py')
| -rw-r--r-- | ep_run/probe_geometry.py | 162 |
1 files changed, 162 insertions, 0 deletions
diff --git a/ep_run/probe_geometry.py b/ep_run/probe_geometry.py new file mode 100644 index 0000000..971e908 --- /dev/null +++ b/ep_run/probe_geometry.py @@ -0,0 +1,162 @@ +"""Geometric probes for trained models: CKA(model_a, model_b) and effective rank per layer. + +Usage: + python3 probe_geometry.py \ + --ckpts bp:runs_local/probe_bp/ckpt.pt bpfree:runs_local/probe_bpfree/ckpt.pt \ + --data_dir data/tinystories --batch_size 32 --n_batches 4 \ + --out probes/probe_results.json + +Outputs JSON with per-layer effective rank, and CKA matrix between every pair of named ckpts. +""" +import argparse +import json +import pickle +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F + +from model_local import LocalGPTConfig +from train_local_ce import LocalCETransformer + + +def load_model(ckpt_path, device): + blob = torch.load(ckpt_path, map_location=device, weights_only=False) + cfg_dict = blob["config"] + args = blob["args"] + cfg = LocalGPTConfig(**{k: v for k, v in cfg_dict.items() + if k in LocalGPTConfig.__dataclass_fields__}) + model = LocalCETransformer(cfg, translator_rank=args.get("translator_rank", 0)).to(device) + model.load_state_dict(blob["model_state"], strict=False) + model.eval() + return model, cfg, args + + +def get_fixed_batches(data_dir, block_size, batch_size, n_batches, device, seed=12345): + fn = data_dir / "val.bin" + data = np.memmap(fn, dtype=np.uint16, mode="r") + g = torch.Generator().manual_seed(seed) + batches = [] + for _ in range(n_batches): + ix = torch.randint(len(data) - block_size - 1, (batch_size,), generator=g) + x = torch.stack([torch.from_numpy(data[i : i + block_size].astype(np.int64)) for i in ix]) + batches.append(x.to(device, non_blocking=True)) + return batches + + +@torch.no_grad() +def collect_activations(model, batches): + """Run model on each batch, return list of per-layer activations stacked across batches. + Returns: list of length L+1, each element is tensor of shape (N_total_tokens, d_model) + """ + per_layer = None + for X in batches: + acts = model.forward_activations(X) # list of (B, T, d), length L+1 + if per_layer is None: + per_layer = [[] for _ in range(len(acts))] + for l, a in enumerate(acts): + per_layer[l].append(a.reshape(-1, a.size(-1)).float().cpu()) + return [torch.cat(parts, dim=0) for parts in per_layer] # list of (N, d) + + +def linear_cka(X, Y, center=True): + """Linear CKA between (N, d_x) and (N, d_y) matrices. + CKA(X,Y) = ||Y^T X||_F^2 / (||X^T X||_F * ||Y^T Y||_F) + """ + if center: + X = X - X.mean(dim=0, keepdim=True) + Y = Y - Y.mean(dim=0, keepdim=True) + XtY = X.T @ Y # (d_x, d_y) + num = (XtY ** 2).sum().item() + XtX = X.T @ X + YtY = Y.T @ Y + den = ((XtX ** 2).sum().sqrt() * (YtY ** 2).sum().sqrt()).item() + return num / max(den, 1e-12) + + +def effective_rank(X, eps=1e-12): + """Effective rank = exp(entropy of normalized eigenvalues).""" + X_centered = X - X.mean(dim=0, keepdim=True) + # Use eigendecomp of covariance for stability + cov = (X_centered.T @ X_centered) / max(X_centered.size(0) - 1, 1) + eigvals = torch.linalg.eigvalsh(cov).clamp_min(0.0) + s = eigvals.sum().item() + if s < eps: + return 0.0 + p = eigvals / s + p = p[p > eps] + H = -(p * p.log()).sum().item() + return float(np.exp(H)) + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--ckpts", type=str, nargs="+", required=True, + help="space-separated NAME:PATH pairs, e.g. bp:runs/bp/ckpt.pt bpfree:runs/bpfree/ckpt.pt") + p.add_argument("--data_dir", type=str, default="data/tinystories") + p.add_argument("--batch_size", type=int, default=32) + p.add_argument("--n_batches", type=int, default=4) + p.add_argument("--block_size", type=int, default=512) + p.add_argument("--out", type=str, default="probes/probe_results.json") + args = p.parse_args() + + device = "cuda" if torch.cuda.is_available() else "cpu" + data_dir = Path(args.data_dir) + out_path = Path(args.out) + out_path.parent.mkdir(parents=True, exist_ok=True) + + # Parse ckpt names + name_paths = [s.split(":", 1) for s in args.ckpts] + print(f"Probing {len(name_paths)} ckpts on {data_dir}, {args.n_batches} batches × {args.batch_size}") + + # Use block_size from first ckpt's training config + first_blob = torch.load(name_paths[0][1], map_location="cpu", weights_only=False) + block_size = first_blob["config"].get("block_size", args.block_size) + + batches = get_fixed_batches(data_dir, block_size, args.batch_size, args.n_batches, device) + + # Collect activations per ckpt + all_acts = {} # name → list of (N, d) + eff_ranks = {} # name → list of float per layer + for name, ckpt_path in name_paths: + print(f" loading {name} from {ckpt_path}") + model, cfg, train_args = load_model(ckpt_path, device) + acts = collect_activations(model, batches) + all_acts[name] = acts + eff_ranks[name] = [effective_rank(a) for a in acts] + print(f" layers: {len(acts)}, d_model: {acts[0].size(1)}, N: {acts[0].size(0)}") + print(f" eff_rank per layer: {[f'{r:.1f}' for r in eff_ranks[name]]}") + del model + torch.cuda.empty_cache() + + # CKA matrices: for each pair, compute (L+1) × (L+1) CKA matrix + cka_matrices = {} # "name_a:name_b" → list of lists + names = list(all_acts.keys()) + for i in range(len(names)): + for j in range(i, len(names)): + a, b = names[i], names[j] + La, Lb = len(all_acts[a]), len(all_acts[b]) + mat = [[0.0] * Lb for _ in range(La)] + for la in range(La): + for lb in range(Lb): + mat[la][lb] = linear_cka(all_acts[a][la], all_acts[b][lb]) + cka_matrices[f"{a}::{b}"] = mat + # Show diagonal (corresponding layer pairs) if same depth + if La == Lb: + diag = [mat[k][k] for k in range(La)] + print(f" CKA({a},{b}) diag: {[f'{v:.3f}' for v in diag]}") + + results = { + "ckpts": dict(name_paths), + "data_dir": str(data_dir), + "n_total_tokens": all_acts[names[0]][0].size(0), + "effective_rank": eff_ranks, + "cka": cka_matrices, + } + out_path.write_text(json.dumps(results, indent=2)) + print(f"\nWrote {out_path}") + + +if __name__ == "__main__": + main() |
