summaryrefslogtreecommitdiff
path: root/scripts/setup_env.sh
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/setup_env.sh')
-rwxr-xr-xscripts/setup_env.sh35
1 files changed, 35 insertions, 0 deletions
diff --git a/scripts/setup_env.sh b/scripts/setup_env.sh
new file mode 100755
index 0000000..66a94c8
--- /dev/null
+++ b/scripts/setup_env.sh
@@ -0,0 +1,35 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "${ROOT_DIR}"
+
+PYTHON_BIN="${PYTHON_BIN:-python3}"
+VENV_DIR="${VENV_DIR:-.venv}"
+TORCH_INDEX_URL="${TORCH_INDEX_URL:-https://download.pytorch.org/whl/cu124}"
+
+if [[ ! -d "${VENV_DIR}" ]]; then
+ "${PYTHON_BIN}" -m venv "${VENV_DIR}"
+fi
+
+source "${VENV_DIR}/bin/activate"
+python -m pip install --upgrade pip wheel setuptools
+
+if ! python - <<'PY' >/dev/null 2>&1
+import torch
+assert torch.cuda.is_available() or True
+PY
+then
+ python -m pip install torch --index-url "${TORCH_INDEX_URL}"
+fi
+
+python -m pip install -r requirements.txt
+
+python - <<'PY'
+import torch
+import torch_geometric
+import ogb
+print("torch", torch.__version__, "cuda_available", torch.cuda.is_available())
+print("torch_geometric", torch_geometric.__version__)
+print("ogb", ogb.__version__)
+PY