summaryrefslogtreecommitdiff
path: root/dataset/build_arc_dataset.py
diff options
context:
space:
mode:
authorOne <imone@tuta.io>2025-07-09 10:13:51 +0800
committerOne <imone@tuta.io>2025-07-09 10:13:51 +0800
commitbd6222774edcec1608a6842d0b06a637a4acef59 (patch)
tree3b95517044286d82a9166bcce3134bbea099fcfe /dataset/build_arc_dataset.py
parentcaa00bb77fbfce0cc14a45d5ec1c754394f96c1a (diff)
Release
Diffstat (limited to 'dataset/build_arc_dataset.py')
-rw-r--r--dataset/build_arc_dataset.py291
1 files changed, 291 insertions, 0 deletions
diff --git a/dataset/build_arc_dataset.py b/dataset/build_arc_dataset.py
new file mode 100644
index 0000000..2da5703
--- /dev/null
+++ b/dataset/build_arc_dataset.py
@@ -0,0 +1,291 @@
+from typing import List, Optional, Tuple, Dict
+from dataclasses import dataclass
+from pathlib import Path
+import os
+import json
+import hashlib
+import numpy as np
+from glob import glob
+
+from argdantic import ArgParser
+from pydantic import BaseModel
+
+from common import PuzzleDatasetMetadata, dihedral_transform
+
+
+cli = ArgParser()
+
+
+class DataProcessConfig(BaseModel):
+ # ARC-1
+ dataset_dirs: List[str] = ["dataset/raw-data/ARC-AGI/data", "dataset/raw-data/ConceptARC/corpus"]
+ output_dir: str = "data/arc-aug-1000"
+
+ # ARC-2
+ # dataset_dirs: List[str] = ["dataset/raw-data/ARC-AGI-2/data"]
+ # output_dir: str = "data/arc-2-aug-1000"
+
+ seed: int = 42
+ num_aug: int = 1000
+
+
+ARCMaxGridSize = 30
+ARCAugmentRetriesFactor = 5
+
+
+@dataclass
+class ARCPuzzle:
+ id: str
+
+ examples: List[Tuple[np.ndarray, np.ndarray]]
+
+
+def arc_grid_to_np(grid: List[List[int]]):
+ arr = np.array(grid)
+
+ # Shape check
+ assert arr.ndim == 2
+ assert arr.shape[0] <= ARCMaxGridSize and arr.shape[1] <= ARCMaxGridSize
+ # Element check
+ assert np.all((arr >= 0) & (arr <= 9))
+ return arr.astype(np.uint8)
+
+
+def np_grid_to_seq_translational_augment(inp: np.ndarray, out: np.ndarray, do_translation: bool):
+ # PAD: 0, <eos>: 1, digits: 2 ... 11
+ # Compute random top-left pad
+ if do_translation:
+ pad_r = np.random.randint(0, ARCMaxGridSize - max(inp.shape[0], out.shape[0]) + 1)
+ pad_c = np.random.randint(0, ARCMaxGridSize - max(inp.shape[1], out.shape[1]) + 1)
+ else:
+ pad_r = pad_c = 0
+
+ # Pad grid
+ result = []
+ for grid in [inp, out]:
+ nrow, ncol = grid.shape
+ grid = np.pad(grid + 2, ((pad_r, ARCMaxGridSize - pad_r - nrow), (pad_c, ARCMaxGridSize - pad_c - ncol)), constant_values=0)
+
+ # Add <eos>
+ eos_row, eos_col = pad_r + nrow, pad_c + ncol
+ if eos_row < ARCMaxGridSize:
+ grid[eos_row, pad_c:eos_col] = 1
+ if eos_col < ARCMaxGridSize:
+ grid[pad_r:eos_row, eos_col] = 1
+
+ result.append(grid.flatten())
+
+ return result
+
+
+def puzzle_hash(puzzle: dict):
+ # Hash the puzzle for checking equivalence
+ def _grid_hash(grid: np.ndarray):
+ buffer = [x.to_bytes(1) for x in grid.shape]
+ buffer.append(grid.tobytes())
+
+ return hashlib.sha256(b"".join(buffer)).hexdigest()
+
+ hashes = []
+ for example_type, example in puzzle.items():
+ for input, label in example.examples:
+ hashes.append(f"{_grid_hash(input)}|{_grid_hash(label)}")
+
+ hashes.sort()
+ return hashlib.sha256("|".join(hashes).encode()).hexdigest()
+
+
+def convert_single_arc_puzzle(results: dict, default_name: str, puzzle: dict, aug_count: int, dest_mapping: Dict[str, Tuple[str, str]]):
+ # Remove "name"
+ name = puzzle.pop("name", default_name)
+
+ # Convert
+ dests = set(dest_mapping.values())
+ converted = {dest: ARCPuzzle(name, []) for dest in dests}
+ for example_type, examples in puzzle.items():
+ dest = dest_mapping[example_type]
+ converted[dest].examples.extend([(arc_grid_to_np(example["input"]), arc_grid_to_np(example["output"])) for example in examples])
+
+ group = [converted]
+
+ # Augment
+ if aug_count > 0:
+ hashes = {puzzle_hash(converted)}
+
+ for _trial in range(ARCAugmentRetriesFactor * aug_count):
+ # Augment plan
+ trans_id = np.random.randint(0, 8)
+ mapping = np.concatenate([np.arange(0, 1, dtype=np.uint8), np.random.permutation(np.arange(1, 10, dtype=np.uint8))]) # Permute colors, Excluding "0" (black)
+
+ aug_repr = f"t{trans_id}_{''.join(str(x) for x in mapping)}"
+
+ def _map_grid(grid: np.ndarray):
+ return dihedral_transform(mapping[grid], trans_id)
+
+ # Check duplicate
+ augmented = {dest: ARCPuzzle(f"{puzzle.id}_{aug_repr}", [(_map_grid(input), _map_grid(label)) for (input, label) in puzzle.examples]) for dest, puzzle in converted.items()}
+ h = puzzle_hash(augmented)
+ if h not in hashes:
+ hashes.add(h)
+ group.append(augmented)
+
+ if len(group) >= aug_count + 1:
+ break
+
+ if len(group) < aug_count + 1:
+ print (f"[Puzzle {name}] augmentation not full, only {len(group)}")
+
+ # Append
+ for dest in dests:
+ # Convert the examples
+ dest_split, dest_set = dest
+
+ results.setdefault(dest_split, {})
+ results[dest_split].setdefault(dest_set, [])
+ results[dest_split][dest_set].append([converted[dest] for converted in group])
+
+
+def load_puzzles_arcagi(results: dict, dataset_path: str, config: DataProcessConfig):
+ train_examples_dest = ("train", "all")
+ test_examples_map = {
+ "evaluation": [(1.0, ("test", "all"))],
+ "_default": [(1.0, ("train", "all"))]
+ }
+
+ total_puzzles = 0
+ for subdir in os.scandir(dataset_path):
+ if subdir.is_dir():
+ # Load all puzzles in this directory
+ puzzles = []
+ for filename in glob(os.path.join(subdir.path, "*.json")):
+ with open(filename, "r") as f:
+ puzzles.append((Path(filename).stem, json.load(f)))
+
+ # Shuffle puzzles
+ np.random.shuffle(puzzles)
+
+ # Assign by fraction
+ for idx, (default_name, puzzle) in enumerate(puzzles):
+ fraction = idx / len(puzzles)
+ test_examples_dest = None
+ for f, dest in test_examples_map.get(subdir.name, test_examples_map["_default"]):
+ if fraction < f:
+ test_examples_dest = dest
+ break
+
+ assert test_examples_dest is not None
+
+ convert_single_arc_puzzle(results, default_name, puzzle, config.num_aug, {"train": train_examples_dest, "test": test_examples_dest})
+ total_puzzles += 1
+
+ print (f"[{dataset_path}] total puzzles: {total_puzzles}")
+
+
+def convert_dataset(config: DataProcessConfig):
+ np.random.seed(config.seed)
+
+ # Read dataset
+ data = {}
+ for dataset_dir in config.dataset_dirs:
+ load_puzzles_arcagi(data, dataset_dir, config)
+
+ # Map global puzzle identifiers
+ num_identifiers = 1 # 0 is blank
+ identifier_map = {}
+ for split_name, split in data.items():
+ for subset_name, subset in split.items():
+ for group in subset:
+ for puzzle in group:
+ if puzzle.id not in identifier_map:
+ identifier_map[puzzle.id] = num_identifiers
+ num_identifiers += 1
+
+ print (f"Total puzzle IDs (including <blank>): {num_identifiers}")
+
+ # Save
+ for split_name, split in data.items():
+ os.makedirs(os.path.join(config.output_dir, split_name), exist_ok=True)
+
+ # Translational augmentations
+ enable_translational_augment = split_name == "train"
+
+ # Statistics
+ total_examples = 0
+ total_puzzles = 0
+ total_groups = 0
+
+ for subset_name, subset in split.items():
+ # Construct subset
+ results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
+ results["puzzle_indices"].append(0)
+ results["group_indices"].append(0)
+
+ example_id = 0
+ puzzle_id = 0
+
+ for group in subset:
+ for puzzle in group:
+ # Push puzzle
+ no_aug_id = np.random.randint(0, len(puzzle.examples))
+ for _idx_ex, (inp, out) in enumerate(puzzle.examples):
+ inp, out = np_grid_to_seq_translational_augment(inp, out, do_translation=enable_translational_augment and _idx_ex != no_aug_id)
+
+ results["inputs"].append(inp)
+ results["labels"].append(out)
+ example_id += 1
+
+ total_examples += 1
+
+ results["puzzle_indices"].append(example_id)
+ results["puzzle_identifiers"].append(identifier_map[puzzle.id])
+
+ puzzle_id += 1
+
+ total_puzzles += 1
+
+ # Push group
+ results["group_indices"].append(puzzle_id)
+ total_groups += 1
+
+ for k, v in results.items():
+ if k in {"inputs", "labels"}:
+ v = np.stack(v, 0)
+ else:
+ v = np.array(v, dtype=np.int32)
+
+ np.save(os.path.join(config.output_dir, split_name, f"{subset_name}__{k}.npy"), v)
+
+ # Metadata
+ metadata = PuzzleDatasetMetadata(
+ seq_len=ARCMaxGridSize * ARCMaxGridSize,
+ vocab_size=10 + 2, # PAD + EOS + "0" ... "9"
+
+ pad_id=0,
+ ignore_label_id=0,
+
+ blank_identifier_id=0,
+ num_puzzle_identifiers=num_identifiers,
+
+ total_groups=total_groups,
+ mean_puzzle_examples=total_examples / total_puzzles,
+ sets=list(split.keys())
+ )
+
+ # Save metadata as JSON.
+ with open(os.path.join(config.output_dir, split_name, "dataset.json"), "w") as f:
+ json.dump(metadata.model_dump(), f)
+
+ # Save IDs mapping
+ with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
+ ids_mapping = {v: k for k, v in identifier_map.items()}
+
+ json.dump([ids_mapping.get(i, "<blank>") for i in range(num_identifiers)], f)
+
+
+@cli.command(singleton=True)
+def main(config: DataProcessConfig):
+ convert_dataset(config)
+
+
+if __name__ == "__main__":
+ cli()