diff options
| author | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
|---|---|---|
| committer | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
| commit | b83947778e2c776f757a07d4719b7ce961d7ed55 (patch) | |
| tree | b9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/analyze_ln_jacobian.py | |
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}),
analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints
git-ignored (share separately).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/analyze_ln_jacobian.py')
| -rw-r--r-- | ep_run/analyze_ln_jacobian.py | 166 |
1 files changed, 166 insertions, 0 deletions
diff --git a/ep_run/analyze_ln_jacobian.py b/ep_run/analyze_ln_jacobian.py new file mode 100644 index 0000000..dcf2149 --- /dev/null +++ b/ep_run/analyze_ln_jacobian.py @@ -0,0 +1,166 @@ +"""Analyze LN Jacobian decomposition: how much does each component (scaling, mean-center, +radial removal) contribute to the gradient at each LN layer? + +Trains a small FA model for 250 steps, then on one diagnostic batch: +1. Forward with hooks to capture each LN's (x, z, sigma) +2. Backward to get g_tilde = dL/dz (gradient wrt LN output) +3. Decompose: true J_LN @ g_tilde vs center_scale vs projected vs identity(STE) +4. Report per-layer cosines and energy fractions + +Run for both softmax and sigmoid to explain why center_scale costs more on softmax. +""" +import json +import pickle +from pathlib import Path +import torch +import torch.nn as nn +import torch.nn.functional as F +from model_local import LocalGPT, LocalGPTConfig +from local_layers import initialize_dfa_targets +import numpy as np + + +def get_batch(data_dir, block_size, batch_size, device): + data = np.memmap(data_dir / "train.bin", dtype=np.uint16, mode="r") + ix = torch.randint(len(data) - block_size - 1, (batch_size,)) + x = torch.stack([torch.from_numpy(data[i:i+block_size].astype(np.int64)) for i in ix]) + y = torch.stack([torch.from_numpy(data[i+1:i+1+block_size].astype(np.int64)) for i in ix]) + return x.to(device), y.to(device) + + +def analyze_one_config(attn_mode, device, data_dir): + """Train FA model for 250 steps, then analyze LN Jacobian on one batch.""" + torch.manual_seed(1337) + with open(data_dir / "meta.pkl", "rb") as f: + meta = pickle.load(f) + + cfg = LocalGPTConfig( + block_size=64, vocab_size=meta["vocab_size"], + n_layer=4, n_head=4, n_embd=128, dropout=0.0, + attn_mode=attn_mode, method="fa", + ) + model = LocalGPT(cfg).to(device) + + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + model.train() + for step in range(250): + X, Y = get_batch(data_dir, cfg.block_size, 32, device) + _, loss = model(X, Y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # Now diagnostic: hook into LN layers to capture forward quantities + ln_data = {} # name -> {x, z, sigma, g_tilde} + + def make_forward_hook(name): + def hook(module, input, output): + x = input[0].detach() + mu = x.mean(dim=-1, keepdim=True) + xc = x - mu + var = (xc * xc).mean(dim=-1, keepdim=True) + sigma = torch.sqrt(var + 1e-5) + z = xc / sigma + ln_data[name] = {"x": x, "z": z, "sigma": sigma} + output.retain_grad() + ln_data[name]["output_ref"] = output + return hook + + hooks = [] + for name, module in model.named_modules(): + if isinstance(module, nn.LayerNorm): + hooks.append(module.register_forward_hook(make_forward_hook(name))) + + # Forward + backward on diagnostic batch + model.eval() + X, Y = get_batch(data_dir, cfg.block_size, 32, device) + logits, loss = model(X, Y) + loss.backward() + + # Collect g_tilde for each LN + for name in ln_data: + out_ref = ln_data[name]["output_ref"] + if out_ref.grad is not None: + ln_data[name]["g_tilde"] = out_ref.grad.detach() + + for h in hooks: + h.remove() + + # Analyze decomposition + results = {} + for name, d in ln_data.items(): + if "g_tilde" not in d: + continue + g = d["g_tilde"] # (B, T, dim) + z = d["z"] + sigma = d["sigma"] + dim = g.shape[-1] + + # True LN Jacobian action: g_x = (1/sigma) * (g - mean(g) - z*mean(g*z)) + g_mean = g.mean(dim=-1, keepdim=True) + gz_mean = (g * z).mean(dim=-1, keepdim=True) + + g_true = (g - g_mean - z * gz_mean) / sigma # full LN backward + g_center = (g - g_mean) / sigma # center_scale only + g_ste = g # identity STE + + # Energy fractions: what fraction of ||g||^2 is in each removed component? + g_norm_sq = (g * g).sum(-1).mean() + mean_component = g_mean.expand_as(g) + radial_component = z * gz_mean + r_mean = (mean_component * mean_component).sum(-1).mean() / (g_norm_sq + 1e-12) + r_radial = (radial_component * radial_component).sum(-1).mean() / (g_norm_sq + 1e-12) + + # Cosines: how well does each surrogate match the true LN backward? + def batch_cos(a, b): + a_flat = a.reshape(-1, dim) + b_flat = b.reshape(-1, dim) + cos = F.cosine_similarity(a_flat, b_flat, dim=-1) + return cos.mean().item() + + cos_center = batch_cos(g_center, g_true) + cos_ste = batch_cos(g_ste, g_true) + cos_center_to_ste = batch_cos(g_center, g_ste) + + # Sigma statistics + sigma_mean = sigma.mean().item() + sigma_std = sigma.std().item() + + results[name] = { + "r_mean": r_mean.item(), + "r_radial": r_radial.item(), + "sigma_mean": sigma_mean, + "sigma_std": sigma_std, + "cos_center_vs_true": cos_center, + "cos_ste_vs_true": cos_ste, + } + + return results + + +def main(): + device = "cuda" if torch.cuda.is_available() else "cpu" + data_dir = Path("data/shakespeare_char") + + for attn in ["softmax", "sigmoid"]: + print(f"\n{'='*60}") + print(f" Attention: {attn}") + print(f"{'='*60}") + results = analyze_one_config(attn, device, data_dir) + print(f"{'name':30s} {'r_mean':>8s} {'r_rad':>8s} {'σ_μ':>8s} {'cos_c/t':>8s} {'cos_s/t':>8s}") + print("-" * 80) + for name, r in sorted(results.items()): + print(f"{name:30s} {r['r_mean']:8.4f} {r['r_radial']:8.4f} " + f"{r['sigma_mean']:8.3f} {r['cos_center_vs_true']:8.4f} {r['cos_ste_vs_true']:8.4f}") + + # Summary + r_means = [r["r_mean"] for r in results.values()] + r_rads = [r["r_radial"] for r in results.values()] + cos_cs = [r["cos_center_vs_true"] for r in results.values()] + cos_ss = [r["cos_ste_vs_true"] for r in results.values()] + print(f"\n AVG r_mean={sum(r_means)/len(r_means):.4f} r_radial={sum(r_rads)/len(r_rads):.4f} " + f"cos_center={sum(cos_cs)/len(cos_cs):.4f} cos_ste={sum(cos_ss)/len(cos_ss):.4f}") + + +if __name__ == "__main__": + main() |
