diff options
Diffstat (limited to 'ep_run/lt_ep_compare.py')
| -rw-r--r-- | ep_run/lt_ep_compare.py | 69 |
1 files changed, 69 insertions, 0 deletions
diff --git a/ep_run/lt_ep_compare.py b/ep_run/lt_ep_compare.py new file mode 100644 index 0000000..0159855 --- /dev/null +++ b/ep_run/lt_ep_compare.py @@ -0,0 +1,69 @@ +"""option 2: compare local attention-gradient quality vs true BP on the LM's attention. +Uses the project's OWN LocalCausalSelfAttention: FA (feedback alignment) and fuse_attn_local +(the hand-derived SoftmaxValueMixLocalFn local backward). Reports attention-param grad cosine +vs BP, to set against the AEP result (0.993) from the equilibrium reformulation.""" +import math, pickle, numpy as np, torch, torch.nn.functional as F +from pathlib import Path +from model_local import LocalCausalSelfAttention, LocalGPTConfig + +dev = 'cuda' if torch.cuda.is_available() else 'cpu' +torch.manual_seed(0) +DD = Path('/tmp/lt_ep/data/shakespeare_char') +vocab = pickle.load(open(DD / 'meta.pkl', 'rb'))['vocab_size'] +B, T, C, H = 16, 64, 128, 4 + + +def get_batch(): + data = np.memmap(DD / 'train.bin', dtype=np.uint16, mode='r') + ix = torch.randint(len(data) - T - 1, (B,)) + x = torch.stack([torch.from_numpy(data[i:i + T].astype(np.int64)) for i in ix]) + y = torch.stack([torch.from_numpy(data[i + 1:i + 1 + T].astype(np.int64)) for i in ix]) + return x.to(dev), y.to(dev) + + +idx, y = get_batch() +tok = torch.randn(vocab, C, device=dev) * 0.02 +pos = torch.randn(T, C, device=dev) * 0.02 +EMB = (tok[idx] + pos[None]).detach() +Whead = torch.randn(C, vocab, device=dev) / math.sqrt(C) + + +def make(method, fuse): + cfg = LocalGPTConfig(block_size=T, vocab_size=vocab, n_head=H, n_embd=C, + attn_mode='softmax', method=method, fuse_attn_local=fuse, dropout=0.0, bias=False) + return LocalCausalSelfAttention(cfg).to(dev) + + +bp = make('bp', False) +fa = make('fa', False) +fuse = make('bp', True) +# identical weights across the three so gradients are comparable +for m in (fa, fuse): + for p in ('q_proj', 'k_proj', 'v_proj', 'o_proj'): + getattr(m, p).weight.data.copy_(getattr(bp, p).weight.data) + + +def grads(model): + for p in model.parameters(): + if p.grad is not None: + p.grad = None + o = model(EMB) + F.cross_entropy((o @ Whead).reshape(-1, vocab), y.reshape(-1)).backward() + return [getattr(model, p).weight.grad for p in ('q_proj', 'k_proj', 'v_proj', 'o_proj')] + + +def cos(g, gb): + cs = [F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() for a, b in zip(g, gb)] + return sum(cs) / len(cs), cs + + +gb = grads(bp) +gfa = grads(fa) +gfuse = grads(fuse) +names = ('WQ', 'WK', 'WV', 'WO') +mfa, cfa = cos(gfa, gb) +mfu, cfu = cos(gfuse, gb) +print("FEEDFORWARD attention (project's own code), grad cosine vs BP on Shakespeare:") +print(f" FA (feedback align) : mean {mfa:+.3f} " + " ".join(f"{n}={c:+.2f}" for n, c in zip(names, cfa))) +print(f" fuse_attn_local (SoftmaxValueMixLocalFn): mean {mfu:+.3f} " + " ".join(f"{n}={c:+.2f}" for n, c in zip(names, cfu))) +print("\n(compare to AEP on equilibrium attention: mean +0.993)") |
