summaryrefslogtreecommitdiff
path: root/scripts/bp_transformer.py
blob: 7c9b5438b85a66f5cfab24bfee8090f067429750 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""
Vanilla backprop Transformer baseline for the SAME masked-image-completion task,
so we can compare against CET-EP and CET-TBPTE on the identical metric
(masked-patch pixel MSE on CIFAR-10, images in [-1,1]).

Standard recipe: conv patch-embed + learned pos-embed, MAE-style learned mask
token on occluded patches, N standard pre-LN transformer blocks (MHA + FFN),
linear pixel head, MSE loss on masked patches only. Trained with normal Adam/BP.
"""
import argparse, os, time, json, math
import torch, torch.nn as nn, torch.nn.functional as F
from cet_mvp import get_loaders, make_patch_mask  # reuse data + masking


class Block(nn.Module):
    def __init__(self, D, heads, mlp_ratio):
        super().__init__()
        self.ln1 = nn.LayerNorm(D)
        self.attn = nn.MultiheadAttention(D, heads, batch_first=True)
        self.ln2 = nn.LayerNorm(D)
        self.mlp = nn.Sequential(nn.Linear(D, int(mlp_ratio * D)), nn.GELU(),
                                 nn.Linear(int(mlp_ratio * D), D))

    def forward(self, x):
        h = self.ln1(x)
        x = x + self.attn(h, h, h, need_weights=False)[0]
        x = x + self.mlp(self.ln2(x))
        return x


class BPTransformer(nn.Module):
    def __init__(self, img=32, ch=3, patch=8, stride=8, D=128, heads=4,
                 depth=1, mlp_ratio=2.0):
        super().__init__()
        self.ch, self.patch, self.stride = ch, patch, stride
        gh = (img - patch) // stride + 1
        self.gh, self.N, self.pdim = gh, gh * gh, ch * patch * patch
        self.embed = nn.Conv2d(ch, D, patch, stride=stride)
        self.pos = nn.Parameter(torch.zeros(1, self.N, D)); nn.init.normal_(self.pos, std=0.02)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, D)); nn.init.normal_(self.mask_token, std=0.02)
        self.blocks = nn.ModuleList([Block(D, heads, mlp_ratio) for _ in range(depth)])
        self.ln = nn.LayerNorm(D)
        self.head = nn.Linear(D, self.pdim)

    def patchify(self, x):                                   # (B,C,H,W)->(B,N,pdim)
        p, s = self.patch, self.stride
        u = x.unfold(2, p, s).unfold(3, p, s)                # B,C,gh,gw,p,p
        return u.permute(0, 2, 3, 1, 4, 5).reshape(x.size(0), self.N, self.pdim)

    def forward(self, xbar, pm):                             # pm: (B,N) 1=masked
        t = self.embed(xbar).flatten(2).transpose(1, 2)      # (B,N,D)
        t = torch.where(pm.unsqueeze(-1).bool(), self.mask_token, t) + self.pos
        for b in self.blocks:
            t = b(t)
        return self.head(self.ln(t))                         # (B,N,pdim)


def patch_mask_bool(B, gh, ratio, device, gen=None):
    npatch = gh * gh; nmask = int(round(ratio * npatch))
    idx = torch.rand(B, npatch, device=device, generator=gen).argsort(1)
    pm = torch.zeros(B, npatch, device=device)
    pm.scatter_(1, idx[:, :nmask], 1.0)
    return pm                                                # (B,N)


def masked_patch_mse(pred, true, pm):
    m = pm.unsqueeze(-1)
    return ((pred - true) ** 2 * m).sum() / (m.sum() * pred.size(-1)).clamp_min(1.0)


@torch.no_grad()
def evaluate(model, loader, cfg, device, max_batches=100):
    model.eval(); tot, n = 0.0, 0
    gen = torch.Generator(device=device).manual_seed(0)
    for i, (x, _) in enumerate(loader):
        if i >= max_batches:
            break
        x = x.to(device)
        pm = patch_mask_bool(x.size(0), model.gh, cfg.mask_ratio, device, gen)
        M = pm.view(-1, model.gh, model.gh).repeat_interleave(cfg.patch, 1).repeat_interleave(cfg.patch, 2).unsqueeze(1)
        xbar = x * (1 - M)
        pred = model(xbar, pm)
        tot += masked_patch_mse(pred, model.patchify(x), pm).item() * x.size(0); n += x.size(0)
    model.train(); return tot / n


def train(cfg):
    device = cfg.device; torch.manual_seed(cfg.seed)
    model = BPTransformer(cfg.img, cfg.ch, cfg.patch, cfg.stride, cfg.D, cfg.heads,
                          cfg.depth, cfg.mlp_ratio).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.wd)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, cfg.steps, eta_min=cfg.lr_min)
    trl, tel = get_loaders(cfg.batch, dataset=cfg.dataset)
    print(f"[bp] params={sum(p.numel() for p in model.parameters())/1e3:.1f}K "
          f"depth={cfg.depth} D={cfg.D} mlp={cfg.mlp_ratio}", flush=True)
    step, t0, run = 0, time.time(), 0.0
    while step < cfg.steps:
        for x, _ in trl:
            if step >= cfg.steps:
                break
            x = x.to(device, non_blocking=True)
            pm = patch_mask_bool(x.size(0), model.gh, cfg.mask_ratio, device)
            M = pm.view(-1, model.gh, model.gh).repeat_interleave(cfg.patch, 1).repeat_interleave(cfg.patch, 2).unsqueeze(1)
            xbar = x * (1 - M)
            pred = model(xbar, pm)
            loss = masked_patch_mse(pred, model.patchify(x), pm)
            opt.zero_grad(set_to_none=True); loss.backward(); opt.step(); sched.step()
            run += loss.item(); step += 1
            if step % cfg.log_every == 0:
                print(f"step {step:5d}/{cfg.steps} | train masked-MSE {run/cfg.log_every:.5f} "
                      f"| {step/(time.time()-t0):.1f} it/s", flush=True); run = 0.0
            if step % cfg.eval_every == 0 or step == cfg.steps:
                print(f"  >> [eval] step {step} test masked-MSE {evaluate(model, tel, cfg, device, 20):.5f}", flush=True)
    final = evaluate(model, tel, cfg, device, 100)
    os.makedirs(cfg.out, exist_ok=True)
    json.dump({'mode': 'bp_transformer', 'final_test_masked_mse': final, 'steps': cfg.steps,
               'params_K': sum(p.numel() for p in model.parameters()) / 1e3},
              open(os.path.join(cfg.out, 'result_bp_transformer.json'), 'w'), indent=2)
    print(f"[bp] DONE final test masked-MSE = {final:.5f}", flush=True)


def main():
    p = argparse.ArgumentParser()
    p.add_argument('--dataset', choices=['cifar10', 'fashionmnist'], default='cifar10')
    p.add_argument('--steps', type=int, default=3000); p.add_argument('--batch', type=int, default=128)
    p.add_argument('--img', type=int, default=32); p.add_argument('--ch', type=int, default=3)
    p.add_argument('--patch', type=int, default=8); p.add_argument('--stride', type=int, default=8)
    p.add_argument('--D', type=int, default=128); p.add_argument('--heads', type=int, default=4)
    p.add_argument('--depth', type=int, default=1); p.add_argument('--mlp_ratio', type=float, default=2.0)
    p.add_argument('--mask_ratio', type=float, default=0.5)
    p.add_argument('--lr', type=float, default=4e-4); p.add_argument('--lr_min', type=float, default=1e-6)
    p.add_argument('--wd', type=float, default=3e-5)
    p.add_argument('--log_every', type=int, default=100); p.add_argument('--eval_every', type=int, default=500)
    p.add_argument('--seed', type=int, default=0); p.add_argument('--out', type=str, default='/home/yurenh2/ept/runs')
    p.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
    cfg = p.parse_args()
    print('config:', vars(cfg), flush=True); train(cfg)


if __name__ == '__main__':
    main()