diff options
| author | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
|---|---|---|
| committer | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
| commit | b83947778e2c776f757a07d4719b7ce961d7ed55 (patch) | |
| tree | b9cc01d7adda691d9156d9d04f4fb2f644674e96 /scripts/bp_transformer.py | |
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}),
analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints
git-ignored (share separately).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'scripts/bp_transformer.py')
| -rw-r--r-- | scripts/bp_transformer.py | 141 |
1 files changed, 141 insertions, 0 deletions
diff --git a/scripts/bp_transformer.py b/scripts/bp_transformer.py new file mode 100644 index 0000000..7c9b543 --- /dev/null +++ b/scripts/bp_transformer.py @@ -0,0 +1,141 @@ +""" +Vanilla backprop Transformer baseline for the SAME masked-image-completion task, +so we can compare against CET-EP and CET-TBPTE on the identical metric +(masked-patch pixel MSE on CIFAR-10, images in [-1,1]). + +Standard recipe: conv patch-embed + learned pos-embed, MAE-style learned mask +token on occluded patches, N standard pre-LN transformer blocks (MHA + FFN), +linear pixel head, MSE loss on masked patches only. Trained with normal Adam/BP. +""" +import argparse, os, time, json, math +import torch, torch.nn as nn, torch.nn.functional as F +from cet_mvp import get_loaders, make_patch_mask # reuse data + masking + + +class Block(nn.Module): + def __init__(self, D, heads, mlp_ratio): + super().__init__() + self.ln1 = nn.LayerNorm(D) + self.attn = nn.MultiheadAttention(D, heads, batch_first=True) + self.ln2 = nn.LayerNorm(D) + self.mlp = nn.Sequential(nn.Linear(D, int(mlp_ratio * D)), nn.GELU(), + nn.Linear(int(mlp_ratio * D), D)) + + def forward(self, x): + h = self.ln1(x) + x = x + self.attn(h, h, h, need_weights=False)[0] + x = x + self.mlp(self.ln2(x)) + return x + + +class BPTransformer(nn.Module): + def __init__(self, img=32, ch=3, patch=8, stride=8, D=128, heads=4, + depth=1, mlp_ratio=2.0): + super().__init__() + self.ch, self.patch, self.stride = ch, patch, stride + gh = (img - patch) // stride + 1 + self.gh, self.N, self.pdim = gh, gh * gh, ch * patch * patch + self.embed = nn.Conv2d(ch, D, patch, stride=stride) + self.pos = nn.Parameter(torch.zeros(1, self.N, D)); nn.init.normal_(self.pos, std=0.02) + self.mask_token = nn.Parameter(torch.zeros(1, 1, D)); nn.init.normal_(self.mask_token, std=0.02) + self.blocks = nn.ModuleList([Block(D, heads, mlp_ratio) for _ in range(depth)]) + self.ln = nn.LayerNorm(D) + self.head = nn.Linear(D, self.pdim) + + def patchify(self, x): # (B,C,H,W)->(B,N,pdim) + p, s = self.patch, self.stride + u = x.unfold(2, p, s).unfold(3, p, s) # B,C,gh,gw,p,p + return u.permute(0, 2, 3, 1, 4, 5).reshape(x.size(0), self.N, self.pdim) + + def forward(self, xbar, pm): # pm: (B,N) 1=masked + t = self.embed(xbar).flatten(2).transpose(1, 2) # (B,N,D) + t = torch.where(pm.unsqueeze(-1).bool(), self.mask_token, t) + self.pos + for b in self.blocks: + t = b(t) + return self.head(self.ln(t)) # (B,N,pdim) + + +def patch_mask_bool(B, gh, ratio, device, gen=None): + npatch = gh * gh; nmask = int(round(ratio * npatch)) + idx = torch.rand(B, npatch, device=device, generator=gen).argsort(1) + pm = torch.zeros(B, npatch, device=device) + pm.scatter_(1, idx[:, :nmask], 1.0) + return pm # (B,N) + + +def masked_patch_mse(pred, true, pm): + m = pm.unsqueeze(-1) + return ((pred - true) ** 2 * m).sum() / (m.sum() * pred.size(-1)).clamp_min(1.0) + + +@torch.no_grad() +def evaluate(model, loader, cfg, device, max_batches=100): + model.eval(); tot, n = 0.0, 0 + gen = torch.Generator(device=device).manual_seed(0) + for i, (x, _) in enumerate(loader): + if i >= max_batches: + break + x = x.to(device) + pm = patch_mask_bool(x.size(0), model.gh, cfg.mask_ratio, device, gen) + M = pm.view(-1, model.gh, model.gh).repeat_interleave(cfg.patch, 1).repeat_interleave(cfg.patch, 2).unsqueeze(1) + xbar = x * (1 - M) + pred = model(xbar, pm) + tot += masked_patch_mse(pred, model.patchify(x), pm).item() * x.size(0); n += x.size(0) + model.train(); return tot / n + + +def train(cfg): + device = cfg.device; torch.manual_seed(cfg.seed) + model = BPTransformer(cfg.img, cfg.ch, cfg.patch, cfg.stride, cfg.D, cfg.heads, + cfg.depth, cfg.mlp_ratio).to(device) + opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.wd) + sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, cfg.steps, eta_min=cfg.lr_min) + trl, tel = get_loaders(cfg.batch, dataset=cfg.dataset) + print(f"[bp] params={sum(p.numel() for p in model.parameters())/1e3:.1f}K " + f"depth={cfg.depth} D={cfg.D} mlp={cfg.mlp_ratio}", flush=True) + step, t0, run = 0, time.time(), 0.0 + while step < cfg.steps: + for x, _ in trl: + if step >= cfg.steps: + break + x = x.to(device, non_blocking=True) + pm = patch_mask_bool(x.size(0), model.gh, cfg.mask_ratio, device) + M = pm.view(-1, model.gh, model.gh).repeat_interleave(cfg.patch, 1).repeat_interleave(cfg.patch, 2).unsqueeze(1) + xbar = x * (1 - M) + pred = model(xbar, pm) + loss = masked_patch_mse(pred, model.patchify(x), pm) + opt.zero_grad(set_to_none=True); loss.backward(); opt.step(); sched.step() + run += loss.item(); step += 1 + if step % cfg.log_every == 0: + print(f"step {step:5d}/{cfg.steps} | train masked-MSE {run/cfg.log_every:.5f} " + f"| {step/(time.time()-t0):.1f} it/s", flush=True); run = 0.0 + if step % cfg.eval_every == 0 or step == cfg.steps: + print(f" >> [eval] step {step} test masked-MSE {evaluate(model, tel, cfg, device, 20):.5f}", flush=True) + final = evaluate(model, tel, cfg, device, 100) + os.makedirs(cfg.out, exist_ok=True) + json.dump({'mode': 'bp_transformer', 'final_test_masked_mse': final, 'steps': cfg.steps, + 'params_K': sum(p.numel() for p in model.parameters()) / 1e3}, + open(os.path.join(cfg.out, 'result_bp_transformer.json'), 'w'), indent=2) + print(f"[bp] DONE final test masked-MSE = {final:.5f}", flush=True) + + +def main(): + p = argparse.ArgumentParser() + p.add_argument('--dataset', choices=['cifar10', 'fashionmnist'], default='cifar10') + p.add_argument('--steps', type=int, default=3000); p.add_argument('--batch', type=int, default=128) + p.add_argument('--img', type=int, default=32); p.add_argument('--ch', type=int, default=3) + p.add_argument('--patch', type=int, default=8); p.add_argument('--stride', type=int, default=8) + p.add_argument('--D', type=int, default=128); p.add_argument('--heads', type=int, default=4) + p.add_argument('--depth', type=int, default=1); p.add_argument('--mlp_ratio', type=float, default=2.0) + p.add_argument('--mask_ratio', type=float, default=0.5) + p.add_argument('--lr', type=float, default=4e-4); p.add_argument('--lr_min', type=float, default=1e-6) + p.add_argument('--wd', type=float, default=3e-5) + p.add_argument('--log_every', type=int, default=100); p.add_argument('--eval_every', type=int, default=500) + p.add_argument('--seed', type=int, default=0); p.add_argument('--out', type=str, default='/home/yurenh2/ept/runs') + p.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') + cfg = p.parse_args() + print('config:', vars(cfg), flush=True); train(cfg) + + +if __name__ == '__main__': + main() |
