diff options
| author | One <imone@tuta.io> | 2025-07-09 10:13:51 +0800 |
|---|---|---|
| committer | One <imone@tuta.io> | 2025-07-09 10:13:51 +0800 |
| commit | bd6222774edcec1608a6842d0b06a637a4acef59 (patch) | |
| tree | 3b95517044286d82a9166bcce3134bbea099fcfe /dataset/build_sudoku_dataset.py | |
| parent | caa00bb77fbfce0cc14a45d5ec1c754394f96c1a (diff) | |
Release
Diffstat (limited to 'dataset/build_sudoku_dataset.py')
| -rw-r--r-- | dataset/build_sudoku_dataset.py | 169 |
1 files changed, 169 insertions, 0 deletions
diff --git a/dataset/build_sudoku_dataset.py b/dataset/build_sudoku_dataset.py new file mode 100644 index 0000000..5d5b50c --- /dev/null +++ b/dataset/build_sudoku_dataset.py @@ -0,0 +1,169 @@ +from typing import Optional +import os +import csv +import json +import numpy as np + +from argdantic import ArgParser +from pydantic import BaseModel +from tqdm import tqdm +from huggingface_hub import hf_hub_download + +from common import PuzzleDatasetMetadata + + +cli = ArgParser() + + +class DataProcessConfig(BaseModel): + source_repo: str = "imone/sudoku-hard-v2" + output_dir: str = "data/sudoku-extreme-full" + + subsample_size: Optional[int] = None + min_difficulty: Optional[int] = None + num_aug: int = 0 + + +def shuffle_sudoku(board: np.ndarray, solution: np.ndarray): + # Create a random digit mapping: a permutation of 1..9, with zero (blank) unchanged + digit_map = np.pad(np.random.permutation(np.arange(1, 10)), (1, 0)) + + # Randomly decide whether to transpose. + transpose_flag = np.random.rand() < 0.5 + + # Generate a valid row permutation: + # - Shuffle the 3 bands (each band = 3 rows) and for each band, shuffle its 3 rows. + bands = np.random.permutation(3) + row_perm = np.concatenate([b * 3 + np.random.permutation(3) for b in bands]) + + # Similarly for columns (stacks). + stacks = np.random.permutation(3) + col_perm = np.concatenate([s * 3 + np.random.permutation(3) for s in stacks]) + + # Build an 81->81 mapping. For each new cell at (i, j) + # (row index = i // 9, col index = i % 9), + # its value comes from old row = row_perm[i//9] and old col = col_perm[i%9]. + mapping = np.array([row_perm[i // 9] * 9 + col_perm[i % 9] for i in range(81)]) + + def apply_transformation(x: np.ndarray) -> np.ndarray: + # Apply transpose flag + if transpose_flag: + x = x.T + # Apply the position mapping. + new_board = x.flatten()[mapping].reshape(9, 9).copy() + # Apply digit mapping + return digit_map[new_board] + + return apply_transformation(board), apply_transformation(solution) + + +def convert_subset(set_name: str, config: DataProcessConfig): + # Read CSV + inputs = [] + labels = [] + + with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile: + reader = csv.reader(csvfile) + next(reader) # Skip header + for source, q, a, rating in reader: + if (config.min_difficulty is None) or (int(rating) >= config.min_difficulty): + assert len(q) == 81 and len(a) == 81 + + inputs.append(np.frombuffer(q.replace('.', '0').encode(), dtype=np.uint8).reshape(9, 9) - ord('0')) + labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(9, 9) - ord('0')) + + # If subsample_size is specified for the training set, + # randomly sample the desired number of examples. + if set_name == "train" and config.subsample_size is not None: + total_samples = len(inputs) + if config.subsample_size < total_samples: + indices = np.random.choice(total_samples, size=config.subsample_size, replace=False) + inputs = [inputs[i] for i in indices] + labels = [labels[i] for i in indices] + + # Generate dataset + num_augments = config.num_aug if set_name == "train" else 0 + + results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]} + puzzle_id = 0 + example_id = 0 + + results["puzzle_indices"].append(0) + results["group_indices"].append(0) + + for orig_inp, orig_out in zip(tqdm(inputs), labels): + for aug_idx in range(1 + num_augments): + # First index is not augmented + if aug_idx == 0: + inp, out = orig_inp, orig_out + else: + inp, out = shuffle_sudoku(orig_inp, orig_out) + + # Push puzzle (only single example) + results["inputs"].append(inp) + results["labels"].append(out) + example_id += 1 + puzzle_id += 1 + + results["puzzle_indices"].append(example_id) + results["puzzle_identifiers"].append(0) + + # Push group + results["group_indices"].append(puzzle_id) + + # To Numpy + def _seq_to_numpy(seq): + arr = np.concatenate(seq).reshape(len(seq), -1) + + assert np.all((arr >= 0) & (arr <= 9)) + return arr + 1 + + results = { + "inputs": _seq_to_numpy(results["inputs"]), + "labels": _seq_to_numpy(results["labels"]), + + "group_indices": np.array(results["group_indices"], dtype=np.int32), + "puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32), + "puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32), + } + + # Metadata + metadata = PuzzleDatasetMetadata( + seq_len=81, + vocab_size=10 + 1, # PAD + "0" ... "9" + + pad_id=0, + ignore_label_id=0, + + blank_identifier_id=0, + num_puzzle_identifiers=1, + + total_groups=len(results["group_indices"]) - 1, + mean_puzzle_examples=1, + sets=["all"] + ) + + # Save metadata as JSON. + save_dir = os.path.join(config.output_dir, set_name) + os.makedirs(save_dir, exist_ok=True) + + with open(os.path.join(save_dir, "dataset.json"), "w") as f: + json.dump(metadata.model_dump(), f) + + # Save data + for k, v in results.items(): + np.save(os.path.join(save_dir, f"all__{k}.npy"), v) + + # Save IDs mapping (for visualization only) + with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f: + json.dump(["<blank>"], f) + + +@cli.command(singleton=True) +def preprocess_data(config: DataProcessConfig): + convert_subset("train", config) + convert_subset("test", config) + + +if __name__ == "__main__": + cli() |
