summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorOne <imone@tuta.io>2025-07-09 10:13:51 +0800
committerOne <imone@tuta.io>2025-07-09 10:13:51 +0800
commitbd6222774edcec1608a6842d0b06a637a4acef59 (patch)
tree3b95517044286d82a9166bcce3134bbea099fcfe
parentcaa00bb77fbfce0cc14a45d5ec1c754394f96c1a (diff)
Release
-rw-r--r--.gitignore169
-rw-r--r--.vscode/launch.json26
-rw-r--r--.vscode/settings.json3
-rw-r--r--README.md169
-rw-r--r--arc_eval.ipynb252
-rw-r--r--assets/hrm.pngbin0 -> 99852 bytes
-rw-r--r--assets/npyjs.js176
-rw-r--r--config/arch/hrm_v1.yaml21
-rw-r--r--config/cfg_pretrain.yaml31
-rw-r--r--dataset/build_arc_dataset.py291
-rw-r--r--dataset/build_maze_dataset.py142
-rw-r--r--dataset/build_sudoku_dataset.py169
-rw-r--r--dataset/common.py51
-rw-r--r--evaluate.py68
-rw-r--r--models/common.py32
-rw-r--r--models/hrm/hrm_act_v1.py283
-rw-r--r--models/layers.py150
-rw-r--r--models/losses.py101
-rw-r--r--models/sparse_embedding.py132
-rw-r--r--pretrain.py454
-rw-r--r--puzzle_dataset.py199
-rw-r--r--puzzle_visualizer.html426
-rw-r--r--requirements.txt11
-rw-r--r--utils/functions.py19
24 files changed, 3375 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..644b422
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,169 @@
+# WandB
+/wandb/
+# checkpoints
+/checkpoints/
+# cache
+/cache/
+# data
+/data/
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/ \ No newline at end of file
diff --git a/.vscode/launch.json b/.vscode/launch.json
new file mode 100644
index 0000000..d49f941
--- /dev/null
+++ b/.vscode/launch.json
@@ -0,0 +1,26 @@
+{
+ // Use IntelliSense to learn about possible attributes.
+ // Hover to view descriptions of existing attributes.
+ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
+ "version": "0.2.0",
+ "configurations": [
+ {
+ "name": "Python Debugger: Current File",
+ "type": "debugpy",
+ "request": "launch",
+ "program": "${file}",
+ "console": "integratedTerminal"
+ },
+ {
+ "name": "Debug: Single GPU",
+ "type": "debugpy",
+ "request": "launch",
+ "program": "pretrain.py",
+ "args": [],
+ "env": {
+ "OMP_NUM_THREADS": "1",
+ "DISABLE_COMPILE": "true"
+ }
+ }
+ ]
+} \ No newline at end of file
diff --git a/.vscode/settings.json b/.vscode/settings.json
new file mode 100644
index 0000000..8aef7b1
--- /dev/null
+++ b/.vscode/settings.json
@@ -0,0 +1,3 @@
+{
+ "python.analysis.typeCheckingMode": "standard"
+} \ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..87620eb
--- /dev/null
+++ b/README.md
@@ -0,0 +1,169 @@
+# Hierarchical Reasoning Model
+
+![](./assets/hrm.png)
+
+Reasoning, the process of devising and executing complex goal-oriented action sequences, remains a critical challenge in AI.
+Current large language models (LLMs) primarily employ Chain-of-Thought (CoT) techniques, which suffer from brittle task decomposition, extensive data requirements, and high latency. Inspired by the hierarchical and multi-timescale processing in the human brain, we propose the Hierarchical Reasoning Model (HRM), a novel recurrent architecture that attains significant computational depth while maintaining both training stability and efficiency.
+HRM executes sequential reasoning tasks in a single forward pass without explicit supervision of the intermediate process, through two interdependent recurrent modules: a high-level module responsible for slow, abstract planning, and a low-level module handling rapid, detailed computations. With only 27 million parameters, HRM achieves exceptional performance on complex reasoning tasks using only 1000 training samples. The model operates without pre-training or CoT data, yet achieves nearly perfect performance on challenging tasks including complex Sudoku puzzles and optimal path finding in large mazes.
+Furthermore, HRM outperforms much larger models with significantly longer context windows on the Abstraction and Reasoning Corpus (ARC), a key benchmark for measuring artificial general intelligence capabilities.
+These results underscore HRM’s potential as a transformative advancement toward universal computation and general-purpose reasoning systems.
+
+## Quick Start Guide 🚀
+
+### Prerequisites ⚙️
+
+Ensure PyTorch and CUDA are installed. The repo needs CUDA extensions to be built. If not present, run the following commands:
+
+```bash
+# Install CUDA 12.4
+CUDA_URL=https://developer.download.nvidia.com/compute/cuda/12.4.1/local_installers/cuda_12.4.1_550.54.15_linux.run
+
+wget -q --show-progress --progress=bar:force:noscroll -O cuda_installer.run $CUDA_URL
+sudo sh cuda_installer.run --silent --toolkit --override
+
+export CUDA_HOME=/usr/local/cuda-12.4
+
+# Install PyTorch with CUDA 12.4
+PYTORCH_INDEX_URL=https://download.pytorch.org/whl/cu124
+
+pip3 install torch torchvision torchaudio --index-url $PYTORCH_INDEX_URL
+
+# Additional packages for building extensions
+pip3 install packaging ninja wheel setuptools setuptools-scm
+```
+
+## Install Python Dependencies 🐍
+
+```bash
+pip install -r requirements.txt
+```
+
+## W&B Integration 📈
+
+This project uses [Weights & Biases](https://wandb.ai/) for experiment tracking and metric visualization. Ensure you're logged in:
+
+```bash
+wandb login
+```
+
+## Run Experiments
+
+### Quick Demo: Sudoku Solver 💻🗲
+
+Train a master-level Sudoku AI capable of solving extremely difficult puzzles on a modern laptop GPU. 🧩
+
+```bash
+# Download and build Sudoku dataset
+python dataset/build_sudoku_dataset.py --output-dir data/sudoku-extreme-1k-aug-1000 --subsample-size 1000 --num-aug 1000
+
+# Start training (single GPU, smaller batch size)
+OMP_NUM_THREADS=8 python pretrain.py data_path=data/sudoku-extreme-1k-aug-1000 epochs=20000 eval_interval=2000 global_batch_size=384 lr=7e-5 puzzle_emb_lr=7e-5 weight_decay=1.0 puzzle_emb_weight_decay=1.0
+```
+
+Runtime: ~10 hours on a RTX 4070 laptop GPU
+
+## Full-scale Experiments 🔵
+
+Experiments below assume an 8-GPU setup.
+
+### Dataset Preparation
+
+```bash
+# Initialize submodules
+git submodule update --init --recursive
+
+# ARC-1
+python dataset/build_arc_dataset.py # ARC offical + ConceptARC, 960 examples
+# ARC-2
+python dataset/build_arc_dataset.py --dataset-dirs dataset/raw-data/ARC-AGI-2/data --output-dir data/arc-2-aug-1000 # ARC-2 official, 1120 examples
+
+# Sudoku-Extreme
+python dataset/build_sudoku_dataset.py # Full version
+python dataset/build_sudoku_dataset.py --output-dir data/sudoku-extreme-1k-aug-1000 --subsample-size 1000 --num-aug 1000 # 1000 examples
+
+# Maze
+python dataset/build_maze_dataset.py # 1000 examples
+```
+
+### Dataset Visualization
+
+Explore the puzzles visually:
+
+* Open `puzzle_visualizer.html` in your browser.
+* Upload the generated dataset folder located in `data/...`.
+
+## Launch experiments
+
+### Small-sample (1K)
+
+ARC-1:
+
+```bash
+OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py
+```
+
+*Runtime:* ~24 hours
+
+ARC-2:
+
+```bash
+OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/arc-2-aug-1000
+```
+
+*Runtime:* ~24 hours (checkpoint after 8 hours is often sufficient)
+
+Sudoku Extreme (1k):
+
+```bash
+OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/sudoku-extreme-1k-aug-1000 epochs=20000 eval_interval=2000 lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0
+```
+
+*Runtime:* ~10 minutes
+
+Maze 30x30 Hard (1k):
+
+```bash
+OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/maze-30x30-hard-1k epochs=20000 eval_interval=2000 lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0
+```
+
+*Runtime:* ~1 hour
+
+### Full Sudoku-Hard
+
+```bash
+OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/sudoku-hard-full epochs=100 eval_interval=10 lr_min_ratio=0.1 global_batch_size=2304 lr=3e-4 puzzle_emb_lr=3e-4 weight_decay=0.1 puzzle_emb_weight_decay=0.1 arch.loss.loss_type=softmax_cross_entropy arch.L_cycles=8 arch.halt_max_steps=8 arch.pos_encodings=learned
+```
+
+*Runtime:* ~2 hours
+
+## Evaluation
+
+Evaluate your trained models:
+
+* Check `eval/exact_accuracy` in W&B.
+* For ARC-AGI, follow these additional steps:
+
+```bash
+OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 evaluate.py checkpoint=<CHECKPOINT_PATH>
+```
+
+* Then use the provided `arc_eval.ipynb` notebook to finalize and inspect your results.
+
+## Notes
+
+ - Small-sample learning typically exhibits accuracy variance of around ±2 points.
+ - For Sudoku-Extreme (1,000-example dataset), late-stage overfitting may cause numerical instability during training and Q-learning. It is advisable to use early stopping once the training accuracy approaches 100%.
+
+## Citation 📜
+
+```
+@misc{wang2025hierarchicalreasoningmodel,
+ title={Hierarchical Reasoning Model},
+ author={Guan Wang and Jin Li and Yuhao Sun and Xing Chen and Changling Liu and Yue Wu and Meng Lu and Sen Song and Yasin Abbasi Yadkori},
+ year={2025},
+ eprint={2506.21734},
+ archivePrefix={arXiv},
+ primaryClass={cs.AI},
+ url={https://arxiv.org/abs/2506.21734},
+}
+``` \ No newline at end of file
diff --git a/arc_eval.ipynb b/arc_eval.ipynb
new file mode 100644
index 0000000..b2786b8
--- /dev/null
+++ b/arc_eval.ipynb
@@ -0,0 +1,252 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import json\n",
+ "from glob import glob\n",
+ "import hashlib\n",
+ "import matplotlib.pyplot as plt\n",
+ "import matplotlib.colors as mcolors\n",
+ "\n",
+ "import torch\n",
+ "import torch.nn.functional as F\n",
+ "import numpy as np\n",
+ "from numba import njit\n",
+ "\n",
+ "from dataset.common import inverse_dihedral_transform\n",
+ "\n",
+ "\n",
+ "DATASET_PATH = \"data/arc-aug-1000\" # ARC-1\n",
+ "# DATASET_PATH = \"data/arc-2-aug-1000\" # ARC-2\n",
+ "\n",
+ "CHECKPOINT_PATH = \"checkpoints/Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV1 amphibian-turaco/step_414456\"\n",
+ "\n",
+ "\n",
+ "PAD_PUZZLE_IDENTIFIER = 0\n",
+ "\n",
+ "# Visualization\n",
+ "ARC_COLOR_MAP = mcolors.ListedColormap([\n",
+ " \"#000000\", # symbol_0: black\n",
+ " \"#0074D9\", # symbol_1: blue\n",
+ " \"#FF4136\", # symbol_2: red\n",
+ " \"#2ECC40\", # symbol_3: green\n",
+ " \"#FFDC00\", # symbol_4: yellow\n",
+ " \"#AAAAAA\", # symbol_5: grey\n",
+ " \"#F012BE\", # symbol_6: fuschia\n",
+ " \"#FF851B\", # symbol_7: orange\n",
+ " \"#7FDBFF\", # symbol_8: teal\n",
+ " \"#870C25\" # symbol_9: brown\n",
+ "])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def load_identifiers_and_preds(dataset_path: str, checkpoint_path: str):\n",
+ " # Load puzzle identifiers\n",
+ " with open(os.path.join(dataset_path, \"identifiers.json\"), \"r\") as f:\n",
+ " identifier_map = json.load(f)\n",
+ " \n",
+ " # Load preds\n",
+ " all_preds = {}\n",
+ " for filename in glob(f\"{checkpoint_path}_all_preds.*\"):\n",
+ " preds = torch.load(filename)\n",
+ " for k, v in preds.items():\n",
+ " all_preds.setdefault(k, [])\n",
+ " all_preds[k].append(v)\n",
+ " \n",
+ " del preds\n",
+ "\n",
+ " all_preds = {k: torch.cat(v, dim=0) for k, v in all_preds.items()}\n",
+ " \n",
+ " # Remove paddings\n",
+ " mask = all_preds[\"puzzle_identifiers\"] != PAD_PUZZLE_IDENTIFIER\n",
+ " all_preds = {k: v[mask] for k, v in all_preds.items()}\n",
+ "\n",
+ " return identifier_map, all_preds\n",
+ "\n",
+ "\n",
+ "def inverse_aug(name: str, grid: np.ndarray):\n",
+ " if \"_\" not in name:\n",
+ " return grid\n",
+ "\n",
+ " trans_id, perm = name.split(\"_\")[-2:]\n",
+ " trans_id = int(trans_id[1:]) # Remove \"t\" letter\n",
+ " inv_perm = np.argsort(list(perm))\n",
+ " \n",
+ " return inv_perm[inverse_dihedral_transform(grid, trans_id)]\n",
+ "\n",
+ "\n",
+ "def grid_hash(grid: np.ndarray):\n",
+ " return hash((grid.tobytes(), grid.shape))\n",
+ "\n",
+ "\n",
+ "@njit\n",
+ "def crop(grid: np.ndarray):\n",
+ " # Find maximum-sized rectangle without any EOS token inside.\n",
+ " grid = grid.reshape(30, 30)\n",
+ "\n",
+ " max_area = 0\n",
+ " max_size = (0, 0)\n",
+ " nr, nc = grid.shape\n",
+ " \n",
+ " num_c = nc\n",
+ " for num_r in range(1, nr + 1):\n",
+ " # Scan for maximum c\n",
+ " for c in range(1, num_c + 1):\n",
+ " x = grid[num_r - 1, c - 1]\n",
+ " if (x < 2) | (x > 11):\n",
+ " num_c = c - 1\n",
+ " break\n",
+ " \n",
+ " area = num_r * num_c\n",
+ " if area > max_area:\n",
+ " max_area = area\n",
+ " max_size = (num_r, num_c)\n",
+ "\n",
+ " return grid[:max_size[0], :max_size[1]] - 2\n",
+ "\n",
+ "\n",
+ "def test(visualize, Ks=[1, 2, 10, 100, 1000]):\n",
+ " identifier_map, all_preds = load_identifiers_and_preds(DATASET_PATH, CHECKPOINT_PATH)\n",
+ " \n",
+ " global_hmap = {}\n",
+ " \n",
+ " # Get puzzles and corresponding answers\n",
+ " puzzle_labels = {}\n",
+ " for identifier, input, label in zip(all_preds[\"puzzle_identifiers\"], all_preds[\"inputs\"], all_preds[\"labels\"]):\n",
+ " name = identifier_map[identifier]\n",
+ " if \"_\" not in name: # Not-augmented\n",
+ " puzzle_labels.setdefault(name, {})\n",
+ " \n",
+ " input = crop(input.numpy())\n",
+ " label = crop(label.numpy())\n",
+ "\n",
+ " input_hash = grid_hash(input)\n",
+ " label_hash = grid_hash(label)\n",
+ "\n",
+ " global_hmap[input_hash] = input\n",
+ " global_hmap[label_hash] = label\n",
+ "\n",
+ " assert input_hash not in puzzle_labels[name]\n",
+ " puzzle_labels[name][input_hash] = label_hash\n",
+ " \n",
+ " print (\"Number of puzzles\", len(puzzle_labels))\n",
+ " \n",
+ " # Argmax prediction\n",
+ " preds = all_preds[\"logits\"].argmax(-1)\n",
+ "\n",
+ " # Collate\n",
+ " pred_answers = {}\n",
+ " for identifier, input, pred, q in zip(all_preds[\"puzzle_identifiers\"], all_preds[\"inputs\"], preds, all_preds[\"q_halt_logits\"].sigmoid()):\n",
+ " name = identifier_map[identifier]\n",
+ " orig_name = name.split(\"_\")[0]\n",
+ " \n",
+ " input = input.numpy()\n",
+ " input_hash = grid_hash(inverse_aug(name, crop(input)))\n",
+ " assert input_hash in puzzle_labels[orig_name]\n",
+ " \n",
+ " pred = inverse_aug(name, crop(pred.numpy()))\n",
+ " pred_hash = grid_hash(pred)\n",
+ " global_hmap[pred_hash] = pred\n",
+ " \n",
+ " pred_answers.setdefault(orig_name, {})\n",
+ " pred_answers[orig_name].setdefault(input_hash, [])\n",
+ " pred_answers[orig_name][input_hash].append((pred_hash, q.item()))\n",
+ "\n",
+ " # test-1\n",
+ " if visualize:\n",
+ " num_figs = sum(len(tests) for name, tests in puzzle_labels.items())\n",
+ " fig, axes = plt.subplots(num_figs, 4, figsize=(8, num_figs * 4))\n",
+ " \n",
+ " fig_id = 0\n",
+ " \n",
+ " correct = [0 for _ in range(len(Ks))]\n",
+ " for name, tests in puzzle_labels.items():\n",
+ " num_test_correct = [0 for _ in range(len(Ks))]\n",
+ " for input_hash, label_hash in tests.items():\n",
+ " p = pred_answers[name][input_hash]\n",
+ " p_map = {}\n",
+ " \n",
+ " for h, q in p:\n",
+ " p_map.setdefault(h, [0, 0])\n",
+ " p_map[h][0] += 1\n",
+ " p_map[h][1] += q\n",
+ " \n",
+ " for h, stats in p_map.items():\n",
+ " stats[1] /= stats[0]\n",
+ " \n",
+ " p_map = sorted(p_map.items(), key=lambda kv: kv[1], reverse=True)\n",
+ "\n",
+ " # 2-vote\n",
+ " for i, k in enumerate(Ks):\n",
+ " ok = False\n",
+ " for h, stats in p_map[:k]:\n",
+ " ok |= h == label_hash\n",
+ " \n",
+ " num_test_correct[i] += ok\n",
+ "\n",
+ " if visualize:\n",
+ " # Show input and ground truth\n",
+ " axes[fig_id, 0].imshow(global_hmap[input_hash], cmap=ARC_COLOR_MAP)\n",
+ " axes[fig_id, 0].set_title(f\"{name}\\nInput\")\n",
+ " axes[fig_id, 0].axis('off')\n",
+ " \n",
+ " axes[fig_id, 1].imshow(global_hmap[label_hash], cmap=ARC_COLOR_MAP)\n",
+ " axes[fig_id, 1].set_title(f\"{name}\\nAnswer\")\n",
+ " axes[fig_id, 1].axis('off')\n",
+ " \n",
+ " trial_id = 2\n",
+ " for h, stats in p_map[:2]:\n",
+ " ans = global_hmap[h]\n",
+ " \n",
+ " axes[fig_id, trial_id].imshow(ans, cmap=ARC_COLOR_MAP)\n",
+ " axes[fig_id, trial_id].set_title(f\"{name}\\nTrial {trial_id}\")\n",
+ " axes[fig_id, trial_id].axis('off')\n",
+ " \n",
+ " trial_id += 1\n",
+ " \n",
+ " fig_id += 1\n",
+ " \n",
+ " # Total correctness\n",
+ " for i in range(len(Ks)):\n",
+ " correct[i] += num_test_correct[i] == len(tests)\n",
+ "\n",
+ " for i, k in enumerate(Ks):\n",
+ " print (f\"{k}-shot: {correct[i] / len(puzzle_labels) * 100:.2f}%\")\n",
+ "\n",
+ "\n",
+ "test(visualize=False)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/assets/hrm.png b/assets/hrm.png
new file mode 100644
index 0000000..e6c2f10
--- /dev/null
+++ b/assets/hrm.png
Binary files differ
diff --git a/assets/npyjs.js b/assets/npyjs.js
new file mode 100644
index 0000000..b474575
--- /dev/null
+++ b/assets/npyjs.js
@@ -0,0 +1,176 @@
+class npyjs {
+
+ constructor(opts) {
+ if (opts && !('convertFloat16' in opts)) {
+ console.warn([
+ "npyjs constructor now accepts {convertFloat16?: boolean}.",
+ "For usage, go to https://github.com/jhuapl-boss/npyjs."
+ ].join(" "));
+ }
+
+ this.convertFloat16 = opts?.convertFloat16 ?? true;
+
+ this.dtypes = {
+ "<u1": {
+ name: "uint8",
+ size: 8,
+ arrayConstructor: Uint8Array,
+ },
+ "|u1": {
+ name: "uint8",
+ size: 8,
+ arrayConstructor: Uint8Array,
+ },
+ "<u2": {
+ name: "uint16",
+ size: 16,
+ arrayConstructor: Uint16Array,
+ },
+ "|i1": {
+ name: "int8",
+ size: 8,
+ arrayConstructor: Int8Array,
+ },
+ "<i2": {
+ name: "int16",
+ size: 16,
+ arrayConstructor: Int16Array,
+ },
+ "<u4": {
+ name: "uint32",
+ size: 32,
+ arrayConstructor: Uint32Array,
+ },
+ "<i4": {
+ name: "int32",
+ size: 32,
+ arrayConstructor: Int32Array,
+ },
+ "<u8": {
+ name: "uint64",
+ size: 64,
+ arrayConstructor: BigUint64Array,
+ },
+ "<i8": {
+ name: "int64",
+ size: 64,
+ arrayConstructor: BigInt64Array,
+ },
+ "<f4": {
+ name: "float32",
+ size: 32,
+ arrayConstructor: Float32Array
+ },
+ "<f8": {
+ name: "float64",
+ size: 64,
+ arrayConstructor: Float64Array
+ },
+ "<f2": {
+ name: "float16",
+ size: 16,
+ arrayConstructor: Uint16Array,
+ converter: this.convertFloat16 ? this.float16ToFloat32Array : undefined
+ },
+ };
+ }
+
+ float16ToFloat32Array(float16Array) {
+ const length = float16Array.length;
+ const float32Array = new Float32Array(length);
+
+ for (let i = 0; i < length; i++) {
+ float32Array[i] = npyjs.float16ToFloat32(float16Array[i]);
+ }
+
+ return float32Array;
+ }
+
+ static float16ToFloat32(float16) {
+ // Extract the parts of the float16
+ const sign = (float16 >> 15) & 0x1;
+ const exponent = (float16 >> 10) & 0x1f;
+ const fraction = float16 & 0x3ff;
+
+ // Handle special cases
+ if (exponent === 0) {
+ if (fraction === 0) {
+ // Zero
+ return sign ? -0 : 0;
+ }
+ // Denormalized number
+ return (sign ? -1 : 1) * Math.pow(2, -14) * (fraction / 0x400);
+ } else if (exponent === 0x1f) {
+ if (fraction === 0) {
+ // Infinity
+ return sign ? -Infinity : Infinity;
+ }
+ // NaN
+ return NaN;
+ }
+
+ // Normalized number
+ return (sign ? -1 : 1) * Math.pow(2, exponent - 15) * (1 + fraction / 0x400);
+ }
+
+ parse(arrayBufferContents) {
+ // const version = arrayBufferContents.slice(6, 8); // Uint8-encoded
+ const headerLength = new DataView(arrayBufferContents.slice(8, 10)).getUint8(0);
+ const offsetBytes = 10 + headerLength;
+
+ const hcontents = new TextDecoder("utf-8").decode(
+ new Uint8Array(arrayBufferContents.slice(10, 10 + headerLength))
+ );
+ const header = JSON.parse(
+ hcontents
+ .toLowerCase() // True -> true
+ .replace(/'/g, '"')
+ .replace("(", "[")
+ .replace(/,*\),*/g, "]")
+ );
+ const shape = header.shape;
+ const dtype = this.dtypes[header.descr];
+
+ if (!dtype) {
+ console.error(`Unsupported dtype: ${header.descr}`);
+ return null;
+ }
+
+ const nums = new dtype.arrayConstructor(
+ arrayBufferContents,
+ offsetBytes
+ );
+
+ // Convert float16 to float32 if converter exists
+ const data = dtype.converter ? dtype.converter.call(this, nums) : nums;
+
+ return {
+ dtype: dtype.name,
+ data: data,
+ shape,
+ fortranOrder: header.fortran_order
+ };
+ }
+
+ async load(filename, callback, fetchArgs) {
+ /*
+ Loads an array from a stream of bytes.
+ */
+ fetchArgs = fetchArgs || {};
+ let arrayBuf;
+ // If filename is ArrayBuffer
+ if (filename instanceof ArrayBuffer) {
+ arrayBuf = filename;
+ }
+ // If filename is a file path
+ else {
+ const resp = await fetch(filename, { ...fetchArgs });
+ arrayBuf = await resp.arrayBuffer();
+ }
+ const result = this.parse(arrayBuf);
+ if (callback) {
+ return callback(result);
+ }
+ return result;
+ }
+}
diff --git a/config/arch/hrm_v1.yaml b/config/arch/hrm_v1.yaml
new file mode 100644
index 0000000..a5646b8
--- /dev/null
+++ b/config/arch/hrm_v1.yaml
@@ -0,0 +1,21 @@
+name: hrm.hrm_act_v1@HierarchicalReasoningModel_ACTV1
+loss:
+ name: losses@ACTLossHead
+ loss_type: stablemax_cross_entropy
+
+halt_exploration_prob: 0.1
+halt_max_steps: 16
+
+H_cycles: 2
+L_cycles: 2
+
+H_layers: 4
+L_layers: 4
+
+hidden_size: 512
+num_heads: 8 # min(2, hidden_size // 64)
+expansion: 4
+
+puzzle_emb_ndim: ${.hidden_size}
+
+pos_encodings: rope
diff --git a/config/cfg_pretrain.yaml b/config/cfg_pretrain.yaml
new file mode 100644
index 0000000..51c55a0
--- /dev/null
+++ b/config/cfg_pretrain.yaml
@@ -0,0 +1,31 @@
+# ARC training config
+
+defaults:
+ - arch: hrm_v1
+ - _self_
+
+hydra:
+ output_subdir: null
+
+# Data path
+data_path: data/arc-aug-1000
+
+# Hyperparams - Training
+global_batch_size: 768
+
+epochs: 100000
+eval_interval: 10000
+checkpoint_every_eval: True
+
+lr: 1e-4
+lr_min_ratio: 1.0
+lr_warmup_steps: 2000
+
+# Standard hyperparameter settings for LM, as used in Llama
+beta1: 0.9
+beta2: 0.95
+weight_decay: 0.1
+puzzle_emb_weight_decay: 0.1
+
+# Hyperparams - Puzzle embeddings training
+puzzle_emb_lr: 1e-2
diff --git a/dataset/build_arc_dataset.py b/dataset/build_arc_dataset.py
new file mode 100644
index 0000000..2da5703
--- /dev/null
+++ b/dataset/build_arc_dataset.py
@@ -0,0 +1,291 @@
+from typing import List, Optional, Tuple, Dict
+from dataclasses import dataclass
+from pathlib import Path
+import os
+import json
+import hashlib
+import numpy as np
+from glob import glob
+
+from argdantic import ArgParser
+from pydantic import BaseModel
+
+from common import PuzzleDatasetMetadata, dihedral_transform
+
+
+cli = ArgParser()
+
+
+class DataProcessConfig(BaseModel):
+ # ARC-1
+ dataset_dirs: List[str] = ["dataset/raw-data/ARC-AGI/data", "dataset/raw-data/ConceptARC/corpus"]
+ output_dir: str = "data/arc-aug-1000"
+
+ # ARC-2
+ # dataset_dirs: List[str] = ["dataset/raw-data/ARC-AGI-2/data"]
+ # output_dir: str = "data/arc-2-aug-1000"
+
+ seed: int = 42
+ num_aug: int = 1000
+
+
+ARCMaxGridSize = 30
+ARCAugmentRetriesFactor = 5
+
+
+@dataclass
+class ARCPuzzle:
+ id: str
+
+ examples: List[Tuple[np.ndarray, np.ndarray]]
+
+
+def arc_grid_to_np(grid: List[List[int]]):
+ arr = np.array(grid)
+
+ # Shape check
+ assert arr.ndim == 2
+ assert arr.shape[0] <= ARCMaxGridSize and arr.shape[1] <= ARCMaxGridSize
+ # Element check
+ assert np.all((arr >= 0) & (arr <= 9))
+ return arr.astype(np.uint8)
+
+
+def np_grid_to_seq_translational_augment(inp: np.ndarray, out: np.ndarray, do_translation: bool):
+ # PAD: 0, <eos>: 1, digits: 2 ... 11
+ # Compute random top-left pad
+ if do_translation:
+ pad_r = np.random.randint(0, ARCMaxGridSize - max(inp.shape[0], out.shape[0]) + 1)
+ pad_c = np.random.randint(0, ARCMaxGridSize - max(inp.shape[1], out.shape[1]) + 1)
+ else:
+ pad_r = pad_c = 0
+
+ # Pad grid
+ result = []
+ for grid in [inp, out]:
+ nrow, ncol = grid.shape
+ grid = np.pad(grid + 2, ((pad_r, ARCMaxGridSize - pad_r - nrow), (pad_c, ARCMaxGridSize - pad_c - ncol)), constant_values=0)
+
+ # Add <eos>
+ eos_row, eos_col = pad_r + nrow, pad_c + ncol
+ if eos_row < ARCMaxGridSize:
+ grid[eos_row, pad_c:eos_col] = 1
+ if eos_col < ARCMaxGridSize:
+ grid[pad_r:eos_row, eos_col] = 1
+
+ result.append(grid.flatten())
+
+ return result
+
+
+def puzzle_hash(puzzle: dict):
+ # Hash the puzzle for checking equivalence
+ def _grid_hash(grid: np.ndarray):
+ buffer = [x.to_bytes(1) for x in grid.shape]
+ buffer.append(grid.tobytes())
+
+ return hashlib.sha256(b"".join(buffer)).hexdigest()
+
+ hashes = []
+ for example_type, example in puzzle.items():
+ for input, label in example.examples:
+ hashes.append(f"{_grid_hash(input)}|{_grid_hash(label)}")
+
+ hashes.sort()
+ return hashlib.sha256("|".join(hashes).encode()).hexdigest()
+
+
+def convert_single_arc_puzzle(results: dict, default_name: str, puzzle: dict, aug_count: int, dest_mapping: Dict[str, Tuple[str, str]]):
+ # Remove "name"
+ name = puzzle.pop("name", default_name)
+
+ # Convert
+ dests = set(dest_mapping.values())
+ converted = {dest: ARCPuzzle(name, []) for dest in dests}
+ for example_type, examples in puzzle.items():
+ dest = dest_mapping[example_type]
+ converted[dest].examples.extend([(arc_grid_to_np(example["input"]), arc_grid_to_np(example["output"])) for example in examples])
+
+ group = [converted]
+
+ # Augment
+ if aug_count > 0:
+ hashes = {puzzle_hash(converted)}
+
+ for _trial in range(ARCAugmentRetriesFactor * aug_count):
+ # Augment plan
+ trans_id = np.random.randint(0, 8)
+ mapping = np.concatenate([np.arange(0, 1, dtype=np.uint8), np.random.permutation(np.arange(1, 10, dtype=np.uint8))]) # Permute colors, Excluding "0" (black)
+
+ aug_repr = f"t{trans_id}_{''.join(str(x) for x in mapping)}"
+
+ def _map_grid(grid: np.ndarray):
+ return dihedral_transform(mapping[grid], trans_id)
+
+ # Check duplicate
+ augmented = {dest: ARCPuzzle(f"{puzzle.id}_{aug_repr}", [(_map_grid(input), _map_grid(label)) for (input, label) in puzzle.examples]) for dest, puzzle in converted.items()}
+ h = puzzle_hash(augmented)
+ if h not in hashes:
+ hashes.add(h)
+ group.append(augmented)
+
+ if len(group) >= aug_count + 1:
+ break
+
+ if len(group) < aug_count + 1:
+ print (f"[Puzzle {name}] augmentation not full, only {len(group)}")
+
+ # Append
+ for dest in dests:
+ # Convert the examples
+ dest_split, dest_set = dest
+
+ results.setdefault(dest_split, {})
+ results[dest_split].setdefault(dest_set, [])
+ results[dest_split][dest_set].append([converted[dest] for converted in group])
+
+
+def load_puzzles_arcagi(results: dict, dataset_path: str, config: DataProcessConfig):
+ train_examples_dest = ("train", "all")
+ test_examples_map = {
+ "evaluation": [(1.0, ("test", "all"))],
+ "_default": [(1.0, ("train", "all"))]
+ }
+
+ total_puzzles = 0
+ for subdir in os.scandir(dataset_path):
+ if subdir.is_dir():
+ # Load all puzzles in this directory
+ puzzles = []
+ for filename in glob(os.path.join(subdir.path, "*.json")):
+ with open(filename, "r") as f:
+ puzzles.append((Path(filename).stem, json.load(f)))
+
+ # Shuffle puzzles
+ np.random.shuffle(puzzles)
+
+ # Assign by fraction
+ for idx, (default_name, puzzle) in enumerate(puzzles):
+ fraction = idx / len(puzzles)
+ test_examples_dest = None
+ for f, dest in test_examples_map.get(subdir.name, test_examples_map["_default"]):
+ if fraction < f:
+ test_examples_dest = dest
+ break
+
+ assert test_examples_dest is not None
+
+ convert_single_arc_puzzle(results, default_name, puzzle, config.num_aug, {"train": train_examples_dest, "test": test_examples_dest})
+ total_puzzles += 1
+
+ print (f"[{dataset_path}] total puzzles: {total_puzzles}")
+
+
+def convert_dataset(config: DataProcessConfig):
+ np.random.seed(config.seed)
+
+ # Read dataset
+ data = {}
+ for dataset_dir in config.dataset_dirs:
+ load_puzzles_arcagi(data, dataset_dir, config)
+
+ # Map global puzzle identifiers
+ num_identifiers = 1 # 0 is blank
+ identifier_map = {}
+ for split_name, split in data.items():
+ for subset_name, subset in split.items():
+ for group in subset:
+ for puzzle in group:
+ if puzzle.id not in identifier_map:
+ identifier_map[puzzle.id] = num_identifiers
+ num_identifiers += 1
+
+ print (f"Total puzzle IDs (including <blank>): {num_identifiers}")
+
+ # Save
+ for split_name, split in data.items():
+ os.makedirs(os.path.join(config.output_dir, split_name), exist_ok=True)
+
+ # Translational augmentations
+ enable_translational_augment = split_name == "train"
+
+ # Statistics
+ total_examples = 0
+ total_puzzles = 0
+ total_groups = 0
+
+ for subset_name, subset in split.items():
+ # Construct subset
+ results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
+ results["puzzle_indices"].append(0)
+ results["group_indices"].append(0)
+
+ example_id = 0
+ puzzle_id = 0
+
+ for group in subset:
+ for puzzle in group:
+ # Push puzzle
+ no_aug_id = np.random.randint(0, len(puzzle.examples))
+ for _idx_ex, (inp, out) in enumerate(puzzle.examples):
+ inp, out = np_grid_to_seq_translational_augment(inp, out, do_translation=enable_translational_augment and _idx_ex != no_aug_id)
+
+ results["inputs"].append(inp)
+ results["labels"].append(out)
+ example_id += 1
+
+ total_examples += 1
+
+ results["puzzle_indices"].append(example_id)
+ results["puzzle_identifiers"].append(identifier_map[puzzle.id])
+
+ puzzle_id += 1
+
+ total_puzzles += 1
+
+ # Push group
+ results["group_indices"].append(puzzle_id)
+ total_groups += 1
+
+ for k, v in results.items():
+ if k in {"inputs", "labels"}:
+ v = np.stack(v, 0)
+ else:
+ v = np.array(v, dtype=np.int32)
+
+ np.save(os.path.join(config.output_dir, split_name, f"{subset_name}__{k}.npy"), v)
+
+ # Metadata
+ metadata = PuzzleDatasetMetadata(
+ seq_len=ARCMaxGridSize * ARCMaxGridSize,
+ vocab_size=10 + 2, # PAD + EOS + "0" ... "9"
+
+ pad_id=0,
+ ignore_label_id=0,
+
+ blank_identifier_id=0,
+ num_puzzle_identifiers=num_identifiers,
+
+ total_groups=total_groups,
+ mean_puzzle_examples=total_examples / total_puzzles,
+ sets=list(split.keys())
+ )
+
+ # Save metadata as JSON.
+ with open(os.path.join(config.output_dir, split_name, "dataset.json"), "w") as f:
+ json.dump(metadata.model_dump(), f)
+
+ # Save IDs mapping
+ with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
+ ids_mapping = {v: k for k, v in identifier_map.items()}
+
+ json.dump([ids_mapping.get(i, "<blank>") for i in range(num_identifiers)], f)
+
+
+@cli.command(singleton=True)
+def main(config: DataProcessConfig):
+ convert_dataset(config)
+
+
+if __name__ == "__main__":
+ cli()
diff --git a/dataset/build_maze_dataset.py b/dataset/build_maze_dataset.py
new file mode 100644
index 0000000..e99baf2
--- /dev/null
+++ b/dataset/build_maze_dataset.py
@@ -0,0 +1,142 @@
+from typing import Optional
+import math
+import os
+import csv
+import json
+import numpy as np
+
+from argdantic import ArgParser
+from pydantic import BaseModel
+from tqdm import tqdm
+from huggingface_hub import hf_hub_download
+
+from common import PuzzleDatasetMetadata, dihedral_transform
+
+
+CHARSET = "# SGo"
+
+
+cli = ArgParser()
+
+
+class DataProcessConfig(BaseModel):
+ source_repo: str = "imone/small-sample-challenge-maze-30x30-hard"
+ output_dir: str = "data/maze-30x30-hard-1k"
+
+ subsample_size: Optional[int] = None
+ aug: bool = False
+
+
+def convert_subset(set_name: str, config: DataProcessConfig):
+ # Read CSV
+ all_chars = set()
+ grid_size = None
+ inputs = []
+ labels = []
+
+ with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile: # type: ignore
+ reader = csv.reader(csvfile)
+ next(reader) # Skip header
+ for source, q, a, rating in reader:
+ all_chars.update(q)
+ all_chars.update(a)
+
+ if grid_size is None:
+ n = int(len(q) ** 0.5)
+ grid_size = (n, n)
+
+ inputs.append(np.frombuffer(q.encode(), dtype=np.uint8).reshape(grid_size))
+ labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(grid_size))
+
+ # If subsample_size is specified for the training set,
+ # randomly sample the desired number of examples.
+ if set_name == "train" and config.subsample_size is not None:
+ total_samples = len(inputs)
+ if config.subsample_size < total_samples:
+ indices = np.random.choice(total_samples, size=config.subsample_size, replace=False)
+ inputs = [inputs[i] for i in indices]
+ labels = [labels[i] for i in indices]
+
+ # Generate dataset
+ results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
+ puzzle_id = 0
+ example_id = 0
+
+ results["puzzle_indices"].append(0)
+ results["group_indices"].append(0)
+
+ for inp, out in zip(tqdm(inputs), labels):
+ # Dihedral transformations for augmentation
+ for aug_idx in range(8 if (set_name == "train" and config.aug) else 1):
+ results["inputs"].append(dihedral_transform(inp, aug_idx))
+ results["labels"].append(dihedral_transform(out, aug_idx))
+ example_id += 1
+ puzzle_id += 1
+
+ results["puzzle_indices"].append(example_id)
+ results["puzzle_identifiers"].append(0)
+
+ # Push group
+ results["group_indices"].append(puzzle_id)
+
+ # Char mappings
+ assert len(all_chars - set(CHARSET)) == 0
+
+ char2id = np.zeros(256, np.uint8)
+ char2id[np.array(list(map(ord, CHARSET)))] = np.arange(len(CHARSET)) + 1
+
+ # To Numpy
+ def _seq_to_numpy(seq):
+ arr = np.vstack([char2id[s.reshape(-1)] for s in seq])
+
+ return arr
+
+ results = {
+ "inputs": _seq_to_numpy(results["inputs"]),
+ "labels": _seq_to_numpy(results["labels"]),
+
+ "group_indices": np.array(results["group_indices"], dtype=np.int32),
+ "puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32),
+ "puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32),
+ }
+
+ # Metadata
+ metadata = PuzzleDatasetMetadata(
+ seq_len=int(math.prod(grid_size)), # type: ignore
+ vocab_size=len(CHARSET) + 1, # PAD + Charset
+
+ pad_id=0,
+ ignore_label_id=0,
+
+ blank_identifier_id=0,
+ num_puzzle_identifiers=1,
+
+ total_groups=len(results["group_indices"]) - 1,
+ mean_puzzle_examples=1,
+ sets=["all"]
+ )
+
+ # Save metadata as JSON.
+ save_dir = os.path.join(config.output_dir, set_name)
+ os.makedirs(save_dir, exist_ok=True)
+
+ with open(os.path.join(save_dir, "dataset.json"), "w") as f:
+ json.dump(metadata.model_dump(), f)
+
+ # Save data
+ for k, v in results.items():
+ np.save(os.path.join(save_dir, f"all__{k}.npy"), v)
+
+ # Save IDs mapping (for visualization only)
+ with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
+ json.dump(["<blank>"], f)
+
+
+@cli.command(singleton=True)
+def preprocess_data(config: DataProcessConfig):
+ convert_subset("train", config)
+ convert_subset("test", config)
+
+
+if __name__ == "__main__":
+ cli()
diff --git a/dataset/build_sudoku_dataset.py b/dataset/build_sudoku_dataset.py
new file mode 100644
index 0000000..5d5b50c
--- /dev/null
+++ b/dataset/build_sudoku_dataset.py
@@ -0,0 +1,169 @@
+from typing import Optional
+import os
+import csv
+import json
+import numpy as np
+
+from argdantic import ArgParser
+from pydantic import BaseModel
+from tqdm import tqdm
+from huggingface_hub import hf_hub_download
+
+from common import PuzzleDatasetMetadata
+
+
+cli = ArgParser()
+
+
+class DataProcessConfig(BaseModel):
+ source_repo: str = "imone/sudoku-hard-v2"
+ output_dir: str = "data/sudoku-extreme-full"
+
+ subsample_size: Optional[int] = None
+ min_difficulty: Optional[int] = None
+ num_aug: int = 0
+
+
+def shuffle_sudoku(board: np.ndarray, solution: np.ndarray):
+ # Create a random digit mapping: a permutation of 1..9, with zero (blank) unchanged
+ digit_map = np.pad(np.random.permutation(np.arange(1, 10)), (1, 0))
+
+ # Randomly decide whether to transpose.
+ transpose_flag = np.random.rand() < 0.5
+
+ # Generate a valid row permutation:
+ # - Shuffle the 3 bands (each band = 3 rows) and for each band, shuffle its 3 rows.
+ bands = np.random.permutation(3)
+ row_perm = np.concatenate([b * 3 + np.random.permutation(3) for b in bands])
+
+ # Similarly for columns (stacks).
+ stacks = np.random.permutation(3)
+ col_perm = np.concatenate([s * 3 + np.random.permutation(3) for s in stacks])
+
+ # Build an 81->81 mapping. For each new cell at (i, j)
+ # (row index = i // 9, col index = i % 9),
+ # its value comes from old row = row_perm[i//9] and old col = col_perm[i%9].
+ mapping = np.array([row_perm[i // 9] * 9 + col_perm[i % 9] for i in range(81)])
+
+ def apply_transformation(x: np.ndarray) -> np.ndarray:
+ # Apply transpose flag
+ if transpose_flag:
+ x = x.T
+ # Apply the position mapping.
+ new_board = x.flatten()[mapping].reshape(9, 9).copy()
+ # Apply digit mapping
+ return digit_map[new_board]
+
+ return apply_transformation(board), apply_transformation(solution)
+
+
+def convert_subset(set_name: str, config: DataProcessConfig):
+ # Read CSV
+ inputs = []
+ labels = []
+
+ with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile:
+ reader = csv.reader(csvfile)
+ next(reader) # Skip header
+ for source, q, a, rating in reader:
+ if (config.min_difficulty is None) or (int(rating) >= config.min_difficulty):
+ assert len(q) == 81 and len(a) == 81
+
+ inputs.append(np.frombuffer(q.replace('.', '0').encode(), dtype=np.uint8).reshape(9, 9) - ord('0'))
+ labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(9, 9) - ord('0'))
+
+ # If subsample_size is specified for the training set,
+ # randomly sample the desired number of examples.
+ if set_name == "train" and config.subsample_size is not None:
+ total_samples = len(inputs)
+ if config.subsample_size < total_samples:
+ indices = np.random.choice(total_samples, size=config.subsample_size, replace=False)
+ inputs = [inputs[i] for i in indices]
+ labels = [labels[i] for i in indices]
+
+ # Generate dataset
+ num_augments = config.num_aug if set_name == "train" else 0
+
+ results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
+ puzzle_id = 0
+ example_id = 0
+
+ results["puzzle_indices"].append(0)
+ results["group_indices"].append(0)
+
+ for orig_inp, orig_out in zip(tqdm(inputs), labels):
+ for aug_idx in range(1 + num_augments):
+ # First index is not augmented
+ if aug_idx == 0:
+ inp, out = orig_inp, orig_out
+ else:
+ inp, out = shuffle_sudoku(orig_inp, orig_out)
+
+ # Push puzzle (only single example)
+ results["inputs"].append(inp)
+ results["labels"].append(out)
+ example_id += 1
+ puzzle_id += 1
+
+ results["puzzle_indices"].append(example_id)
+ results["puzzle_identifiers"].append(0)
+
+ # Push group
+ results["group_indices"].append(puzzle_id)
+
+ # To Numpy
+ def _seq_to_numpy(seq):
+ arr = np.concatenate(seq).reshape(len(seq), -1)
+
+ assert np.all((arr >= 0) & (arr <= 9))
+ return arr + 1
+
+ results = {
+ "inputs": _seq_to_numpy(results["inputs"]),
+ "labels": _seq_to_numpy(results["labels"]),
+
+ "group_indices": np.array(results["group_indices"], dtype=np.int32),
+ "puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32),
+ "puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32),
+ }
+
+ # Metadata
+ metadata = PuzzleDatasetMetadata(
+ seq_len=81,
+ vocab_size=10 + 1, # PAD + "0" ... "9"
+
+ pad_id=0,
+ ignore_label_id=0,
+
+ blank_identifier_id=0,
+ num_puzzle_identifiers=1,
+
+ total_groups=len(results["group_indices"]) - 1,
+ mean_puzzle_examples=1,
+ sets=["all"]
+ )
+
+ # Save metadata as JSON.
+ save_dir = os.path.join(config.output_dir, set_name)
+ os.makedirs(save_dir, exist_ok=True)
+
+ with open(os.path.join(save_dir, "dataset.json"), "w") as f:
+ json.dump(metadata.model_dump(), f)
+
+ # Save data
+ for k, v in results.items():
+ np.save(os.path.join(save_dir, f"all__{k}.npy"), v)
+
+ # Save IDs mapping (for visualization only)
+ with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
+ json.dump(["<blank>"], f)
+
+
+@cli.command(singleton=True)
+def preprocess_data(config: DataProcessConfig):
+ convert_subset("train", config)
+ convert_subset("test", config)
+
+
+if __name__ == "__main__":
+ cli()
diff --git a/dataset/common.py b/dataset/common.py
new file mode 100644
index 0000000..7bc51c6
--- /dev/null
+++ b/dataset/common.py
@@ -0,0 +1,51 @@
+from typing import List, Optional
+
+import pydantic
+import numpy as np
+
+
+# Global list mapping each dihedral transform id to its inverse.
+# Index corresponds to the original tid, and the value is its inverse.
+DIHEDRAL_INVERSE = [0, 3, 2, 1, 4, 5, 6, 7]
+
+
+class PuzzleDatasetMetadata(pydantic.BaseModel):
+ pad_id: int
+ ignore_label_id: Optional[int]
+ blank_identifier_id: int
+
+ vocab_size: int
+ seq_len: int
+ num_puzzle_identifiers: int
+
+ total_groups: int
+ mean_puzzle_examples: float
+
+ sets: List[str]
+
+
+def dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
+ """8 dihedral symmetries by rotate, flip and mirror"""
+
+ if tid == 0:
+ return arr # identity
+ elif tid == 1:
+ return np.rot90(arr, k=1)
+ elif tid == 2:
+ return np.rot90(arr, k=2)
+ elif tid == 3:
+ return np.rot90(arr, k=3)
+ elif tid == 4:
+ return np.fliplr(arr) # horizontal flip
+ elif tid == 5:
+ return np.flipud(arr) # vertical flip
+ elif tid == 6:
+ return arr.T # transpose (reflection along main diagonal)
+ elif tid == 7:
+ return np.fliplr(np.rot90(arr, k=1)) # anti-diagonal reflection
+ else:
+ return arr
+
+
+def inverse_dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
+ return dihedral_transform(arr, DIHEDRAL_INVERSE[tid])
diff --git a/evaluate.py b/evaluate.py
new file mode 100644
index 0000000..9bc6ba0
--- /dev/null
+++ b/evaluate.py
@@ -0,0 +1,68 @@
+from typing import List
+import yaml
+import os
+
+import torch
+import torch.distributed as dist
+
+import pydantic
+from omegaconf import OmegaConf
+from pretrain import PretrainConfig, init_train_state, evaluate, create_dataloader
+
+
+class EvalConfig(pydantic.BaseModel):
+ checkpoint: str
+
+ save_outputs: List[str] = ["inputs", "labels", "puzzle_identifiers", "logits", "q_halt_logits", "q_continue_logits"]
+
+
+def launch():
+ eval_cfg = EvalConfig(**OmegaConf.to_container(OmegaConf.from_cli())) # type: ignore
+
+ RANK = 0
+ WORLD_SIZE = 1
+ # Initialize distributed training if in distributed environment (e.g. torchrun)
+ if "LOCAL_RANK" in os.environ:
+ # Initialize distributed, default device and dtype
+ dist.init_process_group(backend="nccl")
+
+ RANK = dist.get_rank()
+ WORLD_SIZE = dist.get_world_size()
+
+ torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
+
+ with open(os.path.join(os.path.dirname(eval_cfg.checkpoint), "all_config.yaml"), "r") as f:
+ config = PretrainConfig(**yaml.safe_load(f))
+
+ config.eval_save_outputs = eval_cfg.save_outputs
+ config.checkpoint_path = os.path.dirname(eval_cfg.checkpoint)
+
+ # Dataloader
+ train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
+ eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, test_set_limit_examples=LIMIT_EXAMPLES, rank=RANK, world_size=WORLD_SIZE)
+
+ # Models
+ train_state = init_train_state(config, train_metadata, world_size=WORLD_SIZE)
+ # Try unwrap torch.compile
+ try:
+ train_state.model.load_state_dict(torch.load(eval_cfg.checkpoint, map_location="cuda"), assign=True)
+ except:
+ train_state.model.load_state_dict({k.removeprefix("_orig_mod."): v for k, v in torch.load(eval_cfg.checkpoint, map_location="cuda").items()}, assign=True)
+
+ train_state.step = 0
+ ckpt_filename = os.path.basename(eval_cfg.checkpoint)
+ if ckpt_filename.startswith("step_"):
+ train_state.step = int(ckpt_filename.removeprefix("step_"))
+
+ # Evaluate
+ print ("Starting evaluation")
+
+ train_state.model.eval()
+ metrics = evaluate(config, train_state, eval_loader, eval_metadata, rank=RANK, world_size=WORLD_SIZE)
+
+ if metrics is not None:
+ print (metrics)
+
+
+if __name__ == "__main__":
+ launch()
diff --git a/models/common.py b/models/common.py
new file mode 100644
index 0000000..1a04505
--- /dev/null
+++ b/models/common.py
@@ -0,0 +1,32 @@
+import math
+
+import torch
+from torch import nn
+
+
+def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0):
+ # NOTE: PyTorch nn.init.trunc_normal_ is not mathematically correct, the std dev is not actually the std dev of initialized tensor
+ # This function is a PyTorch version of jax truncated normal init (default init method in flax)
+ # https://github.com/jax-ml/jax/blob/main/jax/_src/random.py#L807-L848
+ # https://github.com/jax-ml/jax/blob/main/jax/_src/nn/initializers.py#L162-L199
+
+ with torch.no_grad():
+ if std == 0:
+ tensor.zero_()
+ else:
+ sqrt2 = math.sqrt(2)
+ a = math.erf(lower / sqrt2)
+ b = math.erf(upper / sqrt2)
+ z = (b - a) / 2
+
+ c = (2 * math.pi) ** -0.5
+ pdf_u = c * math.exp(-0.5 * lower ** 2)
+ pdf_l = c * math.exp(-0.5 * upper ** 2)
+ comp_std = std / math.sqrt(1 - (upper * pdf_u - lower * pdf_l) / z - ((pdf_u - pdf_l) / z) ** 2)
+
+ tensor.uniform_(a, b)
+ tensor.erfinv_()
+ tensor.mul_(sqrt2 * comp_std)
+ tensor.clip_(lower * comp_std, upper * comp_std)
+
+ return tensor
diff --git a/models/hrm/hrm_act_v1.py b/models/hrm/hrm_act_v1.py
new file mode 100644
index 0000000..e91c7d1
--- /dev/null
+++ b/models/hrm/hrm_act_v1.py
@@ -0,0 +1,283 @@
+from typing import Tuple, List, Dict, Optional
+from dataclasses import dataclass
+import math
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from pydantic import BaseModel
+
+from models.common import trunc_normal_init_
+from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
+from models.sparse_embedding import CastedSparseEmbedding
+
+
+@dataclass
+class HierarchicalReasoningModel_ACTV1InnerCarry:
+ z_H: torch.Tensor
+ z_L: torch.Tensor
+
+
+@dataclass
+class HierarchicalReasoningModel_ACTV1Carry:
+ inner_carry: HierarchicalReasoningModel_ACTV1InnerCarry
+
+ steps: torch.Tensor
+ halted: torch.Tensor
+
+ current_data: Dict[str, torch.Tensor]
+
+
+class HierarchicalReasoningModel_ACTV1Config(BaseModel):
+ batch_size: int
+ seq_len: int
+ puzzle_emb_ndim: int = 0
+ num_puzzle_identifiers: int
+ vocab_size: int
+
+ H_cycles: int
+ L_cycles: int
+
+ H_layers: int
+ L_layers: int
+
+ # Transformer config
+ hidden_size: int
+ expansion: float
+ num_heads: int
+ pos_encodings: str
+
+ rms_norm_eps: float = 1e-5
+ rope_theta: float = 10000.0
+
+ # Halting Q-learning config
+ halt_max_steps: int
+ halt_exploration_prob: float
+
+ forward_dtype: str = "bfloat16"
+
+
+class HierarchicalReasoningModel_ACTV1Block(nn.Module):
+ def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
+ super().__init__()
+
+ self.self_attn = Attention(
+ hidden_size=config.hidden_size,
+ head_dim=config.hidden_size // config.num_heads,
+ num_heads=config.num_heads,
+ num_key_value_heads=config.num_heads,
+ causal=False
+ )
+ self.mlp = SwiGLU(
+ hidden_size=config.hidden_size,
+ expansion=config.expansion,
+ )
+ self.norm_eps = config.rms_norm_eps
+
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
+ # Post Norm
+ # Self Attention
+ hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
+ # Fully Connected
+ hidden_states = rms_norm(hidden_states + self.mlp(hidden_states), variance_epsilon=self.norm_eps)
+ return hidden_states
+
+
+class HierarchicalReasoningModel_ACTV1ReasoningModule(nn.Module):
+ def __init__(self, layers: List[HierarchicalReasoningModel_ACTV1Block]):
+ super().__init__()
+
+ self.layers = torch.nn.ModuleList(layers)
+
+ def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
+ # Input injection (add)
+ hidden_states = hidden_states + input_injection
+ # Layers
+ for layer in self.layers:
+ hidden_states = layer(hidden_states=hidden_states, **kwargs)
+
+ return hidden_states
+
+
+class HierarchicalReasoningModel_ACTV1_Inner(nn.Module):
+ def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
+ super().__init__()
+ self.config = config
+ self.forward_dtype = getattr(torch, self.config.forward_dtype)
+
+ # I/O
+ self.embed_scale = math.sqrt(self.config.hidden_size)
+ embed_init_std = 1.0 / self.embed_scale
+
+ self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
+ self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
+ self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
+
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
+ if self.config.puzzle_emb_ndim > 0:
+ # Zero init puzzle embeddings
+ self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
+ batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
+
+ # LM Blocks
+ if self.config.pos_encodings == "rope":
+ self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
+ max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
+ base=self.config.rope_theta)
+ elif self.config.pos_encodings == "learned":
+ self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
+ else:
+ raise NotImplementedError()
+
+ # Reasoning Layers
+ self.H_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.H_layers)])
+ self.L_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
+
+ # Initial states
+ self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
+ self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
+
+ # Q head special init
+ # Init Q to (almost) zero for faster learning during bootstrapping
+ with torch.no_grad():
+ self.q_head.weight.zero_()
+ self.q_head.bias.fill_(-5) # type: ignore
+
+ def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
+ # Token embedding
+ embedding = self.embed_tokens(input.to(torch.int32))
+
+ # Puzzle embeddings
+ if self.config.puzzle_emb_ndim > 0:
+ puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
+
+ pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
+ if pad_count > 0:
+ puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
+
+ embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
+
+ # Position embeddings
+ if self.config.pos_encodings == "learned":
+ # scale by 1/sqrt(2) to maintain forward variance
+ embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
+
+ # Scale
+ return self.embed_scale * embedding
+
+ def empty_carry(self, batch_size: int):
+ return HierarchicalReasoningModel_ACTV1InnerCarry(
+ z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
+ z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
+ )
+
+ def reset_carry(self, reset_flag: torch.Tensor, carry: HierarchicalReasoningModel_ACTV1InnerCarry):
+ return HierarchicalReasoningModel_ACTV1InnerCarry(
+ z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
+ z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
+ )
+
+ def forward(self, carry: HierarchicalReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ seq_info = dict(
+ cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
+ )
+
+ # Input encoding
+ input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
+
+ # Forward iterations
+ with torch.no_grad():
+ z_H, z_L = carry.z_H, carry.z_L
+
+ for _H_step in range(self.config.H_cycles):
+ for _L_step in range(self.config.L_cycles):
+ if not ((_H_step == self.config.H_cycles - 1) and (_L_step == self.config.L_cycles - 1)):
+ z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
+
+ if not (_H_step == self.config.H_cycles - 1):
+ z_H = self.H_level(z_H, z_L, **seq_info)
+
+ assert not z_H.requires_grad and not z_L.requires_grad
+
+ # 1-step grad
+ z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
+ z_H = self.H_level(z_H, z_L, **seq_info)
+
+ # LM Outputs
+ new_carry = HierarchicalReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad
+ output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
+
+ # Q head
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
+
+ return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
+
+
+class HierarchicalReasoningModel_ACTV1(nn.Module):
+ """ACT wrapper."""
+
+ def __init__(self, config_dict: dict):
+ super().__init__()
+ self.config = HierarchicalReasoningModel_ACTV1Config(**config_dict)
+ self.inner = HierarchicalReasoningModel_ACTV1_Inner(self.config)
+
+ @property
+ def puzzle_emb(self):
+ return self.inner.puzzle_emb
+
+ def initial_carry(self, batch: Dict[str, torch.Tensor]):
+ batch_size = batch["inputs"].shape[0]
+
+ return HierarchicalReasoningModel_ACTV1Carry(
+ inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
+
+ steps=torch.zeros((batch_size, ), dtype=torch.int32),
+ halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
+
+ current_data={k: torch.empty_like(v) for k, v in batch.items()}
+ )
+
+ def forward(self, carry: HierarchicalReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
+ # Update data, carry (removing halted sequences)
+ new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
+
+ new_steps = torch.where(carry.halted, 0, carry.steps)
+
+ new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
+
+ # Forward inner model
+ new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
+
+ outputs = {
+ "logits": logits,
+ "q_halt_logits": q_halt_logits,
+ "q_continue_logits": q_continue_logits
+ }
+
+ with torch.no_grad():
+ # Step
+ new_steps = new_steps + 1
+ is_last_step = new_steps >= self.config.halt_max_steps
+
+ halted = is_last_step
+
+ # if training, and ACT is enabled
+ if self.training and (self.config.halt_max_steps > 1):
+ # Halt signal
+ # NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
+ halted = halted | (q_halt_logits > q_continue_logits)
+
+ # Exploration
+ min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
+
+ halted = halted & (new_steps >= min_halt_steps)
+
+ # Compute target Q
+ # NOTE: No replay buffer and target networks for computing target Q-value.
+ # As batch_size is large, there're many parallel envs.
+ # Similar concept as PQN https://arxiv.org/abs/2407.04811
+ next_q_halt_logits, next_q_continue_logits = self.inner(new_inner_carry, new_current_data)[-1]
+
+ outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
+
+ return HierarchicalReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
diff --git a/models/layers.py b/models/layers.py
new file mode 100644
index 0000000..4f7dee4
--- /dev/null
+++ b/models/layers.py
@@ -0,0 +1,150 @@
+from typing import Tuple
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from models.common import trunc_normal_init_
+
+
+CosSin = Tuple[torch.Tensor, torch.Tensor]
+
+
+def _find_multiple(a, b):
+ return (-(a // -b)) * b
+
+
+def rotate_half(x: torch.Tensor):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
+ # q, k: [bs, num_heads, seq_len, head_dim]
+ # cos, sin: [seq_len, head_dim]
+ orig_dtype = q.dtype
+ q = q.to(cos.dtype)
+ k = k.to(cos.dtype)
+
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+
+ return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
+
+
+class CastedLinear(nn.Module):
+ def __init__(self,
+ in_features: int,
+ out_features: int,
+ bias: bool):
+ super().__init__()
+ # Truncated LeCun normal init
+ self.weight = nn.Parameter(
+ trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5))
+ )
+ self.bias = None
+ if bias:
+ # Zero init bias
+ self.bias = nn.Parameter(torch.zeros((out_features, )))
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return F.linear(input, self.weight.to(input.dtype), bias=self.bias.to(input.dtype) if self.bias is not None else None)
+
+
+class CastedEmbedding(nn.Module):
+ def __init__(self,
+ num_embeddings: int,
+ embedding_dim: int,
+ init_std: float,
+ cast_to: torch.dtype):
+ super().__init__()
+ self.cast_to = cast_to
+
+ # Truncated LeCun normal init
+ self.embedding_weight = nn.Parameter(
+ trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std)
+ )
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return F.embedding(input, self.embedding_weight.to(self.cast_to))
+
+
+class RotaryEmbedding(nn.Module):
+ def __init__(self, dim, max_position_embeddings, base, device=None):
+ super().__init__()
+
+ # RoPE
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
+ t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device)
+ freqs = torch.outer(t, inv_freq)
+
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.cos_cached = nn.Buffer(emb.cos(), persistent=False)
+ self.sin_cached = nn.Buffer(emb.sin(), persistent=False)
+
+ def forward(self):
+ return self.cos_cached, self.sin_cached
+
+
+class Attention(nn.Module):
+ def __init__(self, hidden_size, head_dim, num_heads, num_key_value_heads, causal=False):
+ super().__init__()
+
+ self.hidden_size = hidden_size
+ self.head_dim = head_dim
+ self.output_size = head_dim * num_heads
+ self.num_heads = num_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.causal = causal
+
+ self.qkv_proj = CastedLinear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False)
+ self.o_proj = CastedLinear(self.output_size, self.hidden_size, bias=False)
+
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, seq_len, _ = hidden_states.shape
+
+ # hidden_states: [bs, seq_len, num_heads, head_dim]
+ qkv = self.qkv_proj(hidden_states)
+
+ # Split head
+ qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim).transpose(-2, -3)
+ query = qkv[:, :self.num_heads]
+ key = qkv[:, self.num_heads: self.num_heads + self.num_key_value_heads]
+ value = qkv[:, self.num_heads + self.num_key_value_heads:]
+
+ # RoPE
+ if cos_sin is not None:
+ cos, sin = cos_sin
+ query, key = apply_rotary_pos_emb(query, key, cos, sin)
+
+ # flash attn
+ attn_output = F.scaled_dot_product_attention(query=query, key=key, value=value, is_causal=self.causal)
+
+ # attn_output: [batch_size, num_heads, seq_len, head_dim]
+ attn_output = attn_output.transpose(-2, -3).view(batch_size, seq_len, self.output_size) # type: ignore
+ return self.o_proj(attn_output)
+
+
+class SwiGLU(nn.Module):
+ def __init__(self, hidden_size: int, expansion: float):
+ super().__init__()
+ inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256)
+
+ self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False)
+ self.down_proj = CastedLinear(inter, hidden_size, bias=False)
+
+ def forward(self, x):
+ gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
+ return self.down_proj(F.silu(gate) * up)
+
+
+def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor:
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+
+ variance = hidden_states.square().mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
+ return hidden_states.to(input_dtype)
diff --git a/models/losses.py b/models/losses.py
new file mode 100644
index 0000000..b3118e7
--- /dev/null
+++ b/models/losses.py
@@ -0,0 +1,101 @@
+from typing import Any, Tuple, Dict, Sequence, Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+IGNORE_LABEL_ID = -100
+
+
+def s(x, epsilon=1e-30):
+ return torch.where(
+ x<0,
+ 1/(1-x+ epsilon),
+ x + 1
+ )
+
+
+def log_stablemax(x, dim=-1):
+ s_x = s(x)
+ return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
+
+
+def stablemax_cross_entropy(logits, labels, ignore_index: int = -100):
+ logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
+
+ valid_mask = labels != ignore_index
+ transformed_labels = torch.where(valid_mask, labels, 0)
+ prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
+
+ return -torch.where(valid_mask, prediction_logprobs, 0)
+
+
+def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
+ # Cast logits to f32
+ # Flatten logits
+ return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
+
+
+class ACTLossHead(nn.Module):
+ def __init__(self, model: nn.Module, loss_type: str):
+ super().__init__()
+ self.model = model
+ self.loss_fn = globals()[loss_type]
+
+ def initial_carry(self, *args, **kwargs):
+ return self.model.initial_carry(*args, **kwargs) # type: ignore
+
+ def forward(
+ self,
+ return_keys: Sequence[str],
+ # Model args
+ **model_kwargs,
+ ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
+ # Model logits
+ # B x SeqLen x D
+ new_carry, outputs = self.model(**model_kwargs)
+ labels = new_carry.current_data["labels"]
+
+ # Correctness
+ with torch.no_grad():
+ mask = labels != IGNORE_LABEL_ID
+ loss_counts = mask.sum(-1)
+ loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
+
+ is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
+ seq_is_correct = is_correct.sum(-1) == loss_counts
+
+ # Metrics (halted)
+ valid_metrics = new_carry.halted & (loss_counts > 0)
+ metrics = {
+ "count": valid_metrics.sum(),
+
+ "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
+ "exact_accuracy": (valid_metrics & seq_is_correct).sum(),
+
+ "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
+ "steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
+ }
+
+ # Losses
+ # FIXME: Assuming the batch is always full
+ lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID) / loss_divisor).sum()
+ q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
+
+ metrics.update({
+ "lm_loss": lm_loss.detach(),
+ "q_halt_loss": q_halt_loss.detach(),
+ })
+
+ # Q continue (bootstrapping target loss)
+ q_continue_loss = 0
+ if "target_q_continue" in outputs:
+ q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
+
+ metrics["q_continue_loss"] = q_continue_loss.detach()
+
+ # Filter outputs for return
+ detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
+
+ return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
diff --git a/models/sparse_embedding.py b/models/sparse_embedding.py
new file mode 100644
index 0000000..c701524
--- /dev/null
+++ b/models/sparse_embedding.py
@@ -0,0 +1,132 @@
+from typing import Union
+
+import torch
+from torch import nn
+import torch.distributed as dist
+from torch.optim.optimizer import Optimizer, ParamsT
+
+from models.common import trunc_normal_init_
+
+
+class CastedSparseEmbedding(nn.Module):
+ def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, init_std: float, cast_to: torch.dtype):
+ super().__init__()
+ self.cast_to = cast_to
+
+ # Real Weights
+ # Truncated LeCun normal init
+ self.weights = nn.Buffer(
+ trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std), persistent=True
+ )
+
+ # Local weights and IDs
+ # Local embeddings, with gradient, not persistent
+ self.local_weights = nn.Buffer(torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False)
+ # Local embedding IDs, not persistent
+ self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False)
+
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ if not self.training:
+ # Test mode, no gradient
+ return self.weights[inputs].to(self.cast_to)
+
+ # Training mode, fill puzzle embedding from weights
+ with torch.no_grad():
+ self.local_weights.copy_(self.weights[inputs])
+ self.local_ids.copy_(inputs)
+
+ return self.local_weights.to(self.cast_to)
+
+
+class CastedSparseEmbeddingSignSGD_Distributed(Optimizer):
+ def __init__(
+ self,
+ params: ParamsT,
+
+ world_size: int,
+ lr: Union[float, torch.Tensor] = 1e-3,
+ weight_decay: float = 1e-2,
+ ):
+ if not 0.0 <= lr:
+ raise ValueError(f"Invalid learning rate: {lr}")
+ if not 0.0 <= weight_decay:
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
+
+ defaults = dict(
+ lr=lr,
+ weight_decay=weight_decay,
+ world_size=world_size
+ )
+ super().__init__(params, defaults)
+
+ @torch.no_grad
+ def step(self, closure=None): # type: ignore
+ for group in self.param_groups:
+ # Find the sparse embedding weights
+ local_weights_grad = None
+ local_ids = None
+ weights = None
+
+ assert len(group["params"]) == 3
+ for p in group["params"]:
+ if p.requires_grad:
+ local_weights_grad = p.grad
+ elif p.ndim == 1:
+ local_ids = p
+ elif p.ndim == 2:
+ weights = p
+ else:
+ assert False
+
+ assert local_weights_grad is not None
+ assert local_ids is not None
+ assert weights is not None
+
+ # Apply SignSGD
+ # Adam ≈ SignSGD if gradient is very sparse
+ _sparse_emb_signsgd_dist(
+ local_weights_grad,
+ local_ids,
+ weights,
+
+ lr=group["lr"],
+ weight_decay=group["weight_decay"],
+ world_size=group["world_size"]
+ )
+
+
+def _sparse_emb_signsgd_dist(
+ local_weights_grad: torch.Tensor,
+ local_ids: torch.Tensor,
+ weights: torch.Tensor,
+
+ lr: float,
+ weight_decay: float,
+ world_size: int
+) -> None:
+ N, D = local_weights_grad.shape
+
+ # All-gather
+ all_weights_grad = local_weights_grad
+ all_ids = local_ids
+
+ if world_size > 1:
+ all_weights_grad = torch.empty((world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device)
+ all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device)
+
+ dist.all_gather_into_tensor(all_weights_grad, local_weights_grad)
+ dist.all_gather_into_tensor(all_ids, local_ids)
+
+ # Unique
+ grad_ids, inv = all_ids.unique(return_inverse=True)
+
+ grad = torch.zeros((grad_ids.shape[0], D), dtype=all_weights_grad.dtype, device=all_weights_grad.device)
+ grad.scatter_add_(0, inv.unsqueeze(-1).expand(-1, D), all_weights_grad)
+
+ # SignSGD with decoupled weight decay
+ p = weights[grad_ids]
+
+ p.mul_(1.0 - lr * weight_decay).add_(torch.sign(grad), alpha=-lr)
+
+ # Write updated slices back
+ weights[grad_ids] = p
diff --git a/pretrain.py b/pretrain.py
new file mode 100644
index 0000000..b939318
--- /dev/null
+++ b/pretrain.py
@@ -0,0 +1,454 @@
+from typing import Optional, Any, Sequence, List
+from dataclasses import dataclass
+import os
+import math
+import yaml
+import shutil
+
+import torch
+import torch.distributed as dist
+from torch import nn
+from torch.utils.data import DataLoader
+
+import tqdm
+import wandb
+import coolname
+import hydra
+import pydantic
+from omegaconf import DictConfig
+from wandb.util import make_artifact_name_safe
+from adam_atan2 import AdamATan2
+
+from puzzle_dataset import PuzzleDataset, PuzzleDatasetConfig, PuzzleDatasetMetadata
+from utils.functions import load_model_class, get_model_source_path
+from models.sparse_embedding import CastedSparseEmbeddingSignSGD_Distributed
+
+
+class LossConfig(pydantic.BaseModel):
+ model_config = pydantic.ConfigDict(extra='allow')
+
+ name: str
+
+
+class ArchConfig(pydantic.BaseModel):
+ model_config = pydantic.ConfigDict(extra='allow')
+
+ name: str
+ loss: LossConfig
+
+
+class PretrainConfig(pydantic.BaseModel):
+ # Config
+ arch: ArchConfig
+ # Data
+ data_path: str
+
+ # Hyperparams
+ global_batch_size: int
+ epochs: int
+
+ lr: float
+ lr_min_ratio: float
+ lr_warmup_steps: int
+
+ weight_decay: float
+ beta1: float
+ beta2: float
+
+ # Puzzle embedding
+ puzzle_emb_lr: float
+ puzzle_emb_weight_decay: float
+
+ # Names
+ project_name: Optional[str] = None
+ run_name: Optional[str] = None
+ checkpoint_path: Optional[str] = None
+
+ # Extras
+ seed: int = 0
+ checkpoint_every_eval: bool = False
+ eval_interval: Optional[int] = None
+ eval_save_outputs: List[str] = []
+
+
+@dataclass
+class TrainState:
+ model: nn.Module
+ optimizers: Sequence[torch.optim.Optimizer]
+ optimizer_lrs: Sequence[float]
+ carry: Any
+
+ step: int
+ total_steps: int
+
+
+def create_dataloader(config: PretrainConfig, split: str, rank: int, world_size: int, **kwargs):
+ dataset = PuzzleDataset(PuzzleDatasetConfig(
+ seed=config.seed,
+
+ dataset_path=config.data_path,
+
+ rank=rank,
+ num_replicas=world_size,
+
+ **kwargs
+ ), split=split)
+ dataloader = DataLoader(
+ dataset,
+ batch_size=None,
+
+ num_workers=1,
+ prefetch_factor=8,
+
+ pin_memory=True,
+ persistent_workers=True
+ )
+ return dataloader, dataset.metadata
+
+
+def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int):
+ model_cfg = dict(
+ **config.arch.__pydantic_extra__, # type: ignore
+
+ batch_size=config.global_batch_size // world_size,
+
+ vocab_size=train_metadata.vocab_size,
+ seq_len=train_metadata.seq_len,
+ num_puzzle_identifiers=train_metadata.num_puzzle_identifiers,
+ causal=False # Non-autoregressive
+ )
+
+ # Instantiate model with loss head
+ model_cls = load_model_class(config.arch.name)
+ loss_head_cls = load_model_class(config.arch.loss.name)
+
+ with torch.device("cuda"):
+ model: nn.Module = model_cls(model_cfg)
+ model = loss_head_cls(model, **config.arch.loss.__pydantic_extra__) # type: ignore
+ if "DISABLE_COMPILE" not in os.environ:
+ model = torch.compile(model, dynamic=False, fullgraph=True) # type: ignore
+
+ # Broadcast parameters from rank 0
+ if world_size > 1:
+ with torch.no_grad():
+ for param in list(model.parameters()) + list(model.buffers()):
+ dist.broadcast(param, src=0)
+
+ # Optimizers and lr
+ optimizers = [
+ CastedSparseEmbeddingSignSGD_Distributed(
+ model.model.puzzle_emb.buffers(), # type: ignore
+
+ lr=0, # Needs to be set by scheduler
+ weight_decay=config.puzzle_emb_weight_decay,
+
+ world_size=world_size
+ ),
+ AdamATan2(
+ model.parameters(),
+
+ lr=0, # Needs to be set by scheduler
+ weight_decay=config.weight_decay,
+ betas=(config.beta1, config.beta2)
+ )
+ ]
+ optimizer_lrs = [
+ config.puzzle_emb_lr,
+ config.lr
+ ]
+
+ return model, optimizers, optimizer_lrs
+
+
+def cosine_schedule_with_warmup_lr_lambda(
+ current_step: int, *, base_lr: float, num_warmup_steps: int, num_training_steps: int, min_ratio: float = 0.0, num_cycles: float = 0.5
+):
+ if current_step < num_warmup_steps:
+ return base_lr * float(current_step) / float(max(1, num_warmup_steps))
+
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
+ return base_lr * (min_ratio + max(0.0, (1 - min_ratio) * 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))))
+
+
+def init_train_state(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int):
+ # Estimated total training steps
+ total_steps = int(config.epochs * train_metadata.total_groups * train_metadata.mean_puzzle_examples / config.global_batch_size)
+
+ # Model
+ model, optimizers, optimizer_lrs = create_model(config, train_metadata, world_size=world_size)
+
+ return TrainState(
+ step=0,
+ total_steps=total_steps,
+
+ model=model,
+ optimizers=optimizers,
+ optimizer_lrs=optimizer_lrs,
+ carry=None
+ )
+
+
+def save_train_state(config: PretrainConfig, train_state: TrainState):
+ # FIXME: Only saved model.
+ if config.checkpoint_path is None:
+ return
+
+ os.makedirs(config.checkpoint_path, exist_ok=True)
+ torch.save(train_state.model.state_dict(), os.path.join(config.checkpoint_path, f"step_{train_state.step}"))
+
+
+def compute_lr(base_lr: float, config: PretrainConfig, train_state: TrainState):
+ return cosine_schedule_with_warmup_lr_lambda(
+ current_step=train_state.step,
+ base_lr=base_lr,
+ num_warmup_steps=round(config.lr_warmup_steps),
+ num_training_steps=train_state.total_steps,
+ min_ratio=config.lr_min_ratio
+ )
+
+
+def train_batch(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int, world_size: int):
+ train_state.step += 1
+ if train_state.step > train_state.total_steps: # At most train_total_steps
+ return
+
+ # To device
+ batch = {k: v.cuda() for k, v in batch.items()}
+
+ # Init carry if it is None
+ if train_state.carry is None:
+ with torch.device("cuda"):
+ train_state.carry = train_state.model.initial_carry(batch) # type: ignore
+
+ # Forward
+ train_state.carry, loss, metrics, _, _ = train_state.model(carry=train_state.carry, batch=batch, return_keys=[])
+
+ ((1 / global_batch_size) * loss).backward()
+
+ # Allreduce
+ if world_size > 1:
+ for param in train_state.model.parameters():
+ if param.grad is not None:
+ dist.all_reduce(param.grad)
+
+ # Apply optimizer
+ lr_this_step = None
+ for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs):
+ lr_this_step = compute_lr(base_lr, config, train_state)
+
+ for param_group in optim.param_groups:
+ param_group['lr'] = lr_this_step
+
+ optim.step()
+ optim.zero_grad()
+
+ # Reduce metrics
+ if len(metrics):
+ assert not any(v.requires_grad for v in metrics.values())
+
+ metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order.
+ # Reduce and reconstruct
+ metric_values = torch.stack([metrics[k] for k in metric_keys])
+ if world_size > 1:
+ dist.reduce(metric_values, dst=0)
+
+ if rank == 0:
+ metric_values = metric_values.cpu().numpy()
+ reduced_metrics = {k: metric_values[i] for i, k in enumerate(metric_keys)}
+
+ # Postprocess
+ count = max(reduced_metrics["count"], 1) # Avoid NaNs
+ reduced_metrics = {f"train/{k}": v / (global_batch_size if k.endswith("loss") else count) for k, v in reduced_metrics.items()}
+
+ reduced_metrics["train/lr"] = lr_this_step
+ return reduced_metrics
+
+
+def evaluate(config: PretrainConfig, train_state: TrainState, eval_loader: torch.utils.data.DataLoader, eval_metadata: PuzzleDatasetMetadata, rank: int, world_size: int):
+ with torch.inference_mode():
+ set_ids = {k: idx for idx, k in enumerate(eval_metadata.sets)}
+
+ all_preds = {}
+
+ metric_keys = []
+ metric_values = None
+ metric_global_batch_size = [0 for _ in range(len(set_ids))]
+
+ carry = None
+ for set_name, batch, global_batch_size in eval_loader:
+ # To device
+ batch = {k: v.cuda() for k, v in batch.items()}
+ with torch.device("cuda"):
+ carry = train_state.model.initial_carry(batch) # type: ignore
+
+ # Forward
+ while True:
+ carry, _, metrics, preds, all_finish = train_state.model(carry=carry, batch=batch, return_keys=config.eval_save_outputs)
+
+ if all_finish:
+ break
+
+ for collection in (batch, preds):
+ for k, v in collection.items():
+ if k in config.eval_save_outputs:
+ all_preds.setdefault(k, [])
+ all_preds[k].append(v.cpu()) # Move to CPU for saving GPU memory
+
+ del carry, preds, batch, all_finish
+
+ # Aggregate
+ set_id = set_ids[set_name]
+
+ if metric_values is None:
+ metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order.
+ metric_values = torch.zeros((len(set_ids), len(metrics.values())), dtype=torch.float32, device="cuda")
+
+ metric_values[set_id] += torch.stack([metrics[k] for k in metric_keys])
+ metric_global_batch_size[set_id] += global_batch_size
+
+ if len(all_preds) and config.checkpoint_path is not None:
+ all_preds = {k: torch.cat(v, dim=0) for k, v in all_preds.items()}
+
+ os.makedirs(config.checkpoint_path, exist_ok=True)
+ torch.save(all_preds, os.path.join(config.checkpoint_path, f"step_{train_state.step}_all_preds.{rank}"))
+
+ # Logging
+ # Reduce to rank 0
+ if metric_values is not None:
+ if world_size > 1:
+ dist.reduce(metric_values, dst=0)
+
+ if rank == 0:
+ reduced_metrics = metric_values.cpu().numpy()
+ reduced_metrics = {set_name: {metric_name: reduced_metrics[set_id, metric_id] for metric_id, metric_name in enumerate(metric_keys)}
+ for set_id, set_name in enumerate(set_ids)}
+
+ # Postprocess
+ for set_name, metrics in reduced_metrics.items():
+ count = metrics.pop("count")
+ reduced_metrics[set_name] = {k: v / count for k, v in metrics.items()}
+
+ return reduced_metrics
+
+
+def save_code_and_config(config: PretrainConfig):
+ if config.checkpoint_path is None or wandb.run is None:
+ return
+
+ os.makedirs(config.checkpoint_path, exist_ok=True)
+
+ # Copy code
+ code_list = [
+ get_model_source_path(config.arch.name),
+ get_model_source_path(config.arch.loss.name)
+ ]
+ for code_file in code_list:
+ if code_file is not None:
+ code_name = os.path.basename(code_file)
+
+ shutil.copy(code_file, os.path.join(config.checkpoint_path, code_name))
+
+ # Dump config as yaml
+ config_file = os.path.join(config.checkpoint_path, "all_config.yaml")
+ with open(config_file, "wt") as f:
+ yaml.dump(config.model_dump(), f)
+
+ # Log code
+ wandb.run.log_code(config.checkpoint_path)
+
+
+def load_synced_config(hydra_config: DictConfig, rank: int, world_size: int) -> PretrainConfig:
+ objects = [None]
+ if rank == 0:
+ config = PretrainConfig(**hydra_config) # type: ignore
+
+ # Naming
+ if config.project_name is None:
+ config.project_name = f"{os.path.basename(config.data_path).capitalize()} ACT-torch"
+ if config.run_name is None:
+ config.run_name = f"{config.arch.name.split('@')[-1]} {coolname.generate_slug(2)}"
+ if config.checkpoint_path is None:
+ config.checkpoint_path = os.path.join("checkpoints", config.project_name, config.run_name)
+
+ objects = [config]
+
+ if world_size > 1:
+ dist.broadcast_object_list(objects, src=0)
+
+ return objects[0] # type: ignore
+
+
+@hydra.main(config_path="config", config_name="cfg_pretrain", version_base=None)
+def launch(hydra_config: DictConfig):
+ RANK = 0
+ WORLD_SIZE = 1
+
+ # Initialize distributed training if in distributed environment (e.g. torchrun)
+ if "LOCAL_RANK" in os.environ:
+ # Initialize distributed, default device and dtype
+ dist.init_process_group(backend="nccl")
+
+ RANK = dist.get_rank()
+ WORLD_SIZE = dist.get_world_size()
+
+ torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
+
+ # Load sync'ed config
+ config = load_synced_config(hydra_config, rank=RANK, world_size=WORLD_SIZE)
+
+ # Seed RNGs to ensure consistency
+ torch.random.manual_seed(config.seed + RANK)
+
+ # Dataset
+ train_epochs_per_iter = config.eval_interval if config.eval_interval is not None else config.epochs
+ total_iters = config.epochs // train_epochs_per_iter
+
+ assert config.epochs % train_epochs_per_iter == 0, "Eval interval must be a divisor of total epochs."
+
+ train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=train_epochs_per_iter, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
+ eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
+
+ # Train state
+ train_state = init_train_state(config, train_metadata, world_size=WORLD_SIZE)
+
+ # Progress bar and logger
+ progress_bar = None
+ if RANK == 0:
+ progress_bar = tqdm.tqdm(total=train_state.total_steps)
+
+ wandb.init(project=config.project_name, name=config.run_name, config=config.model_dump(), settings=wandb.Settings(_disable_stats=True)) # type: ignore
+ wandb.log({"num_params": sum(x.numel() for x in train_state.model.parameters())}, step=0)
+ save_code_and_config(config)
+
+ # Training Loop
+ for _iter_id in range(total_iters):
+ print (f"[Rank {RANK}, World Size {WORLD_SIZE}]: Epoch {_iter_id * train_epochs_per_iter}")
+
+ ############ Train Iter
+ train_state.model.train()
+ for set_name, batch, global_batch_size in train_loader:
+ metrics = train_batch(config, train_state, batch, global_batch_size, rank=RANK, world_size=WORLD_SIZE)
+
+ if RANK == 0 and metrics is not None:
+ wandb.log(metrics, step=train_state.step)
+ progress_bar.update(train_state.step - progress_bar.n) # type: ignore
+
+ ############ Evaluation
+ train_state.model.eval()
+ metrics = evaluate(config, train_state, eval_loader, eval_metadata, rank=RANK, world_size=WORLD_SIZE)
+
+ if RANK == 0 and metrics is not None:
+ wandb.log(metrics, step=train_state.step)
+
+ ############ Checkpointing
+ if RANK == 0 and (config.checkpoint_every_eval or (_iter_id == total_iters - 1)):
+ save_train_state(config, train_state)
+
+ # finalize
+ if dist.is_initialized():
+ dist.destroy_process_group()
+ wandb.finish()
+
+
+if __name__ == "__main__":
+ launch()
diff --git a/puzzle_dataset.py b/puzzle_dataset.py
new file mode 100644
index 0000000..2782403
--- /dev/null
+++ b/puzzle_dataset.py
@@ -0,0 +1,199 @@
+import os
+import json
+
+import numpy as np
+import pydantic
+
+import torch
+from torch.utils.data import IterableDataset, get_worker_info
+
+from models.losses import IGNORE_LABEL_ID
+from dataset.common import PuzzleDatasetMetadata
+
+
+def _sample_batch(rng: np.random.Generator, group_order: np.ndarray, puzzle_indices: np.ndarray, group_indices: np.ndarray, start_index: int, global_batch_size: int):
+ # Pack examples into a full batch
+ batch = []
+ batch_puzzle_indices = []
+ current_size = 0
+
+ while (start_index < group_order.size) and (current_size < global_batch_size):
+ # Pick a group and a puzzle from that group
+ group_id = group_order[start_index]
+ puzzle_id = rng.integers(group_indices[group_id], group_indices[group_id + 1])
+ start_index += 1
+
+ # Get range of the puzzle
+ puzzle_start = puzzle_indices[puzzle_id]
+ puzzle_size = int(puzzle_indices[puzzle_id + 1] - puzzle_start)
+
+ append_size = min(puzzle_size, global_batch_size - current_size)
+
+ # Put into batch
+ batch_puzzle_indices.append(np.full(append_size, puzzle_id, dtype=np.int32))
+ batch.append(puzzle_start + np.random.choice(puzzle_size, append_size, replace=False))
+
+ current_size += append_size
+
+ return start_index, np.concatenate(batch), np.concatenate(batch_puzzle_indices)
+
+
+class PuzzleDatasetConfig(pydantic.BaseModel):
+ seed: int
+ dataset_path: str
+ global_batch_size: int
+ test_set_mode: bool
+
+ epochs_per_iter: int # Batch X epochs in an iteration to reduce overhead.
+
+ rank: int
+ num_replicas: int
+
+
+class PuzzleDataset(IterableDataset):
+ def __init__(self, config: PuzzleDatasetConfig, split: str = "train"):
+ super().__init__()
+ self.config = config
+ self.split = split
+ self.metadata = self._load_metadata()
+
+ # Checks
+ assert self.config.global_batch_size % self.config.num_replicas == 0, f"Global batch size {self.config.global_batch_size} must be multiples of nodes {self.config.num_replicas}."
+ self.local_batch_size = self.config.global_batch_size // self.config.num_replicas
+
+ # State
+ self._data = None
+ self._iters = 0
+
+ def _load_metadata(self) -> PuzzleDatasetMetadata:
+ with open(os.path.join(self.config.dataset_path, self.split, "dataset.json"), "r") as f:
+ return PuzzleDatasetMetadata(**json.load(f))
+
+ def _lazy_load_dataset(self):
+ if self._data is not None:
+ return
+
+ field_mmap_modes = {
+ "inputs": "r",
+ "labels": "r",
+
+ # Keep indices in memory
+ "puzzle_identifiers": None,
+ "puzzle_indices": None,
+ "group_indices": None
+ }
+
+ # Load data
+ self._data = {}
+ for set_name in self.metadata.sets:
+ # Load subset
+ self._data[set_name] = {
+ field_name: np.load(os.path.join(self.config.dataset_path, self.split, f"{set_name}__{field_name}.npy"), mmap_mode=mmap_mode)
+ for field_name, mmap_mode in field_mmap_modes.items()
+ }
+
+ def _collate_batch(self, batch):
+ # Convert dtype
+ batch = {k: v.astype(np.int32) for k, v in batch.items()}
+
+ # Convert ignore label IDs
+ if self.metadata.ignore_label_id is not None:
+ batch["labels"][batch["labels"] == self.metadata.ignore_label_id] = IGNORE_LABEL_ID
+
+ # Pad
+ if batch["puzzle_identifiers"].size < self.local_batch_size:
+ pad_size = self.local_batch_size - batch["puzzle_identifiers"].size
+
+ pad_values = {
+ "inputs": self.metadata.pad_id,
+ "labels": IGNORE_LABEL_ID,
+
+ "puzzle_identifiers": self.metadata.blank_identifier_id
+ }
+ batch = {k: np.pad(v, ((0, pad_size), ) + ((0, 0), ) * (v.ndim - 1), constant_values=pad_values[k]) for k, v in batch.items()}
+
+ # To tensor
+ return {k: torch.from_numpy(v) for k, v in batch.items()}
+
+ def _iter_test(self):
+ for set_name, dataset in self._data.items(): # type: ignore
+ total_examples = len(dataset["inputs"])
+
+ # Load examples one by one
+ start_index = 0
+ while start_index < total_examples:
+ # Compute indices
+ end_index = min(total_examples, start_index + self.config.global_batch_size)
+
+ local_start = start_index + self.config.rank * self.local_batch_size
+ local_end = min(start_index + (self.config.rank + 1) * self.local_batch_size, end_index)
+
+ # Get batch of examples, and also puzzle IDs
+ puzzle_indices = []
+ puzzle_index = np.searchsorted(dataset["puzzle_indices"], local_start, side="right") - 1
+ for i in range(local_start, local_end):
+ while puzzle_index + 1 < len(dataset["puzzle_indices"]) and i >= dataset["puzzle_indices"][puzzle_index + 1]:
+ puzzle_index += 1
+
+ puzzle_indices.append(puzzle_index)
+
+ batch = self._collate_batch({
+ "inputs": dataset["inputs"][local_start: local_end],
+ "labels": dataset["labels"][local_start: local_end],
+ "puzzle_identifiers": dataset["puzzle_identifiers"][puzzle_indices]
+ })
+
+ yield set_name, batch, end_index - start_index
+
+ # Advance to next batch
+ start_index += self.config.global_batch_size
+
+ def _iter_train(self):
+ for set_name, dataset in self._data.items(): # type: ignore
+ # Increase epoch count
+ self._iters += 1
+
+ # Randomly shuffle groups
+ rng = np.random.Generator(np.random.Philox(seed=self.config.seed + self._iters))
+
+ group_order = np.concatenate([rng.permutation(dataset["group_indices"].size - 1) for _i in range(self.config.epochs_per_iter)])
+ start_index = 0
+
+ while start_index < group_order.size:
+ start_index, batch_indices, batch_puzzle_indices = _sample_batch(
+ rng,
+ group_order=group_order,
+ puzzle_indices=dataset["puzzle_indices"],
+ group_indices=dataset["group_indices"],
+ start_index=start_index,
+ global_batch_size=self.config.global_batch_size,
+ )
+
+ # Select current rank and collate
+ global_effective_batch_size = batch_puzzle_indices.size # Global effective batch size, excluding pads
+
+ # Drop last batch
+ if global_effective_batch_size < self.config.global_batch_size:
+ break
+
+ batch_indices = batch_indices [self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size]
+ batch_puzzle_indices = batch_puzzle_indices[self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size]
+ batch = self._collate_batch({
+ "inputs": dataset["inputs"][batch_indices],
+ "labels": dataset["labels"][batch_indices],
+ "puzzle_identifiers": dataset["puzzle_identifiers"][batch_puzzle_indices]
+ })
+
+ yield set_name, batch, global_effective_batch_size
+
+ def __iter__(self):
+ worker_info = get_worker_info()
+ assert worker_info is None or worker_info.num_workers == 1, "Multithreaded data loading is not currently supported."
+
+ self._lazy_load_dataset()
+
+ # Iterate using specified mode
+ if self.config.test_set_mode:
+ yield from self._iter_test()
+ else:
+ yield from self._iter_train()
diff --git a/puzzle_visualizer.html b/puzzle_visualizer.html
new file mode 100644
index 0000000..bcefdf1
--- /dev/null
+++ b/puzzle_visualizer.html
@@ -0,0 +1,426 @@
+<!DOCTYPE html>
+<html>
+<head>
+ <meta charset="UTF-8" />
+ <title>ARC‐Converted Dataset Visualizer (Upload Local Folder)</title>
+ <style>
+ body {
+ font-family: sans-serif;
+ margin: 16px;
+ }
+ .selector-area {
+ margin-bottom: 1rem;
+ }
+ .grid-canvas {
+ margin: 4px;
+ border: 1px solid #ccc;
+ }
+ .example-container {
+ display: inline-block;
+ margin: 0 16px 16px 0;
+ vertical-align: top;
+ }
+ .puzzle-display {
+ margin-top: 1rem;
+ }
+ .puzzle-id {
+ font-weight: bold;
+ margin-bottom: 0.5rem;
+ }
+ #groupList, #puzzleList {
+ margin: 1rem 0;
+ }
+ .group-item, .puzzle-item {
+ cursor: pointer;
+ margin: 4px 8px 4px 0;
+ padding: 2px 6px;
+ border: 1px solid #aaa;
+ display: inline-block;
+ }
+ .group-item:hover, .puzzle-item:hover {
+ background: #eef;
+ }
+ </style>
+</head>
+<body>
+<h1>ARC‐Converted Dataset Visualizer (Local Directory)</h1>
+
+<div class="selector-area">
+ <!-- 1) Directory input with webkitdirectory, mozdirectory -->
+ <label>Upload ARC Folder:</label>
+ <input type="file" id="folderInput"
+ webkitdirectory mozdirectory multiple
+ onchange="onFolderSelected(event)" />
+ <br><br>
+
+ <!-- 2) We'll enable set/subset selection after user chooses a folder and data is validated -->
+ <label>Set:</label>
+ <select id="setSelect" disabled>
+ <option value="train">train</option>
+ <option value="test">test</option>
+ </select>
+
+ <label> Subset:</label>
+ <select id="subsetSelect" disabled>
+ <option value="all">all</option>
+ </select>
+
+ <button id="loadBtn" disabled>Load</button>
+</div>
+
+<div>
+ <div id="groupList"></div>
+ <div id="puzzleList"></div>
+ <div class="puzzle-display" id="puzzleView"></div>
+</div>
+
+<!--
+ 3) Use local 'assets/npyjs.js' from your project folder instead of a CDN.
+ Make sure 'assets/npyjs.js' is the unbundled or UMD version that doesn't
+ contain "import" statements.
+-->
+<script src="assets/npyjs.js"></script>
+
+<script>
+/***************************************************************************
+ * Global Maps & Variables
+ ***************************************************************************/
+
+// Map from "train/all__inputs.npy" => File, etc.
+let filesByPath = {};
+
+// Once loaded, we store typed arrays for the chosen set/subset
+let inputsArr, labelsArr;
+let puzzleIndicesArr, groupIndicesArr, puzzleIdentifiersArr;
+let identifiersJson;
+
+// The shape of inputs is [N_examples, seqLen], so we discover seqLen & gridSize
+let seqLen = 0;
+let gridSize = 0;
+
+
+/***************************************************************************
+ * 1) Handle folder selection: read all files, find identifiers.json,
+ * remove topmost folder from each file path, validate.
+ ***************************************************************************/
+function onFolderSelected(event) {
+ filesByPath = {};
+ const fileList = event.target.files;
+ if (!fileList || fileList.length === 0) {
+ alert("No files selected!");
+ return;
+ }
+
+ // We'll gather all webkitRelativePaths
+ const paths = [];
+ for (let i = 0; i < fileList.length; i++) {
+ // Typically "arc-aug-10/train/all__inputs.npy", etc.
+ const file = fileList[i];
+ const relPath = file.webkitRelativePath || file.mozRelativePath || file.name;
+ paths.push(relPath);
+ }
+
+ // 1. Check if we have "identifiers.json" somewhere.
+ const idPath = paths.find(p => p.endsWith("identifiers.json"));
+ if (!idPath) {
+ alert("Error: No 'identifiers.json' found in the uploaded folder.");
+ return;
+ }
+
+ // 2. Derive the top-level directory from that file's path
+ // e.g. if idPath = "arc-aug-10/identifiers.json", topDir = "arc-aug-10"
+ // If there's no slash, topDir = "" => do nothing
+ let topDir = "";
+ const lastSlash = idPath.lastIndexOf("/");
+ if (lastSlash >= 0) {
+ topDir = idPath.substring(0, lastSlash);
+ }
+
+ // 3. Rebuild filesByPath with the top folder removed.
+ // For example, if topDir = "arc-aug-10", then "arc-aug-10/train/all__inputs.npy"
+ // becomes "train/all__inputs.npy"
+ for (let i = 0; i < fileList.length; i++) {
+ const file = fileList[i];
+ let relPath = file.webkitRelativePath || file.mozRelativePath || file.name;
+ // If relPath starts with "arc-aug-10/", remove that prefix
+ if (topDir && relPath.startsWith(topDir + "/")) {
+ relPath = relPath.substring(topDir.length + 1);
+ }
+ filesByPath[relPath] = file;
+ }
+
+ // Enable set/subset selection and "Load"
+ document.getElementById("setSelect").disabled = false;
+ document.getElementById("subsetSelect").disabled = false;
+ document.getElementById("loadBtn").disabled = false;
+}
+
+// When user clicks "Load," parse the .npy for the chosen set/subset
+document.getElementById("loadBtn").addEventListener("click", async () => {
+ document.getElementById("groupList").innerHTML = "";
+ document.getElementById("puzzleList").innerHTML = "";
+ document.getElementById("puzzleView").innerHTML = "";
+
+ const setName = document.getElementById("setSelect").value; // e.g. "train"
+ const subsetName = document.getElementById("subsetSelect").value; // e.g. "all"
+
+ try {
+ await loadDataset(setName, subsetName);
+ buildGroupList(); // show groups
+ } catch (err) {
+ console.error(err);
+ alert("Error while loading dataset: " + err);
+ }
+});
+
+
+/***************************************************************************
+ * 2) Load .npy from local files using Npyjs + FileReader (ArrayBuffer)
+ ***************************************************************************/
+async function loadDataset(setName, subsetName) {
+ const prefix = `${setName}/${subsetName}__`;
+ // e.g. "train/all__inputs.npy"
+ const inputsPath = prefix + "inputs.npy";
+ const labelsPath = prefix + "labels.npy";
+ const pIdxPath = prefix + "puzzle_indices.npy";
+ const gIdxPath = prefix + "group_indices.npy";
+ const pIdsPath = prefix + "puzzle_identifiers.npy";
+ const identifiersPath = "identifiers.json";
+
+ // Check existence
+ const needed = [inputsPath, labelsPath, pIdxPath, gIdxPath, pIdsPath, identifiersPath];
+ for (const f of needed) {
+ if (!filesByPath[f]) {
+ throw new Error(`Missing file: ${f}`);
+ }
+ }
+
+ // parseNpy => read from File -> ArrayBuffer -> Npyjs => typed array
+ const inputsNpy = await parseNpy(filesByPath[inputsPath]);
+ const labelsNpy = await parseNpy(filesByPath[labelsPath]);
+ const puzzleIndicesNpy= await parseNpy(filesByPath[pIdxPath]);
+ const groupIndicesNpy = await parseNpy(filesByPath[gIdxPath]);
+ const puzzleIdsNpy = await parseNpy(filesByPath[pIdsPath]);
+
+ inputsArr = inputsNpy.data;
+ labelsArr = labelsNpy.data;
+ puzzleIndicesArr = puzzleIndicesNpy.data;
+ groupIndicesArr = groupIndicesNpy.data;
+ puzzleIdentifiersArr = puzzleIdsNpy.data;
+
+ // shape e.g. [N_examples, seqLen]
+ seqLen = inputsNpy.shape[1];
+ gridSize = Math.sqrt(seqLen);
+
+ // read JSON
+ identifiersJson = await readJsonFile(filesByPath[identifiersPath]);
+}
+
+/***************************************************************************
+ * parseNpy => read a File as ArrayBuffer, parse with npyjs
+ ***************************************************************************/
+function parseNpy(file) {
+ return new Promise((resolve, reject) => {
+ const reader = new FileReader();
+ reader.onload = async () => {
+ try {
+ const arrayBuffer = reader.result;
+ const npy = new npyjs();
+ resolve(await npy.parse(arrayBuffer));
+ } catch (err) {
+ reject(err);
+ }
+ };
+ reader.onerror = err => reject(err);
+ reader.readAsArrayBuffer(file);
+ });
+}
+
+/***************************************************************************
+ * readJsonFile => read a local JSON file into object
+ ***************************************************************************/
+function readJsonFile(file) {
+ return new Promise((resolve, reject) => {
+ const reader = new FileReader();
+ reader.onload = () => {
+ try {
+ const obj = JSON.parse(reader.result);
+ resolve(obj);
+ } catch (err) {
+ reject(err);
+ }
+ };
+ reader.onerror = (err) => reject(err);
+ reader.readAsText(file);
+ });
+}
+
+/***************************************************************************
+ * 3) Build group list in UI
+ ***************************************************************************/
+function buildGroupList() {
+ document.getElementById("groupList").innerHTML = "<h3>Groups</h3>";
+ const groupListDiv = document.getElementById("groupList");
+
+ const nGroups = groupIndicesArr.length - 1;
+ for (let g = 0; g < nGroups; g++) {
+ const div = document.createElement("span");
+ div.className = "group-item";
+ div.textContent = `Group ${g}`;
+ div.onclick = () => onSelectGroup(g);
+ groupListDiv.appendChild(div);
+ }
+}
+
+/***************************************************************************
+ * onSelectGroup => show puzzles in that group
+ ***************************************************************************/
+function onSelectGroup(groupIndex) {
+ document.getElementById("puzzleList").innerHTML = "";
+ document.getElementById("puzzleView").innerHTML = "";
+
+ const puzzleListDiv = document.getElementById("puzzleList");
+ puzzleListDiv.innerHTML = `<h4>Puzzles in Group ${groupIndex}</h4>`;
+
+ const firstPuzzle = groupIndicesArr[groupIndex];
+ const lastPuzzle = groupIndicesArr[groupIndex + 1];
+
+ for (let p = firstPuzzle; p < lastPuzzle; p++) {
+ const puzzleIntId = puzzleIdentifiersArr[p];
+ const puzzleStrId = (puzzleIntId < identifiersJson.length)
+ ? identifiersJson[puzzleIntId]
+ : "<unknown>";
+
+ const div = document.createElement("span");
+ div.className = "puzzle-item";
+ div.textContent = `Puzzle #${p} [ID=${puzzleIntId}: ${puzzleStrId}]`;
+ div.onclick = () => onSelectPuzzle(p);
+ puzzleListDiv.appendChild(div);
+ }
+}
+
+/***************************************************************************
+ * onSelectPuzzle => show each example
+ ***************************************************************************/
+function onSelectPuzzle(puzzleIndex) {
+ const puzzleView = document.getElementById("puzzleView");
+ puzzleView.innerHTML = "";
+
+ // puzzle ID
+ const puzzleIntId = puzzleIdentifiersArr[puzzleIndex];
+ const puzzleStrId = (puzzleIntId < identifiersJson.length)
+ ? identifiersJson[puzzleIntId]
+ : "<unknown>";
+
+ const titleDiv = document.createElement("div");
+ titleDiv.className = "puzzle-id";
+ titleDiv.textContent = `Puzzle #${puzzleIndex} — ID: ${puzzleStrId}`;
+ puzzleView.appendChild(titleDiv);
+
+ // Examples are [puzzleIndicesArr[p], puzzleIndicesArr[p+1])
+ const firstExample = puzzleIndicesArr[puzzleIndex];
+ const lastExample = puzzleIndicesArr[puzzleIndex + 1];
+
+ for (let e = firstExample; e < lastExample; e++) {
+ const inputSeq = slice1D(inputsArr, e*seqLen, (e+1)*seqLen);
+ const outputSeq = slice1D(labelsArr, e*seqLen, (e+1)*seqLen);
+
+ const inputGrid = decodeGrid(inputSeq);
+ const outputGrid = decodeGrid(outputSeq);
+
+ const exDiv = document.createElement("div");
+ exDiv.className = "example-container";
+ exDiv.appendChild(document.createTextNode(`Example ${e}`));
+ exDiv.appendChild(document.createElement("br"));
+
+ exDiv.appendChild(renderGrid(inputGrid));
+ exDiv.appendChild(renderGrid(outputGrid));
+
+ puzzleView.appendChild(exDiv);
+ }
+}
+
+/***************************************************************************
+ * slice1D => typed array slicing
+ ***************************************************************************/
+function slice1D(arr, start, end) {
+ const result = new Uint32Array(end - start);
+ for (let i = start; i < end; i++) {
+ result[i - start] = Number(arr[i]);
+ }
+ return result;
+}
+
+/***************************************************************************
+ * decodeGrid => turn the flattened seq of length=gridSize^2 into 2D
+ ***************************************************************************/
+function decodeGrid(seq) {
+ const grid = [];
+ let idx = 0;
+ for (let r = 0; r < gridSize; r++) {
+ const row = [];
+ for (let c = 0; c < gridSize; c++) {
+ row.push(seq[idx]);
+ idx++;
+ }
+ grid.push(row);
+ }
+ return grid;
+}
+
+/***************************************************************************
+ * renderGrid => draws a 2D grid to <canvas>
+ ***************************************************************************/
+function renderGrid(grid2d) {
+ const rows = grid2d.length;
+ const cols = grid2d[0].length;
+ const scale = 10;
+
+ const canvas = document.createElement("canvas");
+ canvas.width = cols * scale;
+ canvas.height = rows * scale;
+ canvas.className = "grid-canvas";
+ const ctx = canvas.getContext("2d");
+
+ for (let r = 0; r < rows; r++) {
+ for (let c = 0; c < cols; c++) {
+ const val = grid2d[r][c];
+ ctx.fillStyle = indexToColor(val);
+ ctx.fillRect(c * scale, r * scale, scale, scale);
+ }
+ }
+ return canvas;
+}
+
+/***************************************************************************
+ * indexToColor => color palette:
+ * 0 => pad => white
+ * 1 => eos => light gray
+ * 2..11 => original color(0..9)
+ ***************************************************************************/
+function indexToColor(value) {
+ if (value === 0) return "#FFFFFF"; // pad => white
+ if (value === 1) return "#DDDDDD"; // eos => light gray
+
+ // shift by 2 => original color in [0..9]
+ const colorIdx = value - 2;
+ const palette = [
+ "#000000", // color0 => black
+ "#FF0000", // color1 => red
+ "#00FF00", // color2 => green
+ "#0000FF", // color3 => blue
+ "#FFFF00", // color4 => yellow
+ "#FFA500", // color5 => orange
+ "#800080", // color6 => purple
+ "#00FFFF", // color7 => cyan
+ "#FFC0CB", // color8 => pink
+ "#808080" // color9 => gray
+ ];
+ if (colorIdx >= 0 && colorIdx < palette.length) {
+ return palette[colorIdx];
+ }
+ return "#FFFFFF"; // fallback
+}
+</script>
+</body>
+</html>
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..8c90d6f
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,11 @@
+torch
+adam-atan2
+einops
+tqdm
+coolname
+pydantic
+argdantic
+wandb
+omegaconf
+hydra-core
+huggingface_hub
diff --git a/utils/functions.py b/utils/functions.py
new file mode 100644
index 0000000..b123636
--- /dev/null
+++ b/utils/functions.py
@@ -0,0 +1,19 @@
+import importlib
+import inspect
+
+
+def load_model_class(identifier: str, prefix: str = "models."):
+ module_path, class_name = identifier.split('@')
+
+ # Import the module
+ module = importlib.import_module(prefix + module_path)
+ cls = getattr(module, class_name)
+
+ return cls
+
+
+def get_model_source_path(identifier: str, prefix: str = "models."):
+ module_path, class_name = identifier.split('@')
+
+ module = importlib.import_module(prefix + module_path)
+ return inspect.getsourcefile(module)