1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
|
"""Autoregressive sampling from an equilibrium LM checkpoint: causal coupling means the prefix
settles independently of padding, so we relax the padded sequence and read the last position's
logits each step (simple full-re-settle variant; incremental KV-style settling is the optimized
version, not needed at this scale)."""
import argparse, pickle, torch
import lt_ep_train as M
from pathlib import Path
ap = argparse.ArgumentParser()
ap.add_argument('--ckpt', required=True)
ap.add_argument('--data', default='/tmp/lt_ep/data/tinystories')
ap.add_argument('--C', type=int, default=256)
ap.add_argument('--H', type=int, default=8)
ap.add_argument('--T', type=int, default=256)
ap.add_argument('--T1', type=int, default=150)
ap.add_argument('--eps', type=float, default=0.1)
ap.add_argument('--temp', type=float, default=0.8)
ap.add_argument('--topk', type=int, default=40)
ap.add_argument('--c', type=float, default=1.0)
ap.add_argument('--qknorm', action='store_true')
ap.add_argument('--n', type=int, default=3)
ap.add_argument('--prompt', default='Once upon a time')
ap.add_argument('--use_pema', action='store_true')
cfg = ap.parse_args()
M.DD = Path(cfg.data)
meta = pickle.load(open(M.DD / 'meta.pkl', 'rb'))
M.vocab = meta['vocab_size']
_tokj = M.DD / 'tokenizer.json'
if _tokj.exists(): # BPE: decode via tokenizers
from tokenizers import Tokenizer
_tk = Tokenizer.from_file(str(_tokj))
def encode(s): return _tk.encode(s).ids
def decode(ids): return _tk.decode(ids)
else: # char: decode via stoi/itos
stoi = meta['stoi']; itos = {i: c for c, i in stoi.items()}
def encode(s): return [stoi.get(c, 0) for c in s]
def decode(ids): return ''.join(itos.get(int(i), '?') for i in ids)
from lt_ep_train import EQBlock, relax
dev = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(1234)
blk = EQBlock(cfg.C, cfg.H, 256, cfg.T, attn_mode='thick')
blk.c = cfg.c # MUST match training c (force identity)
blk.qknorm = cfg.qknorm # MUST match training qknorm (else wrong fixed point)
ck = torch.load(cfg.ckpt)
src = ck['pema'] if (cfg.use_pema and ck.get('pema') is not None) else ck['allp']
with torch.no_grad():
for p, w in zip(blk.allp, src):
p.copy_(w.to(dev))
print(f"loaded {cfg.ckpt} (step {ck.get('step')}, best {ck.get('best'):.4f}, "
f"{'pema' if cfg.use_pema else 'raw'} weights)\n", flush=True)
for s in range(cfg.n):
ids = encode(cfg.prompt)
idx = torch.zeros(1, cfg.T, dtype=torch.long, device=dev)
idx[0, :len(ids)] = torch.tensor(ids, device=dev)
pos = len(ids)
while pos < cfg.T:
xin = blk.embed(idx).detach()
z = relax(blk, xin.clone(), xin, cfg.T1, cfg.eps)
logits = (z[0, pos - 1] @ blk.Wh) / cfg.temp
if cfg.topk > 0:
kth = torch.topk(logits, cfg.topk).values[-1]
logits[logits < kth] = float('-inf')
nxt = torch.multinomial(torch.softmax(logits, -1), 1).item()
idx[0, pos] = nxt
pos += 1
text = decode(idx[0, :pos].tolist())
print(f"--- sample {s+1} ---\n{text}\n", flush=True)
|