summaryrefslogtreecommitdiff
path: root/scripts/cet_aep.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/cet_aep.py')
-rw-r--r--scripts/cet_aep.py272
1 files changed, 272 insertions, 0 deletions
diff --git a/scripts/cet_aep.py b/scripts/cet_aep.py
new file mode 100644
index 0000000..44e8922
--- /dev/null
+++ b/scripts/cet_aep.py
@@ -0,0 +1,272 @@
+"""
+AEP applied to CET's attention: replace CET's conservative energy-attention E^att
+with a REAL (non-conservative) transformer attention inside the CET, and use the
+AEP correction so EP still recovers the true gradient.
+
+CET state = (tokens z, reconstruction y). The conservative part
+ E_rest = E_enc + E_pos + E_mem + E_dec (scalar -> symmetric Jacobian)
+keeps its energy-gradient force. The token force gets its attention term from:
+ energy mode : -dE^att/dz (conservative; tied value; this is plain CET)
+ real mode : RealAttn(z) = WO softmax(QK^T/sqrt dh)(z WV) (non-conservative)
+
+Because only RealAttn is non-conservative, the full-force antisymmetric Jacobian
+A_J reduces to the antisymmetric part of dRealAttn/dz alone -> AEP correction is
+ force_z += -(J_A z~ - J_A^T z~) , z~=z-z* , J_A = dRealAttn/dz at z*
+(clean jvp/vjp on RealAttn, no nested autograd).
+
+We compare parameter-gradient quality vs ground-truth BPTT for:
+ energy / naive-EP (conservative CET; sanity, should be ~BPTT)
+ real / naive-EP (non-conservative; expected biased)
+ real / AEP (non-conservative + correction; expected ~BPTT)
+"""
+import argparse, math, time, json, os, torch, torch.nn as nn, torch.nn.functional as F
+from cet_mvp import token_norm, make_patch_mask, masked_cost, masked_mse, get_loaders
+
+
+class CETReal(nn.Module):
+ def __init__(self, img=32, ch=3, patch=8, stride=8, D=64, heads=4, dh=16,
+ mem=128, gamma=0.25):
+ super().__init__()
+ self.ch, self.patch, self.stride, self.D = ch, patch, stride, D
+ self.heads, self.dh, self.gamma = heads, dh, gamma
+ gh = (img - patch) // stride + 1
+ self.gh, self.N = gh, gh * gh
+ self.damp = 0.0 # contraction damping c: real_attn returns attn(z) - c*z
+ self.Wenc = nn.Parameter(torch.empty(D, ch, patch, patch))
+ self.benc = nn.Parameter(torch.zeros(D))
+ self.bpos = nn.Parameter(torch.zeros(self.N, D))
+ self.Wdec = nn.Parameter(torch.empty(D, ch, patch, patch))
+ self.bdec = nn.Parameter(torch.zeros(ch))
+ self.Wmem = nn.Parameter(torch.empty(D, mem))
+ # attention: WQ/WK used by both; WV/WO only by the real (non-conservative) path
+ self.WQ = nn.Parameter(torch.empty(heads, dh, D))
+ self.WK = nn.Parameter(torch.empty(heads, dh, D))
+ self.WV = nn.Parameter(torch.empty(heads, dh, D))
+ self.WO = nn.Parameter(torch.empty(D, heads * dh))
+ nn.init.kaiming_normal_(self.Wenc); self.Wenc.data *= 0.5
+ nn.init.kaiming_normal_(self.Wdec); self.Wdec.data *= 0.5
+ for w in (self.WQ, self.WK, self.WV):
+ nn.init.normal_(w, std=1.0 / math.sqrt(D))
+ nn.init.normal_(self.Wmem, std=0.3 / math.sqrt(D)) # small: keep energy bounded-below
+ nn.init.normal_(self.WO, std=1.0 / math.sqrt(heads * dh))
+
+ def encode(self, xbar):
+ return F.conv2d(xbar, self.Wenc, stride=self.stride).flatten(2).transpose(1, 2)
+
+ def decode_conv(self, y):
+ return F.conv2d(y, self.Wdec, stride=self.stride).flatten(2).transpose(1, 2)
+
+ def E_rest(self, xbar, z, y): # conservative scalar (no attention)
+ enc = self.encode(xbar)
+ E = 2.0 * (z ** 2).sum() - (enc * z).sum() - (z * self.benc).sum() - (z * self.bpos).sum()
+ proj = torch.einsum('bnd,dm->bnm', z, self.Wmem)
+ E = E - (F.relu(proj) ** 2).sum()
+ dc = self.decode_conv(y)
+ E = E + 0.5 * (y ** 2).sum() - (dc * z).sum() - (y * self.bdec[None, :, None, None]).sum()
+ return E
+
+ def E_att(self, z): # conservative LogSumExp energy (tied value)
+ Q = torch.einsum('bnd,hjd->bhnj', z, self.WQ)
+ K = torch.einsum('bnd,hjd->bhnj', z, self.WK)
+ A = torch.einsum('bhmj,bhnj->bhmn', Q, K)
+ return -(1.0 / self.gamma) * torch.logsumexp(self.gamma * A, dim=-1).sum()
+
+ def real_attn(self, z): # NON-conservative real attention force
+ B = z.size(0)
+ q = torch.einsum('bnd,hjd->bhnj', z, self.WQ)
+ k = torch.einsum('bnd,hjd->bhnj', z, self.WK)
+ v = torch.einsum('bnd,hjd->bhnj', z, self.WV)
+ A = torch.softmax((q @ k.transpose(-2, -1)) / math.sqrt(self.dh), dim=-1)
+ o = (A @ v).transpose(1, 2).reshape(B, self.N, self.heads * self.dh)
+ return o @ self.WO.t() - self.damp * z # -c*z: symmetric -> contraction, A_J unchanged
+
+ def force(self, xbar, z, y, mode):
+ """Return (force_z, force_y). force = -dE/dstate (+ real attention if mode='real')."""
+ z = z.requires_grad_(True); y = y.requires_grad_(True)
+ if mode == 'energy':
+ E = self.E_rest(xbar, z, y) + self.E_att(z)
+ gz, gy = torch.autograd.grad(E, [z, y], create_graph=True)
+ return -gz, -gy
+ else:
+ E = self.E_rest(xbar, z, y)
+ gz, gy = torch.autograd.grad(E, [z, y], create_graph=True)
+ return -gz + self.real_attn(z), -gy
+
+ def init_state(self, xbar):
+ return token_norm(self.encode(xbar)).detach(), xbar.clone().detach()
+
+
+def relax(model, xbar, z, y, steps, eps, mode, x=None, M=None, beta=0.0, aep=False, zstar=None):
+ for _ in range(steps):
+ with torch.enable_grad():
+ fz, fy = model.force(xbar, z, y, mode)
+ fz, fy = fz.detach(), fy.detach()
+ if beta != 0.0: # nudge on the output y
+ yy = y.detach().requires_grad_(True)
+ gy, = torch.autograd.grad(masked_cost(yy, x, M), yy)
+ fy = fy - beta * gy
+ if aep: # AEP correction on z (attention block only)
+ v = (z - zstar).detach()
+ fa = lambda zz: model.real_attn(zz)
+ Jv = torch.autograd.functional.jvp(fa, zstar, v)[1]
+ JTv = torch.autograd.functional.vjp(fa, zstar, v)[1]
+ corr = Jv - JTv # = 2 * 0.5 (J v - J^T v)
+ cn, fn = corr.norm(), fz.norm() + 1e-8 # clip so correction can't dominate -> no blow-up
+ if cn > fn:
+ corr = corr * (fn / cn)
+ fz = fz - corr
+ with torch.no_grad():
+ z = z + eps * fz # unconstrained (0.5||z||^2 in E_rest keeps it bounded)
+ y = y + eps * fy
+ return z.detach(), y.detach()
+
+
+def vf_param_grad(model, xbar, x, M, mode, T1, T2, eps, beta, aep):
+ z0, y0 = model.init_state(xbar)
+ zs, ys = relax(model, xbar, z0, y0, T1, eps, mode)
+ zp, yp = relax(model, xbar, zs.clone(), ys.clone(), T2, eps, mode, x, M, +beta, aep, zs)
+ zm, ym = relax(model, xbar, zs.clone(), ys.clone(), T2, eps, mode, x, M, -beta, aep, zs)
+ az, ay = ((zm - zp) / (2 * beta)).detach(), ((ym - yp) / (2 * beta)).detach()
+ with torch.enable_grad():
+ fz, fy = model.force(xbar, zs.detach(), ys.detach(), mode)
+ s = (az * fz).sum() + (ay * fy).sum()
+ grads = torch.autograd.grad(s, list(model.parameters()), allow_unused=True, retain_graph=False)
+ return grads
+
+
+def bptt_param_grad(model, xbar, x, M, mode, T1, eps):
+ z, y = model.init_state(xbar)
+ z, y = z.requires_grad_(True), y.requires_grad_(True)
+ for _ in range(T1):
+ fz, fy = model.force(xbar, z, y, mode)
+ z = z + eps * fz
+ y = y + eps * fy
+ L = masked_cost(y, x, M) / M.sum()
+ return torch.autograd.grad(L, list(model.parameters()), allow_unused=True)
+
+
+def cos(ga, gb, names):
+ fa, fb = [], []
+ per = {}
+ for n, a, b in zip(names, ga, gb):
+ if a is None or b is None:
+ continue
+ fa.append(a.flatten()); fb.append(b.flatten())
+ per[n] = F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item()
+ g = F.cosine_similarity(torch.cat(fa), torch.cat(fb), dim=0).item()
+ return g, per
+
+
+def evaluate(model, loader, cfg, dev, mode='real', max_batches=40):
+ tot, n = 0.0, 0
+ gen = torch.Generator(device=dev).manual_seed(0)
+ for i, (x, _) in enumerate(loader):
+ if i >= max_batches:
+ break
+ x = x.to(dev)
+ M = make_patch_mask(x.size(0), model.gh, cfg.patch, cfg.stride, cfg.img, cfg.img, 0.5, dev, gen)
+ xbar = x * (1 - M)
+ z, y = relax(model, xbar, *model.init_state(xbar), cfg.T1, cfg.eps, mode)
+ tot += masked_mse(y, x, M) * x.size(0); n += x.size(0)
+ return tot / n
+
+
+def fidelity(cfg, model, dev):
+ names = [n for n, _ in model.named_parameters()]
+ trl, _ = get_loaders(cfg.batch, dataset=cfg.dataset)
+ x, _ = next(iter(trl)); x = x.to(dev)
+ M = make_patch_mask(x.size(0), model.gh, cfg.patch, cfg.stride, cfg.img, cfg.img, 0.5, dev)
+ xbar = x * (1 - M)
+ zs, ys = relax(model, xbar, *model.init_state(xbar), cfg.T1, cfg.eps, 'real')
+ v = torch.randn_like(zs)
+ Jv = torch.autograd.functional.jvp(lambda z: model.real_attn(z), zs, v)[1]
+ JTv = torch.autograd.functional.vjp(lambda z: model.real_attn(z), zs, v)[1]
+ asym = (0.5 * (Jv - JTv)).norm().item() / (Jv.norm().item() + 1e-8)
+ print(f"real-attention Jacobian antisymmetry = {asym:.3f}\n")
+ for mode, aep, label in [('energy', False, 'energy/naive (sanity)'),
+ ('real', False, 'real/naive (biased)'),
+ ('real', True, 'real/AEP (fixed)')]:
+ gb = bptt_param_grad(model, xbar, x, M, mode, cfg.T1, cfg.eps)
+ gv = vf_param_grad(model, xbar, x, M, mode, cfg.T1, cfg.T2, cfg.eps, cfg.beta, aep)
+ g, per = cos(gv, gb, names)
+ att = " ".join(f"{k}={per[k]:+.3f}" for k in ('WQ', 'WK', 'WV', 'WO') if k in per)
+ print(f"[{label}] global={g:+.4f} attn: {att}")
+
+
+def train(cfg, model, dev):
+ tag = 'aep' if cfg.aep else 'naive'
+ opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.wd)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, cfg.steps, eta_min=cfg.lr * 0.01)
+ trl, tel = get_loaders(cfg.batch, dataset=cfg.dataset)
+ print(f"[real-attn EP, {tag}] params={sum(p.numel() for p in model.parameters())/1e3:.1f}K "
+ f"T1={cfg.T1} T2={cfg.T2} eps={cfg.eps} beta={cfg.beta}", flush=True)
+ # stay in the stable+faithful regime: cap weight norms (Wmem for bounded-below energy,
+ # attention WV/WO/WQ/WK so the non-conservative force can't grow into the unstable s>=4 regime)
+ caps = {n: p.detach().norm().item() * 1.5 for n, p in model.named_parameters()
+ if n in ('Wmem', 'WQ', 'WK', 'WV', 'WO')}
+ cap_params = {n: p for n, p in model.named_parameters() if n in caps}
+ step, t0, best = 0, time.time(), float('inf')
+ while step < cfg.steps:
+ for x, _ in trl:
+ if step >= cfg.steps:
+ break
+ x = x.to(dev, non_blocking=True)
+ M = make_patch_mask(x.size(0), model.gh, cfg.patch, cfg.stride, cfg.img, cfg.img, 0.5, dev)
+ xbar = x * (1 - M)
+ grads = vf_param_grad(model, xbar, x, M, 'real', cfg.T1, cfg.T2, cfg.eps, cfg.beta, cfg.aep)
+ opt.zero_grad(set_to_none=True)
+ bad = False
+ for p, g in zip(model.parameters(), grads):
+ if g is None or not torch.isfinite(g).all():
+ bad = True; break
+ p.grad = g
+ if bad:
+ print(f" step {step}: non-finite grad, skip", flush=True); step += 1; continue
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
+ opt.step(); sched.step()
+ with torch.no_grad(): # stay in stable+faithful regime
+ for n, p in cap_params.items():
+ pn = p.norm()
+ if pn > caps[n]:
+ p.mul_(caps[n] / pn)
+ step += 1
+ if step % cfg.log_every == 0:
+ te = evaluate(model, tel, cfg, dev, 'real', 15)
+ best = min(best, te)
+ print(f"step {step:4d}/{cfg.steps} | test masked-MSE {te:.5f} (best {best:.5f}) "
+ f"| {step/(time.time()-t0):.2f} it/s", flush=True)
+ final = evaluate(model, tel, cfg, dev, 'real', 60)
+ best = min(best, final)
+ os.makedirs(cfg.out, exist_ok=True)
+ json.dump({'tag': tag, 'final_test_masked_mse': final, 'best_test_masked_mse': best,
+ 'steps': cfg.steps}, open(os.path.join(cfg.out, f'aep_train_{tag}.json'), 'w'), indent=2)
+ print(f"[real-attn EP, {tag}] DONE final={final:.5f} best={best:.5f}", flush=True)
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--cmd', choices=['fidelity', 'train'], default='fidelity')
+ ap.add_argument('--aep', action='store_true')
+ ap.add_argument('--damp', type=float, default=0.0)
+ ap.add_argument('--dataset', default='fashionmnist')
+ ap.add_argument('--img', type=int, default=28); ap.add_argument('--ch', type=int, default=1)
+ ap.add_argument('--patch', type=int, default=7); ap.add_argument('--stride', type=int, default=7)
+ ap.add_argument('--D', type=int, default=64); ap.add_argument('--heads', type=int, default=4)
+ ap.add_argument('--dh', type=int, default=16); ap.add_argument('--mem', type=int, default=128)
+ ap.add_argument('--T1', type=int, default=100); ap.add_argument('--T2', type=int, default=15)
+ ap.add_argument('--eps', type=float, default=0.2); ap.add_argument('--beta', type=float, default=0.02)
+ ap.add_argument('--batch', type=int, default=64); ap.add_argument('--steps', type=int, default=1500)
+ ap.add_argument('--lr', type=float, default=4e-4); ap.add_argument('--wd', type=float, default=1e-4)
+ ap.add_argument('--log_every', type=int, default=100)
+ ap.add_argument('--out', default='/home/yurenh2/ept/runs')
+ cfg = ap.parse_args()
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ torch.manual_seed(0)
+ model = CETReal(cfg.img, cfg.ch, cfg.patch, cfg.stride, cfg.D, cfg.heads, cfg.dh, cfg.mem).to(dev)
+ model.damp = cfg.damp
+ print('config:', vars(cfg), flush=True)
+ (train if cfg.cmd == 'train' else fidelity)(cfg, model, dev)
+
+
+if __name__ == '__main__':
+ main()