1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
|
"""Same-param standard BP transformer char-LM (reference ceiling).
Standard pre-LN block: MHA + FFN, trained with normal backprop."""
import argparse, math, pickle, time, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
from pathlib import Path
dev = 'cuda' if torch.cuda.is_available() else 'cpu'
ap = argparse.ArgumentParser()
ap.add_argument('--data', default='/tmp/lt_ep/data/shakespeare_char')
ap.add_argument('--B', type=int, default=32); ap.add_argument('--T', type=int, default=64)
ap.add_argument('--C', type=int, default=128); ap.add_argument('--H', type=int, default=4)
ap.add_argument('--depth', type=int, default=1); ap.add_argument('--mlp', type=int, default=4)
ap.add_argument('--steps', type=int, default=3000); ap.add_argument('--lr', type=float, default=3e-3)
ap.add_argument('--seed', type=int, default=0)
cfg = ap.parse_args()
torch.manual_seed(cfg.seed)
DD = Path(cfg.data)
vocab = pickle.load(open(DD / 'meta.pkl', 'rb'))['vocab_size']
B, T, C, H, DEPTH, MLP = cfg.B, cfg.T, cfg.C, cfg.H, cfg.depth, cfg.mlp
def get_batch(split):
data = np.memmap(DD / ('train.bin' if split == 'train' else 'val.bin'), dtype=np.uint16, mode='r')
ix = torch.randint(len(data) - T - 1, (B,))
x = torch.stack([torch.from_numpy(data[i:i + T].astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy(data[i + 1:i + 1 + T].astype(np.int64)) for i in ix])
return x.to(dev), y.to(dev)
class Block(nn.Module):
def __init__(s):
super().__init__()
s.ln1 = nn.LayerNorm(C); s.ln2 = nn.LayerNorm(C)
s.attn = nn.MultiheadAttention(C, H, batch_first=True)
s.mlp = nn.Sequential(nn.Linear(C, MLP * C), nn.GELU(), nn.Linear(MLP * C, C))
s.register_buffer('m', torch.triu(torch.ones(T, T) * float('-inf'), 1))
def forward(s, x):
h = s.ln1(x)
x = x + s.attn(h, h, h, attn_mask=s.m[:x.size(1), :x.size(1)], need_weights=False)[0]
return x + s.mlp(s.ln2(x))
class GPT(nn.Module):
def __init__(s):
super().__init__()
s.tok = nn.Embedding(vocab, C); s.pos = nn.Embedding(T, C)
s.blocks = nn.ModuleList([Block() for _ in range(DEPTH)])
s.lnf = nn.LayerNorm(C); s.head = nn.Linear(C, vocab, bias=False)
def forward(s, idx, y=None):
x = s.tok(idx) + s.pos(torch.arange(idx.size(1), device=dev))
for b in s.blocks:
x = b(x)
logits = s.head(s.lnf(x))
loss = None if y is None else F.cross_entropy(logits.reshape(-1, vocab), y.reshape(-1))
return logits, loss
m = GPT().to(dev)
opt = torch.optim.AdamW(m.parameters(), lr=cfg.lr, weight_decay=1e-4)
STEPS = cfg.steps
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, STEPS, eta_min=cfg.lr * 0.1)
np_ = sum(p.numel() for p in m.parameters())
print(f"[bp-charlm] params={np_/1e3:.1f}K depth={DEPTH} C={C} H={H} mlp={MLP}", flush=True)
@torch.no_grad()
def ev():
m.eval(); t = sum(m(*get_batch('val'))[1].item() for _ in range(20)) / 20; m.train(); return t
best, t0 = 9.9, time.time()
for step in range(1, STEPS + 1):
_, loss = m(*get_batch('train'))
opt.zero_grad(set_to_none=True); loss.backward(); opt.step(); sched.step()
if step % 200 == 0 or step == STEPS:
v = ev(); best = min(best, v)
print(f"step {step:4d}/{STEPS} | val CE {v:.4f} (best {best:.4f}) | {step/(time.time()-t0):.1f} it/s", flush=True)
print(f"[bp-charlm] DONE best val CE {best:.4f} (random ln({vocab})={math.log(vocab):.3f})", flush=True)
|