summaryrefslogtreecommitdiff
path: root/ep_run/prepare_tinystories.py
diff options
context:
space:
mode:
Diffstat (limited to 'ep_run/prepare_tinystories.py')
-rw-r--r--ep_run/prepare_tinystories.py40
1 files changed, 40 insertions, 0 deletions
diff --git a/ep_run/prepare_tinystories.py b/ep_run/prepare_tinystories.py
new file mode 100644
index 0000000..d7305a3
--- /dev/null
+++ b/ep_run/prepare_tinystories.py
@@ -0,0 +1,40 @@
+"""Char-level TinyStories -> train.bin/val.bin (uint16) + meta.pkl, same format as
+shakespeare_char so lt_ep_train.py consumes it via --data. Top-127 chars by train-set
+frequency; everything else maps to '?' (keeps the vocab clean of rare unicode)."""
+import collections, pickle
+import numpy as np
+from pathlib import Path
+
+D = Path('/tmp/lt_ep/data/tinystories')
+cnt = collections.Counter()
+with open(D / 'train.txt', encoding='utf-8', errors='replace') as f:
+ while True:
+ chunk = f.read(1 << 24)
+ if not chunk:
+ break
+ cnt.update(chunk)
+keep = sorted(c for c, _ in cnt.most_common(127))
+stoi = {c: i for i, c in enumerate(keep)}
+UNK = stoi.get('?', 0)
+table = {ord(c): i for c, i in stoi.items()}
+
+
+def enc_file(src, dst):
+ out = open(dst, 'wb')
+ n = 0
+ with open(src, encoding='utf-8', errors='replace') as f:
+ while True:
+ chunk = f.read(1 << 24)
+ if not chunk:
+ break
+ arr = np.fromiter((table.get(ord(c), UNK) for c in chunk), dtype=np.uint16, count=len(chunk))
+ arr.tofile(out)
+ n += len(arr)
+ out.close()
+ return n
+
+
+nt = enc_file(D / 'train.txt', D / 'train.bin')
+nv = enc_file(D / 'valid.txt', D / 'val.bin')
+pickle.dump({'vocab_size': len(stoi), 'stoi': stoi}, open(D / 'meta.pkl', 'wb'))
+print(f"vocab={len(stoi)} train_tokens={nt / 1e6:.1f}M val_tokens={nv / 1e6:.1f}M", flush=True)