summaryrefslogtreecommitdiff
path: root/ep_run/probe_geometry.py
diff options
context:
space:
mode:
Diffstat (limited to 'ep_run/probe_geometry.py')
-rw-r--r--ep_run/probe_geometry.py162
1 files changed, 162 insertions, 0 deletions
diff --git a/ep_run/probe_geometry.py b/ep_run/probe_geometry.py
new file mode 100644
index 0000000..971e908
--- /dev/null
+++ b/ep_run/probe_geometry.py
@@ -0,0 +1,162 @@
+"""Geometric probes for trained models: CKA(model_a, model_b) and effective rank per layer.
+
+Usage:
+ python3 probe_geometry.py \
+ --ckpts bp:runs_local/probe_bp/ckpt.pt bpfree:runs_local/probe_bpfree/ckpt.pt \
+ --data_dir data/tinystories --batch_size 32 --n_batches 4 \
+ --out probes/probe_results.json
+
+Outputs JSON with per-layer effective rank, and CKA matrix between every pair of named ckpts.
+"""
+import argparse
+import json
+import pickle
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from model_local import LocalGPTConfig
+from train_local_ce import LocalCETransformer
+
+
+def load_model(ckpt_path, device):
+ blob = torch.load(ckpt_path, map_location=device, weights_only=False)
+ cfg_dict = blob["config"]
+ args = blob["args"]
+ cfg = LocalGPTConfig(**{k: v for k, v in cfg_dict.items()
+ if k in LocalGPTConfig.__dataclass_fields__})
+ model = LocalCETransformer(cfg, translator_rank=args.get("translator_rank", 0)).to(device)
+ model.load_state_dict(blob["model_state"], strict=False)
+ model.eval()
+ return model, cfg, args
+
+
+def get_fixed_batches(data_dir, block_size, batch_size, n_batches, device, seed=12345):
+ fn = data_dir / "val.bin"
+ data = np.memmap(fn, dtype=np.uint16, mode="r")
+ g = torch.Generator().manual_seed(seed)
+ batches = []
+ for _ in range(n_batches):
+ ix = torch.randint(len(data) - block_size - 1, (batch_size,), generator=g)
+ x = torch.stack([torch.from_numpy(data[i : i + block_size].astype(np.int64)) for i in ix])
+ batches.append(x.to(device, non_blocking=True))
+ return batches
+
+
+@torch.no_grad()
+def collect_activations(model, batches):
+ """Run model on each batch, return list of per-layer activations stacked across batches.
+ Returns: list of length L+1, each element is tensor of shape (N_total_tokens, d_model)
+ """
+ per_layer = None
+ for X in batches:
+ acts = model.forward_activations(X) # list of (B, T, d), length L+1
+ if per_layer is None:
+ per_layer = [[] for _ in range(len(acts))]
+ for l, a in enumerate(acts):
+ per_layer[l].append(a.reshape(-1, a.size(-1)).float().cpu())
+ return [torch.cat(parts, dim=0) for parts in per_layer] # list of (N, d)
+
+
+def linear_cka(X, Y, center=True):
+ """Linear CKA between (N, d_x) and (N, d_y) matrices.
+ CKA(X,Y) = ||Y^T X||_F^2 / (||X^T X||_F * ||Y^T Y||_F)
+ """
+ if center:
+ X = X - X.mean(dim=0, keepdim=True)
+ Y = Y - Y.mean(dim=0, keepdim=True)
+ XtY = X.T @ Y # (d_x, d_y)
+ num = (XtY ** 2).sum().item()
+ XtX = X.T @ X
+ YtY = Y.T @ Y
+ den = ((XtX ** 2).sum().sqrt() * (YtY ** 2).sum().sqrt()).item()
+ return num / max(den, 1e-12)
+
+
+def effective_rank(X, eps=1e-12):
+ """Effective rank = exp(entropy of normalized eigenvalues)."""
+ X_centered = X - X.mean(dim=0, keepdim=True)
+ # Use eigendecomp of covariance for stability
+ cov = (X_centered.T @ X_centered) / max(X_centered.size(0) - 1, 1)
+ eigvals = torch.linalg.eigvalsh(cov).clamp_min(0.0)
+ s = eigvals.sum().item()
+ if s < eps:
+ return 0.0
+ p = eigvals / s
+ p = p[p > eps]
+ H = -(p * p.log()).sum().item()
+ return float(np.exp(H))
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument("--ckpts", type=str, nargs="+", required=True,
+ help="space-separated NAME:PATH pairs, e.g. bp:runs/bp/ckpt.pt bpfree:runs/bpfree/ckpt.pt")
+ p.add_argument("--data_dir", type=str, default="data/tinystories")
+ p.add_argument("--batch_size", type=int, default=32)
+ p.add_argument("--n_batches", type=int, default=4)
+ p.add_argument("--block_size", type=int, default=512)
+ p.add_argument("--out", type=str, default="probes/probe_results.json")
+ args = p.parse_args()
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ data_dir = Path(args.data_dir)
+ out_path = Path(args.out)
+ out_path.parent.mkdir(parents=True, exist_ok=True)
+
+ # Parse ckpt names
+ name_paths = [s.split(":", 1) for s in args.ckpts]
+ print(f"Probing {len(name_paths)} ckpts on {data_dir}, {args.n_batches} batches × {args.batch_size}")
+
+ # Use block_size from first ckpt's training config
+ first_blob = torch.load(name_paths[0][1], map_location="cpu", weights_only=False)
+ block_size = first_blob["config"].get("block_size", args.block_size)
+
+ batches = get_fixed_batches(data_dir, block_size, args.batch_size, args.n_batches, device)
+
+ # Collect activations per ckpt
+ all_acts = {} # name → list of (N, d)
+ eff_ranks = {} # name → list of float per layer
+ for name, ckpt_path in name_paths:
+ print(f" loading {name} from {ckpt_path}")
+ model, cfg, train_args = load_model(ckpt_path, device)
+ acts = collect_activations(model, batches)
+ all_acts[name] = acts
+ eff_ranks[name] = [effective_rank(a) for a in acts]
+ print(f" layers: {len(acts)}, d_model: {acts[0].size(1)}, N: {acts[0].size(0)}")
+ print(f" eff_rank per layer: {[f'{r:.1f}' for r in eff_ranks[name]]}")
+ del model
+ torch.cuda.empty_cache()
+
+ # CKA matrices: for each pair, compute (L+1) × (L+1) CKA matrix
+ cka_matrices = {} # "name_a:name_b" → list of lists
+ names = list(all_acts.keys())
+ for i in range(len(names)):
+ for j in range(i, len(names)):
+ a, b = names[i], names[j]
+ La, Lb = len(all_acts[a]), len(all_acts[b])
+ mat = [[0.0] * Lb for _ in range(La)]
+ for la in range(La):
+ for lb in range(Lb):
+ mat[la][lb] = linear_cka(all_acts[a][la], all_acts[b][lb])
+ cka_matrices[f"{a}::{b}"] = mat
+ # Show diagonal (corresponding layer pairs) if same depth
+ if La == Lb:
+ diag = [mat[k][k] for k in range(La)]
+ print(f" CKA({a},{b}) diag: {[f'{v:.3f}' for v in diag]}")
+
+ results = {
+ "ckpts": dict(name_paths),
+ "data_dir": str(data_dir),
+ "n_total_tokens": all_acts[names[0]][0].size(0),
+ "effective_rank": eff_ranks,
+ "cka": cka_matrices,
+ }
+ out_path.write_text(json.dumps(results, indent=2))
+ print(f"\nWrote {out_path}")
+
+
+if __name__ == "__main__":
+ main()