"""Analyze LN Jacobian decomposition: how much does each component (scaling, mean-center, radial removal) contribute to the gradient at each LN layer? Trains a small FA model for 250 steps, then on one diagnostic batch: 1. Forward with hooks to capture each LN's (x, z, sigma) 2. Backward to get g_tilde = dL/dz (gradient wrt LN output) 3. Decompose: true J_LN @ g_tilde vs center_scale vs projected vs identity(STE) 4. Report per-layer cosines and energy fractions Run for both softmax and sigmoid to explain why center_scale costs more on softmax. """ import json import pickle from pathlib import Path import torch import torch.nn as nn import torch.nn.functional as F from model_local import LocalGPT, LocalGPTConfig from local_layers import initialize_dfa_targets import numpy as np def get_batch(data_dir, block_size, batch_size, device): data = np.memmap(data_dir / "train.bin", dtype=np.uint16, mode="r") ix = torch.randint(len(data) - block_size - 1, (batch_size,)) x = torch.stack([torch.from_numpy(data[i:i+block_size].astype(np.int64)) for i in ix]) y = torch.stack([torch.from_numpy(data[i+1:i+1+block_size].astype(np.int64)) for i in ix]) return x.to(device), y.to(device) def analyze_one_config(attn_mode, device, data_dir): """Train FA model for 250 steps, then analyze LN Jacobian on one batch.""" torch.manual_seed(1337) with open(data_dir / "meta.pkl", "rb") as f: meta = pickle.load(f) cfg = LocalGPTConfig( block_size=64, vocab_size=meta["vocab_size"], n_layer=4, n_head=4, n_embd=128, dropout=0.0, attn_mode=attn_mode, method="fa", ) model = LocalGPT(cfg).to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) model.train() for step in range(250): X, Y = get_batch(data_dir, cfg.block_size, 32, device) _, loss = model(X, Y) optimizer.zero_grad() loss.backward() optimizer.step() # Now diagnostic: hook into LN layers to capture forward quantities ln_data = {} # name -> {x, z, sigma, g_tilde} def make_forward_hook(name): def hook(module, input, output): x = input[0].detach() mu = x.mean(dim=-1, keepdim=True) xc = x - mu var = (xc * xc).mean(dim=-1, keepdim=True) sigma = torch.sqrt(var + 1e-5) z = xc / sigma ln_data[name] = {"x": x, "z": z, "sigma": sigma} output.retain_grad() ln_data[name]["output_ref"] = output return hook hooks = [] for name, module in model.named_modules(): if isinstance(module, nn.LayerNorm): hooks.append(module.register_forward_hook(make_forward_hook(name))) # Forward + backward on diagnostic batch model.eval() X, Y = get_batch(data_dir, cfg.block_size, 32, device) logits, loss = model(X, Y) loss.backward() # Collect g_tilde for each LN for name in ln_data: out_ref = ln_data[name]["output_ref"] if out_ref.grad is not None: ln_data[name]["g_tilde"] = out_ref.grad.detach() for h in hooks: h.remove() # Analyze decomposition results = {} for name, d in ln_data.items(): if "g_tilde" not in d: continue g = d["g_tilde"] # (B, T, dim) z = d["z"] sigma = d["sigma"] dim = g.shape[-1] # True LN Jacobian action: g_x = (1/sigma) * (g - mean(g) - z*mean(g*z)) g_mean = g.mean(dim=-1, keepdim=True) gz_mean = (g * z).mean(dim=-1, keepdim=True) g_true = (g - g_mean - z * gz_mean) / sigma # full LN backward g_center = (g - g_mean) / sigma # center_scale only g_ste = g # identity STE # Energy fractions: what fraction of ||g||^2 is in each removed component? g_norm_sq = (g * g).sum(-1).mean() mean_component = g_mean.expand_as(g) radial_component = z * gz_mean r_mean = (mean_component * mean_component).sum(-1).mean() / (g_norm_sq + 1e-12) r_radial = (radial_component * radial_component).sum(-1).mean() / (g_norm_sq + 1e-12) # Cosines: how well does each surrogate match the true LN backward? def batch_cos(a, b): a_flat = a.reshape(-1, dim) b_flat = b.reshape(-1, dim) cos = F.cosine_similarity(a_flat, b_flat, dim=-1) return cos.mean().item() cos_center = batch_cos(g_center, g_true) cos_ste = batch_cos(g_ste, g_true) cos_center_to_ste = batch_cos(g_center, g_ste) # Sigma statistics sigma_mean = sigma.mean().item() sigma_std = sigma.std().item() results[name] = { "r_mean": r_mean.item(), "r_radial": r_radial.item(), "sigma_mean": sigma_mean, "sigma_std": sigma_std, "cos_center_vs_true": cos_center, "cos_ste_vs_true": cos_ste, } return results def main(): device = "cuda" if torch.cuda.is_available() else "cpu" data_dir = Path("data/shakespeare_char") for attn in ["softmax", "sigmoid"]: print(f"\n{'='*60}") print(f" Attention: {attn}") print(f"{'='*60}") results = analyze_one_config(attn, device, data_dir) print(f"{'name':30s} {'r_mean':>8s} {'r_rad':>8s} {'σ_μ':>8s} {'cos_c/t':>8s} {'cos_s/t':>8s}") print("-" * 80) for name, r in sorted(results.items()): print(f"{name:30s} {r['r_mean']:8.4f} {r['r_radial']:8.4f} " f"{r['sigma_mean']:8.3f} {r['cos_center_vs_true']:8.4f} {r['cos_ste_vs_true']:8.4f}") # Summary r_means = [r["r_mean"] for r in results.values()] r_rads = [r["r_radial"] for r in results.values()] cos_cs = [r["cos_center_vs_true"] for r in results.values()] cos_ss = [r["cos_ste_vs_true"] for r in results.values()] print(f"\n AVG r_mean={sum(r_means)/len(r_means):.4f} r_radial={sum(r_rads)/len(r_rads):.4f} " f"cos_center={sum(cos_cs)/len(cos_cs):.4f} cos_ste={sum(cos_ss)/len(cos_ss):.4f}") if __name__ == "__main__": main()