"""Same-param standard BP transformer char-LM (reference ceiling). Standard pre-LN block: MHA + FFN, trained with normal backprop.""" import argparse, math, pickle, time, numpy as np, torch, torch.nn as nn, torch.nn.functional as F from pathlib import Path dev = 'cuda' if torch.cuda.is_available() else 'cpu' ap = argparse.ArgumentParser() ap.add_argument('--data', default='/tmp/lt_ep/data/shakespeare_char') ap.add_argument('--B', type=int, default=32); ap.add_argument('--T', type=int, default=64) ap.add_argument('--C', type=int, default=128); ap.add_argument('--H', type=int, default=4) ap.add_argument('--depth', type=int, default=1); ap.add_argument('--mlp', type=int, default=4) ap.add_argument('--steps', type=int, default=3000); ap.add_argument('--lr', type=float, default=3e-3) ap.add_argument('--seed', type=int, default=0) cfg = ap.parse_args() torch.manual_seed(cfg.seed) DD = Path(cfg.data) vocab = pickle.load(open(DD / 'meta.pkl', 'rb'))['vocab_size'] B, T, C, H, DEPTH, MLP = cfg.B, cfg.T, cfg.C, cfg.H, cfg.depth, cfg.mlp def get_batch(split): data = np.memmap(DD / ('train.bin' if split == 'train' else 'val.bin'), dtype=np.uint16, mode='r') ix = torch.randint(len(data) - T - 1, (B,)) x = torch.stack([torch.from_numpy(data[i:i + T].astype(np.int64)) for i in ix]) y = torch.stack([torch.from_numpy(data[i + 1:i + 1 + T].astype(np.int64)) for i in ix]) return x.to(dev), y.to(dev) class Block(nn.Module): def __init__(s): super().__init__() s.ln1 = nn.LayerNorm(C); s.ln2 = nn.LayerNorm(C) s.attn = nn.MultiheadAttention(C, H, batch_first=True) s.mlp = nn.Sequential(nn.Linear(C, MLP * C), nn.GELU(), nn.Linear(MLP * C, C)) s.register_buffer('m', torch.triu(torch.ones(T, T) * float('-inf'), 1)) def forward(s, x): h = s.ln1(x) x = x + s.attn(h, h, h, attn_mask=s.m[:x.size(1), :x.size(1)], need_weights=False)[0] return x + s.mlp(s.ln2(x)) class GPT(nn.Module): def __init__(s): super().__init__() s.tok = nn.Embedding(vocab, C); s.pos = nn.Embedding(T, C) s.blocks = nn.ModuleList([Block() for _ in range(DEPTH)]) s.lnf = nn.LayerNorm(C); s.head = nn.Linear(C, vocab, bias=False) def forward(s, idx, y=None): x = s.tok(idx) + s.pos(torch.arange(idx.size(1), device=dev)) for b in s.blocks: x = b(x) logits = s.head(s.lnf(x)) loss = None if y is None else F.cross_entropy(logits.reshape(-1, vocab), y.reshape(-1)) return logits, loss m = GPT().to(dev) opt = torch.optim.AdamW(m.parameters(), lr=cfg.lr, weight_decay=1e-4) STEPS = cfg.steps sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, STEPS, eta_min=cfg.lr * 0.1) np_ = sum(p.numel() for p in m.parameters()) print(f"[bp-charlm] params={np_/1e3:.1f}K depth={DEPTH} C={C} H={H} mlp={MLP}", flush=True) @torch.no_grad() def ev(): m.eval(); t = sum(m(*get_batch('val'))[1].item() for _ in range(20)) / 20; m.train(); return t best, t0 = 9.9, time.time() for step in range(1, STEPS + 1): _, loss = m(*get_batch('train')) opt.zero_grad(set_to_none=True); loss.backward(); opt.step(); sched.step() if step % 200 == 0 or step == STEPS: v = ev(); best = min(best, v) print(f"step {step:4d}/{STEPS} | val CE {v:.4f} (best {best:.4f}) | {step/(time.time()-t0):.1f} it/s", flush=True) print(f"[bp-charlm] DONE best val CE {best:.4f} (random ln({vocab})={math.log(vocab):.3f})", flush=True)