summaryrefslogtreecommitdiff
path: root/ep_run/gen_ept.py
diff options
context:
space:
mode:
Diffstat (limited to 'ep_run/gen_ept.py')
-rw-r--r--ep_run/gen_ept.py32
1 files changed, 32 insertions, 0 deletions
diff --git a/ep_run/gen_ept.py b/ep_run/gen_ept.py
new file mode 100644
index 0000000..2a59ff5
--- /dev/null
+++ b/ep_run/gen_ept.py
@@ -0,0 +1,32 @@
+import torch, math, torch.nn.functional as F
+from pathlib import Path
+from tokenizers import Tokenizer
+import lt_ep_train as LT
+dev='cuda'
+DD=Path('/home/yurenh2/ept/ep_run/data/tinystories_bpe')
+tok=Tokenizer.from_file(str(DD/'tokenizer.json'))
+C,H,Mm,T,c,T1,eps = 512,16,256,256,1.0,150,0.1
+blk=LT.EQBlock(C,H,Mm,T,c=c,attn_mode='thick')
+ck=torch.load('/home/yurenh2/ept/ep_run/runs/ep_resreg_warm.pt',map_location=dev)
+params = ck['pema'] if ck.get('pema') is not None else ck['allp']
+with torch.no_grad():
+ for p,s in zip(blk.allp, params): p.copy_(s.to(dev))
+print(f"loaded resreg_warm pema | best val CE {ck['best']:.4f} | step {ck['step']}", flush=True)
+@torch.no_grad()
+def gen(prompt, n_new=150, temp=0.8, topk=40, seed=0):
+ torch.manual_seed(seed)
+ ids=tok.encode(prompt).ids
+ idx=torch.zeros(1,T,dtype=torch.long,device=dev); L=len(ids)
+ idx[0,:L]=torch.tensor(ids,device=dev)
+ for _ in range(n_new):
+ if L>=T: break
+ xin=blk.embed(idx)
+ z=LT.relax(blk, xin.clone(), xin, T1, eps)
+ lg=(z@blk.Wh)[0,L-1]/temp
+ v,_=torch.topk(lg,topk); lg[lg<v[-1]]=-float('inf')
+ p=F.softmax(lg,-1); nt=torch.multinomial(p,1).item()
+ idx[0,L]=nt; L+=1
+ return tok.decode(idx[0,:L].tolist())
+for seed in range(3, 11):
+ print(f"\n===== seed={seed} temp=0.7 =====", flush=True)
+ print(gen("Once upon a time,", 135, temp=0.7, topk=40, seed=seed), flush=True)