1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
|
"""
Characterize AEP (non-conservative EP) on CET's attention, before porting to the LM.
Controlled knob: attention scale s in force_z = -dE_rest/dz + s * RealAttn(z).
s=0 -> pure conservative reconstruction (A_J=0; EP exact)
s up -> attention dominates the force -> more non-conservative -> naive EP biased.
Metric: cosine(EP-grad, BPTT-grad) on the attention params {WQ,WK,WV,WO} (the global
cosine is diluted by the dominant conservative params, so we look at attention itself).
The AEP correction is -s*(J_A v) on z, J_A = antisym Jacobian of RealAttn at the free eq.
Sweeps: (1) s [non-conservativeness], (2) beta [nudge size], (3) T2 [nudge steps],
(4) T1 [free-phase convergence]. Plus: free-eq identical naive vs AEP, and cost.
"""
import argparse, math, time, torch, torch.nn.functional as F
from cet_mvp import make_patch_mask, masked_cost, get_loaders
from cet_aep import CETReal
dev = 'cuda' if torch.cuda.is_available() else 'cpu'
ATTN = ('WQ', 'WK', 'WV', 'WO')
def force(model, xbar, z, y, s):
z = z.requires_grad_(True); y = y.requires_grad_(True)
gz, gy = torch.autograd.grad(model.E_rest(xbar, z, y), [z, y], create_graph=True)
return -gz + s * model.real_attn(z), -gy
def relax_free(model, xbar, z, y, s, T1, eps):
for _ in range(T1):
with torch.enable_grad():
fz, fy = force(model, xbar, z, y, s)
fz, fy = fz.detach(), fy.detach()
with torch.no_grad():
z, y = z + eps * fz, y + eps * fy
return z.detach(), y.detach()
def relax_nudged(model, xbar, zs, ys, s, T2, eps, beta, sign, aep):
z, y = zs.clone(), ys.clone()
for _ in range(T2):
with torch.enable_grad():
fz, fy = force(model, xbar, z, y, s)
fz, fy = fz.detach(), fy.detach()
yy = y.detach().requires_grad_(True)
gy, = torch.autograd.grad(masked_cost(yy, X, M), yy)
fy = fy - sign * beta * gy
if aep:
v = (z - zs).detach()
Jv = torch.autograd.functional.jvp(model.real_attn, zs, v)[1]
JTv = torch.autograd.functional.vjp(model.real_attn, zs, v)[1]
fz = fz - s * (Jv - JTv) # -2 * s * 0.5 (J v - J^T v)
with torch.no_grad():
z, y = z + eps * fz, y + eps * fy
return z.detach(), y.detach()
def vf_grad(model, xbar, s, T1, T2, eps, beta, aep):
zs, ys = relax_free(model, xbar, *model.init_state(xbar), s, T1, eps)
zp, yp = relax_nudged(model, xbar, zs, ys, s, T2, eps, beta, +1, aep)
zm, ym = relax_nudged(model, xbar, zs, ys, s, T2, eps, beta, -1, aep)
az, ay = ((zm - zp) / (2 * beta)).detach(), ((ym - yp) / (2 * beta)).detach()
with torch.enable_grad():
fz, fy = force(model, xbar, zs.detach(), ys.detach(), s)
g = torch.autograd.grad((az * fz).sum() + (ay * fy).sum(),
list(model.parameters()), allow_unused=True)
return zs, g
def bptt_grad(model, xbar, s, T1, eps):
z, y = model.init_state(xbar); z, y = z.requires_grad_(True), y.requires_grad_(True)
for _ in range(T1):
fz, fy = force(model, xbar, z, y, s)
z, y = z + eps * fz, y + eps * fy
return torch.autograd.grad(masked_cost(y, X, M) / M.sum(),
list(model.parameters()), allow_unused=True)
def attn_cos(g, gb, names):
cs = [F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item()
for n, a, b in zip(names, g, gb) if n in ATTN and a is not None and b is not None]
return sum(cs) / len(cs)
def global_cos(g, gb):
a = torch.cat([x.flatten() for x in g if x is not None])
b = torch.cat([x.flatten() for x, y in zip(g, gb) if x is not None and y is not None])
return F.cosine_similarity(a, b, dim=0).item()
def measure(model, names, s, T1, T2, eps, beta):
gb = bptt_grad(model, XBAR, s, T1, eps)
zsn, gn = vf_grad(model, XBAR, s, T1, T2, eps, beta, aep=False)
zsa, ga = vf_grad(model, XBAR, s, T1, T2, eps, beta, aep=True)
eq_id = (zsn - zsa).norm().item() / (zsn.norm().item() + 1e-9) # free eq identical?
return dict(naive=attn_cos(gn, gb, names), aep=attn_cos(ga, gb, names),
gnaive=global_cos(gn, gb), gaep=global_cos(ga, gb), eq_id=eq_id)
def main():
ap = argparse.ArgumentParser()
ap.add_argument('--dataset', default='fashionmnist')
ap.add_argument('--img', type=int, default=28); ap.add_argument('--ch', type=int, default=1)
ap.add_argument('--patch', type=int, default=7); ap.add_argument('--stride', type=int, default=7)
ap.add_argument('--batch', type=int, default=32)
cfg = ap.parse_args()
torch.manual_seed(0)
model = CETReal(cfg.img, cfg.ch, cfg.patch, cfg.stride, D=64, heads=4, dh=16, mem=128).to(dev)
names = [n for n, _ in model.named_parameters()]
trl, _ = get_loaders(cfg.batch, dataset=cfg.dataset)
global X, M, XBAR
X, _ = next(iter(trl)); X = X.to(dev)
M = make_patch_mask(X.size(0), model.gh, cfg.patch, cfg.stride, cfg.img, cfg.img, 0.5, dev)
XBAR = X * (1 - M)
# intrinsic non-conservativeness of the attention map itself
zs, _ = relax_free(model, XBAR, *model.init_state(XBAR), 1.0, 120, 0.2)
v = torch.randn_like(zs)
Jv = torch.autograd.functional.jvp(model.real_attn, zs, v)[1]
JTv = torch.autograd.functional.vjp(model.real_attn, zs, v)[1]
print(f"intrinsic attention-map antisymmetry ||A_J v||/||J v|| = "
f"{(0.5*(Jv-JTv)).norm().item()/(Jv.norm().item()+1e-9):.3f}")
base = dict(T1=120, T2=20, eps=0.2, beta=0.02)
print("\n[1] ATTENTION SCALE s (s=0 conservative -> larger = more non-conservative)")
print(f"{'s':>6} | {'naive(attn)':>11} {'AEP(attn)':>10} | {'naive(glob)':>11} {'AEP(glob)':>10} | free-eq id")
for s in [0.25, 0.5, 1.0, 2.0, 4.0, 8.0]:
r = measure(model, names, s, base['T1'], base['T2'], base['eps'], base['beta'])
print(f"{s:6.2f} | {r['naive']:>11.3f} {r['aep']:>10.3f} | {r['gnaive']:>11.4f} {r['gaep']:>10.4f} | {r['eq_id']:.1e}")
print("\n[2] NUDGE STRENGTH beta (s=2, T2=20)")
print(f"{'beta':>6} | {'naive(attn)':>11} {'AEP(attn)':>10}")
for beta in [0.005, 0.01, 0.02, 0.05, 0.1, 0.2]:
r = measure(model, names, 2.0, 120, 20, 0.2, beta)
print(f"{beta:6.3f} | {r['naive']:>11.3f} {r['aep']:>10.3f}")
print("\n[3] NUDGE STEPS T2 (s=2, beta=0.02)")
print(f"{'T2':>6} | {'naive(attn)':>11} {'AEP(attn)':>10}")
for T2 in [3, 5, 10, 20, 40]:
r = measure(model, names, 2.0, 120, T2, 0.2, 0.02)
print(f"{T2:6d} | {r['naive']:>11.3f} {r['aep']:>10.3f}")
print("\n[4] FREE-PHASE STEPS T1 (s=2; AEP uses A_J at the free eq)")
print(f"{'T1':>6} | {'naive(attn)':>11} {'AEP(attn)':>10}")
for T1 in [20, 40, 80, 120, 200]:
r = measure(model, names, 2.0, T1, 20, 0.2, 0.02)
print(f"{T1:6d} | {r['naive']:>11.3f} {r['aep']:>10.3f}")
print("\n[5] COST (s=2, T1=120, T2=20)")
t = time.time(); [vf_grad(model, XBAR, 2.0, 120, 20, 0.2, 0.02, aep=False) for _ in range(3)]
torch.cuda.synchronize() if dev == 'cuda' else None; tn = (time.time()-t)/3
t = time.time(); [vf_grad(model, XBAR, 2.0, 120, 20, 0.2, 0.02, aep=True) for _ in range(3)]
torch.cuda.synchronize() if dev == 'cuda' else None; ta = (time.time()-t)/3
print(f" naive {tn*1000:.0f} ms/grad AEP {ta*1000:.0f} ms/grad overhead {ta/tn:.2f}x")
if __name__ == '__main__':
main()
|