diff options
Diffstat (limited to 'scripts/setup_env.sh')
| -rwxr-xr-x | scripts/setup_env.sh | 35 |
1 files changed, 35 insertions, 0 deletions
diff --git a/scripts/setup_env.sh b/scripts/setup_env.sh new file mode 100755 index 0000000..66a94c8 --- /dev/null +++ b/scripts/setup_env.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "${ROOT_DIR}" + +PYTHON_BIN="${PYTHON_BIN:-python3}" +VENV_DIR="${VENV_DIR:-.venv}" +TORCH_INDEX_URL="${TORCH_INDEX_URL:-https://download.pytorch.org/whl/cu124}" + +if [[ ! -d "${VENV_DIR}" ]]; then + "${PYTHON_BIN}" -m venv "${VENV_DIR}" +fi + +source "${VENV_DIR}/bin/activate" +python -m pip install --upgrade pip wheel setuptools + +if ! python - <<'PY' >/dev/null 2>&1 +import torch +assert torch.cuda.is_available() or True +PY +then + python -m pip install torch --index-url "${TORCH_INDEX_URL}" +fi + +python -m pip install -r requirements.txt + +python - <<'PY' +import torch +import torch_geometric +import ogb +print("torch", torch.__version__, "cuda_available", torch.cuda.is_available()) +print("torch_geometric", torch_geometric.__version__) +print("ogb", ogb.__version__) +PY |
