"""Autoregressive sampling from an equilibrium LM checkpoint: causal coupling means the prefix settles independently of padding, so we relax the padded sequence and read the last position's logits each step (simple full-re-settle variant; incremental KV-style settling is the optimized version, not needed at this scale).""" import argparse, pickle, torch import lt_ep_train as M from pathlib import Path ap = argparse.ArgumentParser() ap.add_argument('--ckpt', required=True) ap.add_argument('--data', default='/tmp/lt_ep/data/tinystories') ap.add_argument('--C', type=int, default=256) ap.add_argument('--H', type=int, default=8) ap.add_argument('--T', type=int, default=256) ap.add_argument('--T1', type=int, default=150) ap.add_argument('--eps', type=float, default=0.1) ap.add_argument('--temp', type=float, default=0.8) ap.add_argument('--topk', type=int, default=40) ap.add_argument('--c', type=float, default=1.0) ap.add_argument('--qknorm', action='store_true') ap.add_argument('--n', type=int, default=3) ap.add_argument('--prompt', default='Once upon a time') ap.add_argument('--use_pema', action='store_true') cfg = ap.parse_args() M.DD = Path(cfg.data) meta = pickle.load(open(M.DD / 'meta.pkl', 'rb')) M.vocab = meta['vocab_size'] _tokj = M.DD / 'tokenizer.json' if _tokj.exists(): # BPE: decode via tokenizers from tokenizers import Tokenizer _tk = Tokenizer.from_file(str(_tokj)) def encode(s): return _tk.encode(s).ids def decode(ids): return _tk.decode(ids) else: # char: decode via stoi/itos stoi = meta['stoi']; itos = {i: c for c, i in stoi.items()} def encode(s): return [stoi.get(c, 0) for c in s] def decode(ids): return ''.join(itos.get(int(i), '?') for i in ids) from lt_ep_train import EQBlock, relax dev = 'cuda' if torch.cuda.is_available() else 'cpu' torch.manual_seed(1234) blk = EQBlock(cfg.C, cfg.H, 256, cfg.T, attn_mode='thick') blk.c = cfg.c # MUST match training c (force identity) blk.qknorm = cfg.qknorm # MUST match training qknorm (else wrong fixed point) ck = torch.load(cfg.ckpt) src = ck['pema'] if (cfg.use_pema and ck.get('pema') is not None) else ck['allp'] with torch.no_grad(): for p, w in zip(blk.allp, src): p.copy_(w.to(dev)) print(f"loaded {cfg.ckpt} (step {ck.get('step')}, best {ck.get('best'):.4f}, " f"{'pema' if cfg.use_pema else 'raw'} weights)\n", flush=True) for s in range(cfg.n): ids = encode(cfg.prompt) idx = torch.zeros(1, cfg.T, dtype=torch.long, device=dev) idx[0, :len(ids)] = torch.tensor(ids, device=dev) pos = len(ids) while pos < cfg.T: xin = blk.embed(idx).detach() z = relax(blk, xin.clone(), xin, cfg.T1, cfg.eps) logits = (z[0, pos - 1] @ blk.Wh) / cfg.temp if cfg.topk > 0: kth = torch.topk(logits, cfg.topk).values[-1] logits[logits < kth] = float('-inf') nxt = torch.multinomial(torch.softmax(logits, -1), 1).item() idx[0, pos] = nxt pos += 1 text = decode(idx[0, :pos].tolist()) print(f"--- sample {s+1} ---\n{text}\n", flush=True)