summaryrefslogtreecommitdiff
path: root/env/setup.sh
blob: 6a8435c9f6a5d44f303224e824582175ce1ab216 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#!/usr/bin/env bash
# rrm 统一环境安装脚本
# 复现 HRM (sapientinc/HRM) + TRM (SamsungSAILMontreal/TinyRecursiveModels) 用
# 用法: bash env/setup.sh
set -euo pipefail

ENV_NAME=${ENV_NAME:-rrm}
PY_VER=${PY_VER:-3.10}

source "$(conda info --base)/etc/profile.d/conda.sh"

if ! conda env list | awk '{print $1}' | grep -qx "$ENV_NAME"; then
  conda create -n "$ENV_NAME" python="$PY_VER" -y
fi
conda activate "$ENV_NAME"

pip install --upgrade pip wheel setuptools packaging ninja setuptools-scm

# Torch 2.7.0 cu126 (TRM specific_requirements pinned)
pip install torch==2.7.0+cu126 torchvision==0.22.0+cu126 torchaudio==2.7.0+cu126 \
  --index-url https://download.pytorch.org/whl/cu126

# 合并版 requirements (HRM + TRM)
pip install -r "$(dirname "$0")/requirements.txt"

# FlashAttention 2 — A6000/Ampere prebuilt wheel (cxx11abi=TRUE 对应 torch 2.7 cu126)
pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp310-cp310-linux_x86_64.whl

# adam-atan2 需要 nvcc 编译; conda 装 cuda-toolkit 提供 nvcc + 头文件
conda install -y -c nvidia cuda-toolkit=12.6
CUDA_HOME="$CONDA_PREFIX" pip install --no-cache-dir --no-build-isolation adam-atan2==0.0.3

# wandb 离线 (避免 smoke test 上传)
export WANDB_MODE=${WANDB_MODE:-offline}

python - <<'PY'
import torch, flash_attn, adam_atan2
print("torch:", torch.__version__, "CUDA:", torch.version.cuda)
print("GPUs:", torch.cuda.device_count(), "->", [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())])
print("flash_attn:", flash_attn.__version__)
print("adam_atan2 ok")
PY

echo
echo "==> rrm env ready. conda activate $ENV_NAME"