1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
|
"""Analyze LN Jacobian decomposition: how much does each component (scaling, mean-center,
radial removal) contribute to the gradient at each LN layer?
Trains a small FA model for 250 steps, then on one diagnostic batch:
1. Forward with hooks to capture each LN's (x, z, sigma)
2. Backward to get g_tilde = dL/dz (gradient wrt LN output)
3. Decompose: true J_LN @ g_tilde vs center_scale vs projected vs identity(STE)
4. Report per-layer cosines and energy fractions
Run for both softmax and sigmoid to explain why center_scale costs more on softmax.
"""
import json
import pickle
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from model_local import LocalGPT, LocalGPTConfig
from local_layers import initialize_dfa_targets
import numpy as np
def get_batch(data_dir, block_size, batch_size, device):
data = np.memmap(data_dir / "train.bin", dtype=np.uint16, mode="r")
ix = torch.randint(len(data) - block_size - 1, (batch_size,))
x = torch.stack([torch.from_numpy(data[i:i+block_size].astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy(data[i+1:i+1+block_size].astype(np.int64)) for i in ix])
return x.to(device), y.to(device)
def analyze_one_config(attn_mode, device, data_dir):
"""Train FA model for 250 steps, then analyze LN Jacobian on one batch."""
torch.manual_seed(1337)
with open(data_dir / "meta.pkl", "rb") as f:
meta = pickle.load(f)
cfg = LocalGPTConfig(
block_size=64, vocab_size=meta["vocab_size"],
n_layer=4, n_head=4, n_embd=128, dropout=0.0,
attn_mode=attn_mode, method="fa",
)
model = LocalGPT(cfg).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
model.train()
for step in range(250):
X, Y = get_batch(data_dir, cfg.block_size, 32, device)
_, loss = model(X, Y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Now diagnostic: hook into LN layers to capture forward quantities
ln_data = {} # name -> {x, z, sigma, g_tilde}
def make_forward_hook(name):
def hook(module, input, output):
x = input[0].detach()
mu = x.mean(dim=-1, keepdim=True)
xc = x - mu
var = (xc * xc).mean(dim=-1, keepdim=True)
sigma = torch.sqrt(var + 1e-5)
z = xc / sigma
ln_data[name] = {"x": x, "z": z, "sigma": sigma}
output.retain_grad()
ln_data[name]["output_ref"] = output
return hook
hooks = []
for name, module in model.named_modules():
if isinstance(module, nn.LayerNorm):
hooks.append(module.register_forward_hook(make_forward_hook(name)))
# Forward + backward on diagnostic batch
model.eval()
X, Y = get_batch(data_dir, cfg.block_size, 32, device)
logits, loss = model(X, Y)
loss.backward()
# Collect g_tilde for each LN
for name in ln_data:
out_ref = ln_data[name]["output_ref"]
if out_ref.grad is not None:
ln_data[name]["g_tilde"] = out_ref.grad.detach()
for h in hooks:
h.remove()
# Analyze decomposition
results = {}
for name, d in ln_data.items():
if "g_tilde" not in d:
continue
g = d["g_tilde"] # (B, T, dim)
z = d["z"]
sigma = d["sigma"]
dim = g.shape[-1]
# True LN Jacobian action: g_x = (1/sigma) * (g - mean(g) - z*mean(g*z))
g_mean = g.mean(dim=-1, keepdim=True)
gz_mean = (g * z).mean(dim=-1, keepdim=True)
g_true = (g - g_mean - z * gz_mean) / sigma # full LN backward
g_center = (g - g_mean) / sigma # center_scale only
g_ste = g # identity STE
# Energy fractions: what fraction of ||g||^2 is in each removed component?
g_norm_sq = (g * g).sum(-1).mean()
mean_component = g_mean.expand_as(g)
radial_component = z * gz_mean
r_mean = (mean_component * mean_component).sum(-1).mean() / (g_norm_sq + 1e-12)
r_radial = (radial_component * radial_component).sum(-1).mean() / (g_norm_sq + 1e-12)
# Cosines: how well does each surrogate match the true LN backward?
def batch_cos(a, b):
a_flat = a.reshape(-1, dim)
b_flat = b.reshape(-1, dim)
cos = F.cosine_similarity(a_flat, b_flat, dim=-1)
return cos.mean().item()
cos_center = batch_cos(g_center, g_true)
cos_ste = batch_cos(g_ste, g_true)
cos_center_to_ste = batch_cos(g_center, g_ste)
# Sigma statistics
sigma_mean = sigma.mean().item()
sigma_std = sigma.std().item()
results[name] = {
"r_mean": r_mean.item(),
"r_radial": r_radial.item(),
"sigma_mean": sigma_mean,
"sigma_std": sigma_std,
"cos_center_vs_true": cos_center,
"cos_ste_vs_true": cos_ste,
}
return results
def main():
device = "cuda" if torch.cuda.is_available() else "cpu"
data_dir = Path("data/shakespeare_char")
for attn in ["softmax", "sigmoid"]:
print(f"\n{'='*60}")
print(f" Attention: {attn}")
print(f"{'='*60}")
results = analyze_one_config(attn, device, data_dir)
print(f"{'name':30s} {'r_mean':>8s} {'r_rad':>8s} {'σ_μ':>8s} {'cos_c/t':>8s} {'cos_s/t':>8s}")
print("-" * 80)
for name, r in sorted(results.items()):
print(f"{name:30s} {r['r_mean']:8.4f} {r['r_radial']:8.4f} "
f"{r['sigma_mean']:8.3f} {r['cos_center_vs_true']:8.4f} {r['cos_ste_vs_true']:8.4f}")
# Summary
r_means = [r["r_mean"] for r in results.values()]
r_rads = [r["r_radial"] for r in results.values()]
cos_cs = [r["cos_center_vs_true"] for r in results.values()]
cos_ss = [r["cos_ste_vs_true"] for r in results.values()]
print(f"\n AVG r_mean={sum(r_means)/len(r_means):.4f} r_radial={sum(r_rads)/len(r_rads):.4f} "
f"cos_center={sum(cos_cs)/len(cos_cs):.4f} cos_ste={sum(cos_ss)/len(cos_ss):.4f}")
if __name__ == "__main__":
main()
|