""" AEP applied to CET's attention: replace CET's conservative energy-attention E^att with a REAL (non-conservative) transformer attention inside the CET, and use the AEP correction so EP still recovers the true gradient. CET state = (tokens z, reconstruction y). The conservative part E_rest = E_enc + E_pos + E_mem + E_dec (scalar -> symmetric Jacobian) keeps its energy-gradient force. The token force gets its attention term from: energy mode : -dE^att/dz (conservative; tied value; this is plain CET) real mode : RealAttn(z) = WO softmax(QK^T/sqrt dh)(z WV) (non-conservative) Because only RealAttn is non-conservative, the full-force antisymmetric Jacobian A_J reduces to the antisymmetric part of dRealAttn/dz alone -> AEP correction is force_z += -(J_A z~ - J_A^T z~) , z~=z-z* , J_A = dRealAttn/dz at z* (clean jvp/vjp on RealAttn, no nested autograd). We compare parameter-gradient quality vs ground-truth BPTT for: energy / naive-EP (conservative CET; sanity, should be ~BPTT) real / naive-EP (non-conservative; expected biased) real / AEP (non-conservative + correction; expected ~BPTT) """ import argparse, math, time, json, os, torch, torch.nn as nn, torch.nn.functional as F from cet_mvp import token_norm, make_patch_mask, masked_cost, masked_mse, get_loaders class CETReal(nn.Module): def __init__(self, img=32, ch=3, patch=8, stride=8, D=64, heads=4, dh=16, mem=128, gamma=0.25): super().__init__() self.ch, self.patch, self.stride, self.D = ch, patch, stride, D self.heads, self.dh, self.gamma = heads, dh, gamma gh = (img - patch) // stride + 1 self.gh, self.N = gh, gh * gh self.damp = 0.0 # contraction damping c: real_attn returns attn(z) - c*z self.Wenc = nn.Parameter(torch.empty(D, ch, patch, patch)) self.benc = nn.Parameter(torch.zeros(D)) self.bpos = nn.Parameter(torch.zeros(self.N, D)) self.Wdec = nn.Parameter(torch.empty(D, ch, patch, patch)) self.bdec = nn.Parameter(torch.zeros(ch)) self.Wmem = nn.Parameter(torch.empty(D, mem)) # attention: WQ/WK used by both; WV/WO only by the real (non-conservative) path self.WQ = nn.Parameter(torch.empty(heads, dh, D)) self.WK = nn.Parameter(torch.empty(heads, dh, D)) self.WV = nn.Parameter(torch.empty(heads, dh, D)) self.WO = nn.Parameter(torch.empty(D, heads * dh)) nn.init.kaiming_normal_(self.Wenc); self.Wenc.data *= 0.5 nn.init.kaiming_normal_(self.Wdec); self.Wdec.data *= 0.5 for w in (self.WQ, self.WK, self.WV): nn.init.normal_(w, std=1.0 / math.sqrt(D)) nn.init.normal_(self.Wmem, std=0.3 / math.sqrt(D)) # small: keep energy bounded-below nn.init.normal_(self.WO, std=1.0 / math.sqrt(heads * dh)) def encode(self, xbar): return F.conv2d(xbar, self.Wenc, stride=self.stride).flatten(2).transpose(1, 2) def decode_conv(self, y): return F.conv2d(y, self.Wdec, stride=self.stride).flatten(2).transpose(1, 2) def E_rest(self, xbar, z, y): # conservative scalar (no attention) enc = self.encode(xbar) E = 2.0 * (z ** 2).sum() - (enc * z).sum() - (z * self.benc).sum() - (z * self.bpos).sum() proj = torch.einsum('bnd,dm->bnm', z, self.Wmem) E = E - (F.relu(proj) ** 2).sum() dc = self.decode_conv(y) E = E + 0.5 * (y ** 2).sum() - (dc * z).sum() - (y * self.bdec[None, :, None, None]).sum() return E def E_att(self, z): # conservative LogSumExp energy (tied value) Q = torch.einsum('bnd,hjd->bhnj', z, self.WQ) K = torch.einsum('bnd,hjd->bhnj', z, self.WK) A = torch.einsum('bhmj,bhnj->bhmn', Q, K) return -(1.0 / self.gamma) * torch.logsumexp(self.gamma * A, dim=-1).sum() def real_attn(self, z): # NON-conservative real attention force B = z.size(0) q = torch.einsum('bnd,hjd->bhnj', z, self.WQ) k = torch.einsum('bnd,hjd->bhnj', z, self.WK) v = torch.einsum('bnd,hjd->bhnj', z, self.WV) A = torch.softmax((q @ k.transpose(-2, -1)) / math.sqrt(self.dh), dim=-1) o = (A @ v).transpose(1, 2).reshape(B, self.N, self.heads * self.dh) return o @ self.WO.t() - self.damp * z # -c*z: symmetric -> contraction, A_J unchanged def force(self, xbar, z, y, mode): """Return (force_z, force_y). force = -dE/dstate (+ real attention if mode='real').""" z = z.requires_grad_(True); y = y.requires_grad_(True) if mode == 'energy': E = self.E_rest(xbar, z, y) + self.E_att(z) gz, gy = torch.autograd.grad(E, [z, y], create_graph=True) return -gz, -gy else: E = self.E_rest(xbar, z, y) gz, gy = torch.autograd.grad(E, [z, y], create_graph=True) return -gz + self.real_attn(z), -gy def init_state(self, xbar): return token_norm(self.encode(xbar)).detach(), xbar.clone().detach() def relax(model, xbar, z, y, steps, eps, mode, x=None, M=None, beta=0.0, aep=False, zstar=None): for _ in range(steps): with torch.enable_grad(): fz, fy = model.force(xbar, z, y, mode) fz, fy = fz.detach(), fy.detach() if beta != 0.0: # nudge on the output y yy = y.detach().requires_grad_(True) gy, = torch.autograd.grad(masked_cost(yy, x, M), yy) fy = fy - beta * gy if aep: # AEP correction on z (attention block only) v = (z - zstar).detach() fa = lambda zz: model.real_attn(zz) Jv = torch.autograd.functional.jvp(fa, zstar, v)[1] JTv = torch.autograd.functional.vjp(fa, zstar, v)[1] corr = Jv - JTv # = 2 * 0.5 (J v - J^T v) cn, fn = corr.norm(), fz.norm() + 1e-8 # clip so correction can't dominate -> no blow-up if cn > fn: corr = corr * (fn / cn) fz = fz - corr with torch.no_grad(): z = z + eps * fz # unconstrained (0.5||z||^2 in E_rest keeps it bounded) y = y + eps * fy return z.detach(), y.detach() def vf_param_grad(model, xbar, x, M, mode, T1, T2, eps, beta, aep): z0, y0 = model.init_state(xbar) zs, ys = relax(model, xbar, z0, y0, T1, eps, mode) zp, yp = relax(model, xbar, zs.clone(), ys.clone(), T2, eps, mode, x, M, +beta, aep, zs) zm, ym = relax(model, xbar, zs.clone(), ys.clone(), T2, eps, mode, x, M, -beta, aep, zs) az, ay = ((zm - zp) / (2 * beta)).detach(), ((ym - yp) / (2 * beta)).detach() with torch.enable_grad(): fz, fy = model.force(xbar, zs.detach(), ys.detach(), mode) s = (az * fz).sum() + (ay * fy).sum() grads = torch.autograd.grad(s, list(model.parameters()), allow_unused=True, retain_graph=False) return grads def bptt_param_grad(model, xbar, x, M, mode, T1, eps): z, y = model.init_state(xbar) z, y = z.requires_grad_(True), y.requires_grad_(True) for _ in range(T1): fz, fy = model.force(xbar, z, y, mode) z = z + eps * fz y = y + eps * fy L = masked_cost(y, x, M) / M.sum() return torch.autograd.grad(L, list(model.parameters()), allow_unused=True) def cos(ga, gb, names): fa, fb = [], [] per = {} for n, a, b in zip(names, ga, gb): if a is None or b is None: continue fa.append(a.flatten()); fb.append(b.flatten()) per[n] = F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() g = F.cosine_similarity(torch.cat(fa), torch.cat(fb), dim=0).item() return g, per def evaluate(model, loader, cfg, dev, mode='real', max_batches=40): tot, n = 0.0, 0 gen = torch.Generator(device=dev).manual_seed(0) for i, (x, _) in enumerate(loader): if i >= max_batches: break x = x.to(dev) M = make_patch_mask(x.size(0), model.gh, cfg.patch, cfg.stride, cfg.img, cfg.img, 0.5, dev, gen) xbar = x * (1 - M) z, y = relax(model, xbar, *model.init_state(xbar), cfg.T1, cfg.eps, mode) tot += masked_mse(y, x, M) * x.size(0); n += x.size(0) return tot / n def fidelity(cfg, model, dev): names = [n for n, _ in model.named_parameters()] trl, _ = get_loaders(cfg.batch, dataset=cfg.dataset) x, _ = next(iter(trl)); x = x.to(dev) M = make_patch_mask(x.size(0), model.gh, cfg.patch, cfg.stride, cfg.img, cfg.img, 0.5, dev) xbar = x * (1 - M) zs, ys = relax(model, xbar, *model.init_state(xbar), cfg.T1, cfg.eps, 'real') v = torch.randn_like(zs) Jv = torch.autograd.functional.jvp(lambda z: model.real_attn(z), zs, v)[1] JTv = torch.autograd.functional.vjp(lambda z: model.real_attn(z), zs, v)[1] asym = (0.5 * (Jv - JTv)).norm().item() / (Jv.norm().item() + 1e-8) print(f"real-attention Jacobian antisymmetry = {asym:.3f}\n") for mode, aep, label in [('energy', False, 'energy/naive (sanity)'), ('real', False, 'real/naive (biased)'), ('real', True, 'real/AEP (fixed)')]: gb = bptt_param_grad(model, xbar, x, M, mode, cfg.T1, cfg.eps) gv = vf_param_grad(model, xbar, x, M, mode, cfg.T1, cfg.T2, cfg.eps, cfg.beta, aep) g, per = cos(gv, gb, names) att = " ".join(f"{k}={per[k]:+.3f}" for k in ('WQ', 'WK', 'WV', 'WO') if k in per) print(f"[{label}] global={g:+.4f} attn: {att}") def train(cfg, model, dev): tag = 'aep' if cfg.aep else 'naive' opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.wd) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, cfg.steps, eta_min=cfg.lr * 0.01) trl, tel = get_loaders(cfg.batch, dataset=cfg.dataset) print(f"[real-attn EP, {tag}] params={sum(p.numel() for p in model.parameters())/1e3:.1f}K " f"T1={cfg.T1} T2={cfg.T2} eps={cfg.eps} beta={cfg.beta}", flush=True) # stay in the stable+faithful regime: cap weight norms (Wmem for bounded-below energy, # attention WV/WO/WQ/WK so the non-conservative force can't grow into the unstable s>=4 regime) caps = {n: p.detach().norm().item() * 1.5 for n, p in model.named_parameters() if n in ('Wmem', 'WQ', 'WK', 'WV', 'WO')} cap_params = {n: p for n, p in model.named_parameters() if n in caps} step, t0, best = 0, time.time(), float('inf') while step < cfg.steps: for x, _ in trl: if step >= cfg.steps: break x = x.to(dev, non_blocking=True) M = make_patch_mask(x.size(0), model.gh, cfg.patch, cfg.stride, cfg.img, cfg.img, 0.5, dev) xbar = x * (1 - M) grads = vf_param_grad(model, xbar, x, M, 'real', cfg.T1, cfg.T2, cfg.eps, cfg.beta, cfg.aep) opt.zero_grad(set_to_none=True) bad = False for p, g in zip(model.parameters(), grads): if g is None or not torch.isfinite(g).all(): bad = True; break p.grad = g if bad: print(f" step {step}: non-finite grad, skip", flush=True); step += 1; continue torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0) opt.step(); sched.step() with torch.no_grad(): # stay in stable+faithful regime for n, p in cap_params.items(): pn = p.norm() if pn > caps[n]: p.mul_(caps[n] / pn) step += 1 if step % cfg.log_every == 0: te = evaluate(model, tel, cfg, dev, 'real', 15) best = min(best, te) print(f"step {step:4d}/{cfg.steps} | test masked-MSE {te:.5f} (best {best:.5f}) " f"| {step/(time.time()-t0):.2f} it/s", flush=True) final = evaluate(model, tel, cfg, dev, 'real', 60) best = min(best, final) os.makedirs(cfg.out, exist_ok=True) json.dump({'tag': tag, 'final_test_masked_mse': final, 'best_test_masked_mse': best, 'steps': cfg.steps}, open(os.path.join(cfg.out, f'aep_train_{tag}.json'), 'w'), indent=2) print(f"[real-attn EP, {tag}] DONE final={final:.5f} best={best:.5f}", flush=True) def main(): ap = argparse.ArgumentParser() ap.add_argument('--cmd', choices=['fidelity', 'train'], default='fidelity') ap.add_argument('--aep', action='store_true') ap.add_argument('--damp', type=float, default=0.0) ap.add_argument('--dataset', default='fashionmnist') ap.add_argument('--img', type=int, default=28); ap.add_argument('--ch', type=int, default=1) ap.add_argument('--patch', type=int, default=7); ap.add_argument('--stride', type=int, default=7) ap.add_argument('--D', type=int, default=64); ap.add_argument('--heads', type=int, default=4) ap.add_argument('--dh', type=int, default=16); ap.add_argument('--mem', type=int, default=128) ap.add_argument('--T1', type=int, default=100); ap.add_argument('--T2', type=int, default=15) ap.add_argument('--eps', type=float, default=0.2); ap.add_argument('--beta', type=float, default=0.02) ap.add_argument('--batch', type=int, default=64); ap.add_argument('--steps', type=int, default=1500) ap.add_argument('--lr', type=float, default=4e-4); ap.add_argument('--wd', type=float, default=1e-4) ap.add_argument('--log_every', type=int, default=100) ap.add_argument('--out', default='/home/yurenh2/ept/runs') cfg = ap.parse_args() dev = 'cuda' if torch.cuda.is_available() else 'cpu' torch.manual_seed(0) model = CETReal(cfg.img, cfg.ch, cfg.patch, cfg.stride, cfg.D, cfg.heads, cfg.dh, cfg.mem).to(dev) model.damp = cfg.damp print('config:', vars(cfg), flush=True) (train if cfg.cmd == 'train' else fidelity)(cfg, model, dev) if __name__ == '__main__': main()