diff options
| author | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
|---|---|---|
| committer | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
| commit | b83947778e2c776f757a07d4719b7ce961d7ed55 (patch) | |
| tree | b9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/sample_eq.py | |
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}),
analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints
git-ignored (share separately).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/sample_eq.py')
| -rw-r--r-- | ep_run/sample_eq.py | 70 |
1 files changed, 70 insertions, 0 deletions
diff --git a/ep_run/sample_eq.py b/ep_run/sample_eq.py new file mode 100644 index 0000000..9a6552a --- /dev/null +++ b/ep_run/sample_eq.py @@ -0,0 +1,70 @@ +"""Autoregressive sampling from an equilibrium LM checkpoint: causal coupling means the prefix +settles independently of padding, so we relax the padded sequence and read the last position's +logits each step (simple full-re-settle variant; incremental KV-style settling is the optimized +version, not needed at this scale).""" +import argparse, pickle, torch +import lt_ep_train as M +from pathlib import Path + +ap = argparse.ArgumentParser() +ap.add_argument('--ckpt', required=True) +ap.add_argument('--data', default='/tmp/lt_ep/data/tinystories') +ap.add_argument('--C', type=int, default=256) +ap.add_argument('--H', type=int, default=8) +ap.add_argument('--T', type=int, default=256) +ap.add_argument('--T1', type=int, default=150) +ap.add_argument('--eps', type=float, default=0.1) +ap.add_argument('--temp', type=float, default=0.8) +ap.add_argument('--topk', type=int, default=40) +ap.add_argument('--c', type=float, default=1.0) +ap.add_argument('--qknorm', action='store_true') +ap.add_argument('--n', type=int, default=3) +ap.add_argument('--prompt', default='Once upon a time') +ap.add_argument('--use_pema', action='store_true') +cfg = ap.parse_args() + +M.DD = Path(cfg.data) +meta = pickle.load(open(M.DD / 'meta.pkl', 'rb')) +M.vocab = meta['vocab_size'] +_tokj = M.DD / 'tokenizer.json' +if _tokj.exists(): # BPE: decode via tokenizers + from tokenizers import Tokenizer + _tk = Tokenizer.from_file(str(_tokj)) + def encode(s): return _tk.encode(s).ids + def decode(ids): return _tk.decode(ids) +else: # char: decode via stoi/itos + stoi = meta['stoi']; itos = {i: c for c, i in stoi.items()} + def encode(s): return [stoi.get(c, 0) for c in s] + def decode(ids): return ''.join(itos.get(int(i), '?') for i in ids) +from lt_ep_train import EQBlock, relax + +dev = 'cuda' if torch.cuda.is_available() else 'cpu' +torch.manual_seed(1234) +blk = EQBlock(cfg.C, cfg.H, 256, cfg.T, attn_mode='thick') +blk.c = cfg.c # MUST match training c (force identity) +blk.qknorm = cfg.qknorm # MUST match training qknorm (else wrong fixed point) +ck = torch.load(cfg.ckpt) +src = ck['pema'] if (cfg.use_pema and ck.get('pema') is not None) else ck['allp'] +with torch.no_grad(): + for p, w in zip(blk.allp, src): + p.copy_(w.to(dev)) +print(f"loaded {cfg.ckpt} (step {ck.get('step')}, best {ck.get('best'):.4f}, " + f"{'pema' if cfg.use_pema else 'raw'} weights)\n", flush=True) + +for s in range(cfg.n): + ids = encode(cfg.prompt) + idx = torch.zeros(1, cfg.T, dtype=torch.long, device=dev) + idx[0, :len(ids)] = torch.tensor(ids, device=dev) + pos = len(ids) + while pos < cfg.T: + xin = blk.embed(idx).detach() + z = relax(blk, xin.clone(), xin, cfg.T1, cfg.eps) + logits = (z[0, pos - 1] @ blk.Wh) / cfg.temp + if cfg.topk > 0: + kth = torch.topk(logits, cfg.topk).values[-1] + logits[logits < kth] = float('-inf') + nxt = torch.multinomial(torch.softmax(logits, -1), 1).item() + idx[0, pos] = nxt + pos += 1 + text = decode(idx[0, :pos].tolist()) + print(f"--- sample {s+1} ---\n{text}\n", flush=True) |
