diff options
| author | One <imone@tuta.io> | 2025-07-09 10:13:51 +0800 |
|---|---|---|
| committer | One <imone@tuta.io> | 2025-07-09 10:13:51 +0800 |
| commit | bd6222774edcec1608a6842d0b06a637a4acef59 (patch) | |
| tree | 3b95517044286d82a9166bcce3134bbea099fcfe | |
| parent | caa00bb77fbfce0cc14a45d5ec1c754394f96c1a (diff) | |
Release
| -rw-r--r-- | .gitignore | 169 | ||||
| -rw-r--r-- | .vscode/launch.json | 26 | ||||
| -rw-r--r-- | .vscode/settings.json | 3 | ||||
| -rw-r--r-- | README.md | 169 | ||||
| -rw-r--r-- | arc_eval.ipynb | 252 | ||||
| -rw-r--r-- | assets/hrm.png | bin | 0 -> 99852 bytes | |||
| -rw-r--r-- | assets/npyjs.js | 176 | ||||
| -rw-r--r-- | config/arch/hrm_v1.yaml | 21 | ||||
| -rw-r--r-- | config/cfg_pretrain.yaml | 31 | ||||
| -rw-r--r-- | dataset/build_arc_dataset.py | 291 | ||||
| -rw-r--r-- | dataset/build_maze_dataset.py | 142 | ||||
| -rw-r--r-- | dataset/build_sudoku_dataset.py | 169 | ||||
| -rw-r--r-- | dataset/common.py | 51 | ||||
| -rw-r--r-- | evaluate.py | 68 | ||||
| -rw-r--r-- | models/common.py | 32 | ||||
| -rw-r--r-- | models/hrm/hrm_act_v1.py | 283 | ||||
| -rw-r--r-- | models/layers.py | 150 | ||||
| -rw-r--r-- | models/losses.py | 101 | ||||
| -rw-r--r-- | models/sparse_embedding.py | 132 | ||||
| -rw-r--r-- | pretrain.py | 454 | ||||
| -rw-r--r-- | puzzle_dataset.py | 199 | ||||
| -rw-r--r-- | puzzle_visualizer.html | 426 | ||||
| -rw-r--r-- | requirements.txt | 11 | ||||
| -rw-r--r-- | utils/functions.py | 19 |
24 files changed, 3375 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..644b422 --- /dev/null +++ b/.gitignore @@ -0,0 +1,169 @@ +# WandB +/wandb/ +# checkpoints +/checkpoints/ +# cache +/cache/ +# data +/data/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/
\ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..d49f941 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,26 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Current File", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal" + }, + { + "name": "Debug: Single GPU", + "type": "debugpy", + "request": "launch", + "program": "pretrain.py", + "args": [], + "env": { + "OMP_NUM_THREADS": "1", + "DISABLE_COMPILE": "true" + } + } + ] +}
\ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..8aef7b1 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.analysis.typeCheckingMode": "standard" +}
\ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..87620eb --- /dev/null +++ b/README.md @@ -0,0 +1,169 @@ +# Hierarchical Reasoning Model + + + +Reasoning, the process of devising and executing complex goal-oriented action sequences, remains a critical challenge in AI. +Current large language models (LLMs) primarily employ Chain-of-Thought (CoT) techniques, which suffer from brittle task decomposition, extensive data requirements, and high latency. Inspired by the hierarchical and multi-timescale processing in the human brain, we propose the Hierarchical Reasoning Model (HRM), a novel recurrent architecture that attains significant computational depth while maintaining both training stability and efficiency. +HRM executes sequential reasoning tasks in a single forward pass without explicit supervision of the intermediate process, through two interdependent recurrent modules: a high-level module responsible for slow, abstract planning, and a low-level module handling rapid, detailed computations. With only 27 million parameters, HRM achieves exceptional performance on complex reasoning tasks using only 1000 training samples. The model operates without pre-training or CoT data, yet achieves nearly perfect performance on challenging tasks including complex Sudoku puzzles and optimal path finding in large mazes. +Furthermore, HRM outperforms much larger models with significantly longer context windows on the Abstraction and Reasoning Corpus (ARC), a key benchmark for measuring artificial general intelligence capabilities. +These results underscore HRM’s potential as a transformative advancement toward universal computation and general-purpose reasoning systems. + +## Quick Start Guide 🚀 + +### Prerequisites ⚙️ + +Ensure PyTorch and CUDA are installed. The repo needs CUDA extensions to be built. If not present, run the following commands: + +```bash +# Install CUDA 12.4 +CUDA_URL=https://developer.download.nvidia.com/compute/cuda/12.4.1/local_installers/cuda_12.4.1_550.54.15_linux.run + +wget -q --show-progress --progress=bar:force:noscroll -O cuda_installer.run $CUDA_URL +sudo sh cuda_installer.run --silent --toolkit --override + +export CUDA_HOME=/usr/local/cuda-12.4 + +# Install PyTorch with CUDA 12.4 +PYTORCH_INDEX_URL=https://download.pytorch.org/whl/cu124 + +pip3 install torch torchvision torchaudio --index-url $PYTORCH_INDEX_URL + +# Additional packages for building extensions +pip3 install packaging ninja wheel setuptools setuptools-scm +``` + +## Install Python Dependencies 🐍 + +```bash +pip install -r requirements.txt +``` + +## W&B Integration 📈 + +This project uses [Weights & Biases](https://wandb.ai/) for experiment tracking and metric visualization. Ensure you're logged in: + +```bash +wandb login +``` + +## Run Experiments + +### Quick Demo: Sudoku Solver 💻🗲 + +Train a master-level Sudoku AI capable of solving extremely difficult puzzles on a modern laptop GPU. 🧩 + +```bash +# Download and build Sudoku dataset +python dataset/build_sudoku_dataset.py --output-dir data/sudoku-extreme-1k-aug-1000 --subsample-size 1000 --num-aug 1000 + +# Start training (single GPU, smaller batch size) +OMP_NUM_THREADS=8 python pretrain.py data_path=data/sudoku-extreme-1k-aug-1000 epochs=20000 eval_interval=2000 global_batch_size=384 lr=7e-5 puzzle_emb_lr=7e-5 weight_decay=1.0 puzzle_emb_weight_decay=1.0 +``` + +Runtime: ~10 hours on a RTX 4070 laptop GPU + +## Full-scale Experiments 🔵 + +Experiments below assume an 8-GPU setup. + +### Dataset Preparation + +```bash +# Initialize submodules +git submodule update --init --recursive + +# ARC-1 +python dataset/build_arc_dataset.py # ARC offical + ConceptARC, 960 examples +# ARC-2 +python dataset/build_arc_dataset.py --dataset-dirs dataset/raw-data/ARC-AGI-2/data --output-dir data/arc-2-aug-1000 # ARC-2 official, 1120 examples + +# Sudoku-Extreme +python dataset/build_sudoku_dataset.py # Full version +python dataset/build_sudoku_dataset.py --output-dir data/sudoku-extreme-1k-aug-1000 --subsample-size 1000 --num-aug 1000 # 1000 examples + +# Maze +python dataset/build_maze_dataset.py # 1000 examples +``` + +### Dataset Visualization + +Explore the puzzles visually: + +* Open `puzzle_visualizer.html` in your browser. +* Upload the generated dataset folder located in `data/...`. + +## Launch experiments + +### Small-sample (1K) + +ARC-1: + +```bash +OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py +``` + +*Runtime:* ~24 hours + +ARC-2: + +```bash +OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/arc-2-aug-1000 +``` + +*Runtime:* ~24 hours (checkpoint after 8 hours is often sufficient) + +Sudoku Extreme (1k): + +```bash +OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/sudoku-extreme-1k-aug-1000 epochs=20000 eval_interval=2000 lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 +``` + +*Runtime:* ~10 minutes + +Maze 30x30 Hard (1k): + +```bash +OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/maze-30x30-hard-1k epochs=20000 eval_interval=2000 lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 +``` + +*Runtime:* ~1 hour + +### Full Sudoku-Hard + +```bash +OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/sudoku-hard-full epochs=100 eval_interval=10 lr_min_ratio=0.1 global_batch_size=2304 lr=3e-4 puzzle_emb_lr=3e-4 weight_decay=0.1 puzzle_emb_weight_decay=0.1 arch.loss.loss_type=softmax_cross_entropy arch.L_cycles=8 arch.halt_max_steps=8 arch.pos_encodings=learned +``` + +*Runtime:* ~2 hours + +## Evaluation + +Evaluate your trained models: + +* Check `eval/exact_accuracy` in W&B. +* For ARC-AGI, follow these additional steps: + +```bash +OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 evaluate.py checkpoint=<CHECKPOINT_PATH> +``` + +* Then use the provided `arc_eval.ipynb` notebook to finalize and inspect your results. + +## Notes + + - Small-sample learning typically exhibits accuracy variance of around ±2 points. + - For Sudoku-Extreme (1,000-example dataset), late-stage overfitting may cause numerical instability during training and Q-learning. It is advisable to use early stopping once the training accuracy approaches 100%. + +## Citation 📜 + +``` +@misc{wang2025hierarchicalreasoningmodel, + title={Hierarchical Reasoning Model}, + author={Guan Wang and Jin Li and Yuhao Sun and Xing Chen and Changling Liu and Yue Wu and Meng Lu and Sen Song and Yasin Abbasi Yadkori}, + year={2025}, + eprint={2506.21734}, + archivePrefix={arXiv}, + primaryClass={cs.AI}, + url={https://arxiv.org/abs/2506.21734}, +} +```
\ No newline at end of file diff --git a/arc_eval.ipynb b/arc_eval.ipynb new file mode 100644 index 0000000..b2786b8 --- /dev/null +++ b/arc_eval.ipynb @@ -0,0 +1,252 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "from glob import glob\n", + "import hashlib\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.colors as mcolors\n", + "\n", + "import torch\n", + "import torch.nn.functional as F\n", + "import numpy as np\n", + "from numba import njit\n", + "\n", + "from dataset.common import inverse_dihedral_transform\n", + "\n", + "\n", + "DATASET_PATH = \"data/arc-aug-1000\" # ARC-1\n", + "# DATASET_PATH = \"data/arc-2-aug-1000\" # ARC-2\n", + "\n", + "CHECKPOINT_PATH = \"checkpoints/Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV1 amphibian-turaco/step_414456\"\n", + "\n", + "\n", + "PAD_PUZZLE_IDENTIFIER = 0\n", + "\n", + "# Visualization\n", + "ARC_COLOR_MAP = mcolors.ListedColormap([\n", + " \"#000000\", # symbol_0: black\n", + " \"#0074D9\", # symbol_1: blue\n", + " \"#FF4136\", # symbol_2: red\n", + " \"#2ECC40\", # symbol_3: green\n", + " \"#FFDC00\", # symbol_4: yellow\n", + " \"#AAAAAA\", # symbol_5: grey\n", + " \"#F012BE\", # symbol_6: fuschia\n", + " \"#FF851B\", # symbol_7: orange\n", + " \"#7FDBFF\", # symbol_8: teal\n", + " \"#870C25\" # symbol_9: brown\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def load_identifiers_and_preds(dataset_path: str, checkpoint_path: str):\n", + " # Load puzzle identifiers\n", + " with open(os.path.join(dataset_path, \"identifiers.json\"), \"r\") as f:\n", + " identifier_map = json.load(f)\n", + " \n", + " # Load preds\n", + " all_preds = {}\n", + " for filename in glob(f\"{checkpoint_path}_all_preds.*\"):\n", + " preds = torch.load(filename)\n", + " for k, v in preds.items():\n", + " all_preds.setdefault(k, [])\n", + " all_preds[k].append(v)\n", + " \n", + " del preds\n", + "\n", + " all_preds = {k: torch.cat(v, dim=0) for k, v in all_preds.items()}\n", + " \n", + " # Remove paddings\n", + " mask = all_preds[\"puzzle_identifiers\"] != PAD_PUZZLE_IDENTIFIER\n", + " all_preds = {k: v[mask] for k, v in all_preds.items()}\n", + "\n", + " return identifier_map, all_preds\n", + "\n", + "\n", + "def inverse_aug(name: str, grid: np.ndarray):\n", + " if \"_\" not in name:\n", + " return grid\n", + "\n", + " trans_id, perm = name.split(\"_\")[-2:]\n", + " trans_id = int(trans_id[1:]) # Remove \"t\" letter\n", + " inv_perm = np.argsort(list(perm))\n", + " \n", + " return inv_perm[inverse_dihedral_transform(grid, trans_id)]\n", + "\n", + "\n", + "def grid_hash(grid: np.ndarray):\n", + " return hash((grid.tobytes(), grid.shape))\n", + "\n", + "\n", + "@njit\n", + "def crop(grid: np.ndarray):\n", + " # Find maximum-sized rectangle without any EOS token inside.\n", + " grid = grid.reshape(30, 30)\n", + "\n", + " max_area = 0\n", + " max_size = (0, 0)\n", + " nr, nc = grid.shape\n", + " \n", + " num_c = nc\n", + " for num_r in range(1, nr + 1):\n", + " # Scan for maximum c\n", + " for c in range(1, num_c + 1):\n", + " x = grid[num_r - 1, c - 1]\n", + " if (x < 2) | (x > 11):\n", + " num_c = c - 1\n", + " break\n", + " \n", + " area = num_r * num_c\n", + " if area > max_area:\n", + " max_area = area\n", + " max_size = (num_r, num_c)\n", + "\n", + " return grid[:max_size[0], :max_size[1]] - 2\n", + "\n", + "\n", + "def test(visualize, Ks=[1, 2, 10, 100, 1000]):\n", + " identifier_map, all_preds = load_identifiers_and_preds(DATASET_PATH, CHECKPOINT_PATH)\n", + " \n", + " global_hmap = {}\n", + " \n", + " # Get puzzles and corresponding answers\n", + " puzzle_labels = {}\n", + " for identifier, input, label in zip(all_preds[\"puzzle_identifiers\"], all_preds[\"inputs\"], all_preds[\"labels\"]):\n", + " name = identifier_map[identifier]\n", + " if \"_\" not in name: # Not-augmented\n", + " puzzle_labels.setdefault(name, {})\n", + " \n", + " input = crop(input.numpy())\n", + " label = crop(label.numpy())\n", + "\n", + " input_hash = grid_hash(input)\n", + " label_hash = grid_hash(label)\n", + "\n", + " global_hmap[input_hash] = input\n", + " global_hmap[label_hash] = label\n", + "\n", + " assert input_hash not in puzzle_labels[name]\n", + " puzzle_labels[name][input_hash] = label_hash\n", + " \n", + " print (\"Number of puzzles\", len(puzzle_labels))\n", + " \n", + " # Argmax prediction\n", + " preds = all_preds[\"logits\"].argmax(-1)\n", + "\n", + " # Collate\n", + " pred_answers = {}\n", + " for identifier, input, pred, q in zip(all_preds[\"puzzle_identifiers\"], all_preds[\"inputs\"], preds, all_preds[\"q_halt_logits\"].sigmoid()):\n", + " name = identifier_map[identifier]\n", + " orig_name = name.split(\"_\")[0]\n", + " \n", + " input = input.numpy()\n", + " input_hash = grid_hash(inverse_aug(name, crop(input)))\n", + " assert input_hash in puzzle_labels[orig_name]\n", + " \n", + " pred = inverse_aug(name, crop(pred.numpy()))\n", + " pred_hash = grid_hash(pred)\n", + " global_hmap[pred_hash] = pred\n", + " \n", + " pred_answers.setdefault(orig_name, {})\n", + " pred_answers[orig_name].setdefault(input_hash, [])\n", + " pred_answers[orig_name][input_hash].append((pred_hash, q.item()))\n", + "\n", + " # test-1\n", + " if visualize:\n", + " num_figs = sum(len(tests) for name, tests in puzzle_labels.items())\n", + " fig, axes = plt.subplots(num_figs, 4, figsize=(8, num_figs * 4))\n", + " \n", + " fig_id = 0\n", + " \n", + " correct = [0 for _ in range(len(Ks))]\n", + " for name, tests in puzzle_labels.items():\n", + " num_test_correct = [0 for _ in range(len(Ks))]\n", + " for input_hash, label_hash in tests.items():\n", + " p = pred_answers[name][input_hash]\n", + " p_map = {}\n", + " \n", + " for h, q in p:\n", + " p_map.setdefault(h, [0, 0])\n", + " p_map[h][0] += 1\n", + " p_map[h][1] += q\n", + " \n", + " for h, stats in p_map.items():\n", + " stats[1] /= stats[0]\n", + " \n", + " p_map = sorted(p_map.items(), key=lambda kv: kv[1], reverse=True)\n", + "\n", + " # 2-vote\n", + " for i, k in enumerate(Ks):\n", + " ok = False\n", + " for h, stats in p_map[:k]:\n", + " ok |= h == label_hash\n", + " \n", + " num_test_correct[i] += ok\n", + "\n", + " if visualize:\n", + " # Show input and ground truth\n", + " axes[fig_id, 0].imshow(global_hmap[input_hash], cmap=ARC_COLOR_MAP)\n", + " axes[fig_id, 0].set_title(f\"{name}\\nInput\")\n", + " axes[fig_id, 0].axis('off')\n", + " \n", + " axes[fig_id, 1].imshow(global_hmap[label_hash], cmap=ARC_COLOR_MAP)\n", + " axes[fig_id, 1].set_title(f\"{name}\\nAnswer\")\n", + " axes[fig_id, 1].axis('off')\n", + " \n", + " trial_id = 2\n", + " for h, stats in p_map[:2]:\n", + " ans = global_hmap[h]\n", + " \n", + " axes[fig_id, trial_id].imshow(ans, cmap=ARC_COLOR_MAP)\n", + " axes[fig_id, trial_id].set_title(f\"{name}\\nTrial {trial_id}\")\n", + " axes[fig_id, trial_id].axis('off')\n", + " \n", + " trial_id += 1\n", + " \n", + " fig_id += 1\n", + " \n", + " # Total correctness\n", + " for i in range(len(Ks)):\n", + " correct[i] += num_test_correct[i] == len(tests)\n", + "\n", + " for i, k in enumerate(Ks):\n", + " print (f\"{k}-shot: {correct[i] / len(puzzle_labels) * 100:.2f}%\")\n", + "\n", + "\n", + "test(visualize=False)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/assets/hrm.png b/assets/hrm.png Binary files differnew file mode 100644 index 0000000..e6c2f10 --- /dev/null +++ b/assets/hrm.png diff --git a/assets/npyjs.js b/assets/npyjs.js new file mode 100644 index 0000000..b474575 --- /dev/null +++ b/assets/npyjs.js @@ -0,0 +1,176 @@ +class npyjs { + + constructor(opts) { + if (opts && !('convertFloat16' in opts)) { + console.warn([ + "npyjs constructor now accepts {convertFloat16?: boolean}.", + "For usage, go to https://github.com/jhuapl-boss/npyjs." + ].join(" ")); + } + + this.convertFloat16 = opts?.convertFloat16 ?? true; + + this.dtypes = { + "<u1": { + name: "uint8", + size: 8, + arrayConstructor: Uint8Array, + }, + "|u1": { + name: "uint8", + size: 8, + arrayConstructor: Uint8Array, + }, + "<u2": { + name: "uint16", + size: 16, + arrayConstructor: Uint16Array, + }, + "|i1": { + name: "int8", + size: 8, + arrayConstructor: Int8Array, + }, + "<i2": { + name: "int16", + size: 16, + arrayConstructor: Int16Array, + }, + "<u4": { + name: "uint32", + size: 32, + arrayConstructor: Uint32Array, + }, + "<i4": { + name: "int32", + size: 32, + arrayConstructor: Int32Array, + }, + "<u8": { + name: "uint64", + size: 64, + arrayConstructor: BigUint64Array, + }, + "<i8": { + name: "int64", + size: 64, + arrayConstructor: BigInt64Array, + }, + "<f4": { + name: "float32", + size: 32, + arrayConstructor: Float32Array + }, + "<f8": { + name: "float64", + size: 64, + arrayConstructor: Float64Array + }, + "<f2": { + name: "float16", + size: 16, + arrayConstructor: Uint16Array, + converter: this.convertFloat16 ? this.float16ToFloat32Array : undefined + }, + }; + } + + float16ToFloat32Array(float16Array) { + const length = float16Array.length; + const float32Array = new Float32Array(length); + + for (let i = 0; i < length; i++) { + float32Array[i] = npyjs.float16ToFloat32(float16Array[i]); + } + + return float32Array; + } + + static float16ToFloat32(float16) { + // Extract the parts of the float16 + const sign = (float16 >> 15) & 0x1; + const exponent = (float16 >> 10) & 0x1f; + const fraction = float16 & 0x3ff; + + // Handle special cases + if (exponent === 0) { + if (fraction === 0) { + // Zero + return sign ? -0 : 0; + } + // Denormalized number + return (sign ? -1 : 1) * Math.pow(2, -14) * (fraction / 0x400); + } else if (exponent === 0x1f) { + if (fraction === 0) { + // Infinity + return sign ? -Infinity : Infinity; + } + // NaN + return NaN; + } + + // Normalized number + return (sign ? -1 : 1) * Math.pow(2, exponent - 15) * (1 + fraction / 0x400); + } + + parse(arrayBufferContents) { + // const version = arrayBufferContents.slice(6, 8); // Uint8-encoded + const headerLength = new DataView(arrayBufferContents.slice(8, 10)).getUint8(0); + const offsetBytes = 10 + headerLength; + + const hcontents = new TextDecoder("utf-8").decode( + new Uint8Array(arrayBufferContents.slice(10, 10 + headerLength)) + ); + const header = JSON.parse( + hcontents + .toLowerCase() // True -> true + .replace(/'/g, '"') + .replace("(", "[") + .replace(/,*\),*/g, "]") + ); + const shape = header.shape; + const dtype = this.dtypes[header.descr]; + + if (!dtype) { + console.error(`Unsupported dtype: ${header.descr}`); + return null; + } + + const nums = new dtype.arrayConstructor( + arrayBufferContents, + offsetBytes + ); + + // Convert float16 to float32 if converter exists + const data = dtype.converter ? dtype.converter.call(this, nums) : nums; + + return { + dtype: dtype.name, + data: data, + shape, + fortranOrder: header.fortran_order + }; + } + + async load(filename, callback, fetchArgs) { + /* + Loads an array from a stream of bytes. + */ + fetchArgs = fetchArgs || {}; + let arrayBuf; + // If filename is ArrayBuffer + if (filename instanceof ArrayBuffer) { + arrayBuf = filename; + } + // If filename is a file path + else { + const resp = await fetch(filename, { ...fetchArgs }); + arrayBuf = await resp.arrayBuffer(); + } + const result = this.parse(arrayBuf); + if (callback) { + return callback(result); + } + return result; + } +} diff --git a/config/arch/hrm_v1.yaml b/config/arch/hrm_v1.yaml new file mode 100644 index 0000000..a5646b8 --- /dev/null +++ b/config/arch/hrm_v1.yaml @@ -0,0 +1,21 @@ +name: hrm.hrm_act_v1@HierarchicalReasoningModel_ACTV1 +loss: + name: losses@ACTLossHead + loss_type: stablemax_cross_entropy + +halt_exploration_prob: 0.1 +halt_max_steps: 16 + +H_cycles: 2 +L_cycles: 2 + +H_layers: 4 +L_layers: 4 + +hidden_size: 512 +num_heads: 8 # min(2, hidden_size // 64) +expansion: 4 + +puzzle_emb_ndim: ${.hidden_size} + +pos_encodings: rope diff --git a/config/cfg_pretrain.yaml b/config/cfg_pretrain.yaml new file mode 100644 index 0000000..51c55a0 --- /dev/null +++ b/config/cfg_pretrain.yaml @@ -0,0 +1,31 @@ +# ARC training config + +defaults: + - arch: hrm_v1 + - _self_ + +hydra: + output_subdir: null + +# Data path +data_path: data/arc-aug-1000 + +# Hyperparams - Training +global_batch_size: 768 + +epochs: 100000 +eval_interval: 10000 +checkpoint_every_eval: True + +lr: 1e-4 +lr_min_ratio: 1.0 +lr_warmup_steps: 2000 + +# Standard hyperparameter settings for LM, as used in Llama +beta1: 0.9 +beta2: 0.95 +weight_decay: 0.1 +puzzle_emb_weight_decay: 0.1 + +# Hyperparams - Puzzle embeddings training +puzzle_emb_lr: 1e-2 diff --git a/dataset/build_arc_dataset.py b/dataset/build_arc_dataset.py new file mode 100644 index 0000000..2da5703 --- /dev/null +++ b/dataset/build_arc_dataset.py @@ -0,0 +1,291 @@ +from typing import List, Optional, Tuple, Dict +from dataclasses import dataclass +from pathlib import Path +import os +import json +import hashlib +import numpy as np +from glob import glob + +from argdantic import ArgParser +from pydantic import BaseModel + +from common import PuzzleDatasetMetadata, dihedral_transform + + +cli = ArgParser() + + +class DataProcessConfig(BaseModel): + # ARC-1 + dataset_dirs: List[str] = ["dataset/raw-data/ARC-AGI/data", "dataset/raw-data/ConceptARC/corpus"] + output_dir: str = "data/arc-aug-1000" + + # ARC-2 + # dataset_dirs: List[str] = ["dataset/raw-data/ARC-AGI-2/data"] + # output_dir: str = "data/arc-2-aug-1000" + + seed: int = 42 + num_aug: int = 1000 + + +ARCMaxGridSize = 30 +ARCAugmentRetriesFactor = 5 + + +@dataclass +class ARCPuzzle: + id: str + + examples: List[Tuple[np.ndarray, np.ndarray]] + + +def arc_grid_to_np(grid: List[List[int]]): + arr = np.array(grid) + + # Shape check + assert arr.ndim == 2 + assert arr.shape[0] <= ARCMaxGridSize and arr.shape[1] <= ARCMaxGridSize + # Element check + assert np.all((arr >= 0) & (arr <= 9)) + return arr.astype(np.uint8) + + +def np_grid_to_seq_translational_augment(inp: np.ndarray, out: np.ndarray, do_translation: bool): + # PAD: 0, <eos>: 1, digits: 2 ... 11 + # Compute random top-left pad + if do_translation: + pad_r = np.random.randint(0, ARCMaxGridSize - max(inp.shape[0], out.shape[0]) + 1) + pad_c = np.random.randint(0, ARCMaxGridSize - max(inp.shape[1], out.shape[1]) + 1) + else: + pad_r = pad_c = 0 + + # Pad grid + result = [] + for grid in [inp, out]: + nrow, ncol = grid.shape + grid = np.pad(grid + 2, ((pad_r, ARCMaxGridSize - pad_r - nrow), (pad_c, ARCMaxGridSize - pad_c - ncol)), constant_values=0) + + # Add <eos> + eos_row, eos_col = pad_r + nrow, pad_c + ncol + if eos_row < ARCMaxGridSize: + grid[eos_row, pad_c:eos_col] = 1 + if eos_col < ARCMaxGridSize: + grid[pad_r:eos_row, eos_col] = 1 + + result.append(grid.flatten()) + + return result + + +def puzzle_hash(puzzle: dict): + # Hash the puzzle for checking equivalence + def _grid_hash(grid: np.ndarray): + buffer = [x.to_bytes(1) for x in grid.shape] + buffer.append(grid.tobytes()) + + return hashlib.sha256(b"".join(buffer)).hexdigest() + + hashes = [] + for example_type, example in puzzle.items(): + for input, label in example.examples: + hashes.append(f"{_grid_hash(input)}|{_grid_hash(label)}") + + hashes.sort() + return hashlib.sha256("|".join(hashes).encode()).hexdigest() + + +def convert_single_arc_puzzle(results: dict, default_name: str, puzzle: dict, aug_count: int, dest_mapping: Dict[str, Tuple[str, str]]): + # Remove "name" + name = puzzle.pop("name", default_name) + + # Convert + dests = set(dest_mapping.values()) + converted = {dest: ARCPuzzle(name, []) for dest in dests} + for example_type, examples in puzzle.items(): + dest = dest_mapping[example_type] + converted[dest].examples.extend([(arc_grid_to_np(example["input"]), arc_grid_to_np(example["output"])) for example in examples]) + + group = [converted] + + # Augment + if aug_count > 0: + hashes = {puzzle_hash(converted)} + + for _trial in range(ARCAugmentRetriesFactor * aug_count): + # Augment plan + trans_id = np.random.randint(0, 8) + mapping = np.concatenate([np.arange(0, 1, dtype=np.uint8), np.random.permutation(np.arange(1, 10, dtype=np.uint8))]) # Permute colors, Excluding "0" (black) + + aug_repr = f"t{trans_id}_{''.join(str(x) for x in mapping)}" + + def _map_grid(grid: np.ndarray): + return dihedral_transform(mapping[grid], trans_id) + + # Check duplicate + augmented = {dest: ARCPuzzle(f"{puzzle.id}_{aug_repr}", [(_map_grid(input), _map_grid(label)) for (input, label) in puzzle.examples]) for dest, puzzle in converted.items()} + h = puzzle_hash(augmented) + if h not in hashes: + hashes.add(h) + group.append(augmented) + + if len(group) >= aug_count + 1: + break + + if len(group) < aug_count + 1: + print (f"[Puzzle {name}] augmentation not full, only {len(group)}") + + # Append + for dest in dests: + # Convert the examples + dest_split, dest_set = dest + + results.setdefault(dest_split, {}) + results[dest_split].setdefault(dest_set, []) + results[dest_split][dest_set].append([converted[dest] for converted in group]) + + +def load_puzzles_arcagi(results: dict, dataset_path: str, config: DataProcessConfig): + train_examples_dest = ("train", "all") + test_examples_map = { + "evaluation": [(1.0, ("test", "all"))], + "_default": [(1.0, ("train", "all"))] + } + + total_puzzles = 0 + for subdir in os.scandir(dataset_path): + if subdir.is_dir(): + # Load all puzzles in this directory + puzzles = [] + for filename in glob(os.path.join(subdir.path, "*.json")): + with open(filename, "r") as f: + puzzles.append((Path(filename).stem, json.load(f))) + + # Shuffle puzzles + np.random.shuffle(puzzles) + + # Assign by fraction + for idx, (default_name, puzzle) in enumerate(puzzles): + fraction = idx / len(puzzles) + test_examples_dest = None + for f, dest in test_examples_map.get(subdir.name, test_examples_map["_default"]): + if fraction < f: + test_examples_dest = dest + break + + assert test_examples_dest is not None + + convert_single_arc_puzzle(results, default_name, puzzle, config.num_aug, {"train": train_examples_dest, "test": test_examples_dest}) + total_puzzles += 1 + + print (f"[{dataset_path}] total puzzles: {total_puzzles}") + + +def convert_dataset(config: DataProcessConfig): + np.random.seed(config.seed) + + # Read dataset + data = {} + for dataset_dir in config.dataset_dirs: + load_puzzles_arcagi(data, dataset_dir, config) + + # Map global puzzle identifiers + num_identifiers = 1 # 0 is blank + identifier_map = {} + for split_name, split in data.items(): + for subset_name, subset in split.items(): + for group in subset: + for puzzle in group: + if puzzle.id not in identifier_map: + identifier_map[puzzle.id] = num_identifiers + num_identifiers += 1 + + print (f"Total puzzle IDs (including <blank>): {num_identifiers}") + + # Save + for split_name, split in data.items(): + os.makedirs(os.path.join(config.output_dir, split_name), exist_ok=True) + + # Translational augmentations + enable_translational_augment = split_name == "train" + + # Statistics + total_examples = 0 + total_puzzles = 0 + total_groups = 0 + + for subset_name, subset in split.items(): + # Construct subset + results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]} + results["puzzle_indices"].append(0) + results["group_indices"].append(0) + + example_id = 0 + puzzle_id = 0 + + for group in subset: + for puzzle in group: + # Push puzzle + no_aug_id = np.random.randint(0, len(puzzle.examples)) + for _idx_ex, (inp, out) in enumerate(puzzle.examples): + inp, out = np_grid_to_seq_translational_augment(inp, out, do_translation=enable_translational_augment and _idx_ex != no_aug_id) + + results["inputs"].append(inp) + results["labels"].append(out) + example_id += 1 + + total_examples += 1 + + results["puzzle_indices"].append(example_id) + results["puzzle_identifiers"].append(identifier_map[puzzle.id]) + + puzzle_id += 1 + + total_puzzles += 1 + + # Push group + results["group_indices"].append(puzzle_id) + total_groups += 1 + + for k, v in results.items(): + if k in {"inputs", "labels"}: + v = np.stack(v, 0) + else: + v = np.array(v, dtype=np.int32) + + np.save(os.path.join(config.output_dir, split_name, f"{subset_name}__{k}.npy"), v) + + # Metadata + metadata = PuzzleDatasetMetadata( + seq_len=ARCMaxGridSize * ARCMaxGridSize, + vocab_size=10 + 2, # PAD + EOS + "0" ... "9" + + pad_id=0, + ignore_label_id=0, + + blank_identifier_id=0, + num_puzzle_identifiers=num_identifiers, + + total_groups=total_groups, + mean_puzzle_examples=total_examples / total_puzzles, + sets=list(split.keys()) + ) + + # Save metadata as JSON. + with open(os.path.join(config.output_dir, split_name, "dataset.json"), "w") as f: + json.dump(metadata.model_dump(), f) + + # Save IDs mapping + with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f: + ids_mapping = {v: k for k, v in identifier_map.items()} + + json.dump([ids_mapping.get(i, "<blank>") for i in range(num_identifiers)], f) + + +@cli.command(singleton=True) +def main(config: DataProcessConfig): + convert_dataset(config) + + +if __name__ == "__main__": + cli() diff --git a/dataset/build_maze_dataset.py b/dataset/build_maze_dataset.py new file mode 100644 index 0000000..e99baf2 --- /dev/null +++ b/dataset/build_maze_dataset.py @@ -0,0 +1,142 @@ +from typing import Optional +import math +import os +import csv +import json +import numpy as np + +from argdantic import ArgParser +from pydantic import BaseModel +from tqdm import tqdm +from huggingface_hub import hf_hub_download + +from common import PuzzleDatasetMetadata, dihedral_transform + + +CHARSET = "# SGo" + + +cli = ArgParser() + + +class DataProcessConfig(BaseModel): + source_repo: str = "imone/small-sample-challenge-maze-30x30-hard" + output_dir: str = "data/maze-30x30-hard-1k" + + subsample_size: Optional[int] = None + aug: bool = False + + +def convert_subset(set_name: str, config: DataProcessConfig): + # Read CSV + all_chars = set() + grid_size = None + inputs = [] + labels = [] + + with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile: # type: ignore + reader = csv.reader(csvfile) + next(reader) # Skip header + for source, q, a, rating in reader: + all_chars.update(q) + all_chars.update(a) + + if grid_size is None: + n = int(len(q) ** 0.5) + grid_size = (n, n) + + inputs.append(np.frombuffer(q.encode(), dtype=np.uint8).reshape(grid_size)) + labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(grid_size)) + + # If subsample_size is specified for the training set, + # randomly sample the desired number of examples. + if set_name == "train" and config.subsample_size is not None: + total_samples = len(inputs) + if config.subsample_size < total_samples: + indices = np.random.choice(total_samples, size=config.subsample_size, replace=False) + inputs = [inputs[i] for i in indices] + labels = [labels[i] for i in indices] + + # Generate dataset + results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]} + puzzle_id = 0 + example_id = 0 + + results["puzzle_indices"].append(0) + results["group_indices"].append(0) + + for inp, out in zip(tqdm(inputs), labels): + # Dihedral transformations for augmentation + for aug_idx in range(8 if (set_name == "train" and config.aug) else 1): + results["inputs"].append(dihedral_transform(inp, aug_idx)) + results["labels"].append(dihedral_transform(out, aug_idx)) + example_id += 1 + puzzle_id += 1 + + results["puzzle_indices"].append(example_id) + results["puzzle_identifiers"].append(0) + + # Push group + results["group_indices"].append(puzzle_id) + + # Char mappings + assert len(all_chars - set(CHARSET)) == 0 + + char2id = np.zeros(256, np.uint8) + char2id[np.array(list(map(ord, CHARSET)))] = np.arange(len(CHARSET)) + 1 + + # To Numpy + def _seq_to_numpy(seq): + arr = np.vstack([char2id[s.reshape(-1)] for s in seq]) + + return arr + + results = { + "inputs": _seq_to_numpy(results["inputs"]), + "labels": _seq_to_numpy(results["labels"]), + + "group_indices": np.array(results["group_indices"], dtype=np.int32), + "puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32), + "puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32), + } + + # Metadata + metadata = PuzzleDatasetMetadata( + seq_len=int(math.prod(grid_size)), # type: ignore + vocab_size=len(CHARSET) + 1, # PAD + Charset + + pad_id=0, + ignore_label_id=0, + + blank_identifier_id=0, + num_puzzle_identifiers=1, + + total_groups=len(results["group_indices"]) - 1, + mean_puzzle_examples=1, + sets=["all"] + ) + + # Save metadata as JSON. + save_dir = os.path.join(config.output_dir, set_name) + os.makedirs(save_dir, exist_ok=True) + + with open(os.path.join(save_dir, "dataset.json"), "w") as f: + json.dump(metadata.model_dump(), f) + + # Save data + for k, v in results.items(): + np.save(os.path.join(save_dir, f"all__{k}.npy"), v) + + # Save IDs mapping (for visualization only) + with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f: + json.dump(["<blank>"], f) + + +@cli.command(singleton=True) +def preprocess_data(config: DataProcessConfig): + convert_subset("train", config) + convert_subset("test", config) + + +if __name__ == "__main__": + cli() diff --git a/dataset/build_sudoku_dataset.py b/dataset/build_sudoku_dataset.py new file mode 100644 index 0000000..5d5b50c --- /dev/null +++ b/dataset/build_sudoku_dataset.py @@ -0,0 +1,169 @@ +from typing import Optional +import os +import csv +import json +import numpy as np + +from argdantic import ArgParser +from pydantic import BaseModel +from tqdm import tqdm +from huggingface_hub import hf_hub_download + +from common import PuzzleDatasetMetadata + + +cli = ArgParser() + + +class DataProcessConfig(BaseModel): + source_repo: str = "imone/sudoku-hard-v2" + output_dir: str = "data/sudoku-extreme-full" + + subsample_size: Optional[int] = None + min_difficulty: Optional[int] = None + num_aug: int = 0 + + +def shuffle_sudoku(board: np.ndarray, solution: np.ndarray): + # Create a random digit mapping: a permutation of 1..9, with zero (blank) unchanged + digit_map = np.pad(np.random.permutation(np.arange(1, 10)), (1, 0)) + + # Randomly decide whether to transpose. + transpose_flag = np.random.rand() < 0.5 + + # Generate a valid row permutation: + # - Shuffle the 3 bands (each band = 3 rows) and for each band, shuffle its 3 rows. + bands = np.random.permutation(3) + row_perm = np.concatenate([b * 3 + np.random.permutation(3) for b in bands]) + + # Similarly for columns (stacks). + stacks = np.random.permutation(3) + col_perm = np.concatenate([s * 3 + np.random.permutation(3) for s in stacks]) + + # Build an 81->81 mapping. For each new cell at (i, j) + # (row index = i // 9, col index = i % 9), + # its value comes from old row = row_perm[i//9] and old col = col_perm[i%9]. + mapping = np.array([row_perm[i // 9] * 9 + col_perm[i % 9] for i in range(81)]) + + def apply_transformation(x: np.ndarray) -> np.ndarray: + # Apply transpose flag + if transpose_flag: + x = x.T + # Apply the position mapping. + new_board = x.flatten()[mapping].reshape(9, 9).copy() + # Apply digit mapping + return digit_map[new_board] + + return apply_transformation(board), apply_transformation(solution) + + +def convert_subset(set_name: str, config: DataProcessConfig): + # Read CSV + inputs = [] + labels = [] + + with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile: + reader = csv.reader(csvfile) + next(reader) # Skip header + for source, q, a, rating in reader: + if (config.min_difficulty is None) or (int(rating) >= config.min_difficulty): + assert len(q) == 81 and len(a) == 81 + + inputs.append(np.frombuffer(q.replace('.', '0').encode(), dtype=np.uint8).reshape(9, 9) - ord('0')) + labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(9, 9) - ord('0')) + + # If subsample_size is specified for the training set, + # randomly sample the desired number of examples. + if set_name == "train" and config.subsample_size is not None: + total_samples = len(inputs) + if config.subsample_size < total_samples: + indices = np.random.choice(total_samples, size=config.subsample_size, replace=False) + inputs = [inputs[i] for i in indices] + labels = [labels[i] for i in indices] + + # Generate dataset + num_augments = config.num_aug if set_name == "train" else 0 + + results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]} + puzzle_id = 0 + example_id = 0 + + results["puzzle_indices"].append(0) + results["group_indices"].append(0) + + for orig_inp, orig_out in zip(tqdm(inputs), labels): + for aug_idx in range(1 + num_augments): + # First index is not augmented + if aug_idx == 0: + inp, out = orig_inp, orig_out + else: + inp, out = shuffle_sudoku(orig_inp, orig_out) + + # Push puzzle (only single example) + results["inputs"].append(inp) + results["labels"].append(out) + example_id += 1 + puzzle_id += 1 + + results["puzzle_indices"].append(example_id) + results["puzzle_identifiers"].append(0) + + # Push group + results["group_indices"].append(puzzle_id) + + # To Numpy + def _seq_to_numpy(seq): + arr = np.concatenate(seq).reshape(len(seq), -1) + + assert np.all((arr >= 0) & (arr <= 9)) + return arr + 1 + + results = { + "inputs": _seq_to_numpy(results["inputs"]), + "labels": _seq_to_numpy(results["labels"]), + + "group_indices": np.array(results["group_indices"], dtype=np.int32), + "puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32), + "puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32), + } + + # Metadata + metadata = PuzzleDatasetMetadata( + seq_len=81, + vocab_size=10 + 1, # PAD + "0" ... "9" + + pad_id=0, + ignore_label_id=0, + + blank_identifier_id=0, + num_puzzle_identifiers=1, + + total_groups=len(results["group_indices"]) - 1, + mean_puzzle_examples=1, + sets=["all"] + ) + + # Save metadata as JSON. + save_dir = os.path.join(config.output_dir, set_name) + os.makedirs(save_dir, exist_ok=True) + + with open(os.path.join(save_dir, "dataset.json"), "w") as f: + json.dump(metadata.model_dump(), f) + + # Save data + for k, v in results.items(): + np.save(os.path.join(save_dir, f"all__{k}.npy"), v) + + # Save IDs mapping (for visualization only) + with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f: + json.dump(["<blank>"], f) + + +@cli.command(singleton=True) +def preprocess_data(config: DataProcessConfig): + convert_subset("train", config) + convert_subset("test", config) + + +if __name__ == "__main__": + cli() diff --git a/dataset/common.py b/dataset/common.py new file mode 100644 index 0000000..7bc51c6 --- /dev/null +++ b/dataset/common.py @@ -0,0 +1,51 @@ +from typing import List, Optional + +import pydantic +import numpy as np + + +# Global list mapping each dihedral transform id to its inverse. +# Index corresponds to the original tid, and the value is its inverse. +DIHEDRAL_INVERSE = [0, 3, 2, 1, 4, 5, 6, 7] + + +class PuzzleDatasetMetadata(pydantic.BaseModel): + pad_id: int + ignore_label_id: Optional[int] + blank_identifier_id: int + + vocab_size: int + seq_len: int + num_puzzle_identifiers: int + + total_groups: int + mean_puzzle_examples: float + + sets: List[str] + + +def dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray: + """8 dihedral symmetries by rotate, flip and mirror""" + + if tid == 0: + return arr # identity + elif tid == 1: + return np.rot90(arr, k=1) + elif tid == 2: + return np.rot90(arr, k=2) + elif tid == 3: + return np.rot90(arr, k=3) + elif tid == 4: + return np.fliplr(arr) # horizontal flip + elif tid == 5: + return np.flipud(arr) # vertical flip + elif tid == 6: + return arr.T # transpose (reflection along main diagonal) + elif tid == 7: + return np.fliplr(np.rot90(arr, k=1)) # anti-diagonal reflection + else: + return arr + + +def inverse_dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray: + return dihedral_transform(arr, DIHEDRAL_INVERSE[tid]) diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 0000000..9bc6ba0 --- /dev/null +++ b/evaluate.py @@ -0,0 +1,68 @@ +from typing import List +import yaml +import os + +import torch +import torch.distributed as dist + +import pydantic +from omegaconf import OmegaConf +from pretrain import PretrainConfig, init_train_state, evaluate, create_dataloader + + +class EvalConfig(pydantic.BaseModel): + checkpoint: str + + save_outputs: List[str] = ["inputs", "labels", "puzzle_identifiers", "logits", "q_halt_logits", "q_continue_logits"] + + +def launch(): + eval_cfg = EvalConfig(**OmegaConf.to_container(OmegaConf.from_cli())) # type: ignore + + RANK = 0 + WORLD_SIZE = 1 + # Initialize distributed training if in distributed environment (e.g. torchrun) + if "LOCAL_RANK" in os.environ: + # Initialize distributed, default device and dtype + dist.init_process_group(backend="nccl") + + RANK = dist.get_rank() + WORLD_SIZE = dist.get_world_size() + + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + + with open(os.path.join(os.path.dirname(eval_cfg.checkpoint), "all_config.yaml"), "r") as f: + config = PretrainConfig(**yaml.safe_load(f)) + + config.eval_save_outputs = eval_cfg.save_outputs + config.checkpoint_path = os.path.dirname(eval_cfg.checkpoint) + + # Dataloader + train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE) + eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, test_set_limit_examples=LIMIT_EXAMPLES, rank=RANK, world_size=WORLD_SIZE) + + # Models + train_state = init_train_state(config, train_metadata, world_size=WORLD_SIZE) + # Try unwrap torch.compile + try: + train_state.model.load_state_dict(torch.load(eval_cfg.checkpoint, map_location="cuda"), assign=True) + except: + train_state.model.load_state_dict({k.removeprefix("_orig_mod."): v for k, v in torch.load(eval_cfg.checkpoint, map_location="cuda").items()}, assign=True) + + train_state.step = 0 + ckpt_filename = os.path.basename(eval_cfg.checkpoint) + if ckpt_filename.startswith("step_"): + train_state.step = int(ckpt_filename.removeprefix("step_")) + + # Evaluate + print ("Starting evaluation") + + train_state.model.eval() + metrics = evaluate(config, train_state, eval_loader, eval_metadata, rank=RANK, world_size=WORLD_SIZE) + + if metrics is not None: + print (metrics) + + +if __name__ == "__main__": + launch() diff --git a/models/common.py b/models/common.py new file mode 100644 index 0000000..1a04505 --- /dev/null +++ b/models/common.py @@ -0,0 +1,32 @@ +import math + +import torch +from torch import nn + + +def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0): + # NOTE: PyTorch nn.init.trunc_normal_ is not mathematically correct, the std dev is not actually the std dev of initialized tensor + # This function is a PyTorch version of jax truncated normal init (default init method in flax) + # https://github.com/jax-ml/jax/blob/main/jax/_src/random.py#L807-L848 + # https://github.com/jax-ml/jax/blob/main/jax/_src/nn/initializers.py#L162-L199 + + with torch.no_grad(): + if std == 0: + tensor.zero_() + else: + sqrt2 = math.sqrt(2) + a = math.erf(lower / sqrt2) + b = math.erf(upper / sqrt2) + z = (b - a) / 2 + + c = (2 * math.pi) ** -0.5 + pdf_u = c * math.exp(-0.5 * lower ** 2) + pdf_l = c * math.exp(-0.5 * upper ** 2) + comp_std = std / math.sqrt(1 - (upper * pdf_u - lower * pdf_l) / z - ((pdf_u - pdf_l) / z) ** 2) + + tensor.uniform_(a, b) + tensor.erfinv_() + tensor.mul_(sqrt2 * comp_std) + tensor.clip_(lower * comp_std, upper * comp_std) + + return tensor diff --git a/models/hrm/hrm_act_v1.py b/models/hrm/hrm_act_v1.py new file mode 100644 index 0000000..e91c7d1 --- /dev/null +++ b/models/hrm/hrm_act_v1.py @@ -0,0 +1,283 @@ +from typing import Tuple, List, Dict, Optional +from dataclasses import dataclass +import math + +import torch +import torch.nn.functional as F +from torch import nn +from pydantic import BaseModel + +from models.common import trunc_normal_init_ +from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear +from models.sparse_embedding import CastedSparseEmbedding + + +@dataclass +class HierarchicalReasoningModel_ACTV1InnerCarry: + z_H: torch.Tensor + z_L: torch.Tensor + + +@dataclass +class HierarchicalReasoningModel_ACTV1Carry: + inner_carry: HierarchicalReasoningModel_ACTV1InnerCarry + + steps: torch.Tensor + halted: torch.Tensor + + current_data: Dict[str, torch.Tensor] + + +class HierarchicalReasoningModel_ACTV1Config(BaseModel): + batch_size: int + seq_len: int + puzzle_emb_ndim: int = 0 + num_puzzle_identifiers: int + vocab_size: int + + H_cycles: int + L_cycles: int + + H_layers: int + L_layers: int + + # Transformer config + hidden_size: int + expansion: float + num_heads: int + pos_encodings: str + + rms_norm_eps: float = 1e-5 + rope_theta: float = 10000.0 + + # Halting Q-learning config + halt_max_steps: int + halt_exploration_prob: float + + forward_dtype: str = "bfloat16" + + +class HierarchicalReasoningModel_ACTV1Block(nn.Module): + def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None: + super().__init__() + + self.self_attn = Attention( + hidden_size=config.hidden_size, + head_dim=config.hidden_size // config.num_heads, + num_heads=config.num_heads, + num_key_value_heads=config.num_heads, + causal=False + ) + self.mlp = SwiGLU( + hidden_size=config.hidden_size, + expansion=config.expansion, + ) + self.norm_eps = config.rms_norm_eps + + def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor: + # Post Norm + # Self Attention + hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps) + # Fully Connected + hidden_states = rms_norm(hidden_states + self.mlp(hidden_states), variance_epsilon=self.norm_eps) + return hidden_states + + +class HierarchicalReasoningModel_ACTV1ReasoningModule(nn.Module): + def __init__(self, layers: List[HierarchicalReasoningModel_ACTV1Block]): + super().__init__() + + self.layers = torch.nn.ModuleList(layers) + + def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor: + # Input injection (add) + hidden_states = hidden_states + input_injection + # Layers + for layer in self.layers: + hidden_states = layer(hidden_states=hidden_states, **kwargs) + + return hidden_states + + +class HierarchicalReasoningModel_ACTV1_Inner(nn.Module): + def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None: + super().__init__() + self.config = config + self.forward_dtype = getattr(torch, self.config.forward_dtype) + + # I/O + self.embed_scale = math.sqrt(self.config.hidden_size) + embed_init_std = 1.0 / self.embed_scale + + self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype) + self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False) + self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True) + + self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div + if self.config.puzzle_emb_ndim > 0: + # Zero init puzzle embeddings + self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim, + batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype) + + # LM Blocks + if self.config.pos_encodings == "rope": + self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads, + max_position_embeddings=self.config.seq_len + self.puzzle_emb_len, + base=self.config.rope_theta) + elif self.config.pos_encodings == "learned": + self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype) + else: + raise NotImplementedError() + + # Reasoning Layers + self.H_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.H_layers)]) + self.L_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)]) + + # Initial states + self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True) + self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True) + + # Q head special init + # Init Q to (almost) zero for faster learning during bootstrapping + with torch.no_grad(): + self.q_head.weight.zero_() + self.q_head.bias.fill_(-5) # type: ignore + + def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor): + # Token embedding + embedding = self.embed_tokens(input.to(torch.int32)) + + # Puzzle embeddings + if self.config.puzzle_emb_ndim > 0: + puzzle_embedding = self.puzzle_emb(puzzle_identifiers) + + pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1] + if pad_count > 0: + puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count)) + + embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2) + + # Position embeddings + if self.config.pos_encodings == "learned": + # scale by 1/sqrt(2) to maintain forward variance + embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype)) + + # Scale + return self.embed_scale * embedding + + def empty_carry(self, batch_size: int): + return HierarchicalReasoningModel_ACTV1InnerCarry( + z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype), + z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype), + ) + + def reset_carry(self, reset_flag: torch.Tensor, carry: HierarchicalReasoningModel_ACTV1InnerCarry): + return HierarchicalReasoningModel_ACTV1InnerCarry( + z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H), + z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L), + ) + + def forward(self, carry: HierarchicalReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + seq_info = dict( + cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None, + ) + + # Input encoding + input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"]) + + # Forward iterations + with torch.no_grad(): + z_H, z_L = carry.z_H, carry.z_L + + for _H_step in range(self.config.H_cycles): + for _L_step in range(self.config.L_cycles): + if not ((_H_step == self.config.H_cycles - 1) and (_L_step == self.config.L_cycles - 1)): + z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info) + + if not (_H_step == self.config.H_cycles - 1): + z_H = self.H_level(z_H, z_L, **seq_info) + + assert not z_H.requires_grad and not z_L.requires_grad + + # 1-step grad + z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info) + z_H = self.H_level(z_H, z_L, **seq_info) + + # LM Outputs + new_carry = HierarchicalReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad + output = self.lm_head(z_H)[:, self.puzzle_emb_len:] + + # Q head + q_logits = self.q_head(z_H[:, 0]).to(torch.float32) + + return new_carry, output, (q_logits[..., 0], q_logits[..., 1]) + + +class HierarchicalReasoningModel_ACTV1(nn.Module): + """ACT wrapper.""" + + def __init__(self, config_dict: dict): + super().__init__() + self.config = HierarchicalReasoningModel_ACTV1Config(**config_dict) + self.inner = HierarchicalReasoningModel_ACTV1_Inner(self.config) + + @property + def puzzle_emb(self): + return self.inner.puzzle_emb + + def initial_carry(self, batch: Dict[str, torch.Tensor]): + batch_size = batch["inputs"].shape[0] + + return HierarchicalReasoningModel_ACTV1Carry( + inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted. + + steps=torch.zeros((batch_size, ), dtype=torch.int32), + halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted + + current_data={k: torch.empty_like(v) for k, v in batch.items()} + ) + + def forward(self, carry: HierarchicalReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]: + # Update data, carry (removing halted sequences) + new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry) + + new_steps = torch.where(carry.halted, 0, carry.steps) + + new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()} + + # Forward inner model + new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data) + + outputs = { + "logits": logits, + "q_halt_logits": q_halt_logits, + "q_continue_logits": q_continue_logits + } + + with torch.no_grad(): + # Step + new_steps = new_steps + 1 + is_last_step = new_steps >= self.config.halt_max_steps + + halted = is_last_step + + # if training, and ACT is enabled + if self.training and (self.config.halt_max_steps > 1): + # Halt signal + # NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes + halted = halted | (q_halt_logits > q_continue_logits) + + # Exploration + min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1) + + halted = halted & (new_steps >= min_halt_steps) + + # Compute target Q + # NOTE: No replay buffer and target networks for computing target Q-value. + # As batch_size is large, there're many parallel envs. + # Similar concept as PQN https://arxiv.org/abs/2407.04811 + next_q_halt_logits, next_q_continue_logits = self.inner(new_inner_carry, new_current_data)[-1] + + outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits))) + + return HierarchicalReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs diff --git a/models/layers.py b/models/layers.py new file mode 100644 index 0000000..4f7dee4 --- /dev/null +++ b/models/layers.py @@ -0,0 +1,150 @@ +from typing import Tuple + +import torch +from torch import nn +import torch.nn.functional as F + +from models.common import trunc_normal_init_ + + +CosSin = Tuple[torch.Tensor, torch.Tensor] + + +def _find_multiple(a, b): + return (-(a // -b)) * b + + +def rotate_half(x: torch.Tensor): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): + # q, k: [bs, num_heads, seq_len, head_dim] + # cos, sin: [seq_len, head_dim] + orig_dtype = q.dtype + q = q.to(cos.dtype) + k = k.to(cos.dtype) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + + return q_embed.to(orig_dtype), k_embed.to(orig_dtype) + + +class CastedLinear(nn.Module): + def __init__(self, + in_features: int, + out_features: int, + bias: bool): + super().__init__() + # Truncated LeCun normal init + self.weight = nn.Parameter( + trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5)) + ) + self.bias = None + if bias: + # Zero init bias + self.bias = nn.Parameter(torch.zeros((out_features, ))) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.linear(input, self.weight.to(input.dtype), bias=self.bias.to(input.dtype) if self.bias is not None else None) + + +class CastedEmbedding(nn.Module): + def __init__(self, + num_embeddings: int, + embedding_dim: int, + init_std: float, + cast_to: torch.dtype): + super().__init__() + self.cast_to = cast_to + + # Truncated LeCun normal init + self.embedding_weight = nn.Parameter( + trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std) + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.embedding(input, self.embedding_weight.to(self.cast_to)) + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings, base, device=None): + super().__init__() + + # RoPE + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) + t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device) + freqs = torch.outer(t, inv_freq) + + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = nn.Buffer(emb.cos(), persistent=False) + self.sin_cached = nn.Buffer(emb.sin(), persistent=False) + + def forward(self): + return self.cos_cached, self.sin_cached + + +class Attention(nn.Module): + def __init__(self, hidden_size, head_dim, num_heads, num_key_value_heads, causal=False): + super().__init__() + + self.hidden_size = hidden_size + self.head_dim = head_dim + self.output_size = head_dim * num_heads + self.num_heads = num_heads + self.num_key_value_heads = num_key_value_heads + self.causal = causal + + self.qkv_proj = CastedLinear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False) + self.o_proj = CastedLinear(self.output_size, self.hidden_size, bias=False) + + def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, _ = hidden_states.shape + + # hidden_states: [bs, seq_len, num_heads, head_dim] + qkv = self.qkv_proj(hidden_states) + + # Split head + qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim).transpose(-2, -3) + query = qkv[:, :self.num_heads] + key = qkv[:, self.num_heads: self.num_heads + self.num_key_value_heads] + value = qkv[:, self.num_heads + self.num_key_value_heads:] + + # RoPE + if cos_sin is not None: + cos, sin = cos_sin + query, key = apply_rotary_pos_emb(query, key, cos, sin) + + # flash attn + attn_output = F.scaled_dot_product_attention(query=query, key=key, value=value, is_causal=self.causal) + + # attn_output: [batch_size, num_heads, seq_len, head_dim] + attn_output = attn_output.transpose(-2, -3).view(batch_size, seq_len, self.output_size) # type: ignore + return self.o_proj(attn_output) + + +class SwiGLU(nn.Module): + def __init__(self, hidden_size: int, expansion: float): + super().__init__() + inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256) + + self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False) + self.down_proj = CastedLinear(inter, hidden_size, bias=False) + + def forward(self, x): + gate, up = self.gate_up_proj(x).chunk(2, dim=-1) + return self.down_proj(F.silu(gate) * up) + + +def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + + variance = hidden_states.square().mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon) + return hidden_states.to(input_dtype) diff --git a/models/losses.py b/models/losses.py new file mode 100644 index 0000000..b3118e7 --- /dev/null +++ b/models/losses.py @@ -0,0 +1,101 @@ +from typing import Any, Tuple, Dict, Sequence, Optional + +import torch +import torch.nn.functional as F +from torch import nn + + +IGNORE_LABEL_ID = -100 + + +def s(x, epsilon=1e-30): + return torch.where( + x<0, + 1/(1-x+ epsilon), + x + 1 + ) + + +def log_stablemax(x, dim=-1): + s_x = s(x) + return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True)) + + +def stablemax_cross_entropy(logits, labels, ignore_index: int = -100): + logprobs = log_stablemax(logits.to(torch.float64), dim=-1) + + valid_mask = labels != ignore_index + transformed_labels = torch.where(valid_mask, labels, 0) + prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1) + + return -torch.where(valid_mask, prediction_logprobs, 0) + + +def softmax_cross_entropy(logits, labels, ignore_index: int = -100): + # Cast logits to f32 + # Flatten logits + return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape) + + +class ACTLossHead(nn.Module): + def __init__(self, model: nn.Module, loss_type: str): + super().__init__() + self.model = model + self.loss_fn = globals()[loss_type] + + def initial_carry(self, *args, **kwargs): + return self.model.initial_carry(*args, **kwargs) # type: ignore + + def forward( + self, + return_keys: Sequence[str], + # Model args + **model_kwargs, + ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]: + # Model logits + # B x SeqLen x D + new_carry, outputs = self.model(**model_kwargs) + labels = new_carry.current_data["labels"] + + # Correctness + with torch.no_grad(): + mask = labels != IGNORE_LABEL_ID + loss_counts = mask.sum(-1) + loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division + + is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels) + seq_is_correct = is_correct.sum(-1) == loss_counts + + # Metrics (halted) + valid_metrics = new_carry.halted & (loss_counts > 0) + metrics = { + "count": valid_metrics.sum(), + + "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(), + "exact_accuracy": (valid_metrics & seq_is_correct).sum(), + + "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(), + "steps": torch.where(valid_metrics, new_carry.steps, 0).sum(), + } + + # Losses + # FIXME: Assuming the batch is always full + lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID) / loss_divisor).sum() + q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum") + + metrics.update({ + "lm_loss": lm_loss.detach(), + "q_halt_loss": q_halt_loss.detach(), + }) + + # Q continue (bootstrapping target loss) + q_continue_loss = 0 + if "target_q_continue" in outputs: + q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum") + + metrics["q_continue_loss"] = q_continue_loss.detach() + + # Filter outputs for return + detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs} + + return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all() diff --git a/models/sparse_embedding.py b/models/sparse_embedding.py new file mode 100644 index 0000000..c701524 --- /dev/null +++ b/models/sparse_embedding.py @@ -0,0 +1,132 @@ +from typing import Union + +import torch +from torch import nn +import torch.distributed as dist +from torch.optim.optimizer import Optimizer, ParamsT + +from models.common import trunc_normal_init_ + + +class CastedSparseEmbedding(nn.Module): + def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, init_std: float, cast_to: torch.dtype): + super().__init__() + self.cast_to = cast_to + + # Real Weights + # Truncated LeCun normal init + self.weights = nn.Buffer( + trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std), persistent=True + ) + + # Local weights and IDs + # Local embeddings, with gradient, not persistent + self.local_weights = nn.Buffer(torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False) + # Local embedding IDs, not persistent + self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + if not self.training: + # Test mode, no gradient + return self.weights[inputs].to(self.cast_to) + + # Training mode, fill puzzle embedding from weights + with torch.no_grad(): + self.local_weights.copy_(self.weights[inputs]) + self.local_ids.copy_(inputs) + + return self.local_weights.to(self.cast_to) + + +class CastedSparseEmbeddingSignSGD_Distributed(Optimizer): + def __init__( + self, + params: ParamsT, + + world_size: int, + lr: Union[float, torch.Tensor] = 1e-3, + weight_decay: float = 1e-2, + ): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + weight_decay=weight_decay, + world_size=world_size + ) + super().__init__(params, defaults) + + @torch.no_grad + def step(self, closure=None): # type: ignore + for group in self.param_groups: + # Find the sparse embedding weights + local_weights_grad = None + local_ids = None + weights = None + + assert len(group["params"]) == 3 + for p in group["params"]: + if p.requires_grad: + local_weights_grad = p.grad + elif p.ndim == 1: + local_ids = p + elif p.ndim == 2: + weights = p + else: + assert False + + assert local_weights_grad is not None + assert local_ids is not None + assert weights is not None + + # Apply SignSGD + # Adam ≈ SignSGD if gradient is very sparse + _sparse_emb_signsgd_dist( + local_weights_grad, + local_ids, + weights, + + lr=group["lr"], + weight_decay=group["weight_decay"], + world_size=group["world_size"] + ) + + +def _sparse_emb_signsgd_dist( + local_weights_grad: torch.Tensor, + local_ids: torch.Tensor, + weights: torch.Tensor, + + lr: float, + weight_decay: float, + world_size: int +) -> None: + N, D = local_weights_grad.shape + + # All-gather + all_weights_grad = local_weights_grad + all_ids = local_ids + + if world_size > 1: + all_weights_grad = torch.empty((world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device) + all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device) + + dist.all_gather_into_tensor(all_weights_grad, local_weights_grad) + dist.all_gather_into_tensor(all_ids, local_ids) + + # Unique + grad_ids, inv = all_ids.unique(return_inverse=True) + + grad = torch.zeros((grad_ids.shape[0], D), dtype=all_weights_grad.dtype, device=all_weights_grad.device) + grad.scatter_add_(0, inv.unsqueeze(-1).expand(-1, D), all_weights_grad) + + # SignSGD with decoupled weight decay + p = weights[grad_ids] + + p.mul_(1.0 - lr * weight_decay).add_(torch.sign(grad), alpha=-lr) + + # Write updated slices back + weights[grad_ids] = p diff --git a/pretrain.py b/pretrain.py new file mode 100644 index 0000000..b939318 --- /dev/null +++ b/pretrain.py @@ -0,0 +1,454 @@ +from typing import Optional, Any, Sequence, List +from dataclasses import dataclass +import os +import math +import yaml +import shutil + +import torch +import torch.distributed as dist +from torch import nn +from torch.utils.data import DataLoader + +import tqdm +import wandb +import coolname +import hydra +import pydantic +from omegaconf import DictConfig +from wandb.util import make_artifact_name_safe +from adam_atan2 import AdamATan2 + +from puzzle_dataset import PuzzleDataset, PuzzleDatasetConfig, PuzzleDatasetMetadata +from utils.functions import load_model_class, get_model_source_path +from models.sparse_embedding import CastedSparseEmbeddingSignSGD_Distributed + + +class LossConfig(pydantic.BaseModel): + model_config = pydantic.ConfigDict(extra='allow') + + name: str + + +class ArchConfig(pydantic.BaseModel): + model_config = pydantic.ConfigDict(extra='allow') + + name: str + loss: LossConfig + + +class PretrainConfig(pydantic.BaseModel): + # Config + arch: ArchConfig + # Data + data_path: str + + # Hyperparams + global_batch_size: int + epochs: int + + lr: float + lr_min_ratio: float + lr_warmup_steps: int + + weight_decay: float + beta1: float + beta2: float + + # Puzzle embedding + puzzle_emb_lr: float + puzzle_emb_weight_decay: float + + # Names + project_name: Optional[str] = None + run_name: Optional[str] = None + checkpoint_path: Optional[str] = None + + # Extras + seed: int = 0 + checkpoint_every_eval: bool = False + eval_interval: Optional[int] = None + eval_save_outputs: List[str] = [] + + +@dataclass +class TrainState: + model: nn.Module + optimizers: Sequence[torch.optim.Optimizer] + optimizer_lrs: Sequence[float] + carry: Any + + step: int + total_steps: int + + +def create_dataloader(config: PretrainConfig, split: str, rank: int, world_size: int, **kwargs): + dataset = PuzzleDataset(PuzzleDatasetConfig( + seed=config.seed, + + dataset_path=config.data_path, + + rank=rank, + num_replicas=world_size, + + **kwargs + ), split=split) + dataloader = DataLoader( + dataset, + batch_size=None, + + num_workers=1, + prefetch_factor=8, + + pin_memory=True, + persistent_workers=True + ) + return dataloader, dataset.metadata + + +def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int): + model_cfg = dict( + **config.arch.__pydantic_extra__, # type: ignore + + batch_size=config.global_batch_size // world_size, + + vocab_size=train_metadata.vocab_size, + seq_len=train_metadata.seq_len, + num_puzzle_identifiers=train_metadata.num_puzzle_identifiers, + causal=False # Non-autoregressive + ) + + # Instantiate model with loss head + model_cls = load_model_class(config.arch.name) + loss_head_cls = load_model_class(config.arch.loss.name) + + with torch.device("cuda"): + model: nn.Module = model_cls(model_cfg) + model = loss_head_cls(model, **config.arch.loss.__pydantic_extra__) # type: ignore + if "DISABLE_COMPILE" not in os.environ: + model = torch.compile(model, dynamic=False, fullgraph=True) # type: ignore + + # Broadcast parameters from rank 0 + if world_size > 1: + with torch.no_grad(): + for param in list(model.parameters()) + list(model.buffers()): + dist.broadcast(param, src=0) + + # Optimizers and lr + optimizers = [ + CastedSparseEmbeddingSignSGD_Distributed( + model.model.puzzle_emb.buffers(), # type: ignore + + lr=0, # Needs to be set by scheduler + weight_decay=config.puzzle_emb_weight_decay, + + world_size=world_size + ), + AdamATan2( + model.parameters(), + + lr=0, # Needs to be set by scheduler + weight_decay=config.weight_decay, + betas=(config.beta1, config.beta2) + ) + ] + optimizer_lrs = [ + config.puzzle_emb_lr, + config.lr + ] + + return model, optimizers, optimizer_lrs + + +def cosine_schedule_with_warmup_lr_lambda( + current_step: int, *, base_lr: float, num_warmup_steps: int, num_training_steps: int, min_ratio: float = 0.0, num_cycles: float = 0.5 +): + if current_step < num_warmup_steps: + return base_lr * float(current_step) / float(max(1, num_warmup_steps)) + + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + return base_lr * (min_ratio + max(0.0, (1 - min_ratio) * 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))) + + +def init_train_state(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int): + # Estimated total training steps + total_steps = int(config.epochs * train_metadata.total_groups * train_metadata.mean_puzzle_examples / config.global_batch_size) + + # Model + model, optimizers, optimizer_lrs = create_model(config, train_metadata, world_size=world_size) + + return TrainState( + step=0, + total_steps=total_steps, + + model=model, + optimizers=optimizers, + optimizer_lrs=optimizer_lrs, + carry=None + ) + + +def save_train_state(config: PretrainConfig, train_state: TrainState): + # FIXME: Only saved model. + if config.checkpoint_path is None: + return + + os.makedirs(config.checkpoint_path, exist_ok=True) + torch.save(train_state.model.state_dict(), os.path.join(config.checkpoint_path, f"step_{train_state.step}")) + + +def compute_lr(base_lr: float, config: PretrainConfig, train_state: TrainState): + return cosine_schedule_with_warmup_lr_lambda( + current_step=train_state.step, + base_lr=base_lr, + num_warmup_steps=round(config.lr_warmup_steps), + num_training_steps=train_state.total_steps, + min_ratio=config.lr_min_ratio + ) + + +def train_batch(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int, world_size: int): + train_state.step += 1 + if train_state.step > train_state.total_steps: # At most train_total_steps + return + + # To device + batch = {k: v.cuda() for k, v in batch.items()} + + # Init carry if it is None + if train_state.carry is None: + with torch.device("cuda"): + train_state.carry = train_state.model.initial_carry(batch) # type: ignore + + # Forward + train_state.carry, loss, metrics, _, _ = train_state.model(carry=train_state.carry, batch=batch, return_keys=[]) + + ((1 / global_batch_size) * loss).backward() + + # Allreduce + if world_size > 1: + for param in train_state.model.parameters(): + if param.grad is not None: + dist.all_reduce(param.grad) + + # Apply optimizer + lr_this_step = None + for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs): + lr_this_step = compute_lr(base_lr, config, train_state) + + for param_group in optim.param_groups: + param_group['lr'] = lr_this_step + + optim.step() + optim.zero_grad() + + # Reduce metrics + if len(metrics): + assert not any(v.requires_grad for v in metrics.values()) + + metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order. + # Reduce and reconstruct + metric_values = torch.stack([metrics[k] for k in metric_keys]) + if world_size > 1: + dist.reduce(metric_values, dst=0) + + if rank == 0: + metric_values = metric_values.cpu().numpy() + reduced_metrics = {k: metric_values[i] for i, k in enumerate(metric_keys)} + + # Postprocess + count = max(reduced_metrics["count"], 1) # Avoid NaNs + reduced_metrics = {f"train/{k}": v / (global_batch_size if k.endswith("loss") else count) for k, v in reduced_metrics.items()} + + reduced_metrics["train/lr"] = lr_this_step + return reduced_metrics + + +def evaluate(config: PretrainConfig, train_state: TrainState, eval_loader: torch.utils.data.DataLoader, eval_metadata: PuzzleDatasetMetadata, rank: int, world_size: int): + with torch.inference_mode(): + set_ids = {k: idx for idx, k in enumerate(eval_metadata.sets)} + + all_preds = {} + + metric_keys = [] + metric_values = None + metric_global_batch_size = [0 for _ in range(len(set_ids))] + + carry = None + for set_name, batch, global_batch_size in eval_loader: + # To device + batch = {k: v.cuda() for k, v in batch.items()} + with torch.device("cuda"): + carry = train_state.model.initial_carry(batch) # type: ignore + + # Forward + while True: + carry, _, metrics, preds, all_finish = train_state.model(carry=carry, batch=batch, return_keys=config.eval_save_outputs) + + if all_finish: + break + + for collection in (batch, preds): + for k, v in collection.items(): + if k in config.eval_save_outputs: + all_preds.setdefault(k, []) + all_preds[k].append(v.cpu()) # Move to CPU for saving GPU memory + + del carry, preds, batch, all_finish + + # Aggregate + set_id = set_ids[set_name] + + if metric_values is None: + metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order. + metric_values = torch.zeros((len(set_ids), len(metrics.values())), dtype=torch.float32, device="cuda") + + metric_values[set_id] += torch.stack([metrics[k] for k in metric_keys]) + metric_global_batch_size[set_id] += global_batch_size + + if len(all_preds) and config.checkpoint_path is not None: + all_preds = {k: torch.cat(v, dim=0) for k, v in all_preds.items()} + + os.makedirs(config.checkpoint_path, exist_ok=True) + torch.save(all_preds, os.path.join(config.checkpoint_path, f"step_{train_state.step}_all_preds.{rank}")) + + # Logging + # Reduce to rank 0 + if metric_values is not None: + if world_size > 1: + dist.reduce(metric_values, dst=0) + + if rank == 0: + reduced_metrics = metric_values.cpu().numpy() + reduced_metrics = {set_name: {metric_name: reduced_metrics[set_id, metric_id] for metric_id, metric_name in enumerate(metric_keys)} + for set_id, set_name in enumerate(set_ids)} + + # Postprocess + for set_name, metrics in reduced_metrics.items(): + count = metrics.pop("count") + reduced_metrics[set_name] = {k: v / count for k, v in metrics.items()} + + return reduced_metrics + + +def save_code_and_config(config: PretrainConfig): + if config.checkpoint_path is None or wandb.run is None: + return + + os.makedirs(config.checkpoint_path, exist_ok=True) + + # Copy code + code_list = [ + get_model_source_path(config.arch.name), + get_model_source_path(config.arch.loss.name) + ] + for code_file in code_list: + if code_file is not None: + code_name = os.path.basename(code_file) + + shutil.copy(code_file, os.path.join(config.checkpoint_path, code_name)) + + # Dump config as yaml + config_file = os.path.join(config.checkpoint_path, "all_config.yaml") + with open(config_file, "wt") as f: + yaml.dump(config.model_dump(), f) + + # Log code + wandb.run.log_code(config.checkpoint_path) + + +def load_synced_config(hydra_config: DictConfig, rank: int, world_size: int) -> PretrainConfig: + objects = [None] + if rank == 0: + config = PretrainConfig(**hydra_config) # type: ignore + + # Naming + if config.project_name is None: + config.project_name = f"{os.path.basename(config.data_path).capitalize()} ACT-torch" + if config.run_name is None: + config.run_name = f"{config.arch.name.split('@')[-1]} {coolname.generate_slug(2)}" + if config.checkpoint_path is None: + config.checkpoint_path = os.path.join("checkpoints", config.project_name, config.run_name) + + objects = [config] + + if world_size > 1: + dist.broadcast_object_list(objects, src=0) + + return objects[0] # type: ignore + + +@hydra.main(config_path="config", config_name="cfg_pretrain", version_base=None) +def launch(hydra_config: DictConfig): + RANK = 0 + WORLD_SIZE = 1 + + # Initialize distributed training if in distributed environment (e.g. torchrun) + if "LOCAL_RANK" in os.environ: + # Initialize distributed, default device and dtype + dist.init_process_group(backend="nccl") + + RANK = dist.get_rank() + WORLD_SIZE = dist.get_world_size() + + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + + # Load sync'ed config + config = load_synced_config(hydra_config, rank=RANK, world_size=WORLD_SIZE) + + # Seed RNGs to ensure consistency + torch.random.manual_seed(config.seed + RANK) + + # Dataset + train_epochs_per_iter = config.eval_interval if config.eval_interval is not None else config.epochs + total_iters = config.epochs // train_epochs_per_iter + + assert config.epochs % train_epochs_per_iter == 0, "Eval interval must be a divisor of total epochs." + + train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=train_epochs_per_iter, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE) + eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE) + + # Train state + train_state = init_train_state(config, train_metadata, world_size=WORLD_SIZE) + + # Progress bar and logger + progress_bar = None + if RANK == 0: + progress_bar = tqdm.tqdm(total=train_state.total_steps) + + wandb.init(project=config.project_name, name=config.run_name, config=config.model_dump(), settings=wandb.Settings(_disable_stats=True)) # type: ignore + wandb.log({"num_params": sum(x.numel() for x in train_state.model.parameters())}, step=0) + save_code_and_config(config) + + # Training Loop + for _iter_id in range(total_iters): + print (f"[Rank {RANK}, World Size {WORLD_SIZE}]: Epoch {_iter_id * train_epochs_per_iter}") + + ############ Train Iter + train_state.model.train() + for set_name, batch, global_batch_size in train_loader: + metrics = train_batch(config, train_state, batch, global_batch_size, rank=RANK, world_size=WORLD_SIZE) + + if RANK == 0 and metrics is not None: + wandb.log(metrics, step=train_state.step) + progress_bar.update(train_state.step - progress_bar.n) # type: ignore + + ############ Evaluation + train_state.model.eval() + metrics = evaluate(config, train_state, eval_loader, eval_metadata, rank=RANK, world_size=WORLD_SIZE) + + if RANK == 0 and metrics is not None: + wandb.log(metrics, step=train_state.step) + + ############ Checkpointing + if RANK == 0 and (config.checkpoint_every_eval or (_iter_id == total_iters - 1)): + save_train_state(config, train_state) + + # finalize + if dist.is_initialized(): + dist.destroy_process_group() + wandb.finish() + + +if __name__ == "__main__": + launch() diff --git a/puzzle_dataset.py b/puzzle_dataset.py new file mode 100644 index 0000000..2782403 --- /dev/null +++ b/puzzle_dataset.py @@ -0,0 +1,199 @@ +import os +import json + +import numpy as np +import pydantic + +import torch +from torch.utils.data import IterableDataset, get_worker_info + +from models.losses import IGNORE_LABEL_ID +from dataset.common import PuzzleDatasetMetadata + + +def _sample_batch(rng: np.random.Generator, group_order: np.ndarray, puzzle_indices: np.ndarray, group_indices: np.ndarray, start_index: int, global_batch_size: int): + # Pack examples into a full batch + batch = [] + batch_puzzle_indices = [] + current_size = 0 + + while (start_index < group_order.size) and (current_size < global_batch_size): + # Pick a group and a puzzle from that group + group_id = group_order[start_index] + puzzle_id = rng.integers(group_indices[group_id], group_indices[group_id + 1]) + start_index += 1 + + # Get range of the puzzle + puzzle_start = puzzle_indices[puzzle_id] + puzzle_size = int(puzzle_indices[puzzle_id + 1] - puzzle_start) + + append_size = min(puzzle_size, global_batch_size - current_size) + + # Put into batch + batch_puzzle_indices.append(np.full(append_size, puzzle_id, dtype=np.int32)) + batch.append(puzzle_start + np.random.choice(puzzle_size, append_size, replace=False)) + + current_size += append_size + + return start_index, np.concatenate(batch), np.concatenate(batch_puzzle_indices) + + +class PuzzleDatasetConfig(pydantic.BaseModel): + seed: int + dataset_path: str + global_batch_size: int + test_set_mode: bool + + epochs_per_iter: int # Batch X epochs in an iteration to reduce overhead. + + rank: int + num_replicas: int + + +class PuzzleDataset(IterableDataset): + def __init__(self, config: PuzzleDatasetConfig, split: str = "train"): + super().__init__() + self.config = config + self.split = split + self.metadata = self._load_metadata() + + # Checks + assert self.config.global_batch_size % self.config.num_replicas == 0, f"Global batch size {self.config.global_batch_size} must be multiples of nodes {self.config.num_replicas}." + self.local_batch_size = self.config.global_batch_size // self.config.num_replicas + + # State + self._data = None + self._iters = 0 + + def _load_metadata(self) -> PuzzleDatasetMetadata: + with open(os.path.join(self.config.dataset_path, self.split, "dataset.json"), "r") as f: + return PuzzleDatasetMetadata(**json.load(f)) + + def _lazy_load_dataset(self): + if self._data is not None: + return + + field_mmap_modes = { + "inputs": "r", + "labels": "r", + + # Keep indices in memory + "puzzle_identifiers": None, + "puzzle_indices": None, + "group_indices": None + } + + # Load data + self._data = {} + for set_name in self.metadata.sets: + # Load subset + self._data[set_name] = { + field_name: np.load(os.path.join(self.config.dataset_path, self.split, f"{set_name}__{field_name}.npy"), mmap_mode=mmap_mode) + for field_name, mmap_mode in field_mmap_modes.items() + } + + def _collate_batch(self, batch): + # Convert dtype + batch = {k: v.astype(np.int32) for k, v in batch.items()} + + # Convert ignore label IDs + if self.metadata.ignore_label_id is not None: + batch["labels"][batch["labels"] == self.metadata.ignore_label_id] = IGNORE_LABEL_ID + + # Pad + if batch["puzzle_identifiers"].size < self.local_batch_size: + pad_size = self.local_batch_size - batch["puzzle_identifiers"].size + + pad_values = { + "inputs": self.metadata.pad_id, + "labels": IGNORE_LABEL_ID, + + "puzzle_identifiers": self.metadata.blank_identifier_id + } + batch = {k: np.pad(v, ((0, pad_size), ) + ((0, 0), ) * (v.ndim - 1), constant_values=pad_values[k]) for k, v in batch.items()} + + # To tensor + return {k: torch.from_numpy(v) for k, v in batch.items()} + + def _iter_test(self): + for set_name, dataset in self._data.items(): # type: ignore + total_examples = len(dataset["inputs"]) + + # Load examples one by one + start_index = 0 + while start_index < total_examples: + # Compute indices + end_index = min(total_examples, start_index + self.config.global_batch_size) + + local_start = start_index + self.config.rank * self.local_batch_size + local_end = min(start_index + (self.config.rank + 1) * self.local_batch_size, end_index) + + # Get batch of examples, and also puzzle IDs + puzzle_indices = [] + puzzle_index = np.searchsorted(dataset["puzzle_indices"], local_start, side="right") - 1 + for i in range(local_start, local_end): + while puzzle_index + 1 < len(dataset["puzzle_indices"]) and i >= dataset["puzzle_indices"][puzzle_index + 1]: + puzzle_index += 1 + + puzzle_indices.append(puzzle_index) + + batch = self._collate_batch({ + "inputs": dataset["inputs"][local_start: local_end], + "labels": dataset["labels"][local_start: local_end], + "puzzle_identifiers": dataset["puzzle_identifiers"][puzzle_indices] + }) + + yield set_name, batch, end_index - start_index + + # Advance to next batch + start_index += self.config.global_batch_size + + def _iter_train(self): + for set_name, dataset in self._data.items(): # type: ignore + # Increase epoch count + self._iters += 1 + + # Randomly shuffle groups + rng = np.random.Generator(np.random.Philox(seed=self.config.seed + self._iters)) + + group_order = np.concatenate([rng.permutation(dataset["group_indices"].size - 1) for _i in range(self.config.epochs_per_iter)]) + start_index = 0 + + while start_index < group_order.size: + start_index, batch_indices, batch_puzzle_indices = _sample_batch( + rng, + group_order=group_order, + puzzle_indices=dataset["puzzle_indices"], + group_indices=dataset["group_indices"], + start_index=start_index, + global_batch_size=self.config.global_batch_size, + ) + + # Select current rank and collate + global_effective_batch_size = batch_puzzle_indices.size # Global effective batch size, excluding pads + + # Drop last batch + if global_effective_batch_size < self.config.global_batch_size: + break + + batch_indices = batch_indices [self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size] + batch_puzzle_indices = batch_puzzle_indices[self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size] + batch = self._collate_batch({ + "inputs": dataset["inputs"][batch_indices], + "labels": dataset["labels"][batch_indices], + "puzzle_identifiers": dataset["puzzle_identifiers"][batch_puzzle_indices] + }) + + yield set_name, batch, global_effective_batch_size + + def __iter__(self): + worker_info = get_worker_info() + assert worker_info is None or worker_info.num_workers == 1, "Multithreaded data loading is not currently supported." + + self._lazy_load_dataset() + + # Iterate using specified mode + if self.config.test_set_mode: + yield from self._iter_test() + else: + yield from self._iter_train() diff --git a/puzzle_visualizer.html b/puzzle_visualizer.html new file mode 100644 index 0000000..bcefdf1 --- /dev/null +++ b/puzzle_visualizer.html @@ -0,0 +1,426 @@ +<!DOCTYPE html> +<html> +<head> + <meta charset="UTF-8" /> + <title>ARC‐Converted Dataset Visualizer (Upload Local Folder)</title> + <style> + body { + font-family: sans-serif; + margin: 16px; + } + .selector-area { + margin-bottom: 1rem; + } + .grid-canvas { + margin: 4px; + border: 1px solid #ccc; + } + .example-container { + display: inline-block; + margin: 0 16px 16px 0; + vertical-align: top; + } + .puzzle-display { + margin-top: 1rem; + } + .puzzle-id { + font-weight: bold; + margin-bottom: 0.5rem; + } + #groupList, #puzzleList { + margin: 1rem 0; + } + .group-item, .puzzle-item { + cursor: pointer; + margin: 4px 8px 4px 0; + padding: 2px 6px; + border: 1px solid #aaa; + display: inline-block; + } + .group-item:hover, .puzzle-item:hover { + background: #eef; + } + </style> +</head> +<body> +<h1>ARC‐Converted Dataset Visualizer (Local Directory)</h1> + +<div class="selector-area"> + <!-- 1) Directory input with webkitdirectory, mozdirectory --> + <label>Upload ARC Folder:</label> + <input type="file" id="folderInput" + webkitdirectory mozdirectory multiple + onchange="onFolderSelected(event)" /> + <br><br> + + <!-- 2) We'll enable set/subset selection after user chooses a folder and data is validated --> + <label>Set:</label> + <select id="setSelect" disabled> + <option value="train">train</option> + <option value="test">test</option> + </select> + + <label> Subset:</label> + <select id="subsetSelect" disabled> + <option value="all">all</option> + </select> + + <button id="loadBtn" disabled>Load</button> +</div> + +<div> + <div id="groupList"></div> + <div id="puzzleList"></div> + <div class="puzzle-display" id="puzzleView"></div> +</div> + +<!-- + 3) Use local 'assets/npyjs.js' from your project folder instead of a CDN. + Make sure 'assets/npyjs.js' is the unbundled or UMD version that doesn't + contain "import" statements. +--> +<script src="assets/npyjs.js"></script> + +<script> +/*************************************************************************** + * Global Maps & Variables + ***************************************************************************/ + +// Map from "train/all__inputs.npy" => File, etc. +let filesByPath = {}; + +// Once loaded, we store typed arrays for the chosen set/subset +let inputsArr, labelsArr; +let puzzleIndicesArr, groupIndicesArr, puzzleIdentifiersArr; +let identifiersJson; + +// The shape of inputs is [N_examples, seqLen], so we discover seqLen & gridSize +let seqLen = 0; +let gridSize = 0; + + +/*************************************************************************** + * 1) Handle folder selection: read all files, find identifiers.json, + * remove topmost folder from each file path, validate. + ***************************************************************************/ +function onFolderSelected(event) { + filesByPath = {}; + const fileList = event.target.files; + if (!fileList || fileList.length === 0) { + alert("No files selected!"); + return; + } + + // We'll gather all webkitRelativePaths + const paths = []; + for (let i = 0; i < fileList.length; i++) { + // Typically "arc-aug-10/train/all__inputs.npy", etc. + const file = fileList[i]; + const relPath = file.webkitRelativePath || file.mozRelativePath || file.name; + paths.push(relPath); + } + + // 1. Check if we have "identifiers.json" somewhere. + const idPath = paths.find(p => p.endsWith("identifiers.json")); + if (!idPath) { + alert("Error: No 'identifiers.json' found in the uploaded folder."); + return; + } + + // 2. Derive the top-level directory from that file's path + // e.g. if idPath = "arc-aug-10/identifiers.json", topDir = "arc-aug-10" + // If there's no slash, topDir = "" => do nothing + let topDir = ""; + const lastSlash = idPath.lastIndexOf("/"); + if (lastSlash >= 0) { + topDir = idPath.substring(0, lastSlash); + } + + // 3. Rebuild filesByPath with the top folder removed. + // For example, if topDir = "arc-aug-10", then "arc-aug-10/train/all__inputs.npy" + // becomes "train/all__inputs.npy" + for (let i = 0; i < fileList.length; i++) { + const file = fileList[i]; + let relPath = file.webkitRelativePath || file.mozRelativePath || file.name; + // If relPath starts with "arc-aug-10/", remove that prefix + if (topDir && relPath.startsWith(topDir + "/")) { + relPath = relPath.substring(topDir.length + 1); + } + filesByPath[relPath] = file; + } + + // Enable set/subset selection and "Load" + document.getElementById("setSelect").disabled = false; + document.getElementById("subsetSelect").disabled = false; + document.getElementById("loadBtn").disabled = false; +} + +// When user clicks "Load," parse the .npy for the chosen set/subset +document.getElementById("loadBtn").addEventListener("click", async () => { + document.getElementById("groupList").innerHTML = ""; + document.getElementById("puzzleList").innerHTML = ""; + document.getElementById("puzzleView").innerHTML = ""; + + const setName = document.getElementById("setSelect").value; // e.g. "train" + const subsetName = document.getElementById("subsetSelect").value; // e.g. "all" + + try { + await loadDataset(setName, subsetName); + buildGroupList(); // show groups + } catch (err) { + console.error(err); + alert("Error while loading dataset: " + err); + } +}); + + +/*************************************************************************** + * 2) Load .npy from local files using Npyjs + FileReader (ArrayBuffer) + ***************************************************************************/ +async function loadDataset(setName, subsetName) { + const prefix = `${setName}/${subsetName}__`; + // e.g. "train/all__inputs.npy" + const inputsPath = prefix + "inputs.npy"; + const labelsPath = prefix + "labels.npy"; + const pIdxPath = prefix + "puzzle_indices.npy"; + const gIdxPath = prefix + "group_indices.npy"; + const pIdsPath = prefix + "puzzle_identifiers.npy"; + const identifiersPath = "identifiers.json"; + + // Check existence + const needed = [inputsPath, labelsPath, pIdxPath, gIdxPath, pIdsPath, identifiersPath]; + for (const f of needed) { + if (!filesByPath[f]) { + throw new Error(`Missing file: ${f}`); + } + } + + // parseNpy => read from File -> ArrayBuffer -> Npyjs => typed array + const inputsNpy = await parseNpy(filesByPath[inputsPath]); + const labelsNpy = await parseNpy(filesByPath[labelsPath]); + const puzzleIndicesNpy= await parseNpy(filesByPath[pIdxPath]); + const groupIndicesNpy = await parseNpy(filesByPath[gIdxPath]); + const puzzleIdsNpy = await parseNpy(filesByPath[pIdsPath]); + + inputsArr = inputsNpy.data; + labelsArr = labelsNpy.data; + puzzleIndicesArr = puzzleIndicesNpy.data; + groupIndicesArr = groupIndicesNpy.data; + puzzleIdentifiersArr = puzzleIdsNpy.data; + + // shape e.g. [N_examples, seqLen] + seqLen = inputsNpy.shape[1]; + gridSize = Math.sqrt(seqLen); + + // read JSON + identifiersJson = await readJsonFile(filesByPath[identifiersPath]); +} + +/*************************************************************************** + * parseNpy => read a File as ArrayBuffer, parse with npyjs + ***************************************************************************/ +function parseNpy(file) { + return new Promise((resolve, reject) => { + const reader = new FileReader(); + reader.onload = async () => { + try { + const arrayBuffer = reader.result; + const npy = new npyjs(); + resolve(await npy.parse(arrayBuffer)); + } catch (err) { + reject(err); + } + }; + reader.onerror = err => reject(err); + reader.readAsArrayBuffer(file); + }); +} + +/*************************************************************************** + * readJsonFile => read a local JSON file into object + ***************************************************************************/ +function readJsonFile(file) { + return new Promise((resolve, reject) => { + const reader = new FileReader(); + reader.onload = () => { + try { + const obj = JSON.parse(reader.result); + resolve(obj); + } catch (err) { + reject(err); + } + }; + reader.onerror = (err) => reject(err); + reader.readAsText(file); + }); +} + +/*************************************************************************** + * 3) Build group list in UI + ***************************************************************************/ +function buildGroupList() { + document.getElementById("groupList").innerHTML = "<h3>Groups</h3>"; + const groupListDiv = document.getElementById("groupList"); + + const nGroups = groupIndicesArr.length - 1; + for (let g = 0; g < nGroups; g++) { + const div = document.createElement("span"); + div.className = "group-item"; + div.textContent = `Group ${g}`; + div.onclick = () => onSelectGroup(g); + groupListDiv.appendChild(div); + } +} + +/*************************************************************************** + * onSelectGroup => show puzzles in that group + ***************************************************************************/ +function onSelectGroup(groupIndex) { + document.getElementById("puzzleList").innerHTML = ""; + document.getElementById("puzzleView").innerHTML = ""; + + const puzzleListDiv = document.getElementById("puzzleList"); + puzzleListDiv.innerHTML = `<h4>Puzzles in Group ${groupIndex}</h4>`; + + const firstPuzzle = groupIndicesArr[groupIndex]; + const lastPuzzle = groupIndicesArr[groupIndex + 1]; + + for (let p = firstPuzzle; p < lastPuzzle; p++) { + const puzzleIntId = puzzleIdentifiersArr[p]; + const puzzleStrId = (puzzleIntId < identifiersJson.length) + ? identifiersJson[puzzleIntId] + : "<unknown>"; + + const div = document.createElement("span"); + div.className = "puzzle-item"; + div.textContent = `Puzzle #${p} [ID=${puzzleIntId}: ${puzzleStrId}]`; + div.onclick = () => onSelectPuzzle(p); + puzzleListDiv.appendChild(div); + } +} + +/*************************************************************************** + * onSelectPuzzle => show each example + ***************************************************************************/ +function onSelectPuzzle(puzzleIndex) { + const puzzleView = document.getElementById("puzzleView"); + puzzleView.innerHTML = ""; + + // puzzle ID + const puzzleIntId = puzzleIdentifiersArr[puzzleIndex]; + const puzzleStrId = (puzzleIntId < identifiersJson.length) + ? identifiersJson[puzzleIntId] + : "<unknown>"; + + const titleDiv = document.createElement("div"); + titleDiv.className = "puzzle-id"; + titleDiv.textContent = `Puzzle #${puzzleIndex} — ID: ${puzzleStrId}`; + puzzleView.appendChild(titleDiv); + + // Examples are [puzzleIndicesArr[p], puzzleIndicesArr[p+1]) + const firstExample = puzzleIndicesArr[puzzleIndex]; + const lastExample = puzzleIndicesArr[puzzleIndex + 1]; + + for (let e = firstExample; e < lastExample; e++) { + const inputSeq = slice1D(inputsArr, e*seqLen, (e+1)*seqLen); + const outputSeq = slice1D(labelsArr, e*seqLen, (e+1)*seqLen); + + const inputGrid = decodeGrid(inputSeq); + const outputGrid = decodeGrid(outputSeq); + + const exDiv = document.createElement("div"); + exDiv.className = "example-container"; + exDiv.appendChild(document.createTextNode(`Example ${e}`)); + exDiv.appendChild(document.createElement("br")); + + exDiv.appendChild(renderGrid(inputGrid)); + exDiv.appendChild(renderGrid(outputGrid)); + + puzzleView.appendChild(exDiv); + } +} + +/*************************************************************************** + * slice1D => typed array slicing + ***************************************************************************/ +function slice1D(arr, start, end) { + const result = new Uint32Array(end - start); + for (let i = start; i < end; i++) { + result[i - start] = Number(arr[i]); + } + return result; +} + +/*************************************************************************** + * decodeGrid => turn the flattened seq of length=gridSize^2 into 2D + ***************************************************************************/ +function decodeGrid(seq) { + const grid = []; + let idx = 0; + for (let r = 0; r < gridSize; r++) { + const row = []; + for (let c = 0; c < gridSize; c++) { + row.push(seq[idx]); + idx++; + } + grid.push(row); + } + return grid; +} + +/*************************************************************************** + * renderGrid => draws a 2D grid to <canvas> + ***************************************************************************/ +function renderGrid(grid2d) { + const rows = grid2d.length; + const cols = grid2d[0].length; + const scale = 10; + + const canvas = document.createElement("canvas"); + canvas.width = cols * scale; + canvas.height = rows * scale; + canvas.className = "grid-canvas"; + const ctx = canvas.getContext("2d"); + + for (let r = 0; r < rows; r++) { + for (let c = 0; c < cols; c++) { + const val = grid2d[r][c]; + ctx.fillStyle = indexToColor(val); + ctx.fillRect(c * scale, r * scale, scale, scale); + } + } + return canvas; +} + +/*************************************************************************** + * indexToColor => color palette: + * 0 => pad => white + * 1 => eos => light gray + * 2..11 => original color(0..9) + ***************************************************************************/ +function indexToColor(value) { + if (value === 0) return "#FFFFFF"; // pad => white + if (value === 1) return "#DDDDDD"; // eos => light gray + + // shift by 2 => original color in [0..9] + const colorIdx = value - 2; + const palette = [ + "#000000", // color0 => black + "#FF0000", // color1 => red + "#00FF00", // color2 => green + "#0000FF", // color3 => blue + "#FFFF00", // color4 => yellow + "#FFA500", // color5 => orange + "#800080", // color6 => purple + "#00FFFF", // color7 => cyan + "#FFC0CB", // color8 => pink + "#808080" // color9 => gray + ]; + if (colorIdx >= 0 && colorIdx < palette.length) { + return palette[colorIdx]; + } + return "#FFFFFF"; // fallback +} +</script> +</body> +</html> diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..8c90d6f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +torch +adam-atan2 +einops +tqdm +coolname +pydantic +argdantic +wandb +omegaconf +hydra-core +huggingface_hub diff --git a/utils/functions.py b/utils/functions.py new file mode 100644 index 0000000..b123636 --- /dev/null +++ b/utils/functions.py @@ -0,0 +1,19 @@ +import importlib +import inspect + + +def load_model_class(identifier: str, prefix: str = "models."): + module_path, class_name = identifier.split('@') + + # Import the module + module = importlib.import_module(prefix + module_path) + cls = getattr(module, class_name) + + return cls + + +def get_model_source_path(identifier: str, prefix: str = "models."): + module_path, class_name = identifier.split('@') + + module = importlib.import_module(prefix + module_path) + return inspect.getsourcefile(module) |
