"""option 2: compare local attention-gradient quality vs true BP on the LM's attention. Uses the project's OWN LocalCausalSelfAttention: FA (feedback alignment) and fuse_attn_local (the hand-derived SoftmaxValueMixLocalFn local backward). Reports attention-param grad cosine vs BP, to set against the AEP result (0.993) from the equilibrium reformulation.""" import math, pickle, numpy as np, torch, torch.nn.functional as F from pathlib import Path from model_local import LocalCausalSelfAttention, LocalGPTConfig dev = 'cuda' if torch.cuda.is_available() else 'cpu' torch.manual_seed(0) DD = Path('/tmp/lt_ep/data/shakespeare_char') vocab = pickle.load(open(DD / 'meta.pkl', 'rb'))['vocab_size'] B, T, C, H = 16, 64, 128, 4 def get_batch(): data = np.memmap(DD / 'train.bin', dtype=np.uint16, mode='r') ix = torch.randint(len(data) - T - 1, (B,)) x = torch.stack([torch.from_numpy(data[i:i + T].astype(np.int64)) for i in ix]) y = torch.stack([torch.from_numpy(data[i + 1:i + 1 + T].astype(np.int64)) for i in ix]) return x.to(dev), y.to(dev) idx, y = get_batch() tok = torch.randn(vocab, C, device=dev) * 0.02 pos = torch.randn(T, C, device=dev) * 0.02 EMB = (tok[idx] + pos[None]).detach() Whead = torch.randn(C, vocab, device=dev) / math.sqrt(C) def make(method, fuse): cfg = LocalGPTConfig(block_size=T, vocab_size=vocab, n_head=H, n_embd=C, attn_mode='softmax', method=method, fuse_attn_local=fuse, dropout=0.0, bias=False) return LocalCausalSelfAttention(cfg).to(dev) bp = make('bp', False) fa = make('fa', False) fuse = make('bp', True) # identical weights across the three so gradients are comparable for m in (fa, fuse): for p in ('q_proj', 'k_proj', 'v_proj', 'o_proj'): getattr(m, p).weight.data.copy_(getattr(bp, p).weight.data) def grads(model): for p in model.parameters(): if p.grad is not None: p.grad = None o = model(EMB) F.cross_entropy((o @ Whead).reshape(-1, vocab), y.reshape(-1)).backward() return [getattr(model, p).weight.grad for p in ('q_proj', 'k_proj', 'v_proj', 'o_proj')] def cos(g, gb): cs = [F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() for a, b in zip(g, gb)] return sum(cs) / len(cs), cs gb = grads(bp) gfa = grads(fa) gfuse = grads(fuse) names = ('WQ', 'WK', 'WV', 'WO') mfa, cfa = cos(gfa, gb) mfu, cfu = cos(gfuse, gb) print("FEEDFORWARD attention (project's own code), grad cosine vs BP on Shakespeare:") print(f" FA (feedback align) : mean {mfa:+.3f} " + " ".join(f"{n}={c:+.2f}" for n, c in zip(names, cfa))) print(f" fuse_attn_local (SoftmaxValueMixLocalFn): mean {mfu:+.3f} " + " ".join(f"{n}={c:+.2f}" for n, c in zip(names, cfu))) print("\n(compare to AEP on equilibrium attention: mean +0.993)")