summaryrefslogtreecommitdiff
path: root/ep_run/analyze_ln_jacobian.py
diff options
context:
space:
mode:
Diffstat (limited to 'ep_run/analyze_ln_jacobian.py')
-rw-r--r--ep_run/analyze_ln_jacobian.py166
1 files changed, 166 insertions, 0 deletions
diff --git a/ep_run/analyze_ln_jacobian.py b/ep_run/analyze_ln_jacobian.py
new file mode 100644
index 0000000..dcf2149
--- /dev/null
+++ b/ep_run/analyze_ln_jacobian.py
@@ -0,0 +1,166 @@
+"""Analyze LN Jacobian decomposition: how much does each component (scaling, mean-center,
+radial removal) contribute to the gradient at each LN layer?
+
+Trains a small FA model for 250 steps, then on one diagnostic batch:
+1. Forward with hooks to capture each LN's (x, z, sigma)
+2. Backward to get g_tilde = dL/dz (gradient wrt LN output)
+3. Decompose: true J_LN @ g_tilde vs center_scale vs projected vs identity(STE)
+4. Report per-layer cosines and energy fractions
+
+Run for both softmax and sigmoid to explain why center_scale costs more on softmax.
+"""
+import json
+import pickle
+from pathlib import Path
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from model_local import LocalGPT, LocalGPTConfig
+from local_layers import initialize_dfa_targets
+import numpy as np
+
+
+def get_batch(data_dir, block_size, batch_size, device):
+ data = np.memmap(data_dir / "train.bin", dtype=np.uint16, mode="r")
+ ix = torch.randint(len(data) - block_size - 1, (batch_size,))
+ x = torch.stack([torch.from_numpy(data[i:i+block_size].astype(np.int64)) for i in ix])
+ y = torch.stack([torch.from_numpy(data[i+1:i+1+block_size].astype(np.int64)) for i in ix])
+ return x.to(device), y.to(device)
+
+
+def analyze_one_config(attn_mode, device, data_dir):
+ """Train FA model for 250 steps, then analyze LN Jacobian on one batch."""
+ torch.manual_seed(1337)
+ with open(data_dir / "meta.pkl", "rb") as f:
+ meta = pickle.load(f)
+
+ cfg = LocalGPTConfig(
+ block_size=64, vocab_size=meta["vocab_size"],
+ n_layer=4, n_head=4, n_embd=128, dropout=0.0,
+ attn_mode=attn_mode, method="fa",
+ )
+ model = LocalGPT(cfg).to(device)
+
+ optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
+ model.train()
+ for step in range(250):
+ X, Y = get_batch(data_dir, cfg.block_size, 32, device)
+ _, loss = model(X, Y)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ # Now diagnostic: hook into LN layers to capture forward quantities
+ ln_data = {} # name -> {x, z, sigma, g_tilde}
+
+ def make_forward_hook(name):
+ def hook(module, input, output):
+ x = input[0].detach()
+ mu = x.mean(dim=-1, keepdim=True)
+ xc = x - mu
+ var = (xc * xc).mean(dim=-1, keepdim=True)
+ sigma = torch.sqrt(var + 1e-5)
+ z = xc / sigma
+ ln_data[name] = {"x": x, "z": z, "sigma": sigma}
+ output.retain_grad()
+ ln_data[name]["output_ref"] = output
+ return hook
+
+ hooks = []
+ for name, module in model.named_modules():
+ if isinstance(module, nn.LayerNorm):
+ hooks.append(module.register_forward_hook(make_forward_hook(name)))
+
+ # Forward + backward on diagnostic batch
+ model.eval()
+ X, Y = get_batch(data_dir, cfg.block_size, 32, device)
+ logits, loss = model(X, Y)
+ loss.backward()
+
+ # Collect g_tilde for each LN
+ for name in ln_data:
+ out_ref = ln_data[name]["output_ref"]
+ if out_ref.grad is not None:
+ ln_data[name]["g_tilde"] = out_ref.grad.detach()
+
+ for h in hooks:
+ h.remove()
+
+ # Analyze decomposition
+ results = {}
+ for name, d in ln_data.items():
+ if "g_tilde" not in d:
+ continue
+ g = d["g_tilde"] # (B, T, dim)
+ z = d["z"]
+ sigma = d["sigma"]
+ dim = g.shape[-1]
+
+ # True LN Jacobian action: g_x = (1/sigma) * (g - mean(g) - z*mean(g*z))
+ g_mean = g.mean(dim=-1, keepdim=True)
+ gz_mean = (g * z).mean(dim=-1, keepdim=True)
+
+ g_true = (g - g_mean - z * gz_mean) / sigma # full LN backward
+ g_center = (g - g_mean) / sigma # center_scale only
+ g_ste = g # identity STE
+
+ # Energy fractions: what fraction of ||g||^2 is in each removed component?
+ g_norm_sq = (g * g).sum(-1).mean()
+ mean_component = g_mean.expand_as(g)
+ radial_component = z * gz_mean
+ r_mean = (mean_component * mean_component).sum(-1).mean() / (g_norm_sq + 1e-12)
+ r_radial = (radial_component * radial_component).sum(-1).mean() / (g_norm_sq + 1e-12)
+
+ # Cosines: how well does each surrogate match the true LN backward?
+ def batch_cos(a, b):
+ a_flat = a.reshape(-1, dim)
+ b_flat = b.reshape(-1, dim)
+ cos = F.cosine_similarity(a_flat, b_flat, dim=-1)
+ return cos.mean().item()
+
+ cos_center = batch_cos(g_center, g_true)
+ cos_ste = batch_cos(g_ste, g_true)
+ cos_center_to_ste = batch_cos(g_center, g_ste)
+
+ # Sigma statistics
+ sigma_mean = sigma.mean().item()
+ sigma_std = sigma.std().item()
+
+ results[name] = {
+ "r_mean": r_mean.item(),
+ "r_radial": r_radial.item(),
+ "sigma_mean": sigma_mean,
+ "sigma_std": sigma_std,
+ "cos_center_vs_true": cos_center,
+ "cos_ste_vs_true": cos_ste,
+ }
+
+ return results
+
+
+def main():
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ data_dir = Path("data/shakespeare_char")
+
+ for attn in ["softmax", "sigmoid"]:
+ print(f"\n{'='*60}")
+ print(f" Attention: {attn}")
+ print(f"{'='*60}")
+ results = analyze_one_config(attn, device, data_dir)
+ print(f"{'name':30s} {'r_mean':>8s} {'r_rad':>8s} {'σ_μ':>8s} {'cos_c/t':>8s} {'cos_s/t':>8s}")
+ print("-" * 80)
+ for name, r in sorted(results.items()):
+ print(f"{name:30s} {r['r_mean']:8.4f} {r['r_radial']:8.4f} "
+ f"{r['sigma_mean']:8.3f} {r['cos_center_vs_true']:8.4f} {r['cos_ste_vs_true']:8.4f}")
+
+ # Summary
+ r_means = [r["r_mean"] for r in results.values()]
+ r_rads = [r["r_radial"] for r in results.values()]
+ cos_cs = [r["cos_center_vs_true"] for r in results.values()]
+ cos_ss = [r["cos_ste_vs_true"] for r in results.values()]
+ print(f"\n AVG r_mean={sum(r_means)/len(r_means):.4f} r_radial={sum(r_rads)/len(r_rads):.4f} "
+ f"cos_center={sum(cos_cs)/len(cos_cs):.4f} cos_ste={sum(cos_ss)/len(cos_ss):.4f}")
+
+
+if __name__ == "__main__":
+ main()