summaryrefslogtreecommitdiff
path: root/scripts/bp_transformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/bp_transformer.py')
-rw-r--r--scripts/bp_transformer.py141
1 files changed, 141 insertions, 0 deletions
diff --git a/scripts/bp_transformer.py b/scripts/bp_transformer.py
new file mode 100644
index 0000000..7c9b543
--- /dev/null
+++ b/scripts/bp_transformer.py
@@ -0,0 +1,141 @@
+"""
+Vanilla backprop Transformer baseline for the SAME masked-image-completion task,
+so we can compare against CET-EP and CET-TBPTE on the identical metric
+(masked-patch pixel MSE on CIFAR-10, images in [-1,1]).
+
+Standard recipe: conv patch-embed + learned pos-embed, MAE-style learned mask
+token on occluded patches, N standard pre-LN transformer blocks (MHA + FFN),
+linear pixel head, MSE loss on masked patches only. Trained with normal Adam/BP.
+"""
+import argparse, os, time, json, math
+import torch, torch.nn as nn, torch.nn.functional as F
+from cet_mvp import get_loaders, make_patch_mask # reuse data + masking
+
+
+class Block(nn.Module):
+ def __init__(self, D, heads, mlp_ratio):
+ super().__init__()
+ self.ln1 = nn.LayerNorm(D)
+ self.attn = nn.MultiheadAttention(D, heads, batch_first=True)
+ self.ln2 = nn.LayerNorm(D)
+ self.mlp = nn.Sequential(nn.Linear(D, int(mlp_ratio * D)), nn.GELU(),
+ nn.Linear(int(mlp_ratio * D), D))
+
+ def forward(self, x):
+ h = self.ln1(x)
+ x = x + self.attn(h, h, h, need_weights=False)[0]
+ x = x + self.mlp(self.ln2(x))
+ return x
+
+
+class BPTransformer(nn.Module):
+ def __init__(self, img=32, ch=3, patch=8, stride=8, D=128, heads=4,
+ depth=1, mlp_ratio=2.0):
+ super().__init__()
+ self.ch, self.patch, self.stride = ch, patch, stride
+ gh = (img - patch) // stride + 1
+ self.gh, self.N, self.pdim = gh, gh * gh, ch * patch * patch
+ self.embed = nn.Conv2d(ch, D, patch, stride=stride)
+ self.pos = nn.Parameter(torch.zeros(1, self.N, D)); nn.init.normal_(self.pos, std=0.02)
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, D)); nn.init.normal_(self.mask_token, std=0.02)
+ self.blocks = nn.ModuleList([Block(D, heads, mlp_ratio) for _ in range(depth)])
+ self.ln = nn.LayerNorm(D)
+ self.head = nn.Linear(D, self.pdim)
+
+ def patchify(self, x): # (B,C,H,W)->(B,N,pdim)
+ p, s = self.patch, self.stride
+ u = x.unfold(2, p, s).unfold(3, p, s) # B,C,gh,gw,p,p
+ return u.permute(0, 2, 3, 1, 4, 5).reshape(x.size(0), self.N, self.pdim)
+
+ def forward(self, xbar, pm): # pm: (B,N) 1=masked
+ t = self.embed(xbar).flatten(2).transpose(1, 2) # (B,N,D)
+ t = torch.where(pm.unsqueeze(-1).bool(), self.mask_token, t) + self.pos
+ for b in self.blocks:
+ t = b(t)
+ return self.head(self.ln(t)) # (B,N,pdim)
+
+
+def patch_mask_bool(B, gh, ratio, device, gen=None):
+ npatch = gh * gh; nmask = int(round(ratio * npatch))
+ idx = torch.rand(B, npatch, device=device, generator=gen).argsort(1)
+ pm = torch.zeros(B, npatch, device=device)
+ pm.scatter_(1, idx[:, :nmask], 1.0)
+ return pm # (B,N)
+
+
+def masked_patch_mse(pred, true, pm):
+ m = pm.unsqueeze(-1)
+ return ((pred - true) ** 2 * m).sum() / (m.sum() * pred.size(-1)).clamp_min(1.0)
+
+
+@torch.no_grad()
+def evaluate(model, loader, cfg, device, max_batches=100):
+ model.eval(); tot, n = 0.0, 0
+ gen = torch.Generator(device=device).manual_seed(0)
+ for i, (x, _) in enumerate(loader):
+ if i >= max_batches:
+ break
+ x = x.to(device)
+ pm = patch_mask_bool(x.size(0), model.gh, cfg.mask_ratio, device, gen)
+ M = pm.view(-1, model.gh, model.gh).repeat_interleave(cfg.patch, 1).repeat_interleave(cfg.patch, 2).unsqueeze(1)
+ xbar = x * (1 - M)
+ pred = model(xbar, pm)
+ tot += masked_patch_mse(pred, model.patchify(x), pm).item() * x.size(0); n += x.size(0)
+ model.train(); return tot / n
+
+
+def train(cfg):
+ device = cfg.device; torch.manual_seed(cfg.seed)
+ model = BPTransformer(cfg.img, cfg.ch, cfg.patch, cfg.stride, cfg.D, cfg.heads,
+ cfg.depth, cfg.mlp_ratio).to(device)
+ opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.wd)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, cfg.steps, eta_min=cfg.lr_min)
+ trl, tel = get_loaders(cfg.batch, dataset=cfg.dataset)
+ print(f"[bp] params={sum(p.numel() for p in model.parameters())/1e3:.1f}K "
+ f"depth={cfg.depth} D={cfg.D} mlp={cfg.mlp_ratio}", flush=True)
+ step, t0, run = 0, time.time(), 0.0
+ while step < cfg.steps:
+ for x, _ in trl:
+ if step >= cfg.steps:
+ break
+ x = x.to(device, non_blocking=True)
+ pm = patch_mask_bool(x.size(0), model.gh, cfg.mask_ratio, device)
+ M = pm.view(-1, model.gh, model.gh).repeat_interleave(cfg.patch, 1).repeat_interleave(cfg.patch, 2).unsqueeze(1)
+ xbar = x * (1 - M)
+ pred = model(xbar, pm)
+ loss = masked_patch_mse(pred, model.patchify(x), pm)
+ opt.zero_grad(set_to_none=True); loss.backward(); opt.step(); sched.step()
+ run += loss.item(); step += 1
+ if step % cfg.log_every == 0:
+ print(f"step {step:5d}/{cfg.steps} | train masked-MSE {run/cfg.log_every:.5f} "
+ f"| {step/(time.time()-t0):.1f} it/s", flush=True); run = 0.0
+ if step % cfg.eval_every == 0 or step == cfg.steps:
+ print(f" >> [eval] step {step} test masked-MSE {evaluate(model, tel, cfg, device, 20):.5f}", flush=True)
+ final = evaluate(model, tel, cfg, device, 100)
+ os.makedirs(cfg.out, exist_ok=True)
+ json.dump({'mode': 'bp_transformer', 'final_test_masked_mse': final, 'steps': cfg.steps,
+ 'params_K': sum(p.numel() for p in model.parameters()) / 1e3},
+ open(os.path.join(cfg.out, 'result_bp_transformer.json'), 'w'), indent=2)
+ print(f"[bp] DONE final test masked-MSE = {final:.5f}", flush=True)
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--dataset', choices=['cifar10', 'fashionmnist'], default='cifar10')
+ p.add_argument('--steps', type=int, default=3000); p.add_argument('--batch', type=int, default=128)
+ p.add_argument('--img', type=int, default=32); p.add_argument('--ch', type=int, default=3)
+ p.add_argument('--patch', type=int, default=8); p.add_argument('--stride', type=int, default=8)
+ p.add_argument('--D', type=int, default=128); p.add_argument('--heads', type=int, default=4)
+ p.add_argument('--depth', type=int, default=1); p.add_argument('--mlp_ratio', type=float, default=2.0)
+ p.add_argument('--mask_ratio', type=float, default=0.5)
+ p.add_argument('--lr', type=float, default=4e-4); p.add_argument('--lr_min', type=float, default=1e-6)
+ p.add_argument('--wd', type=float, default=3e-5)
+ p.add_argument('--log_every', type=int, default=100); p.add_argument('--eval_every', type=int, default=500)
+ p.add_argument('--seed', type=int, default=0); p.add_argument('--out', type=str, default='/home/yurenh2/ept/runs')
+ p.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
+ cfg = p.parse_args()
+ print('config:', vars(cfg), flush=True); train(cfg)
+
+
+if __name__ == '__main__':
+ main()