diff options
Diffstat (limited to 'README.md')
| -rw-r--r-- | README.md | 32 |
1 files changed, 27 insertions, 5 deletions
@@ -15,16 +15,16 @@ These results underscore HRMās potential as a transformative advancement towar Ensure PyTorch and CUDA are installed. The repo needs CUDA extensions to be built. If not present, run the following commands: ```bash -# Install CUDA 12.4 -CUDA_URL=https://developer.download.nvidia.com/compute/cuda/12.4.1/local_installers/cuda_12.4.1_550.54.15_linux.run +# Install CUDA 12.6 +CUDA_URL=https://developer.download.nvidia.com/compute/cuda/12.6.3/local_installers/cuda_12.6.3_560.35.05_linux.run wget -q --show-progress --progress=bar:force:noscroll -O cuda_installer.run $CUDA_URL sudo sh cuda_installer.run --silent --toolkit --override -export CUDA_HOME=/usr/local/cuda-12.4 +export CUDA_HOME=/usr/local/cuda-12.6 -# Install PyTorch with CUDA 12.4 -PYTORCH_INDEX_URL=https://download.pytorch.org/whl/cu124 +# Install PyTorch with CUDA 12.6 +PYTORCH_INDEX_URL=https://download.pytorch.org/whl/cu126 pip3 install torch torchvision torchaudio --index-url $PYTORCH_INDEX_URL @@ -32,6 +32,20 @@ pip3 install torch torchvision torchaudio --index-url $PYTORCH_INDEX_URL pip3 install packaging ninja wheel setuptools setuptools-scm ``` +Then install FlashAttention. For Hopper GPUs, install FlashAttention 3 + +```bash +git clone git@github.com:Dao-AILab/flash-attention.git +cd flash-attention/hopper +python setup.py install +``` + +For Ampere or earlier GPUs, install FlashAttenion 2 + +```bash +pip3 install flash-attn +``` + ## Install Python Dependencies š ```bash @@ -62,6 +76,14 @@ OMP_NUM_THREADS=8 python pretrain.py data_path=data/sudoku-extreme-1k-aug-1000 e Runtime: ~10 hours on a RTX 4070 laptop GPU +## Trained Checkpoints š§ + + - [ARC-AGI-2](https://huggingface.co/sapientinc/HRM-checkpoint-ARC-2) + - [Sudoku 9x9 Extreme (1000 examples)](https://huggingface.co/sapientinc/HRM-checkpoint-sudoku-extreme) + - [Maze 30x30 Hard (1000 examples)](https://huggingface.co/sapientinc/HRM-checkpoint-maze-30x30-hard) + +To use the checkpoints, see Evaluation section below. + ## Full-scale Experiments šµ Experiments below assume an 8-GPU setup. |
