summaryrefslogtreecommitdiff
path: root/ep_run/gen_ept.py
blob: 2a59ff500dbefaeec4f21bbbee2ed3b971c7b12b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch, math, torch.nn.functional as F
from pathlib import Path
from tokenizers import Tokenizer
import lt_ep_train as LT
dev='cuda'
DD=Path('/home/yurenh2/ept/ep_run/data/tinystories_bpe')
tok=Tokenizer.from_file(str(DD/'tokenizer.json'))
C,H,Mm,T,c,T1,eps = 512,16,256,256,1.0,150,0.1
blk=LT.EQBlock(C,H,Mm,T,c=c,attn_mode='thick')
ck=torch.load('/home/yurenh2/ept/ep_run/runs/ep_resreg_warm.pt',map_location=dev)
params = ck['pema'] if ck.get('pema') is not None else ck['allp']
with torch.no_grad():
    for p,s in zip(blk.allp, params): p.copy_(s.to(dev))
print(f"loaded resreg_warm pema | best val CE {ck['best']:.4f} | step {ck['step']}", flush=True)
@torch.no_grad()
def gen(prompt, n_new=150, temp=0.8, topk=40, seed=0):
    torch.manual_seed(seed)
    ids=tok.encode(prompt).ids
    idx=torch.zeros(1,T,dtype=torch.long,device=dev); L=len(ids)
    idx[0,:L]=torch.tensor(ids,device=dev)
    for _ in range(n_new):
        if L>=T: break
        xin=blk.embed(idx)
        z=LT.relax(blk, xin.clone(), xin, T1, eps)
        lg=(z@blk.Wh)[0,L-1]/temp
        v,_=torch.topk(lg,topk); lg[lg<v[-1]]=-float('inf')
        p=F.softmax(lg,-1); nt=torch.multinomial(p,1).item()
        idx[0,L]=nt; L+=1
    return tok.decode(idx[0,:L].tolist())
for seed in range(3, 11):
    print(f"\n===== seed={seed} temp=0.7 =====", flush=True)
    print(gen("Once upon a time,", 135, temp=0.7, topk=40, seed=seed), flush=True)