summaryrefslogtreecommitdiff
path: root/ep_run/bp_charlm.py
blob: 410d812cda9be5af136e812d827fa7a14d331509 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""Same-param standard BP transformer char-LM (reference ceiling).
Standard pre-LN block: MHA + FFN, trained with normal backprop."""
import argparse, math, pickle, time, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
from pathlib import Path
dev = 'cuda' if torch.cuda.is_available() else 'cpu'
ap = argparse.ArgumentParser()
ap.add_argument('--data', default='/tmp/lt_ep/data/shakespeare_char')
ap.add_argument('--B', type=int, default=32); ap.add_argument('--T', type=int, default=64)
ap.add_argument('--C', type=int, default=128); ap.add_argument('--H', type=int, default=4)
ap.add_argument('--depth', type=int, default=1); ap.add_argument('--mlp', type=int, default=4)
ap.add_argument('--steps', type=int, default=3000); ap.add_argument('--lr', type=float, default=3e-3)
ap.add_argument('--seed', type=int, default=0)
cfg = ap.parse_args()
torch.manual_seed(cfg.seed)
DD = Path(cfg.data)
vocab = pickle.load(open(DD / 'meta.pkl', 'rb'))['vocab_size']
B, T, C, H, DEPTH, MLP = cfg.B, cfg.T, cfg.C, cfg.H, cfg.depth, cfg.mlp


def get_batch(split):
    data = np.memmap(DD / ('train.bin' if split == 'train' else 'val.bin'), dtype=np.uint16, mode='r')
    ix = torch.randint(len(data) - T - 1, (B,))
    x = torch.stack([torch.from_numpy(data[i:i + T].astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy(data[i + 1:i + 1 + T].astype(np.int64)) for i in ix])
    return x.to(dev), y.to(dev)


class Block(nn.Module):
    def __init__(s):
        super().__init__()
        s.ln1 = nn.LayerNorm(C); s.ln2 = nn.LayerNorm(C)
        s.attn = nn.MultiheadAttention(C, H, batch_first=True)
        s.mlp = nn.Sequential(nn.Linear(C, MLP * C), nn.GELU(), nn.Linear(MLP * C, C))
        s.register_buffer('m', torch.triu(torch.ones(T, T) * float('-inf'), 1))

    def forward(s, x):
        h = s.ln1(x)
        x = x + s.attn(h, h, h, attn_mask=s.m[:x.size(1), :x.size(1)], need_weights=False)[0]
        return x + s.mlp(s.ln2(x))


class GPT(nn.Module):
    def __init__(s):
        super().__init__()
        s.tok = nn.Embedding(vocab, C); s.pos = nn.Embedding(T, C)
        s.blocks = nn.ModuleList([Block() for _ in range(DEPTH)])
        s.lnf = nn.LayerNorm(C); s.head = nn.Linear(C, vocab, bias=False)

    def forward(s, idx, y=None):
        x = s.tok(idx) + s.pos(torch.arange(idx.size(1), device=dev))
        for b in s.blocks:
            x = b(x)
        logits = s.head(s.lnf(x))
        loss = None if y is None else F.cross_entropy(logits.reshape(-1, vocab), y.reshape(-1))
        return logits, loss


m = GPT().to(dev)
opt = torch.optim.AdamW(m.parameters(), lr=cfg.lr, weight_decay=1e-4)
STEPS = cfg.steps
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, STEPS, eta_min=cfg.lr * 0.1)
np_ = sum(p.numel() for p in m.parameters())
print(f"[bp-charlm] params={np_/1e3:.1f}K depth={DEPTH} C={C} H={H} mlp={MLP}", flush=True)


@torch.no_grad()
def ev():
    m.eval(); t = sum(m(*get_batch('val'))[1].item() for _ in range(20)) / 20; m.train(); return t


best, t0 = 9.9, time.time()
for step in range(1, STEPS + 1):
    _, loss = m(*get_batch('train'))
    opt.zero_grad(set_to_none=True); loss.backward(); opt.step(); sched.step()
    if step % 200 == 0 or step == STEPS:
        v = ev(); best = min(best, v)
        print(f"step {step:4d}/{STEPS} | val CE {v:.4f} (best {best:.4f}) | {step/(time.time()-t0):.1f} it/s", flush=True)
print(f"[bp-charlm] DONE best val CE {best:.4f}  (random ln({vocab})={math.log(vocab):.3f})", flush=True)