summaryrefslogtreecommitdiff
path: root/ep_run/sample_eq.py
diff options
context:
space:
mode:
Diffstat (limited to 'ep_run/sample_eq.py')
-rw-r--r--ep_run/sample_eq.py70
1 files changed, 70 insertions, 0 deletions
diff --git a/ep_run/sample_eq.py b/ep_run/sample_eq.py
new file mode 100644
index 0000000..9a6552a
--- /dev/null
+++ b/ep_run/sample_eq.py
@@ -0,0 +1,70 @@
+"""Autoregressive sampling from an equilibrium LM checkpoint: causal coupling means the prefix
+settles independently of padding, so we relax the padded sequence and read the last position's
+logits each step (simple full-re-settle variant; incremental KV-style settling is the optimized
+version, not needed at this scale)."""
+import argparse, pickle, torch
+import lt_ep_train as M
+from pathlib import Path
+
+ap = argparse.ArgumentParser()
+ap.add_argument('--ckpt', required=True)
+ap.add_argument('--data', default='/tmp/lt_ep/data/tinystories')
+ap.add_argument('--C', type=int, default=256)
+ap.add_argument('--H', type=int, default=8)
+ap.add_argument('--T', type=int, default=256)
+ap.add_argument('--T1', type=int, default=150)
+ap.add_argument('--eps', type=float, default=0.1)
+ap.add_argument('--temp', type=float, default=0.8)
+ap.add_argument('--topk', type=int, default=40)
+ap.add_argument('--c', type=float, default=1.0)
+ap.add_argument('--qknorm', action='store_true')
+ap.add_argument('--n', type=int, default=3)
+ap.add_argument('--prompt', default='Once upon a time')
+ap.add_argument('--use_pema', action='store_true')
+cfg = ap.parse_args()
+
+M.DD = Path(cfg.data)
+meta = pickle.load(open(M.DD / 'meta.pkl', 'rb'))
+M.vocab = meta['vocab_size']
+_tokj = M.DD / 'tokenizer.json'
+if _tokj.exists(): # BPE: decode via tokenizers
+ from tokenizers import Tokenizer
+ _tk = Tokenizer.from_file(str(_tokj))
+ def encode(s): return _tk.encode(s).ids
+ def decode(ids): return _tk.decode(ids)
+else: # char: decode via stoi/itos
+ stoi = meta['stoi']; itos = {i: c for c, i in stoi.items()}
+ def encode(s): return [stoi.get(c, 0) for c in s]
+ def decode(ids): return ''.join(itos.get(int(i), '?') for i in ids)
+from lt_ep_train import EQBlock, relax
+
+dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+torch.manual_seed(1234)
+blk = EQBlock(cfg.C, cfg.H, 256, cfg.T, attn_mode='thick')
+blk.c = cfg.c # MUST match training c (force identity)
+blk.qknorm = cfg.qknorm # MUST match training qknorm (else wrong fixed point)
+ck = torch.load(cfg.ckpt)
+src = ck['pema'] if (cfg.use_pema and ck.get('pema') is not None) else ck['allp']
+with torch.no_grad():
+ for p, w in zip(blk.allp, src):
+ p.copy_(w.to(dev))
+print(f"loaded {cfg.ckpt} (step {ck.get('step')}, best {ck.get('best'):.4f}, "
+ f"{'pema' if cfg.use_pema else 'raw'} weights)\n", flush=True)
+
+for s in range(cfg.n):
+ ids = encode(cfg.prompt)
+ idx = torch.zeros(1, cfg.T, dtype=torch.long, device=dev)
+ idx[0, :len(ids)] = torch.tensor(ids, device=dev)
+ pos = len(ids)
+ while pos < cfg.T:
+ xin = blk.embed(idx).detach()
+ z = relax(blk, xin.clone(), xin, cfg.T1, cfg.eps)
+ logits = (z[0, pos - 1] @ blk.Wh) / cfg.temp
+ if cfg.topk > 0:
+ kth = torch.topk(logits, cfg.topk).values[-1]
+ logits[logits < kth] = float('-inf')
+ nxt = torch.multinomial(torch.softmax(logits, -1), 1).item()
+ idx[0, pos] = nxt
+ pos += 1
+ text = decode(idx[0, :pos].tolist())
+ print(f"--- sample {s+1} ---\n{text}\n", flush=True)