summaryrefslogtreecommitdiff
path: root/ep_run/bp_charlm.py
diff options
context:
space:
mode:
authorYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
committerYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
commitb83947778e2c776f757a07d4719b7ce961d7ed55 (patch)
treeb9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/bp_charlm.py
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}), analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints git-ignored (share separately). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/bp_charlm.py')
-rw-r--r--ep_run/bp_charlm.py78
1 files changed, 78 insertions, 0 deletions
diff --git a/ep_run/bp_charlm.py b/ep_run/bp_charlm.py
new file mode 100644
index 0000000..410d812
--- /dev/null
+++ b/ep_run/bp_charlm.py
@@ -0,0 +1,78 @@
+"""Same-param standard BP transformer char-LM (reference ceiling).
+Standard pre-LN block: MHA + FFN, trained with normal backprop."""
+import argparse, math, pickle, time, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
+from pathlib import Path
+dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ap = argparse.ArgumentParser()
+ap.add_argument('--data', default='/tmp/lt_ep/data/shakespeare_char')
+ap.add_argument('--B', type=int, default=32); ap.add_argument('--T', type=int, default=64)
+ap.add_argument('--C', type=int, default=128); ap.add_argument('--H', type=int, default=4)
+ap.add_argument('--depth', type=int, default=1); ap.add_argument('--mlp', type=int, default=4)
+ap.add_argument('--steps', type=int, default=3000); ap.add_argument('--lr', type=float, default=3e-3)
+ap.add_argument('--seed', type=int, default=0)
+cfg = ap.parse_args()
+torch.manual_seed(cfg.seed)
+DD = Path(cfg.data)
+vocab = pickle.load(open(DD / 'meta.pkl', 'rb'))['vocab_size']
+B, T, C, H, DEPTH, MLP = cfg.B, cfg.T, cfg.C, cfg.H, cfg.depth, cfg.mlp
+
+
+def get_batch(split):
+ data = np.memmap(DD / ('train.bin' if split == 'train' else 'val.bin'), dtype=np.uint16, mode='r')
+ ix = torch.randint(len(data) - T - 1, (B,))
+ x = torch.stack([torch.from_numpy(data[i:i + T].astype(np.int64)) for i in ix])
+ y = torch.stack([torch.from_numpy(data[i + 1:i + 1 + T].astype(np.int64)) for i in ix])
+ return x.to(dev), y.to(dev)
+
+
+class Block(nn.Module):
+ def __init__(s):
+ super().__init__()
+ s.ln1 = nn.LayerNorm(C); s.ln2 = nn.LayerNorm(C)
+ s.attn = nn.MultiheadAttention(C, H, batch_first=True)
+ s.mlp = nn.Sequential(nn.Linear(C, MLP * C), nn.GELU(), nn.Linear(MLP * C, C))
+ s.register_buffer('m', torch.triu(torch.ones(T, T) * float('-inf'), 1))
+
+ def forward(s, x):
+ h = s.ln1(x)
+ x = x + s.attn(h, h, h, attn_mask=s.m[:x.size(1), :x.size(1)], need_weights=False)[0]
+ return x + s.mlp(s.ln2(x))
+
+
+class GPT(nn.Module):
+ def __init__(s):
+ super().__init__()
+ s.tok = nn.Embedding(vocab, C); s.pos = nn.Embedding(T, C)
+ s.blocks = nn.ModuleList([Block() for _ in range(DEPTH)])
+ s.lnf = nn.LayerNorm(C); s.head = nn.Linear(C, vocab, bias=False)
+
+ def forward(s, idx, y=None):
+ x = s.tok(idx) + s.pos(torch.arange(idx.size(1), device=dev))
+ for b in s.blocks:
+ x = b(x)
+ logits = s.head(s.lnf(x))
+ loss = None if y is None else F.cross_entropy(logits.reshape(-1, vocab), y.reshape(-1))
+ return logits, loss
+
+
+m = GPT().to(dev)
+opt = torch.optim.AdamW(m.parameters(), lr=cfg.lr, weight_decay=1e-4)
+STEPS = cfg.steps
+sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, STEPS, eta_min=cfg.lr * 0.1)
+np_ = sum(p.numel() for p in m.parameters())
+print(f"[bp-charlm] params={np_/1e3:.1f}K depth={DEPTH} C={C} H={H} mlp={MLP}", flush=True)
+
+
+@torch.no_grad()
+def ev():
+ m.eval(); t = sum(m(*get_batch('val'))[1].item() for _ in range(20)) / 20; m.train(); return t
+
+
+best, t0 = 9.9, time.time()
+for step in range(1, STEPS + 1):
+ _, loss = m(*get_batch('train'))
+ opt.zero_grad(set_to_none=True); loss.backward(); opt.step(); sched.step()
+ if step % 200 == 0 or step == STEPS:
+ v = ev(); best = min(best, v)
+ print(f"step {step:4d}/{STEPS} | val CE {v:.4f} (best {best:.4f}) | {step/(time.time()-t0):.1f} it/s", flush=True)
+print(f"[bp-charlm] DONE best val CE {best:.4f} (random ln({vocab})={math.log(vocab):.3f})", flush=True)