summaryrefslogtreecommitdiff
path: root/scripts/aep_projected.py
blob: af8891e7b7c1a9aad9facce6d58e949524cf017d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""C / option 1: PROJECTED AEP — non-conservative EP on the token-norm constraint manifold.

Two fixes over the unconstrained version:
 (1) STABILITY: relax with the token-norm projection  z <- Pi(z + eps F)  (bounds z;
     this is what made plain CET stable). Lets large-s / deep attention stop diverging.
 (2) CORRECT GRADIENT under the constraint: the VF contraction must be projected onto the
     TANGENT space of the manifold. The tangent projector at a normalized token z is
        P_z(v) = v - mean(v) - mean(v*z) * z
     (exactly the local-transformer's LayerNormProjectedSurrogate). Without it the VF
     estimator picks up the normal force and collapses (energy-mode cosine ~0.002).

Param-gradient:  dL/dtheta = <a_z, P_z*( dF_z/dtheta )> + <a_y, dF_y/dtheta>,
                 a = (state_-b - state_+b)/(2 beta).
AEP correction (nudged phase, on z): -s (J v - J^T v) of RealAttn, then projected.
"""
import argparse, math, torch, torch.nn.functional as F
from cet_mvp import token_norm, make_patch_mask, masked_cost, get_loaders
from cet_aep import CETReal

dev = 'cuda' if torch.cuda.is_available() else 'cpu'
ATTN = ('WQ', 'WK', 'WV', 'WO')


def P_tan(z, v):                       # tangent projection at normalized token z
    v = v - v.mean(-1, keepdim=True)
    zz = (z * z).mean(-1, keepdim=True).clamp_min(1e-6)
    return v - ((v * z).mean(-1, keepdim=True) / zz) * z


def force(model, xbar, z, y, s):
    z = z.requires_grad_(True); y = y.requires_grad_(True)
    gz, gy = torch.autograd.grad(model.E_rest(xbar, z, y), [z, y], create_graph=True)
    return -gz + s * model.real_attn(z), -gy


def relax_free(model, xbar, z, y, s, T1, eps):
    for _ in range(T1):
        with torch.enable_grad():
            fz, fy = force(model, xbar, z, y, s); fz, fy = fz.detach(), fy.detach()
        with torch.no_grad():
            z = token_norm(z + eps * fz); y = y + eps * fy
    return z.detach(), y.detach()


def relax_nudged(model, xbar, zs, ys, s, T2, eps, beta, sign, aep):
    z, y = zs.clone(), ys.clone()
    for _ in range(T2):
        with torch.enable_grad():
            fz, fy = force(model, xbar, z, y, s); fz, fy = fz.detach(), fy.detach()
            yy = y.detach().requires_grad_(True)
            gy, = torch.autograd.grad(masked_cost(yy, X, M), yy)
            fy = fy - sign * beta * gy
            if aep:
                v = (z - zs).detach()
                Jv = torch.autograd.functional.jvp(model.real_attn, zs, v)[1]
                JTv = torch.autograd.functional.vjp(model.real_attn, zs, v)[1]
                fz = fz - s * (Jv - JTv)
        with torch.no_grad():
            z = token_norm(z + eps * fz); y = y + eps * fy
    return z.detach(), y.detach()


def vf_grad(model, xbar, s, T1, T2, eps, beta, aep):
    zs, ys = relax_free(model, xbar, *model.init_state(xbar), s, T1, eps)
    zp, yp = relax_nudged(model, xbar, zs, ys, s, T2, eps, beta, +1, aep)
    zm, ym = relax_nudged(model, xbar, zs, ys, s, T2, eps, beta, -1, aep)
    az = P_tan(zs, ((zm - zp) / (2 * beta))).detach()          # adjoint in tangent space
    ay = ((ym - yp) / (2 * beta)).detach()
    with torch.enable_grad():
        fz, fy = force(model, xbar, zs.detach(), ys.detach(), s)
        s_ = (az * P_tan(zs, fz)).sum() + (ay * fy).sum()       # projected contraction
        g = torch.autograd.grad(s_, list(model.parameters()), allow_unused=True)
    return zs, g


def bptt_grad(model, xbar, s, T1, eps):
    z, y = model.init_state(xbar); z, y = z.requires_grad_(True), y.requires_grad_(True)
    for _ in range(T1):
        fz, fy = force(model, xbar, z, y, s)
        z = token_norm(z + eps * fz); y = y + eps * fy
    return torch.autograd.grad(masked_cost(y, X, M) / M.sum(),
                               list(model.parameters()), allow_unused=True)


def cosines(g, gb, names):
    def c(a, b): return F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item()
    at = [c(a, b) for n, a, b in zip(names, g, gb) if n in ATTN and a is not None and b is not None]
    A = torch.cat([x.flatten() for x in g if x is not None])
    B = torch.cat([y.flatten() for x, y in zip(g, gb) if x is not None and y is not None])
    return (sum(at) / len(at) if at else float('nan')), c(A, B)


def measure(model, names, s, T1, T2, eps, beta):
    gb = bptt_grad(model, XBAR, s, T1, eps)
    _, gn = vf_grad(model, XBAR, s, T1, T2, eps, beta, False)
    zs, ga = vf_grad(model, XBAR, s, T1, T2, eps, beta, True)
    an, gng = cosines(gn, gb, names)
    aa, gag = cosines(ga, gb, names)
    fin = torch.isfinite(zs).all().item()
    return an, aa, gng, gag, fin


def main():
    torch.manual_seed(0)
    model = CETReal(28, 1, 7, 7, D=64, heads=4, dh=16, mem=128).to(dev)
    names = [n for n, _ in model.named_parameters()]
    trl, _ = get_loaders(32, dataset='fashionmnist')
    global X, M, XBAR
    X, _ = next(iter(trl)); X = X.to(dev)
    M = make_patch_mask(X.size(0), model.gh, 7, 7, 28, 28, 0.5, dev)
    XBAR = X * (1 - M)

    print("SANITY s=0 (pure conservative): projected-VF global cosine should be ~1")
    _, _, gnaive, _, _ = measure(model, names, 0.0, 120, 20, 0.2, 0.02)
    print(f"  s=0 global cosine = {gnaive:.4f}\n")

    print("PROJECTED AEP across attention scale s  (T1=120 T2=30 beta=0.02)")
    print(f"{'s':>6} | {'naive(attn)':>11} {'AEP(attn)':>10} | {'finite?':>7}  (unproj. broke at s>=4)")
    for s in [0.5, 1.0, 2.0, 4.0, 8.0, 16.0]:
        an, aa, gn, ga, fin = measure(model, names, s, 120, 30, 0.2, 0.02)
        print(f"{s:6.2f} | {an:>11.3f} {aa:>10.3f} | {str(bool(fin)):>7}")


if __name__ == '__main__':
    main()