1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
|
"""option 2: compare local attention-gradient quality vs true BP on the LM's attention.
Uses the project's OWN LocalCausalSelfAttention: FA (feedback alignment) and fuse_attn_local
(the hand-derived SoftmaxValueMixLocalFn local backward). Reports attention-param grad cosine
vs BP, to set against the AEP result (0.993) from the equilibrium reformulation."""
import math, pickle, numpy as np, torch, torch.nn.functional as F
from pathlib import Path
from model_local import LocalCausalSelfAttention, LocalGPTConfig
dev = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(0)
DD = Path('/tmp/lt_ep/data/shakespeare_char')
vocab = pickle.load(open(DD / 'meta.pkl', 'rb'))['vocab_size']
B, T, C, H = 16, 64, 128, 4
def get_batch():
data = np.memmap(DD / 'train.bin', dtype=np.uint16, mode='r')
ix = torch.randint(len(data) - T - 1, (B,))
x = torch.stack([torch.from_numpy(data[i:i + T].astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy(data[i + 1:i + 1 + T].astype(np.int64)) for i in ix])
return x.to(dev), y.to(dev)
idx, y = get_batch()
tok = torch.randn(vocab, C, device=dev) * 0.02
pos = torch.randn(T, C, device=dev) * 0.02
EMB = (tok[idx] + pos[None]).detach()
Whead = torch.randn(C, vocab, device=dev) / math.sqrt(C)
def make(method, fuse):
cfg = LocalGPTConfig(block_size=T, vocab_size=vocab, n_head=H, n_embd=C,
attn_mode='softmax', method=method, fuse_attn_local=fuse, dropout=0.0, bias=False)
return LocalCausalSelfAttention(cfg).to(dev)
bp = make('bp', False)
fa = make('fa', False)
fuse = make('bp', True)
# identical weights across the three so gradients are comparable
for m in (fa, fuse):
for p in ('q_proj', 'k_proj', 'v_proj', 'o_proj'):
getattr(m, p).weight.data.copy_(getattr(bp, p).weight.data)
def grads(model):
for p in model.parameters():
if p.grad is not None:
p.grad = None
o = model(EMB)
F.cross_entropy((o @ Whead).reshape(-1, vocab), y.reshape(-1)).backward()
return [getattr(model, p).weight.grad for p in ('q_proj', 'k_proj', 'v_proj', 'o_proj')]
def cos(g, gb):
cs = [F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() for a, b in zip(g, gb)]
return sum(cs) / len(cs), cs
gb = grads(bp)
gfa = grads(fa)
gfuse = grads(fuse)
names = ('WQ', 'WK', 'WV', 'WO')
mfa, cfa = cos(gfa, gb)
mfu, cfu = cos(gfuse, gb)
print("FEEDFORWARD attention (project's own code), grad cosine vs BP on Shakespeare:")
print(f" FA (feedback align) : mean {mfa:+.3f} " + " ".join(f"{n}={c:+.2f}" for n, c in zip(names, cfa)))
print(f" fuse_attn_local (SoftmaxValueMixLocalFn): mean {mfu:+.3f} " + " ".join(f"{n}={c:+.2f}" for n, c in zip(names, cfu)))
print("\n(compare to AEP on equilibrium attention: mean +0.993)")
|