summaryrefslogtreecommitdiff
path: root/ep_run/prepare_tinystories.py
blob: d7305a3c4502facba668f5506ceb68f9d8863858 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""Char-level TinyStories -> train.bin/val.bin (uint16) + meta.pkl, same format as
shakespeare_char so lt_ep_train.py consumes it via --data. Top-127 chars by train-set
frequency; everything else maps to '?' (keeps the vocab clean of rare unicode)."""
import collections, pickle
import numpy as np
from pathlib import Path

D = Path('/tmp/lt_ep/data/tinystories')
cnt = collections.Counter()
with open(D / 'train.txt', encoding='utf-8', errors='replace') as f:
    while True:
        chunk = f.read(1 << 24)
        if not chunk:
            break
        cnt.update(chunk)
keep = sorted(c for c, _ in cnt.most_common(127))
stoi = {c: i for i, c in enumerate(keep)}
UNK = stoi.get('?', 0)
table = {ord(c): i for c, i in stoi.items()}


def enc_file(src, dst):
    out = open(dst, 'wb')
    n = 0
    with open(src, encoding='utf-8', errors='replace') as f:
        while True:
            chunk = f.read(1 << 24)
            if not chunk:
                break
            arr = np.fromiter((table.get(ord(c), UNK) for c in chunk), dtype=np.uint16, count=len(chunk))
            arr.tofile(out)
            n += len(arr)
    out.close()
    return n


nt = enc_file(D / 'train.txt', D / 'train.bin')
nv = enc_file(D / 'valid.txt', D / 'val.bin')
pickle.dump({'vocab_size': len(stoi), 'stoi': stoi}, open(D / 'meta.pkl', 'wb'))
print(f"vocab={len(stoi)} train_tokens={nt / 1e6:.1f}M val_tokens={nv / 1e6:.1f}M", flush=True)