"""Geometric probes for trained models: CKA(model_a, model_b) and effective rank per layer. Usage: python3 probe_geometry.py \ --ckpts bp:runs_local/probe_bp/ckpt.pt bpfree:runs_local/probe_bpfree/ckpt.pt \ --data_dir data/tinystories --batch_size 32 --n_batches 4 \ --out probes/probe_results.json Outputs JSON with per-layer effective rank, and CKA matrix between every pair of named ckpts. """ import argparse import json import pickle from pathlib import Path import numpy as np import torch import torch.nn.functional as F from model_local import LocalGPTConfig from train_local_ce import LocalCETransformer def load_model(ckpt_path, device): blob = torch.load(ckpt_path, map_location=device, weights_only=False) cfg_dict = blob["config"] args = blob["args"] cfg = LocalGPTConfig(**{k: v for k, v in cfg_dict.items() if k in LocalGPTConfig.__dataclass_fields__}) model = LocalCETransformer(cfg, translator_rank=args.get("translator_rank", 0)).to(device) model.load_state_dict(blob["model_state"], strict=False) model.eval() return model, cfg, args def get_fixed_batches(data_dir, block_size, batch_size, n_batches, device, seed=12345): fn = data_dir / "val.bin" data = np.memmap(fn, dtype=np.uint16, mode="r") g = torch.Generator().manual_seed(seed) batches = [] for _ in range(n_batches): ix = torch.randint(len(data) - block_size - 1, (batch_size,), generator=g) x = torch.stack([torch.from_numpy(data[i : i + block_size].astype(np.int64)) for i in ix]) batches.append(x.to(device, non_blocking=True)) return batches @torch.no_grad() def collect_activations(model, batches): """Run model on each batch, return list of per-layer activations stacked across batches. Returns: list of length L+1, each element is tensor of shape (N_total_tokens, d_model) """ per_layer = None for X in batches: acts = model.forward_activations(X) # list of (B, T, d), length L+1 if per_layer is None: per_layer = [[] for _ in range(len(acts))] for l, a in enumerate(acts): per_layer[l].append(a.reshape(-1, a.size(-1)).float().cpu()) return [torch.cat(parts, dim=0) for parts in per_layer] # list of (N, d) def linear_cka(X, Y, center=True): """Linear CKA between (N, d_x) and (N, d_y) matrices. CKA(X,Y) = ||Y^T X||_F^2 / (||X^T X||_F * ||Y^T Y||_F) """ if center: X = X - X.mean(dim=0, keepdim=True) Y = Y - Y.mean(dim=0, keepdim=True) XtY = X.T @ Y # (d_x, d_y) num = (XtY ** 2).sum().item() XtX = X.T @ X YtY = Y.T @ Y den = ((XtX ** 2).sum().sqrt() * (YtY ** 2).sum().sqrt()).item() return num / max(den, 1e-12) def effective_rank(X, eps=1e-12): """Effective rank = exp(entropy of normalized eigenvalues).""" X_centered = X - X.mean(dim=0, keepdim=True) # Use eigendecomp of covariance for stability cov = (X_centered.T @ X_centered) / max(X_centered.size(0) - 1, 1) eigvals = torch.linalg.eigvalsh(cov).clamp_min(0.0) s = eigvals.sum().item() if s < eps: return 0.0 p = eigvals / s p = p[p > eps] H = -(p * p.log()).sum().item() return float(np.exp(H)) def main(): p = argparse.ArgumentParser() p.add_argument("--ckpts", type=str, nargs="+", required=True, help="space-separated NAME:PATH pairs, e.g. bp:runs/bp/ckpt.pt bpfree:runs/bpfree/ckpt.pt") p.add_argument("--data_dir", type=str, default="data/tinystories") p.add_argument("--batch_size", type=int, default=32) p.add_argument("--n_batches", type=int, default=4) p.add_argument("--block_size", type=int, default=512) p.add_argument("--out", type=str, default="probes/probe_results.json") args = p.parse_args() device = "cuda" if torch.cuda.is_available() else "cpu" data_dir = Path(args.data_dir) out_path = Path(args.out) out_path.parent.mkdir(parents=True, exist_ok=True) # Parse ckpt names name_paths = [s.split(":", 1) for s in args.ckpts] print(f"Probing {len(name_paths)} ckpts on {data_dir}, {args.n_batches} batches × {args.batch_size}") # Use block_size from first ckpt's training config first_blob = torch.load(name_paths[0][1], map_location="cpu", weights_only=False) block_size = first_blob["config"].get("block_size", args.block_size) batches = get_fixed_batches(data_dir, block_size, args.batch_size, args.n_batches, device) # Collect activations per ckpt all_acts = {} # name → list of (N, d) eff_ranks = {} # name → list of float per layer for name, ckpt_path in name_paths: print(f" loading {name} from {ckpt_path}") model, cfg, train_args = load_model(ckpt_path, device) acts = collect_activations(model, batches) all_acts[name] = acts eff_ranks[name] = [effective_rank(a) for a in acts] print(f" layers: {len(acts)}, d_model: {acts[0].size(1)}, N: {acts[0].size(0)}") print(f" eff_rank per layer: {[f'{r:.1f}' for r in eff_ranks[name]]}") del model torch.cuda.empty_cache() # CKA matrices: for each pair, compute (L+1) × (L+1) CKA matrix cka_matrices = {} # "name_a:name_b" → list of lists names = list(all_acts.keys()) for i in range(len(names)): for j in range(i, len(names)): a, b = names[i], names[j] La, Lb = len(all_acts[a]), len(all_acts[b]) mat = [[0.0] * Lb for _ in range(La)] for la in range(La): for lb in range(Lb): mat[la][lb] = linear_cka(all_acts[a][la], all_acts[b][lb]) cka_matrices[f"{a}::{b}"] = mat # Show diagonal (corresponding layer pairs) if same depth if La == Lb: diag = [mat[k][k] for k in range(La)] print(f" CKA({a},{b}) diag: {[f'{v:.3f}' for v in diag]}") results = { "ckpts": dict(name_paths), "data_dir": str(data_dir), "n_total_tokens": all_acts[names[0]][0].size(0), "effective_rank": eff_ranks, "cka": cka_matrices, } out_path.write_text(json.dumps(results, indent=2)) print(f"\nWrote {out_path}") if __name__ == "__main__": main()