summaryrefslogtreecommitdiff
path: root/ep_run/prepare_tinystories_bpe.py
diff options
context:
space:
mode:
Diffstat (limited to 'ep_run/prepare_tinystories_bpe.py')
-rw-r--r--ep_run/prepare_tinystories_bpe.py49
1 files changed, 49 insertions, 0 deletions
diff --git a/ep_run/prepare_tinystories_bpe.py b/ep_run/prepare_tinystories_bpe.py
new file mode 100644
index 0000000..9b03a83
--- /dev/null
+++ b/ep_run/prepare_tinystories_bpe.py
@@ -0,0 +1,49 @@
+"""TinyStories -> 4k BPE -> train.bin/val.bin (uint16) + meta.pkl + tokenizer.json.
+Same bin format as the char pipeline so lt_ep_train consumes it via --data."""
+import pickle
+import numpy as np
+from pathlib import Path
+from tokenizers import Tokenizer
+from tokenizers.models import BPE
+from tokenizers.trainers import BpeTrainer
+from tokenizers.pre_tokenizers import ByteLevel
+from tokenizers.decoders import ByteLevel as ByteLevelDec
+
+SRC = Path('/home/yurenh2/ept/ep_run/data/tsrc')
+D = Path('/home/yurenh2/ept/ep_run/data/tinystories_bpe')
+D.mkdir(parents=True, exist_ok=True)
+VOCAB = 4096
+
+tok = Tokenizer(BPE(unk_token=None))
+tok.pre_tokenizer = ByteLevel(add_prefix_space=False)
+tok.decoder = ByteLevelDec()
+trainer = BpeTrainer(vocab_size=VOCAB, special_tokens=[], show_progress=True)
+tok.train([str(SRC / 'train.txt')], trainer)
+tok.save(str(D / 'tokenizer.json'))
+print(f"trained BPE vocab={tok.get_vocab_size()}", flush=True)
+
+
+def enc_file(src, dst):
+ out = open(dst, 'wb')
+ n = 0
+ buf = []
+ with open(src, encoding='utf-8', errors='replace') as f:
+ for line in f:
+ buf.append(line)
+ if len(buf) >= 20000:
+ ids = [i for e in tok.encode_batch([''.join(buf)]) for i in e.ids]
+ np.array(ids, dtype=np.uint16).tofile(out)
+ n += len(ids)
+ buf = []
+ if buf:
+ ids = [i for e in tok.encode_batch([''.join(buf)]) for i in e.ids]
+ np.array(ids, dtype=np.uint16).tofile(out)
+ n += len(ids)
+ out.close()
+ return n
+
+
+nt = enc_file(SRC / 'train.txt', D / 'train.bin')
+nv = enc_file(SRC / 'valid.txt', D / 'val.bin')
+pickle.dump({'vocab_size': tok.get_vocab_size()}, open(D / 'meta.pkl', 'wb'))
+print(f"vocab={tok.get_vocab_size()} train_tokens={nt/1e6:.1f}M val_tokens={nv/1e6:.1f}M", flush=True)