summaryrefslogtreecommitdiff
path: root/ep_run/analyze_softmax_jacobian.py
diff options
context:
space:
mode:
Diffstat (limited to 'ep_run/analyze_softmax_jacobian.py')
-rw-r--r--ep_run/analyze_softmax_jacobian.py168
1 files changed, 168 insertions, 0 deletions
diff --git a/ep_run/analyze_softmax_jacobian.py b/ep_run/analyze_softmax_jacobian.py
new file mode 100644
index 0000000..91ebd70
--- /dev/null
+++ b/ep_run/analyze_softmax_jacobian.py
@@ -0,0 +1,168 @@
+"""Analyze softmax attention Jacobian: decompose into diagonal (local) vs off-diagonal (lateral).
+
+The softmax Jacobian J = diag(A) - AA^T acts on gradient g as:
+ g_S = A ⊙ g - A * (A^T g) (full, has lateral sum)
+ g_S_diag = A ⊙ (1-A) ⊙ g (diagonal-only, element-wise, same formula as sigmoid)
+ g_S_ste = g (identity STE)
+
+This script measures:
+1. How much energy is in diagonal vs off-diagonal components
+2. Cosine between full vs diagonal-only vs STE on real FA training data
+3. Per-head, per-layer breakdown
+4. Whether removing the lateral sum is catastrophic or tolerable
+"""
+import pickle
+from pathlib import Path
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from model_local import LocalGPT, LocalGPTConfig
+import numpy as np
+
+
+def get_batch(data_dir, block_size, batch_size, device):
+ data = np.memmap(data_dir / "train.bin", dtype=np.uint16, mode="r")
+ ix = torch.randint(len(data) - block_size - 1, (batch_size,))
+ x = torch.stack([torch.from_numpy(data[i:i+block_size].astype(np.int64)) for i in ix])
+ y = torch.stack([torch.from_numpy(data[i+1:i+1+block_size].astype(np.int64)) for i in ix])
+ return x.to(device), y.to(device)
+
+
+def main():
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ data_dir = Path("data/shakespeare_char")
+ torch.manual_seed(1337)
+ with open(data_dir / "meta.pkl", "rb") as f:
+ meta = pickle.load(f)
+
+ # Train a softmax FA model for 500 steps to get meaningful attention patterns
+ cfg = LocalGPTConfig(
+ block_size=64, vocab_size=meta["vocab_size"],
+ n_layer=4, n_head=4, n_embd=128, dropout=0.0,
+ attn_mode="softmax", method="fa",
+ )
+ model = LocalGPT(cfg).to(device)
+ optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
+ model.train()
+ for step in range(500):
+ X, Y = get_batch(data_dir, cfg.block_size, 32, device)
+ _, loss = model(X, Y)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ print(f"Trained 500 steps, final loss: {loss.item():.3f}")
+
+ # Hook into attention forward to capture scores and attention weights
+ attn_data = {}
+
+ def make_attn_hook(name, module):
+ original_forward = module.forward
+
+ def hooked_forward(x):
+ B, T, C = x.shape
+ q = module.q_proj(x).view(B, T, module.n_head, module.head_dim).transpose(1, 2)
+ k = module.k_proj(x).view(B, T, module.n_head, module.head_dim).transpose(1, 2)
+ v = module.v_proj(x).view(B, T, module.n_head, module.head_dim).transpose(1, 2)
+
+ scores = (q @ k.transpose(-2, -1)) * (module.head_dim ** -0.5)
+ mask = module.causal_mask[:T, :T]
+ scores = scores.masked_fill(~mask, float("-inf"))
+
+ attn = F.softmax(scores, dim=-1)
+ attn_data[name] = {
+ "scores": scores.detach(),
+ "attn": attn.detach(),
+ }
+
+ # Need grad wrt attention output for Jacobian analysis
+ attn_for_grad = attn.clone().requires_grad_(True)
+ out = (attn_for_grad @ v).transpose(1, 2).contiguous().view(B, T, C)
+ out = module.resid_drop(module.o_proj(out))
+
+ attn_data[name]["attn_for_grad"] = attn_for_grad
+ return out
+
+ module.forward = hooked_forward
+ return module
+
+ # Install hooks
+ for name, module in model.named_modules():
+ if hasattr(module, "q_proj") and hasattr(module, "k_proj"):
+ make_attn_hook(name, module)
+
+ # Forward + backward on diagnostic batch
+ model.eval()
+ X, Y = get_batch(data_dir, cfg.block_size, 32, device)
+ logits, loss = model(X, Y)
+ loss.backward()
+
+ # Analyze each attention layer
+ print(f"\n{'layer':30s} {'A_mean':>8s} {'A_entropy':>10s} {'r_diag':>8s} {'r_offdiag':>10s} "
+ f"{'cos_diag':>9s} {'cos_ste':>8s}")
+ print("-" * 100)
+
+ for name, d in sorted(attn_data.items()):
+ A = d["attn"] # (B, n_head, T, T)
+ attn_ref = d.get("attn_for_grad")
+
+ if attn_ref is None or attn_ref.grad is None:
+ print(f"{name:30s} (no grad captured)")
+ continue
+
+ g = attn_ref.grad.detach() # (B, n_head, T, T) = dL/dA
+ B_size, n_head, T, _ = A.shape
+
+ # Per-head analysis
+ for h in range(n_head):
+ A_h = A[:, h, :, :] # (B, T, T)
+ g_h = g[:, h, :, :] # (B, T, T)
+
+ # Full softmax backward: g_S = A * (g - A @ g sum along last dim)
+ Ag_sum = (A_h * g_h).sum(dim=-1, keepdim=True) # (B, T, 1) = sum_j A_j g_j per query
+ g_full = A_h * (g_h - Ag_sum) # (B, T, T)
+
+ # Diagonal-only (element-wise, sigmoid-like): g_diag = A*(1-A)*g
+ g_diag = A_h * (1 - A_h) * g_h # (B, T, T)
+
+ # STE: g_ste = g
+ g_ste = g_h
+
+ # Energy fractions
+ g_full_norm = (g_full * g_full).sum((-2, -1)).mean()
+ g_diag_norm = (g_diag * g_diag).sum((-2, -1)).mean()
+ diff_norm = ((g_full - g_diag) * (g_full - g_diag)).sum((-2, -1)).mean()
+
+ # Cosines (flatten per-sample)
+ def cos(a, b):
+ af = a.reshape(B_size, -1)
+ bf = b.reshape(B_size, -1)
+ return F.cosine_similarity(af, bf, dim=-1).mean().item()
+
+ cos_diag = cos(g_diag, g_full)
+ cos_ste = cos(g_ste, g_full)
+
+ # Attention statistics
+ # Mask out -inf positions for stats
+ valid_mask = A_h > 0
+ A_valid = A_h[valid_mask]
+ A_mean = A_valid.mean().item()
+
+ # Entropy per query row
+ entropy = -(A_h * (A_h + 1e-10).log()).sum(-1).mean().item()
+
+ r_diag = g_diag_norm / (g_full_norm + 1e-12)
+
+ print(f"{name}.head{h:d} "
+ f" {A_mean:8.4f} {entropy:10.3f} {r_diag.item():8.3f} "
+ f"{(1-r_diag).item():10.3f} {cos_diag:9.4f} {cos_ste:8.4f}")
+
+ # Summary
+ print(f"\nKey: r_diag = ||g_diag||^2 / ||g_full||^2 (energy in diagonal/element-wise part)")
+ print(f" cos_diag = cosine(diagonal-only, full softmax backward)")
+ print(f" cos_ste = cosine(identity STE, full softmax backward)")
+ print(f"\nIf cos_diag ≈ 1: diagonal-only (sigmoid-like) approximation is good → lateral sum not needed")
+ print(f"If cos_diag << 1: off-diagonal (lateral sum) is critical → need to keep or find local surrogate")
+
+
+if __name__ == "__main__":
+ main()