summaryrefslogtreecommitdiff
path: root/ep_run/lt_ep_compare.py
blob: 01598554ab86c43d16c4e7dc953eb0e258e6be93 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""option 2: compare local attention-gradient quality vs true BP on the LM's attention.
Uses the project's OWN LocalCausalSelfAttention: FA (feedback alignment) and fuse_attn_local
(the hand-derived SoftmaxValueMixLocalFn local backward). Reports attention-param grad cosine
vs BP, to set against the AEP result (0.993) from the equilibrium reformulation."""
import math, pickle, numpy as np, torch, torch.nn.functional as F
from pathlib import Path
from model_local import LocalCausalSelfAttention, LocalGPTConfig

dev = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(0)
DD = Path('/tmp/lt_ep/data/shakespeare_char')
vocab = pickle.load(open(DD / 'meta.pkl', 'rb'))['vocab_size']
B, T, C, H = 16, 64, 128, 4


def get_batch():
    data = np.memmap(DD / 'train.bin', dtype=np.uint16, mode='r')
    ix = torch.randint(len(data) - T - 1, (B,))
    x = torch.stack([torch.from_numpy(data[i:i + T].astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy(data[i + 1:i + 1 + T].astype(np.int64)) for i in ix])
    return x.to(dev), y.to(dev)


idx, y = get_batch()
tok = torch.randn(vocab, C, device=dev) * 0.02
pos = torch.randn(T, C, device=dev) * 0.02
EMB = (tok[idx] + pos[None]).detach()
Whead = torch.randn(C, vocab, device=dev) / math.sqrt(C)


def make(method, fuse):
    cfg = LocalGPTConfig(block_size=T, vocab_size=vocab, n_head=H, n_embd=C,
                         attn_mode='softmax', method=method, fuse_attn_local=fuse, dropout=0.0, bias=False)
    return LocalCausalSelfAttention(cfg).to(dev)


bp = make('bp', False)
fa = make('fa', False)
fuse = make('bp', True)
# identical weights across the three so gradients are comparable
for m in (fa, fuse):
    for p in ('q_proj', 'k_proj', 'v_proj', 'o_proj'):
        getattr(m, p).weight.data.copy_(getattr(bp, p).weight.data)


def grads(model):
    for p in model.parameters():
        if p.grad is not None:
            p.grad = None
    o = model(EMB)
    F.cross_entropy((o @ Whead).reshape(-1, vocab), y.reshape(-1)).backward()
    return [getattr(model, p).weight.grad for p in ('q_proj', 'k_proj', 'v_proj', 'o_proj')]


def cos(g, gb):
    cs = [F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() for a, b in zip(g, gb)]
    return sum(cs) / len(cs), cs


gb = grads(bp)
gfa = grads(fa)
gfuse = grads(fuse)
names = ('WQ', 'WK', 'WV', 'WO')
mfa, cfa = cos(gfa, gb)
mfu, cfu = cos(gfuse, gb)
print("FEEDFORWARD attention (project's own code), grad cosine vs BP on Shakespeare:")
print(f"  FA (feedback align)        : mean {mfa:+.3f}   " + " ".join(f"{n}={c:+.2f}" for n, c in zip(names, cfa)))
print(f"  fuse_attn_local (SoftmaxValueMixLocalFn): mean {mfu:+.3f}   " + " ".join(f"{n}={c:+.2f}" for n, c in zip(names, cfu)))
print("\n(compare to AEP on equilibrium attention: mean +0.993)")