"""Analyze softmax attention Jacobian: decompose into diagonal (local) vs off-diagonal (lateral). The softmax Jacobian J = diag(A) - AA^T acts on gradient g as: g_S = A ⊙ g - A * (A^T g) (full, has lateral sum) g_S_diag = A ⊙ (1-A) ⊙ g (diagonal-only, element-wise, same formula as sigmoid) g_S_ste = g (identity STE) This script measures: 1. How much energy is in diagonal vs off-diagonal components 2. Cosine between full vs diagonal-only vs STE on real FA training data 3. Per-head, per-layer breakdown 4. Whether removing the lateral sum is catastrophic or tolerable """ import pickle from pathlib import Path import torch import torch.nn as nn import torch.nn.functional as F from model_local import LocalGPT, LocalGPTConfig import numpy as np def get_batch(data_dir, block_size, batch_size, device): data = np.memmap(data_dir / "train.bin", dtype=np.uint16, mode="r") ix = torch.randint(len(data) - block_size - 1, (batch_size,)) x = torch.stack([torch.from_numpy(data[i:i+block_size].astype(np.int64)) for i in ix]) y = torch.stack([torch.from_numpy(data[i+1:i+1+block_size].astype(np.int64)) for i in ix]) return x.to(device), y.to(device) def main(): device = "cuda" if torch.cuda.is_available() else "cpu" data_dir = Path("data/shakespeare_char") torch.manual_seed(1337) with open(data_dir / "meta.pkl", "rb") as f: meta = pickle.load(f) # Train a softmax FA model for 500 steps to get meaningful attention patterns cfg = LocalGPTConfig( block_size=64, vocab_size=meta["vocab_size"], n_layer=4, n_head=4, n_embd=128, dropout=0.0, attn_mode="softmax", method="fa", ) model = LocalGPT(cfg).to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) model.train() for step in range(500): X, Y = get_batch(data_dir, cfg.block_size, 32, device) _, loss = model(X, Y) optimizer.zero_grad() loss.backward() optimizer.step() print(f"Trained 500 steps, final loss: {loss.item():.3f}") # Hook into attention forward to capture scores and attention weights attn_data = {} def make_attn_hook(name, module): original_forward = module.forward def hooked_forward(x): B, T, C = x.shape q = module.q_proj(x).view(B, T, module.n_head, module.head_dim).transpose(1, 2) k = module.k_proj(x).view(B, T, module.n_head, module.head_dim).transpose(1, 2) v = module.v_proj(x).view(B, T, module.n_head, module.head_dim).transpose(1, 2) scores = (q @ k.transpose(-2, -1)) * (module.head_dim ** -0.5) mask = module.causal_mask[:T, :T] scores = scores.masked_fill(~mask, float("-inf")) attn = F.softmax(scores, dim=-1) attn_data[name] = { "scores": scores.detach(), "attn": attn.detach(), } # Need grad wrt attention output for Jacobian analysis attn_for_grad = attn.clone().requires_grad_(True) out = (attn_for_grad @ v).transpose(1, 2).contiguous().view(B, T, C) out = module.resid_drop(module.o_proj(out)) attn_data[name]["attn_for_grad"] = attn_for_grad return out module.forward = hooked_forward return module # Install hooks for name, module in model.named_modules(): if hasattr(module, "q_proj") and hasattr(module, "k_proj"): make_attn_hook(name, module) # Forward + backward on diagnostic batch model.eval() X, Y = get_batch(data_dir, cfg.block_size, 32, device) logits, loss = model(X, Y) loss.backward() # Analyze each attention layer print(f"\n{'layer':30s} {'A_mean':>8s} {'A_entropy':>10s} {'r_diag':>8s} {'r_offdiag':>10s} " f"{'cos_diag':>9s} {'cos_ste':>8s}") print("-" * 100) for name, d in sorted(attn_data.items()): A = d["attn"] # (B, n_head, T, T) attn_ref = d.get("attn_for_grad") if attn_ref is None or attn_ref.grad is None: print(f"{name:30s} (no grad captured)") continue g = attn_ref.grad.detach() # (B, n_head, T, T) = dL/dA B_size, n_head, T, _ = A.shape # Per-head analysis for h in range(n_head): A_h = A[:, h, :, :] # (B, T, T) g_h = g[:, h, :, :] # (B, T, T) # Full softmax backward: g_S = A * (g - A @ g sum along last dim) Ag_sum = (A_h * g_h).sum(dim=-1, keepdim=True) # (B, T, 1) = sum_j A_j g_j per query g_full = A_h * (g_h - Ag_sum) # (B, T, T) # Diagonal-only (element-wise, sigmoid-like): g_diag = A*(1-A)*g g_diag = A_h * (1 - A_h) * g_h # (B, T, T) # STE: g_ste = g g_ste = g_h # Energy fractions g_full_norm = (g_full * g_full).sum((-2, -1)).mean() g_diag_norm = (g_diag * g_diag).sum((-2, -1)).mean() diff_norm = ((g_full - g_diag) * (g_full - g_diag)).sum((-2, -1)).mean() # Cosines (flatten per-sample) def cos(a, b): af = a.reshape(B_size, -1) bf = b.reshape(B_size, -1) return F.cosine_similarity(af, bf, dim=-1).mean().item() cos_diag = cos(g_diag, g_full) cos_ste = cos(g_ste, g_full) # Attention statistics # Mask out -inf positions for stats valid_mask = A_h > 0 A_valid = A_h[valid_mask] A_mean = A_valid.mean().item() # Entropy per query row entropy = -(A_h * (A_h + 1e-10).log()).sum(-1).mean().item() r_diag = g_diag_norm / (g_full_norm + 1e-12) print(f"{name}.head{h:d} " f" {A_mean:8.4f} {entropy:10.3f} {r_diag.item():8.3f} " f"{(1-r_diag).item():10.3f} {cos_diag:9.4f} {cos_ste:8.4f}") # Summary print(f"\nKey: r_diag = ||g_diag||^2 / ||g_full||^2 (energy in diagonal/element-wise part)") print(f" cos_diag = cosine(diagonal-only, full softmax backward)") print(f" cos_ste = cosine(identity STE, full softmax backward)") print(f"\nIf cos_diag ≈ 1: diagonal-only (sigmoid-like) approximation is good → lateral sum not needed") print(f"If cos_diag << 1: off-diagonal (lateral sum) is critical → need to keep or find local surrogate") if __name__ == "__main__": main()