summaryrefslogtreecommitdiff
path: root/scripts/setup_env.sh
blob: 66a94c8b44545218c8eb29bce292fcedf7883b11 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#!/usr/bin/env bash
set -euo pipefail

ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
cd "${ROOT_DIR}"

PYTHON_BIN="${PYTHON_BIN:-python3}"
VENV_DIR="${VENV_DIR:-.venv}"
TORCH_INDEX_URL="${TORCH_INDEX_URL:-https://download.pytorch.org/whl/cu124}"

if [[ ! -d "${VENV_DIR}" ]]; then
  "${PYTHON_BIN}" -m venv "${VENV_DIR}"
fi

source "${VENV_DIR}/bin/activate"
python -m pip install --upgrade pip wheel setuptools

if ! python - <<'PY' >/dev/null 2>&1
import torch
assert torch.cuda.is_available() or True
PY
then
  python -m pip install torch --index-url "${TORCH_INDEX_URL}"
fi

python -m pip install -r requirements.txt

python - <<'PY'
import torch
import torch_geometric
import ogb
print("torch", torch.__version__, "cuda_available", torch.cuda.is_available())
print("torch_geometric", torch_geometric.__version__)
print("ogb", ogb.__version__)
PY