""" Vanilla backprop Transformer baseline for the SAME masked-image-completion task, so we can compare against CET-EP and CET-TBPTE on the identical metric (masked-patch pixel MSE on CIFAR-10, images in [-1,1]). Standard recipe: conv patch-embed + learned pos-embed, MAE-style learned mask token on occluded patches, N standard pre-LN transformer blocks (MHA + FFN), linear pixel head, MSE loss on masked patches only. Trained with normal Adam/BP. """ import argparse, os, time, json, math import torch, torch.nn as nn, torch.nn.functional as F from cet_mvp import get_loaders, make_patch_mask # reuse data + masking class Block(nn.Module): def __init__(self, D, heads, mlp_ratio): super().__init__() self.ln1 = nn.LayerNorm(D) self.attn = nn.MultiheadAttention(D, heads, batch_first=True) self.ln2 = nn.LayerNorm(D) self.mlp = nn.Sequential(nn.Linear(D, int(mlp_ratio * D)), nn.GELU(), nn.Linear(int(mlp_ratio * D), D)) def forward(self, x): h = self.ln1(x) x = x + self.attn(h, h, h, need_weights=False)[0] x = x + self.mlp(self.ln2(x)) return x class BPTransformer(nn.Module): def __init__(self, img=32, ch=3, patch=8, stride=8, D=128, heads=4, depth=1, mlp_ratio=2.0): super().__init__() self.ch, self.patch, self.stride = ch, patch, stride gh = (img - patch) // stride + 1 self.gh, self.N, self.pdim = gh, gh * gh, ch * patch * patch self.embed = nn.Conv2d(ch, D, patch, stride=stride) self.pos = nn.Parameter(torch.zeros(1, self.N, D)); nn.init.normal_(self.pos, std=0.02) self.mask_token = nn.Parameter(torch.zeros(1, 1, D)); nn.init.normal_(self.mask_token, std=0.02) self.blocks = nn.ModuleList([Block(D, heads, mlp_ratio) for _ in range(depth)]) self.ln = nn.LayerNorm(D) self.head = nn.Linear(D, self.pdim) def patchify(self, x): # (B,C,H,W)->(B,N,pdim) p, s = self.patch, self.stride u = x.unfold(2, p, s).unfold(3, p, s) # B,C,gh,gw,p,p return u.permute(0, 2, 3, 1, 4, 5).reshape(x.size(0), self.N, self.pdim) def forward(self, xbar, pm): # pm: (B,N) 1=masked t = self.embed(xbar).flatten(2).transpose(1, 2) # (B,N,D) t = torch.where(pm.unsqueeze(-1).bool(), self.mask_token, t) + self.pos for b in self.blocks: t = b(t) return self.head(self.ln(t)) # (B,N,pdim) def patch_mask_bool(B, gh, ratio, device, gen=None): npatch = gh * gh; nmask = int(round(ratio * npatch)) idx = torch.rand(B, npatch, device=device, generator=gen).argsort(1) pm = torch.zeros(B, npatch, device=device) pm.scatter_(1, idx[:, :nmask], 1.0) return pm # (B,N) def masked_patch_mse(pred, true, pm): m = pm.unsqueeze(-1) return ((pred - true) ** 2 * m).sum() / (m.sum() * pred.size(-1)).clamp_min(1.0) @torch.no_grad() def evaluate(model, loader, cfg, device, max_batches=100): model.eval(); tot, n = 0.0, 0 gen = torch.Generator(device=device).manual_seed(0) for i, (x, _) in enumerate(loader): if i >= max_batches: break x = x.to(device) pm = patch_mask_bool(x.size(0), model.gh, cfg.mask_ratio, device, gen) M = pm.view(-1, model.gh, model.gh).repeat_interleave(cfg.patch, 1).repeat_interleave(cfg.patch, 2).unsqueeze(1) xbar = x * (1 - M) pred = model(xbar, pm) tot += masked_patch_mse(pred, model.patchify(x), pm).item() * x.size(0); n += x.size(0) model.train(); return tot / n def train(cfg): device = cfg.device; torch.manual_seed(cfg.seed) model = BPTransformer(cfg.img, cfg.ch, cfg.patch, cfg.stride, cfg.D, cfg.heads, cfg.depth, cfg.mlp_ratio).to(device) opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.wd) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, cfg.steps, eta_min=cfg.lr_min) trl, tel = get_loaders(cfg.batch, dataset=cfg.dataset) print(f"[bp] params={sum(p.numel() for p in model.parameters())/1e3:.1f}K " f"depth={cfg.depth} D={cfg.D} mlp={cfg.mlp_ratio}", flush=True) step, t0, run = 0, time.time(), 0.0 while step < cfg.steps: for x, _ in trl: if step >= cfg.steps: break x = x.to(device, non_blocking=True) pm = patch_mask_bool(x.size(0), model.gh, cfg.mask_ratio, device) M = pm.view(-1, model.gh, model.gh).repeat_interleave(cfg.patch, 1).repeat_interleave(cfg.patch, 2).unsqueeze(1) xbar = x * (1 - M) pred = model(xbar, pm) loss = masked_patch_mse(pred, model.patchify(x), pm) opt.zero_grad(set_to_none=True); loss.backward(); opt.step(); sched.step() run += loss.item(); step += 1 if step % cfg.log_every == 0: print(f"step {step:5d}/{cfg.steps} | train masked-MSE {run/cfg.log_every:.5f} " f"| {step/(time.time()-t0):.1f} it/s", flush=True); run = 0.0 if step % cfg.eval_every == 0 or step == cfg.steps: print(f" >> [eval] step {step} test masked-MSE {evaluate(model, tel, cfg, device, 20):.5f}", flush=True) final = evaluate(model, tel, cfg, device, 100) os.makedirs(cfg.out, exist_ok=True) json.dump({'mode': 'bp_transformer', 'final_test_masked_mse': final, 'steps': cfg.steps, 'params_K': sum(p.numel() for p in model.parameters()) / 1e3}, open(os.path.join(cfg.out, 'result_bp_transformer.json'), 'w'), indent=2) print(f"[bp] DONE final test masked-MSE = {final:.5f}", flush=True) def main(): p = argparse.ArgumentParser() p.add_argument('--dataset', choices=['cifar10', 'fashionmnist'], default='cifar10') p.add_argument('--steps', type=int, default=3000); p.add_argument('--batch', type=int, default=128) p.add_argument('--img', type=int, default=32); p.add_argument('--ch', type=int, default=3) p.add_argument('--patch', type=int, default=8); p.add_argument('--stride', type=int, default=8) p.add_argument('--D', type=int, default=128); p.add_argument('--heads', type=int, default=4) p.add_argument('--depth', type=int, default=1); p.add_argument('--mlp_ratio', type=float, default=2.0) p.add_argument('--mask_ratio', type=float, default=0.5) p.add_argument('--lr', type=float, default=4e-4); p.add_argument('--lr_min', type=float, default=1e-6) p.add_argument('--wd', type=float, default=3e-5) p.add_argument('--log_every', type=int, default=100); p.add_argument('--eval_every', type=int, default=500) p.add_argument('--seed', type=int, default=0); p.add_argument('--out', type=str, default='/home/yurenh2/ept/runs') p.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') cfg = p.parse_args() print('config:', vars(cfg), flush=True); train(cfg) if __name__ == '__main__': main()