diff options
Diffstat (limited to 'scripts/cet_aep.py')
| -rw-r--r-- | scripts/cet_aep.py | 272 |
1 files changed, 272 insertions, 0 deletions
diff --git a/scripts/cet_aep.py b/scripts/cet_aep.py new file mode 100644 index 0000000..44e8922 --- /dev/null +++ b/scripts/cet_aep.py @@ -0,0 +1,272 @@ +""" +AEP applied to CET's attention: replace CET's conservative energy-attention E^att +with a REAL (non-conservative) transformer attention inside the CET, and use the +AEP correction so EP still recovers the true gradient. + +CET state = (tokens z, reconstruction y). The conservative part + E_rest = E_enc + E_pos + E_mem + E_dec (scalar -> symmetric Jacobian) +keeps its energy-gradient force. The token force gets its attention term from: + energy mode : -dE^att/dz (conservative; tied value; this is plain CET) + real mode : RealAttn(z) = WO softmax(QK^T/sqrt dh)(z WV) (non-conservative) + +Because only RealAttn is non-conservative, the full-force antisymmetric Jacobian +A_J reduces to the antisymmetric part of dRealAttn/dz alone -> AEP correction is + force_z += -(J_A z~ - J_A^T z~) , z~=z-z* , J_A = dRealAttn/dz at z* +(clean jvp/vjp on RealAttn, no nested autograd). + +We compare parameter-gradient quality vs ground-truth BPTT for: + energy / naive-EP (conservative CET; sanity, should be ~BPTT) + real / naive-EP (non-conservative; expected biased) + real / AEP (non-conservative + correction; expected ~BPTT) +""" +import argparse, math, time, json, os, torch, torch.nn as nn, torch.nn.functional as F +from cet_mvp import token_norm, make_patch_mask, masked_cost, masked_mse, get_loaders + + +class CETReal(nn.Module): + def __init__(self, img=32, ch=3, patch=8, stride=8, D=64, heads=4, dh=16, + mem=128, gamma=0.25): + super().__init__() + self.ch, self.patch, self.stride, self.D = ch, patch, stride, D + self.heads, self.dh, self.gamma = heads, dh, gamma + gh = (img - patch) // stride + 1 + self.gh, self.N = gh, gh * gh + self.damp = 0.0 # contraction damping c: real_attn returns attn(z) - c*z + self.Wenc = nn.Parameter(torch.empty(D, ch, patch, patch)) + self.benc = nn.Parameter(torch.zeros(D)) + self.bpos = nn.Parameter(torch.zeros(self.N, D)) + self.Wdec = nn.Parameter(torch.empty(D, ch, patch, patch)) + self.bdec = nn.Parameter(torch.zeros(ch)) + self.Wmem = nn.Parameter(torch.empty(D, mem)) + # attention: WQ/WK used by both; WV/WO only by the real (non-conservative) path + self.WQ = nn.Parameter(torch.empty(heads, dh, D)) + self.WK = nn.Parameter(torch.empty(heads, dh, D)) + self.WV = nn.Parameter(torch.empty(heads, dh, D)) + self.WO = nn.Parameter(torch.empty(D, heads * dh)) + nn.init.kaiming_normal_(self.Wenc); self.Wenc.data *= 0.5 + nn.init.kaiming_normal_(self.Wdec); self.Wdec.data *= 0.5 + for w in (self.WQ, self.WK, self.WV): + nn.init.normal_(w, std=1.0 / math.sqrt(D)) + nn.init.normal_(self.Wmem, std=0.3 / math.sqrt(D)) # small: keep energy bounded-below + nn.init.normal_(self.WO, std=1.0 / math.sqrt(heads * dh)) + + def encode(self, xbar): + return F.conv2d(xbar, self.Wenc, stride=self.stride).flatten(2).transpose(1, 2) + + def decode_conv(self, y): + return F.conv2d(y, self.Wdec, stride=self.stride).flatten(2).transpose(1, 2) + + def E_rest(self, xbar, z, y): # conservative scalar (no attention) + enc = self.encode(xbar) + E = 2.0 * (z ** 2).sum() - (enc * z).sum() - (z * self.benc).sum() - (z * self.bpos).sum() + proj = torch.einsum('bnd,dm->bnm', z, self.Wmem) + E = E - (F.relu(proj) ** 2).sum() + dc = self.decode_conv(y) + E = E + 0.5 * (y ** 2).sum() - (dc * z).sum() - (y * self.bdec[None, :, None, None]).sum() + return E + + def E_att(self, z): # conservative LogSumExp energy (tied value) + Q = torch.einsum('bnd,hjd->bhnj', z, self.WQ) + K = torch.einsum('bnd,hjd->bhnj', z, self.WK) + A = torch.einsum('bhmj,bhnj->bhmn', Q, K) + return -(1.0 / self.gamma) * torch.logsumexp(self.gamma * A, dim=-1).sum() + + def real_attn(self, z): # NON-conservative real attention force + B = z.size(0) + q = torch.einsum('bnd,hjd->bhnj', z, self.WQ) + k = torch.einsum('bnd,hjd->bhnj', z, self.WK) + v = torch.einsum('bnd,hjd->bhnj', z, self.WV) + A = torch.softmax((q @ k.transpose(-2, -1)) / math.sqrt(self.dh), dim=-1) + o = (A @ v).transpose(1, 2).reshape(B, self.N, self.heads * self.dh) + return o @ self.WO.t() - self.damp * z # -c*z: symmetric -> contraction, A_J unchanged + + def force(self, xbar, z, y, mode): + """Return (force_z, force_y). force = -dE/dstate (+ real attention if mode='real').""" + z = z.requires_grad_(True); y = y.requires_grad_(True) + if mode == 'energy': + E = self.E_rest(xbar, z, y) + self.E_att(z) + gz, gy = torch.autograd.grad(E, [z, y], create_graph=True) + return -gz, -gy + else: + E = self.E_rest(xbar, z, y) + gz, gy = torch.autograd.grad(E, [z, y], create_graph=True) + return -gz + self.real_attn(z), -gy + + def init_state(self, xbar): + return token_norm(self.encode(xbar)).detach(), xbar.clone().detach() + + +def relax(model, xbar, z, y, steps, eps, mode, x=None, M=None, beta=0.0, aep=False, zstar=None): + for _ in range(steps): + with torch.enable_grad(): + fz, fy = model.force(xbar, z, y, mode) + fz, fy = fz.detach(), fy.detach() + if beta != 0.0: # nudge on the output y + yy = y.detach().requires_grad_(True) + gy, = torch.autograd.grad(masked_cost(yy, x, M), yy) + fy = fy - beta * gy + if aep: # AEP correction on z (attention block only) + v = (z - zstar).detach() + fa = lambda zz: model.real_attn(zz) + Jv = torch.autograd.functional.jvp(fa, zstar, v)[1] + JTv = torch.autograd.functional.vjp(fa, zstar, v)[1] + corr = Jv - JTv # = 2 * 0.5 (J v - J^T v) + cn, fn = corr.norm(), fz.norm() + 1e-8 # clip so correction can't dominate -> no blow-up + if cn > fn: + corr = corr * (fn / cn) + fz = fz - corr + with torch.no_grad(): + z = z + eps * fz # unconstrained (0.5||z||^2 in E_rest keeps it bounded) + y = y + eps * fy + return z.detach(), y.detach() + + +def vf_param_grad(model, xbar, x, M, mode, T1, T2, eps, beta, aep): + z0, y0 = model.init_state(xbar) + zs, ys = relax(model, xbar, z0, y0, T1, eps, mode) + zp, yp = relax(model, xbar, zs.clone(), ys.clone(), T2, eps, mode, x, M, +beta, aep, zs) + zm, ym = relax(model, xbar, zs.clone(), ys.clone(), T2, eps, mode, x, M, -beta, aep, zs) + az, ay = ((zm - zp) / (2 * beta)).detach(), ((ym - yp) / (2 * beta)).detach() + with torch.enable_grad(): + fz, fy = model.force(xbar, zs.detach(), ys.detach(), mode) + s = (az * fz).sum() + (ay * fy).sum() + grads = torch.autograd.grad(s, list(model.parameters()), allow_unused=True, retain_graph=False) + return grads + + +def bptt_param_grad(model, xbar, x, M, mode, T1, eps): + z, y = model.init_state(xbar) + z, y = z.requires_grad_(True), y.requires_grad_(True) + for _ in range(T1): + fz, fy = model.force(xbar, z, y, mode) + z = z + eps * fz + y = y + eps * fy + L = masked_cost(y, x, M) / M.sum() + return torch.autograd.grad(L, list(model.parameters()), allow_unused=True) + + +def cos(ga, gb, names): + fa, fb = [], [] + per = {} + for n, a, b in zip(names, ga, gb): + if a is None or b is None: + continue + fa.append(a.flatten()); fb.append(b.flatten()) + per[n] = F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() + g = F.cosine_similarity(torch.cat(fa), torch.cat(fb), dim=0).item() + return g, per + + +def evaluate(model, loader, cfg, dev, mode='real', max_batches=40): + tot, n = 0.0, 0 + gen = torch.Generator(device=dev).manual_seed(0) + for i, (x, _) in enumerate(loader): + if i >= max_batches: + break + x = x.to(dev) + M = make_patch_mask(x.size(0), model.gh, cfg.patch, cfg.stride, cfg.img, cfg.img, 0.5, dev, gen) + xbar = x * (1 - M) + z, y = relax(model, xbar, *model.init_state(xbar), cfg.T1, cfg.eps, mode) + tot += masked_mse(y, x, M) * x.size(0); n += x.size(0) + return tot / n + + +def fidelity(cfg, model, dev): + names = [n for n, _ in model.named_parameters()] + trl, _ = get_loaders(cfg.batch, dataset=cfg.dataset) + x, _ = next(iter(trl)); x = x.to(dev) + M = make_patch_mask(x.size(0), model.gh, cfg.patch, cfg.stride, cfg.img, cfg.img, 0.5, dev) + xbar = x * (1 - M) + zs, ys = relax(model, xbar, *model.init_state(xbar), cfg.T1, cfg.eps, 'real') + v = torch.randn_like(zs) + Jv = torch.autograd.functional.jvp(lambda z: model.real_attn(z), zs, v)[1] + JTv = torch.autograd.functional.vjp(lambda z: model.real_attn(z), zs, v)[1] + asym = (0.5 * (Jv - JTv)).norm().item() / (Jv.norm().item() + 1e-8) + print(f"real-attention Jacobian antisymmetry = {asym:.3f}\n") + for mode, aep, label in [('energy', False, 'energy/naive (sanity)'), + ('real', False, 'real/naive (biased)'), + ('real', True, 'real/AEP (fixed)')]: + gb = bptt_param_grad(model, xbar, x, M, mode, cfg.T1, cfg.eps) + gv = vf_param_grad(model, xbar, x, M, mode, cfg.T1, cfg.T2, cfg.eps, cfg.beta, aep) + g, per = cos(gv, gb, names) + att = " ".join(f"{k}={per[k]:+.3f}" for k in ('WQ', 'WK', 'WV', 'WO') if k in per) + print(f"[{label}] global={g:+.4f} attn: {att}") + + +def train(cfg, model, dev): + tag = 'aep' if cfg.aep else 'naive' + opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.wd) + sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, cfg.steps, eta_min=cfg.lr * 0.01) + trl, tel = get_loaders(cfg.batch, dataset=cfg.dataset) + print(f"[real-attn EP, {tag}] params={sum(p.numel() for p in model.parameters())/1e3:.1f}K " + f"T1={cfg.T1} T2={cfg.T2} eps={cfg.eps} beta={cfg.beta}", flush=True) + # stay in the stable+faithful regime: cap weight norms (Wmem for bounded-below energy, + # attention WV/WO/WQ/WK so the non-conservative force can't grow into the unstable s>=4 regime) + caps = {n: p.detach().norm().item() * 1.5 for n, p in model.named_parameters() + if n in ('Wmem', 'WQ', 'WK', 'WV', 'WO')} + cap_params = {n: p for n, p in model.named_parameters() if n in caps} + step, t0, best = 0, time.time(), float('inf') + while step < cfg.steps: + for x, _ in trl: + if step >= cfg.steps: + break + x = x.to(dev, non_blocking=True) + M = make_patch_mask(x.size(0), model.gh, cfg.patch, cfg.stride, cfg.img, cfg.img, 0.5, dev) + xbar = x * (1 - M) + grads = vf_param_grad(model, xbar, x, M, 'real', cfg.T1, cfg.T2, cfg.eps, cfg.beta, cfg.aep) + opt.zero_grad(set_to_none=True) + bad = False + for p, g in zip(model.parameters(), grads): + if g is None or not torch.isfinite(g).all(): + bad = True; break + p.grad = g + if bad: + print(f" step {step}: non-finite grad, skip", flush=True); step += 1; continue + torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0) + opt.step(); sched.step() + with torch.no_grad(): # stay in stable+faithful regime + for n, p in cap_params.items(): + pn = p.norm() + if pn > caps[n]: + p.mul_(caps[n] / pn) + step += 1 + if step % cfg.log_every == 0: + te = evaluate(model, tel, cfg, dev, 'real', 15) + best = min(best, te) + print(f"step {step:4d}/{cfg.steps} | test masked-MSE {te:.5f} (best {best:.5f}) " + f"| {step/(time.time()-t0):.2f} it/s", flush=True) + final = evaluate(model, tel, cfg, dev, 'real', 60) + best = min(best, final) + os.makedirs(cfg.out, exist_ok=True) + json.dump({'tag': tag, 'final_test_masked_mse': final, 'best_test_masked_mse': best, + 'steps': cfg.steps}, open(os.path.join(cfg.out, f'aep_train_{tag}.json'), 'w'), indent=2) + print(f"[real-attn EP, {tag}] DONE final={final:.5f} best={best:.5f}", flush=True) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--cmd', choices=['fidelity', 'train'], default='fidelity') + ap.add_argument('--aep', action='store_true') + ap.add_argument('--damp', type=float, default=0.0) + ap.add_argument('--dataset', default='fashionmnist') + ap.add_argument('--img', type=int, default=28); ap.add_argument('--ch', type=int, default=1) + ap.add_argument('--patch', type=int, default=7); ap.add_argument('--stride', type=int, default=7) + ap.add_argument('--D', type=int, default=64); ap.add_argument('--heads', type=int, default=4) + ap.add_argument('--dh', type=int, default=16); ap.add_argument('--mem', type=int, default=128) + ap.add_argument('--T1', type=int, default=100); ap.add_argument('--T2', type=int, default=15) + ap.add_argument('--eps', type=float, default=0.2); ap.add_argument('--beta', type=float, default=0.02) + ap.add_argument('--batch', type=int, default=64); ap.add_argument('--steps', type=int, default=1500) + ap.add_argument('--lr', type=float, default=4e-4); ap.add_argument('--wd', type=float, default=1e-4) + ap.add_argument('--log_every', type=int, default=100) + ap.add_argument('--out', default='/home/yurenh2/ept/runs') + cfg = ap.parse_args() + dev = 'cuda' if torch.cuda.is_available() else 'cpu' + torch.manual_seed(0) + model = CETReal(cfg.img, cfg.ch, cfg.patch, cfg.stride, cfg.D, cfg.heads, cfg.dh, cfg.mem).to(dev) + model.damp = cfg.damp + print('config:', vars(cfg), flush=True) + (train if cfg.cmd == 'train' else fidelity)(cfg, model, dev) + + +if __name__ == '__main__': + main() |
