diff options
| author | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
|---|---|---|
| committer | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
| commit | b83947778e2c776f757a07d4719b7ce961d7ed55 (patch) | |
| tree | b9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/analyze_softmax_jacobian.py | |
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}),
analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints
git-ignored (share separately).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/analyze_softmax_jacobian.py')
| -rw-r--r-- | ep_run/analyze_softmax_jacobian.py | 168 |
1 files changed, 168 insertions, 0 deletions
diff --git a/ep_run/analyze_softmax_jacobian.py b/ep_run/analyze_softmax_jacobian.py new file mode 100644 index 0000000..91ebd70 --- /dev/null +++ b/ep_run/analyze_softmax_jacobian.py @@ -0,0 +1,168 @@ +"""Analyze softmax attention Jacobian: decompose into diagonal (local) vs off-diagonal (lateral). + +The softmax Jacobian J = diag(A) - AA^T acts on gradient g as: + g_S = A ⊙ g - A * (A^T g) (full, has lateral sum) + g_S_diag = A ⊙ (1-A) ⊙ g (diagonal-only, element-wise, same formula as sigmoid) + g_S_ste = g (identity STE) + +This script measures: +1. How much energy is in diagonal vs off-diagonal components +2. Cosine between full vs diagonal-only vs STE on real FA training data +3. Per-head, per-layer breakdown +4. Whether removing the lateral sum is catastrophic or tolerable +""" +import pickle +from pathlib import Path +import torch +import torch.nn as nn +import torch.nn.functional as F +from model_local import LocalGPT, LocalGPTConfig +import numpy as np + + +def get_batch(data_dir, block_size, batch_size, device): + data = np.memmap(data_dir / "train.bin", dtype=np.uint16, mode="r") + ix = torch.randint(len(data) - block_size - 1, (batch_size,)) + x = torch.stack([torch.from_numpy(data[i:i+block_size].astype(np.int64)) for i in ix]) + y = torch.stack([torch.from_numpy(data[i+1:i+1+block_size].astype(np.int64)) for i in ix]) + return x.to(device), y.to(device) + + +def main(): + device = "cuda" if torch.cuda.is_available() else "cpu" + data_dir = Path("data/shakespeare_char") + torch.manual_seed(1337) + with open(data_dir / "meta.pkl", "rb") as f: + meta = pickle.load(f) + + # Train a softmax FA model for 500 steps to get meaningful attention patterns + cfg = LocalGPTConfig( + block_size=64, vocab_size=meta["vocab_size"], + n_layer=4, n_head=4, n_embd=128, dropout=0.0, + attn_mode="softmax", method="fa", + ) + model = LocalGPT(cfg).to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + model.train() + for step in range(500): + X, Y = get_batch(data_dir, cfg.block_size, 32, device) + _, loss = model(X, Y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + print(f"Trained 500 steps, final loss: {loss.item():.3f}") + + # Hook into attention forward to capture scores and attention weights + attn_data = {} + + def make_attn_hook(name, module): + original_forward = module.forward + + def hooked_forward(x): + B, T, C = x.shape + q = module.q_proj(x).view(B, T, module.n_head, module.head_dim).transpose(1, 2) + k = module.k_proj(x).view(B, T, module.n_head, module.head_dim).transpose(1, 2) + v = module.v_proj(x).view(B, T, module.n_head, module.head_dim).transpose(1, 2) + + scores = (q @ k.transpose(-2, -1)) * (module.head_dim ** -0.5) + mask = module.causal_mask[:T, :T] + scores = scores.masked_fill(~mask, float("-inf")) + + attn = F.softmax(scores, dim=-1) + attn_data[name] = { + "scores": scores.detach(), + "attn": attn.detach(), + } + + # Need grad wrt attention output for Jacobian analysis + attn_for_grad = attn.clone().requires_grad_(True) + out = (attn_for_grad @ v).transpose(1, 2).contiguous().view(B, T, C) + out = module.resid_drop(module.o_proj(out)) + + attn_data[name]["attn_for_grad"] = attn_for_grad + return out + + module.forward = hooked_forward + return module + + # Install hooks + for name, module in model.named_modules(): + if hasattr(module, "q_proj") and hasattr(module, "k_proj"): + make_attn_hook(name, module) + + # Forward + backward on diagnostic batch + model.eval() + X, Y = get_batch(data_dir, cfg.block_size, 32, device) + logits, loss = model(X, Y) + loss.backward() + + # Analyze each attention layer + print(f"\n{'layer':30s} {'A_mean':>8s} {'A_entropy':>10s} {'r_diag':>8s} {'r_offdiag':>10s} " + f"{'cos_diag':>9s} {'cos_ste':>8s}") + print("-" * 100) + + for name, d in sorted(attn_data.items()): + A = d["attn"] # (B, n_head, T, T) + attn_ref = d.get("attn_for_grad") + + if attn_ref is None or attn_ref.grad is None: + print(f"{name:30s} (no grad captured)") + continue + + g = attn_ref.grad.detach() # (B, n_head, T, T) = dL/dA + B_size, n_head, T, _ = A.shape + + # Per-head analysis + for h in range(n_head): + A_h = A[:, h, :, :] # (B, T, T) + g_h = g[:, h, :, :] # (B, T, T) + + # Full softmax backward: g_S = A * (g - A @ g sum along last dim) + Ag_sum = (A_h * g_h).sum(dim=-1, keepdim=True) # (B, T, 1) = sum_j A_j g_j per query + g_full = A_h * (g_h - Ag_sum) # (B, T, T) + + # Diagonal-only (element-wise, sigmoid-like): g_diag = A*(1-A)*g + g_diag = A_h * (1 - A_h) * g_h # (B, T, T) + + # STE: g_ste = g + g_ste = g_h + + # Energy fractions + g_full_norm = (g_full * g_full).sum((-2, -1)).mean() + g_diag_norm = (g_diag * g_diag).sum((-2, -1)).mean() + diff_norm = ((g_full - g_diag) * (g_full - g_diag)).sum((-2, -1)).mean() + + # Cosines (flatten per-sample) + def cos(a, b): + af = a.reshape(B_size, -1) + bf = b.reshape(B_size, -1) + return F.cosine_similarity(af, bf, dim=-1).mean().item() + + cos_diag = cos(g_diag, g_full) + cos_ste = cos(g_ste, g_full) + + # Attention statistics + # Mask out -inf positions for stats + valid_mask = A_h > 0 + A_valid = A_h[valid_mask] + A_mean = A_valid.mean().item() + + # Entropy per query row + entropy = -(A_h * (A_h + 1e-10).log()).sum(-1).mean().item() + + r_diag = g_diag_norm / (g_full_norm + 1e-12) + + print(f"{name}.head{h:d} " + f" {A_mean:8.4f} {entropy:10.3f} {r_diag.item():8.3f} " + f"{(1-r_diag).item():10.3f} {cos_diag:9.4f} {cos_ste:8.4f}") + + # Summary + print(f"\nKey: r_diag = ||g_diag||^2 / ||g_full||^2 (energy in diagonal/element-wise part)") + print(f" cos_diag = cosine(diagonal-only, full softmax backward)") + print(f" cos_ste = cosine(identity STE, full softmax backward)") + print(f"\nIf cos_diag ≈ 1: diagonal-only (sigmoid-like) approximation is good → lateral sum not needed") + print(f"If cos_diag << 1: off-diagonal (lateral sum) is critical → need to keep or find local surrogate") + + +if __name__ == "__main__": + main() |
