diff options
Diffstat (limited to 'ep_run/bp_charlm.py')
| -rw-r--r-- | ep_run/bp_charlm.py | 78 |
1 files changed, 78 insertions, 0 deletions
diff --git a/ep_run/bp_charlm.py b/ep_run/bp_charlm.py new file mode 100644 index 0000000..410d812 --- /dev/null +++ b/ep_run/bp_charlm.py @@ -0,0 +1,78 @@ +"""Same-param standard BP transformer char-LM (reference ceiling). +Standard pre-LN block: MHA + FFN, trained with normal backprop.""" +import argparse, math, pickle, time, numpy as np, torch, torch.nn as nn, torch.nn.functional as F +from pathlib import Path +dev = 'cuda' if torch.cuda.is_available() else 'cpu' +ap = argparse.ArgumentParser() +ap.add_argument('--data', default='/tmp/lt_ep/data/shakespeare_char') +ap.add_argument('--B', type=int, default=32); ap.add_argument('--T', type=int, default=64) +ap.add_argument('--C', type=int, default=128); ap.add_argument('--H', type=int, default=4) +ap.add_argument('--depth', type=int, default=1); ap.add_argument('--mlp', type=int, default=4) +ap.add_argument('--steps', type=int, default=3000); ap.add_argument('--lr', type=float, default=3e-3) +ap.add_argument('--seed', type=int, default=0) +cfg = ap.parse_args() +torch.manual_seed(cfg.seed) +DD = Path(cfg.data) +vocab = pickle.load(open(DD / 'meta.pkl', 'rb'))['vocab_size'] +B, T, C, H, DEPTH, MLP = cfg.B, cfg.T, cfg.C, cfg.H, cfg.depth, cfg.mlp + + +def get_batch(split): + data = np.memmap(DD / ('train.bin' if split == 'train' else 'val.bin'), dtype=np.uint16, mode='r') + ix = torch.randint(len(data) - T - 1, (B,)) + x = torch.stack([torch.from_numpy(data[i:i + T].astype(np.int64)) for i in ix]) + y = torch.stack([torch.from_numpy(data[i + 1:i + 1 + T].astype(np.int64)) for i in ix]) + return x.to(dev), y.to(dev) + + +class Block(nn.Module): + def __init__(s): + super().__init__() + s.ln1 = nn.LayerNorm(C); s.ln2 = nn.LayerNorm(C) + s.attn = nn.MultiheadAttention(C, H, batch_first=True) + s.mlp = nn.Sequential(nn.Linear(C, MLP * C), nn.GELU(), nn.Linear(MLP * C, C)) + s.register_buffer('m', torch.triu(torch.ones(T, T) * float('-inf'), 1)) + + def forward(s, x): + h = s.ln1(x) + x = x + s.attn(h, h, h, attn_mask=s.m[:x.size(1), :x.size(1)], need_weights=False)[0] + return x + s.mlp(s.ln2(x)) + + +class GPT(nn.Module): + def __init__(s): + super().__init__() + s.tok = nn.Embedding(vocab, C); s.pos = nn.Embedding(T, C) + s.blocks = nn.ModuleList([Block() for _ in range(DEPTH)]) + s.lnf = nn.LayerNorm(C); s.head = nn.Linear(C, vocab, bias=False) + + def forward(s, idx, y=None): + x = s.tok(idx) + s.pos(torch.arange(idx.size(1), device=dev)) + for b in s.blocks: + x = b(x) + logits = s.head(s.lnf(x)) + loss = None if y is None else F.cross_entropy(logits.reshape(-1, vocab), y.reshape(-1)) + return logits, loss + + +m = GPT().to(dev) +opt = torch.optim.AdamW(m.parameters(), lr=cfg.lr, weight_decay=1e-4) +STEPS = cfg.steps +sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, STEPS, eta_min=cfg.lr * 0.1) +np_ = sum(p.numel() for p in m.parameters()) +print(f"[bp-charlm] params={np_/1e3:.1f}K depth={DEPTH} C={C} H={H} mlp={MLP}", flush=True) + + +@torch.no_grad() +def ev(): + m.eval(); t = sum(m(*get_batch('val'))[1].item() for _ in range(20)) / 20; m.train(); return t + + +best, t0 = 9.9, time.time() +for step in range(1, STEPS + 1): + _, loss = m(*get_batch('train')) + opt.zero_grad(set_to_none=True); loss.backward(); opt.step(); sched.step() + if step % 200 == 0 or step == STEPS: + v = ev(); best = min(best, v) + print(f"step {step:4d}/{STEPS} | val CE {v:.4f} (best {best:.4f}) | {step/(time.time()-t0):.1f} it/s", flush=True) +print(f"[bp-charlm] DONE best val CE {best:.4f} (random ln({vocab})={math.log(vocab):.3f})", flush=True) |
