summaryrefslogtreecommitdiff
path: root/ep_run/sample_eq.py
blob: 9a6552acf7a2a28acdbbc4b55b075ddc9453979d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""Autoregressive sampling from an equilibrium LM checkpoint: causal coupling means the prefix
settles independently of padding, so we relax the padded sequence and read the last position's
logits each step (simple full-re-settle variant; incremental KV-style settling is the optimized
version, not needed at this scale)."""
import argparse, pickle, torch
import lt_ep_train as M
from pathlib import Path

ap = argparse.ArgumentParser()
ap.add_argument('--ckpt', required=True)
ap.add_argument('--data', default='/tmp/lt_ep/data/tinystories')
ap.add_argument('--C', type=int, default=256)
ap.add_argument('--H', type=int, default=8)
ap.add_argument('--T', type=int, default=256)
ap.add_argument('--T1', type=int, default=150)
ap.add_argument('--eps', type=float, default=0.1)
ap.add_argument('--temp', type=float, default=0.8)
ap.add_argument('--topk', type=int, default=40)
ap.add_argument('--c', type=float, default=1.0)
ap.add_argument('--qknorm', action='store_true')
ap.add_argument('--n', type=int, default=3)
ap.add_argument('--prompt', default='Once upon a time')
ap.add_argument('--use_pema', action='store_true')
cfg = ap.parse_args()

M.DD = Path(cfg.data)
meta = pickle.load(open(M.DD / 'meta.pkl', 'rb'))
M.vocab = meta['vocab_size']
_tokj = M.DD / 'tokenizer.json'
if _tokj.exists():                                   # BPE: decode via tokenizers
    from tokenizers import Tokenizer
    _tk = Tokenizer.from_file(str(_tokj))
    def encode(s): return _tk.encode(s).ids
    def decode(ids): return _tk.decode(ids)
else:                                                # char: decode via stoi/itos
    stoi = meta['stoi']; itos = {i: c for c, i in stoi.items()}
    def encode(s): return [stoi.get(c, 0) for c in s]
    def decode(ids): return ''.join(itos.get(int(i), '?') for i in ids)
from lt_ep_train import EQBlock, relax

dev = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(1234)
blk = EQBlock(cfg.C, cfg.H, 256, cfg.T, attn_mode='thick')
blk.c = cfg.c                                        # MUST match training c (force identity)
blk.qknorm = cfg.qknorm                              # MUST match training qknorm (else wrong fixed point)
ck = torch.load(cfg.ckpt)
src = ck['pema'] if (cfg.use_pema and ck.get('pema') is not None) else ck['allp']
with torch.no_grad():
    for p, w in zip(blk.allp, src):
        p.copy_(w.to(dev))
print(f"loaded {cfg.ckpt} (step {ck.get('step')}, best {ck.get('best'):.4f}, "
      f"{'pema' if cfg.use_pema else 'raw'} weights)\n", flush=True)

for s in range(cfg.n):
    ids = encode(cfg.prompt)
    idx = torch.zeros(1, cfg.T, dtype=torch.long, device=dev)
    idx[0, :len(ids)] = torch.tensor(ids, device=dev)
    pos = len(ids)
    while pos < cfg.T:
        xin = blk.embed(idx).detach()
        z = relax(blk, xin.clone(), xin, cfg.T1, cfg.eps)
        logits = (z[0, pos - 1] @ blk.Wh) / cfg.temp
        if cfg.topk > 0:
            kth = torch.topk(logits, cfg.topk).values[-1]
            logits[logits < kth] = float('-inf')
        nxt = torch.multinomial(torch.softmax(logits, -1), 1).item()
        idx[0, pos] = nxt
        pos += 1
    text = decode(idx[0, :pos].tolist())
    print(f"--- sample {s+1} ---\n{text}\n", flush=True)