summaryrefslogtreecommitdiff
path: root/scripts/aep_option1.py
blob: 65583d43a82453b21b437398f953b4f7273705ee (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""option 1: CORRECT gradient for non-conservative attention UNDER the token-norm constraint.

Implicit differentiation of the projected fixed-point map  G(x) = Pi(x + eps F(x)):
  adjoint   a <- J_G^T a + g ,   J_G^T = (I + eps J_F^T) Pi'^T ,  g = dC/dx*
  gradient  dL/dtheta = eps * < Pi'^T a , dF/dtheta(x*) >

Built from LOCAL pieces only (this is the projected analogue of EP's nudged adjoint):
  Pi'^T  : vjp(token_norm, u, .)              (the LN/projection Jacobian = LayerNormProjectedSurrogate)
  J_F^T  : -Hess(E_rest).b  (symmetric, via HVP)  +  s * vjp(real_attn, z*, .)  (the non-conservative bit)
Validation: cosine vs BPTT-through-the-projected-relaxation (ground truth). C lost fidelity; this should recover it.
"""
import torch, torch.nn.functional as F, math
from cet_mvp import token_norm, make_patch_mask, masked_cost, get_loaders
from cet_aep import CETReal

dev = 'cuda' if torch.cuda.is_available() else 'cpu'
ATTN = ('WQ', 'WK', 'WV', 'WO')


def force(model, xbar, z, y, s, cg=False):
    gz, gy = torch.autograd.grad(model.E_rest(xbar, z, y), [z, y], create_graph=cg)
    return -gz + s * model.real_attn(z), -gy


def relax_free(model, xbar, z, y, s, T1, eps):
    for _ in range(T1):
        with torch.enable_grad():
            zr, yr = z.requires_grad_(True), y.requires_grad_(True)
            fz, fy = force(model, xbar, zr, yr, s)
            fz, fy = fz.detach(), fy.detach()
        with torch.no_grad():
            z, y = token_norm(z + eps * fz), y + eps * fy
    return z.detach(), y.detach()


def adjoint_grad(model, xbar, s, T1, eps, Tadj):
    zs, ys = relax_free(model, xbar, *model.init_state(xbar), s, T1, eps)
    # pre-projection point u for Pi' ; cost grad g=(0, dC/dy)
    zr, yr = zs.detach().requires_grad_(True), ys.detach().requires_grad_(True)
    fz, fy = force(model, xbar, zr, yr, s)
    uz = (zs + eps * fz).detach()
    yc = ys.detach().requires_grad_(True)
    gy_c, = torch.autograd.grad(masked_cost(yc, X, M) / M.sum(), yc)
    gy_c = gy_c.detach()

    az, ay = torch.zeros_like(zs), gy_c.clone()      # init adjoint at g (cost grad)
    for _ in range(Tadj):
        bz = torch.autograd.functional.vjp(token_norm, uz, az)[1]      # Pi'^T a (z); y identity
        by = ay
        # J_F^T b = -Hess(E_rest).b  +  s * vjp(real_attn, zs, bz)
        zr2, yr2 = zs.detach().requires_grad_(True), ys.detach().requires_grad_(True)
        gz2, gy2 = torch.autograd.grad(model.E_rest(xbar, zr2, yr2), [zr2, yr2], create_graph=True)
        hz, hy = torch.autograd.grad((gz2 * bz).sum() + (gy2 * by).sum(), [zr2, yr2])
        av = torch.autograd.functional.vjp(model.real_attn, zs, bz)[1]
        JFt_z, JFt_y = -hz + s * av, -hy
        az = (bz + eps * JFt_z + torch.zeros_like(zs)).detach()
        ay = (by + eps * JFt_y + gy_c).detach()

    # gradient: eps * d/dtheta < Pi'^T a , F(x*, theta) >
    bz = torch.autograd.functional.vjp(token_norm, uz, az)[1].detach()
    by = ay.detach()
    zr3, yr3 = zs.detach().requires_grad_(True), ys.detach().requires_grad_(True)
    gz3, gy3 = torch.autograd.grad(model.E_rest(xbar, zr3, yr3), [zr3, yr3], create_graph=True)
    Fz = -gz3 + s * model.real_attn(zr3)
    Fy = -gy3
    contr = eps * ((bz * Fz).sum() + (by * Fy).sum())
    return torch.autograd.grad(contr, list(model.parameters()), allow_unused=True)


def bptt_grad(model, xbar, s, T1, eps):
    z, y = model.init_state(xbar); z, y = z.requires_grad_(True), y.requires_grad_(True)
    for _ in range(T1):
        fz, fy = force(model, xbar, z, y, s, cg=True)
        z, y = token_norm(z + eps * fz), y + eps * fy
    return torch.autograd.grad(masked_cost(y, X, M) / M.sum(),
                               list(model.parameters()), allow_unused=True)


def cosines(g, gb, names):
    c = lambda a, b: F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item()
    at = [c(a, b) for n, a, b in zip(names, g, gb) if n in ATTN and a is not None and b is not None]
    A = torch.cat([x.flatten() for x in g if x is not None])
    B = torch.cat([y.flatten() for x, y in zip(g, gb) if x is not None and y is not None])
    return (sum(at) / len(at) if at else float('nan')), c(A, B)


def main():
    torch.manual_seed(0)
    model = CETReal(28, 1, 7, 7, D=64, heads=4, dh=16, mem=128).to(dev)
    names = [n for n, _ in model.named_parameters()]
    trl, _ = get_loaders(32, dataset='fashionmnist')
    global X, M, XBAR
    X, _ = next(iter(trl)); X = X.to(dev)
    M = make_patch_mask(X.size(0), model.gh, 7, 7, 28, 28, 0.5, dev)
    XBAR = X * (1 - M)
    def resid(s, T1, eps=0.2):
        zs, ys = relax_free(model, XBAR, *model.init_state(XBAR), s, T1, eps)
        with torch.enable_grad():
            zr, yr = zs.requires_grad_(True), ys.requires_grad_(True)
            fz, fy = force(model, XBAR, zr, yr, s)
        zn = token_norm(zs + eps * fz.detach())
        return ((zn - zs).norm() / (zs.norm() + 1e-9)).item()

    print("PROJECTED-ADJOINT (option 1) vs BPTT  — is the s>=2 break convergence or no-fixed-point?")
    print(f"{'s':>5} {'T1=Tadj':>8} | {'attn cos':>9} {'glob cos':>9} | {'fwd resid':>9}")
    for s in [0.5, 1.0, 2.0]:
        for it in [120, 400]:
            gb = bptt_grad(model, XBAR, s, it, 0.2)
            ga = adjoint_grad(model, XBAR, s, it, 0.2, it)
            a, g = cosines(ga, gb, names)
            print(f"{s:5.1f} {it:>8} | {a:>9.3f} {g:>9.3f} | {resid(s, it):>9.2e}")


if __name__ == '__main__':
    main()