summaryrefslogtreecommitdiff
path: root/dataset/build_maze_dataset.py
diff options
context:
space:
mode:
authorOne <imone@tuta.io>2025-07-09 10:13:51 +0800
committerOne <imone@tuta.io>2025-07-09 10:13:51 +0800
commitbd6222774edcec1608a6842d0b06a637a4acef59 (patch)
tree3b95517044286d82a9166bcce3134bbea099fcfe /dataset/build_maze_dataset.py
parentcaa00bb77fbfce0cc14a45d5ec1c754394f96c1a (diff)
Release
Diffstat (limited to 'dataset/build_maze_dataset.py')
-rw-r--r--dataset/build_maze_dataset.py142
1 files changed, 142 insertions, 0 deletions
diff --git a/dataset/build_maze_dataset.py b/dataset/build_maze_dataset.py
new file mode 100644
index 0000000..e99baf2
--- /dev/null
+++ b/dataset/build_maze_dataset.py
@@ -0,0 +1,142 @@
+from typing import Optional
+import math
+import os
+import csv
+import json
+import numpy as np
+
+from argdantic import ArgParser
+from pydantic import BaseModel
+from tqdm import tqdm
+from huggingface_hub import hf_hub_download
+
+from common import PuzzleDatasetMetadata, dihedral_transform
+
+
+CHARSET = "# SGo"
+
+
+cli = ArgParser()
+
+
+class DataProcessConfig(BaseModel):
+ source_repo: str = "imone/small-sample-challenge-maze-30x30-hard"
+ output_dir: str = "data/maze-30x30-hard-1k"
+
+ subsample_size: Optional[int] = None
+ aug: bool = False
+
+
+def convert_subset(set_name: str, config: DataProcessConfig):
+ # Read CSV
+ all_chars = set()
+ grid_size = None
+ inputs = []
+ labels = []
+
+ with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile: # type: ignore
+ reader = csv.reader(csvfile)
+ next(reader) # Skip header
+ for source, q, a, rating in reader:
+ all_chars.update(q)
+ all_chars.update(a)
+
+ if grid_size is None:
+ n = int(len(q) ** 0.5)
+ grid_size = (n, n)
+
+ inputs.append(np.frombuffer(q.encode(), dtype=np.uint8).reshape(grid_size))
+ labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(grid_size))
+
+ # If subsample_size is specified for the training set,
+ # randomly sample the desired number of examples.
+ if set_name == "train" and config.subsample_size is not None:
+ total_samples = len(inputs)
+ if config.subsample_size < total_samples:
+ indices = np.random.choice(total_samples, size=config.subsample_size, replace=False)
+ inputs = [inputs[i] for i in indices]
+ labels = [labels[i] for i in indices]
+
+ # Generate dataset
+ results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
+ puzzle_id = 0
+ example_id = 0
+
+ results["puzzle_indices"].append(0)
+ results["group_indices"].append(0)
+
+ for inp, out in zip(tqdm(inputs), labels):
+ # Dihedral transformations for augmentation
+ for aug_idx in range(8 if (set_name == "train" and config.aug) else 1):
+ results["inputs"].append(dihedral_transform(inp, aug_idx))
+ results["labels"].append(dihedral_transform(out, aug_idx))
+ example_id += 1
+ puzzle_id += 1
+
+ results["puzzle_indices"].append(example_id)
+ results["puzzle_identifiers"].append(0)
+
+ # Push group
+ results["group_indices"].append(puzzle_id)
+
+ # Char mappings
+ assert len(all_chars - set(CHARSET)) == 0
+
+ char2id = np.zeros(256, np.uint8)
+ char2id[np.array(list(map(ord, CHARSET)))] = np.arange(len(CHARSET)) + 1
+
+ # To Numpy
+ def _seq_to_numpy(seq):
+ arr = np.vstack([char2id[s.reshape(-1)] for s in seq])
+
+ return arr
+
+ results = {
+ "inputs": _seq_to_numpy(results["inputs"]),
+ "labels": _seq_to_numpy(results["labels"]),
+
+ "group_indices": np.array(results["group_indices"], dtype=np.int32),
+ "puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32),
+ "puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32),
+ }
+
+ # Metadata
+ metadata = PuzzleDatasetMetadata(
+ seq_len=int(math.prod(grid_size)), # type: ignore
+ vocab_size=len(CHARSET) + 1, # PAD + Charset
+
+ pad_id=0,
+ ignore_label_id=0,
+
+ blank_identifier_id=0,
+ num_puzzle_identifiers=1,
+
+ total_groups=len(results["group_indices"]) - 1,
+ mean_puzzle_examples=1,
+ sets=["all"]
+ )
+
+ # Save metadata as JSON.
+ save_dir = os.path.join(config.output_dir, set_name)
+ os.makedirs(save_dir, exist_ok=True)
+
+ with open(os.path.join(save_dir, "dataset.json"), "w") as f:
+ json.dump(metadata.model_dump(), f)
+
+ # Save data
+ for k, v in results.items():
+ np.save(os.path.join(save_dir, f"all__{k}.npy"), v)
+
+ # Save IDs mapping (for visualization only)
+ with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
+ json.dump(["<blank>"], f)
+
+
+@cli.command(singleton=True)
+def preprocess_data(config: DataProcessConfig):
+ convert_subset("train", config)
+ convert_subset("test", config)
+
+
+if __name__ == "__main__":
+ cli()