diff options
| author | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
|---|---|---|
| committer | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
| commit | b83947778e2c776f757a07d4719b7ce961d7ed55 (patch) | |
| tree | b9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/gen_ept.py | |
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}),
analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints
git-ignored (share separately).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/gen_ept.py')
| -rw-r--r-- | ep_run/gen_ept.py | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/ep_run/gen_ept.py b/ep_run/gen_ept.py new file mode 100644 index 0000000..2a59ff5 --- /dev/null +++ b/ep_run/gen_ept.py @@ -0,0 +1,32 @@ +import torch, math, torch.nn.functional as F +from pathlib import Path +from tokenizers import Tokenizer +import lt_ep_train as LT +dev='cuda' +DD=Path('/home/yurenh2/ept/ep_run/data/tinystories_bpe') +tok=Tokenizer.from_file(str(DD/'tokenizer.json')) +C,H,Mm,T,c,T1,eps = 512,16,256,256,1.0,150,0.1 +blk=LT.EQBlock(C,H,Mm,T,c=c,attn_mode='thick') +ck=torch.load('/home/yurenh2/ept/ep_run/runs/ep_resreg_warm.pt',map_location=dev) +params = ck['pema'] if ck.get('pema') is not None else ck['allp'] +with torch.no_grad(): + for p,s in zip(blk.allp, params): p.copy_(s.to(dev)) +print(f"loaded resreg_warm pema | best val CE {ck['best']:.4f} | step {ck['step']}", flush=True) +@torch.no_grad() +def gen(prompt, n_new=150, temp=0.8, topk=40, seed=0): + torch.manual_seed(seed) + ids=tok.encode(prompt).ids + idx=torch.zeros(1,T,dtype=torch.long,device=dev); L=len(ids) + idx[0,:L]=torch.tensor(ids,device=dev) + for _ in range(n_new): + if L>=T: break + xin=blk.embed(idx) + z=LT.relax(blk, xin.clone(), xin, T1, eps) + lg=(z@blk.Wh)[0,L-1]/temp + v,_=torch.topk(lg,topk); lg[lg<v[-1]]=-float('inf') + p=F.softmax(lg,-1); nt=torch.multinomial(p,1).item() + idx[0,L]=nt; L+=1 + return tok.decode(idx[0,:L].tolist()) +for seed in range(3, 11): + print(f"\n===== seed={seed} temp=0.7 =====", flush=True) + print(gen("Once upon a time,", 135, temp=0.7, topk=40, seed=seed), flush=True) |
