summaryrefslogtreecommitdiff
path: root/scripts/aep_characterize.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/aep_characterize.py')
-rw-r--r--scripts/aep_characterize.py157
1 files changed, 157 insertions, 0 deletions
diff --git a/scripts/aep_characterize.py b/scripts/aep_characterize.py
new file mode 100644
index 0000000..3e642c1
--- /dev/null
+++ b/scripts/aep_characterize.py
@@ -0,0 +1,157 @@
+"""
+Characterize AEP (non-conservative EP) on CET's attention, before porting to the LM.
+
+Controlled knob: attention scale s in force_z = -dE_rest/dz + s * RealAttn(z).
+ s=0 -> pure conservative reconstruction (A_J=0; EP exact)
+ s up -> attention dominates the force -> more non-conservative -> naive EP biased.
+Metric: cosine(EP-grad, BPTT-grad) on the attention params {WQ,WK,WV,WO} (the global
+cosine is diluted by the dominant conservative params, so we look at attention itself).
+The AEP correction is -s*(J_A v) on z, J_A = antisym Jacobian of RealAttn at the free eq.
+
+Sweeps: (1) s [non-conservativeness], (2) beta [nudge size], (3) T2 [nudge steps],
+ (4) T1 [free-phase convergence]. Plus: free-eq identical naive vs AEP, and cost.
+"""
+import argparse, math, time, torch, torch.nn.functional as F
+from cet_mvp import make_patch_mask, masked_cost, get_loaders
+from cet_aep import CETReal
+
+dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ATTN = ('WQ', 'WK', 'WV', 'WO')
+
+
+def force(model, xbar, z, y, s):
+ z = z.requires_grad_(True); y = y.requires_grad_(True)
+ gz, gy = torch.autograd.grad(model.E_rest(xbar, z, y), [z, y], create_graph=True)
+ return -gz + s * model.real_attn(z), -gy
+
+
+def relax_free(model, xbar, z, y, s, T1, eps):
+ for _ in range(T1):
+ with torch.enable_grad():
+ fz, fy = force(model, xbar, z, y, s)
+ fz, fy = fz.detach(), fy.detach()
+ with torch.no_grad():
+ z, y = z + eps * fz, y + eps * fy
+ return z.detach(), y.detach()
+
+
+def relax_nudged(model, xbar, zs, ys, s, T2, eps, beta, sign, aep):
+ z, y = zs.clone(), ys.clone()
+ for _ in range(T2):
+ with torch.enable_grad():
+ fz, fy = force(model, xbar, z, y, s)
+ fz, fy = fz.detach(), fy.detach()
+ yy = y.detach().requires_grad_(True)
+ gy, = torch.autograd.grad(masked_cost(yy, X, M), yy)
+ fy = fy - sign * beta * gy
+ if aep:
+ v = (z - zs).detach()
+ Jv = torch.autograd.functional.jvp(model.real_attn, zs, v)[1]
+ JTv = torch.autograd.functional.vjp(model.real_attn, zs, v)[1]
+ fz = fz - s * (Jv - JTv) # -2 * s * 0.5 (J v - J^T v)
+ with torch.no_grad():
+ z, y = z + eps * fz, y + eps * fy
+ return z.detach(), y.detach()
+
+
+def vf_grad(model, xbar, s, T1, T2, eps, beta, aep):
+ zs, ys = relax_free(model, xbar, *model.init_state(xbar), s, T1, eps)
+ zp, yp = relax_nudged(model, xbar, zs, ys, s, T2, eps, beta, +1, aep)
+ zm, ym = relax_nudged(model, xbar, zs, ys, s, T2, eps, beta, -1, aep)
+ az, ay = ((zm - zp) / (2 * beta)).detach(), ((ym - yp) / (2 * beta)).detach()
+ with torch.enable_grad():
+ fz, fy = force(model, xbar, zs.detach(), ys.detach(), s)
+ g = torch.autograd.grad((az * fz).sum() + (ay * fy).sum(),
+ list(model.parameters()), allow_unused=True)
+ return zs, g
+
+
+def bptt_grad(model, xbar, s, T1, eps):
+ z, y = model.init_state(xbar); z, y = z.requires_grad_(True), y.requires_grad_(True)
+ for _ in range(T1):
+ fz, fy = force(model, xbar, z, y, s)
+ z, y = z + eps * fz, y + eps * fy
+ return torch.autograd.grad(masked_cost(y, X, M) / M.sum(),
+ list(model.parameters()), allow_unused=True)
+
+
+def attn_cos(g, gb, names):
+ cs = [F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item()
+ for n, a, b in zip(names, g, gb) if n in ATTN and a is not None and b is not None]
+ return sum(cs) / len(cs)
+
+
+def global_cos(g, gb):
+ a = torch.cat([x.flatten() for x in g if x is not None])
+ b = torch.cat([x.flatten() for x, y in zip(g, gb) if x is not None and y is not None])
+ return F.cosine_similarity(a, b, dim=0).item()
+
+
+def measure(model, names, s, T1, T2, eps, beta):
+ gb = bptt_grad(model, XBAR, s, T1, eps)
+ zsn, gn = vf_grad(model, XBAR, s, T1, T2, eps, beta, aep=False)
+ zsa, ga = vf_grad(model, XBAR, s, T1, T2, eps, beta, aep=True)
+ eq_id = (zsn - zsa).norm().item() / (zsn.norm().item() + 1e-9) # free eq identical?
+ return dict(naive=attn_cos(gn, gb, names), aep=attn_cos(ga, gb, names),
+ gnaive=global_cos(gn, gb), gaep=global_cos(ga, gb), eq_id=eq_id)
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--dataset', default='fashionmnist')
+ ap.add_argument('--img', type=int, default=28); ap.add_argument('--ch', type=int, default=1)
+ ap.add_argument('--patch', type=int, default=7); ap.add_argument('--stride', type=int, default=7)
+ ap.add_argument('--batch', type=int, default=32)
+ cfg = ap.parse_args()
+ torch.manual_seed(0)
+ model = CETReal(cfg.img, cfg.ch, cfg.patch, cfg.stride, D=64, heads=4, dh=16, mem=128).to(dev)
+ names = [n for n, _ in model.named_parameters()]
+ trl, _ = get_loaders(cfg.batch, dataset=cfg.dataset)
+ global X, M, XBAR
+ X, _ = next(iter(trl)); X = X.to(dev)
+ M = make_patch_mask(X.size(0), model.gh, cfg.patch, cfg.stride, cfg.img, cfg.img, 0.5, dev)
+ XBAR = X * (1 - M)
+
+ # intrinsic non-conservativeness of the attention map itself
+ zs, _ = relax_free(model, XBAR, *model.init_state(XBAR), 1.0, 120, 0.2)
+ v = torch.randn_like(zs)
+ Jv = torch.autograd.functional.jvp(model.real_attn, zs, v)[1]
+ JTv = torch.autograd.functional.vjp(model.real_attn, zs, v)[1]
+ print(f"intrinsic attention-map antisymmetry ||A_J v||/||J v|| = "
+ f"{(0.5*(Jv-JTv)).norm().item()/(Jv.norm().item()+1e-9):.3f}")
+
+ base = dict(T1=120, T2=20, eps=0.2, beta=0.02)
+ print("\n[1] ATTENTION SCALE s (s=0 conservative -> larger = more non-conservative)")
+ print(f"{'s':>6} | {'naive(attn)':>11} {'AEP(attn)':>10} | {'naive(glob)':>11} {'AEP(glob)':>10} | free-eq id")
+ for s in [0.25, 0.5, 1.0, 2.0, 4.0, 8.0]:
+ r = measure(model, names, s, base['T1'], base['T2'], base['eps'], base['beta'])
+ print(f"{s:6.2f} | {r['naive']:>11.3f} {r['aep']:>10.3f} | {r['gnaive']:>11.4f} {r['gaep']:>10.4f} | {r['eq_id']:.1e}")
+
+ print("\n[2] NUDGE STRENGTH beta (s=2, T2=20)")
+ print(f"{'beta':>6} | {'naive(attn)':>11} {'AEP(attn)':>10}")
+ for beta in [0.005, 0.01, 0.02, 0.05, 0.1, 0.2]:
+ r = measure(model, names, 2.0, 120, 20, 0.2, beta)
+ print(f"{beta:6.3f} | {r['naive']:>11.3f} {r['aep']:>10.3f}")
+
+ print("\n[3] NUDGE STEPS T2 (s=2, beta=0.02)")
+ print(f"{'T2':>6} | {'naive(attn)':>11} {'AEP(attn)':>10}")
+ for T2 in [3, 5, 10, 20, 40]:
+ r = measure(model, names, 2.0, 120, T2, 0.2, 0.02)
+ print(f"{T2:6d} | {r['naive']:>11.3f} {r['aep']:>10.3f}")
+
+ print("\n[4] FREE-PHASE STEPS T1 (s=2; AEP uses A_J at the free eq)")
+ print(f"{'T1':>6} | {'naive(attn)':>11} {'AEP(attn)':>10}")
+ for T1 in [20, 40, 80, 120, 200]:
+ r = measure(model, names, 2.0, T1, 20, 0.2, 0.02)
+ print(f"{T1:6d} | {r['naive']:>11.3f} {r['aep']:>10.3f}")
+
+ print("\n[5] COST (s=2, T1=120, T2=20)")
+ t = time.time(); [vf_grad(model, XBAR, 2.0, 120, 20, 0.2, 0.02, aep=False) for _ in range(3)]
+ torch.cuda.synchronize() if dev == 'cuda' else None; tn = (time.time()-t)/3
+ t = time.time(); [vf_grad(model, XBAR, 2.0, 120, 20, 0.2, 0.02, aep=True) for _ in range(3)]
+ torch.cuda.synchronize() if dev == 'cuda' else None; ta = (time.time()-t)/3
+ print(f" naive {tn*1000:.0f} ms/grad AEP {ta*1000:.0f} ms/grad overhead {ta/tn:.2f}x")
+
+
+if __name__ == '__main__':
+ main()