import torch, math, torch.nn.functional as F from pathlib import Path from tokenizers import Tokenizer import lt_ep_train as LT dev='cuda' DD=Path('/home/yurenh2/ept/ep_run/data/tinystories_bpe') tok=Tokenizer.from_file(str(DD/'tokenizer.json')) C,H,Mm,T,c,T1,eps = 512,16,256,256,1.0,150,0.1 blk=LT.EQBlock(C,H,Mm,T,c=c,attn_mode='thick') ck=torch.load('/home/yurenh2/ept/ep_run/runs/ep_resreg_warm.pt',map_location=dev) params = ck['pema'] if ck.get('pema') is not None else ck['allp'] with torch.no_grad(): for p,s in zip(blk.allp, params): p.copy_(s.to(dev)) print(f"loaded resreg_warm pema | best val CE {ck['best']:.4f} | step {ck['step']}", flush=True) @torch.no_grad() def gen(prompt, n_new=150, temp=0.8, topk=40, seed=0): torch.manual_seed(seed) ids=tok.encode(prompt).ids idx=torch.zeros(1,T,dtype=torch.long,device=dev); L=len(ids) idx[0,:L]=torch.tensor(ids,device=dev) for _ in range(n_new): if L>=T: break xin=blk.embed(idx) z=LT.relax(blk, xin.clone(), xin, T1, eps) lg=(z@blk.Wh)[0,L-1]/temp v,_=torch.topk(lg,topk); lg[lg