summaryrefslogtreecommitdiff
path: root/puzzle_dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'puzzle_dataset.py')
-rw-r--r--puzzle_dataset.py199
1 files changed, 199 insertions, 0 deletions
diff --git a/puzzle_dataset.py b/puzzle_dataset.py
new file mode 100644
index 0000000..2782403
--- /dev/null
+++ b/puzzle_dataset.py
@@ -0,0 +1,199 @@
+import os
+import json
+
+import numpy as np
+import pydantic
+
+import torch
+from torch.utils.data import IterableDataset, get_worker_info
+
+from models.losses import IGNORE_LABEL_ID
+from dataset.common import PuzzleDatasetMetadata
+
+
+def _sample_batch(rng: np.random.Generator, group_order: np.ndarray, puzzle_indices: np.ndarray, group_indices: np.ndarray, start_index: int, global_batch_size: int):
+ # Pack examples into a full batch
+ batch = []
+ batch_puzzle_indices = []
+ current_size = 0
+
+ while (start_index < group_order.size) and (current_size < global_batch_size):
+ # Pick a group and a puzzle from that group
+ group_id = group_order[start_index]
+ puzzle_id = rng.integers(group_indices[group_id], group_indices[group_id + 1])
+ start_index += 1
+
+ # Get range of the puzzle
+ puzzle_start = puzzle_indices[puzzle_id]
+ puzzle_size = int(puzzle_indices[puzzle_id + 1] - puzzle_start)
+
+ append_size = min(puzzle_size, global_batch_size - current_size)
+
+ # Put into batch
+ batch_puzzle_indices.append(np.full(append_size, puzzle_id, dtype=np.int32))
+ batch.append(puzzle_start + np.random.choice(puzzle_size, append_size, replace=False))
+
+ current_size += append_size
+
+ return start_index, np.concatenate(batch), np.concatenate(batch_puzzle_indices)
+
+
+class PuzzleDatasetConfig(pydantic.BaseModel):
+ seed: int
+ dataset_path: str
+ global_batch_size: int
+ test_set_mode: bool
+
+ epochs_per_iter: int # Batch X epochs in an iteration to reduce overhead.
+
+ rank: int
+ num_replicas: int
+
+
+class PuzzleDataset(IterableDataset):
+ def __init__(self, config: PuzzleDatasetConfig, split: str = "train"):
+ super().__init__()
+ self.config = config
+ self.split = split
+ self.metadata = self._load_metadata()
+
+ # Checks
+ assert self.config.global_batch_size % self.config.num_replicas == 0, f"Global batch size {self.config.global_batch_size} must be multiples of nodes {self.config.num_replicas}."
+ self.local_batch_size = self.config.global_batch_size // self.config.num_replicas
+
+ # State
+ self._data = None
+ self._iters = 0
+
+ def _load_metadata(self) -> PuzzleDatasetMetadata:
+ with open(os.path.join(self.config.dataset_path, self.split, "dataset.json"), "r") as f:
+ return PuzzleDatasetMetadata(**json.load(f))
+
+ def _lazy_load_dataset(self):
+ if self._data is not None:
+ return
+
+ field_mmap_modes = {
+ "inputs": "r",
+ "labels": "r",
+
+ # Keep indices in memory
+ "puzzle_identifiers": None,
+ "puzzle_indices": None,
+ "group_indices": None
+ }
+
+ # Load data
+ self._data = {}
+ for set_name in self.metadata.sets:
+ # Load subset
+ self._data[set_name] = {
+ field_name: np.load(os.path.join(self.config.dataset_path, self.split, f"{set_name}__{field_name}.npy"), mmap_mode=mmap_mode)
+ for field_name, mmap_mode in field_mmap_modes.items()
+ }
+
+ def _collate_batch(self, batch):
+ # Convert dtype
+ batch = {k: v.astype(np.int32) for k, v in batch.items()}
+
+ # Convert ignore label IDs
+ if self.metadata.ignore_label_id is not None:
+ batch["labels"][batch["labels"] == self.metadata.ignore_label_id] = IGNORE_LABEL_ID
+
+ # Pad
+ if batch["puzzle_identifiers"].size < self.local_batch_size:
+ pad_size = self.local_batch_size - batch["puzzle_identifiers"].size
+
+ pad_values = {
+ "inputs": self.metadata.pad_id,
+ "labels": IGNORE_LABEL_ID,
+
+ "puzzle_identifiers": self.metadata.blank_identifier_id
+ }
+ batch = {k: np.pad(v, ((0, pad_size), ) + ((0, 0), ) * (v.ndim - 1), constant_values=pad_values[k]) for k, v in batch.items()}
+
+ # To tensor
+ return {k: torch.from_numpy(v) for k, v in batch.items()}
+
+ def _iter_test(self):
+ for set_name, dataset in self._data.items(): # type: ignore
+ total_examples = len(dataset["inputs"])
+
+ # Load examples one by one
+ start_index = 0
+ while start_index < total_examples:
+ # Compute indices
+ end_index = min(total_examples, start_index + self.config.global_batch_size)
+
+ local_start = start_index + self.config.rank * self.local_batch_size
+ local_end = min(start_index + (self.config.rank + 1) * self.local_batch_size, end_index)
+
+ # Get batch of examples, and also puzzle IDs
+ puzzle_indices = []
+ puzzle_index = np.searchsorted(dataset["puzzle_indices"], local_start, side="right") - 1
+ for i in range(local_start, local_end):
+ while puzzle_index + 1 < len(dataset["puzzle_indices"]) and i >= dataset["puzzle_indices"][puzzle_index + 1]:
+ puzzle_index += 1
+
+ puzzle_indices.append(puzzle_index)
+
+ batch = self._collate_batch({
+ "inputs": dataset["inputs"][local_start: local_end],
+ "labels": dataset["labels"][local_start: local_end],
+ "puzzle_identifiers": dataset["puzzle_identifiers"][puzzle_indices]
+ })
+
+ yield set_name, batch, end_index - start_index
+
+ # Advance to next batch
+ start_index += self.config.global_batch_size
+
+ def _iter_train(self):
+ for set_name, dataset in self._data.items(): # type: ignore
+ # Increase epoch count
+ self._iters += 1
+
+ # Randomly shuffle groups
+ rng = np.random.Generator(np.random.Philox(seed=self.config.seed + self._iters))
+
+ group_order = np.concatenate([rng.permutation(dataset["group_indices"].size - 1) for _i in range(self.config.epochs_per_iter)])
+ start_index = 0
+
+ while start_index < group_order.size:
+ start_index, batch_indices, batch_puzzle_indices = _sample_batch(
+ rng,
+ group_order=group_order,
+ puzzle_indices=dataset["puzzle_indices"],
+ group_indices=dataset["group_indices"],
+ start_index=start_index,
+ global_batch_size=self.config.global_batch_size,
+ )
+
+ # Select current rank and collate
+ global_effective_batch_size = batch_puzzle_indices.size # Global effective batch size, excluding pads
+
+ # Drop last batch
+ if global_effective_batch_size < self.config.global_batch_size:
+ break
+
+ batch_indices = batch_indices [self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size]
+ batch_puzzle_indices = batch_puzzle_indices[self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size]
+ batch = self._collate_batch({
+ "inputs": dataset["inputs"][batch_indices],
+ "labels": dataset["labels"][batch_indices],
+ "puzzle_identifiers": dataset["puzzle_identifiers"][batch_puzzle_indices]
+ })
+
+ yield set_name, batch, global_effective_batch_size
+
+ def __iter__(self):
+ worker_info = get_worker_info()
+ assert worker_info is None or worker_info.num_workers == 1, "Multithreaded data loading is not currently supported."
+
+ self._lazy_load_dataset()
+
+ # Iterate using specified mode
+ if self.config.test_set_mode:
+ yield from self._iter_test()
+ else:
+ yield from self._iter_train()