summaryrefslogtreecommitdiff
path: root/scripts/run_ogb_act_task.sh
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/run_ogb_act_task.sh')
-rwxr-xr-xscripts/run_ogb_act_task.sh119
1 files changed, 119 insertions, 0 deletions
diff --git a/scripts/run_ogb_act_task.sh b/scripts/run_ogb_act_task.sh
new file mode 100755
index 0000000..37eb1a6
--- /dev/null
+++ b/scripts/run_ogb_act_task.sh
@@ -0,0 +1,119 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "${ROOT_DIR}"
+export PYTHONPATH="${ROOT_DIR}:${PYTHONPATH:-}"
+
+TASK="${TASK:-ogbg-molhiv}"
+DEVICE="${DEVICE:-cuda:0}"
+EPOCHS="${EPOCHS:-100}"
+SEEDS="${SEEDS:-${SEED:-0}}"
+HIDDEN="${HIDDEN:-128}"
+BS="${BS:-128}"
+LR="${LR:-}"
+EVAL_EVERY="${EVAL_EVERY:-10}"
+NUM_WORKERS="${NUM_WORKERS:-0}"
+T="${T:-1}"
+N_SUP="${N_SUP:-3}"
+HALT_MAX="${HALT_MAX:-8}"
+HALT_MIN="${HALT_MIN:-2}"
+HALT_TARGET="${HALT_TARGET:-loss}"
+HALT_LOSS_THRESHOLD="${HALT_LOSS_THRESHOLD:-0.2}"
+HALT_EXPLORATION_PROB="${HALT_EXPLORATION_PROB:-0.1}"
+LAM_Q="${LAM_Q:-0.1}"
+Q_WARMUP="${Q_WARMUP:-0}"
+ACT_TRAIN_MODE="${ACT_TRAIN_MODE:-stream}"
+EMA="${EMA:-0}"
+MAX_TRAIN_BATCHES="${MAX_TRAIN_BATCHES:-}"
+MAX_EVAL_BATCHES="${MAX_EVAL_BATCHES:-}"
+COLLECT="${COLLECT:-1}"
+VIEWS="${VIEWS:-gin gine gcn graphsage gatv2 graphconv transformer pna gen film resgated tag sgc cheb arma mf appnp}"
+
+mkdir -p runs logs summaries
+
+fmt_float() {
+ python3 - "$1" <<'PY'
+import sys
+print(f"{float(sys.argv[1]):g}")
+PY
+}
+
+result_path() {
+ local view="$1"
+ local seed="$2"
+ local target_tag="${HALT_TARGET}"
+ local loss_tag
+ local lam_tag
+ local hex_tag
+ local ema_tag=""
+ loss_tag="$(fmt_float "${HALT_LOSS_THRESHOLD}")"
+ lam_tag="$(fmt_float "${LAM_Q}")"
+ hex_tag="$(fmt_float "${HALT_EXPLORATION_PROB}")"
+ if [[ "${HALT_TARGET}" == "loss" ]]; then
+ target_tag="loss${loss_tag}"
+ fi
+ if [[ "$(fmt_float "${EMA}")" != "0" ]]; then
+ ema_tag="_ema$(fmt_float "${EMA}")"
+ fi
+ echo "runs/${TASK}_${view}_rrog-act_T${T}_ns${N_SUP}_${ACT_TRAIN_MODE}_hm${HALT_MAX}_hmin${HALT_MIN}_${target_tag}_lq${lam_tag}_hex${hex_tag}_qw${Q_WARMUP}_h${HIDDEN}_e${EPOCHS}${ema_tag}_s${seed}.json"
+}
+
+run_cell() {
+ local view="$1"
+ local seed="$2"
+ local out
+ out="$(result_path "${view}" "${seed}")"
+ if [[ -f "${out}" ]]; then
+ echo "[skip] ${out}"
+ return
+ fi
+
+ echo "[run] ${TASK} view=${view} compute=rrog-act mode=${ACT_TRAIN_MODE} T=${T} ns=${N_SUP} seed=${seed} device=${DEVICE}"
+ cmd=(
+ python3 -m rrog.cli run
+ --task "${TASK}"
+ --view "${view}"
+ --compute rrog-act
+ --epochs "${EPOCHS}"
+ --hidden "${HIDDEN}"
+ --bs "${BS}"
+ --T "${T}"
+ --n_sup "${N_SUP}"
+ --halt_max_steps "${HALT_MAX}"
+ --halt_min_steps "${HALT_MIN}"
+ --halt_target "${HALT_TARGET}"
+ --halt_loss_threshold "${HALT_LOSS_THRESHOLD}"
+ --halt_exploration_prob "${HALT_EXPLORATION_PROB}"
+ --lam_q "${LAM_Q}"
+ --q_warmup_epochs "${Q_WARMUP}"
+ --act_train_mode "${ACT_TRAIN_MODE}"
+ --eval_every "${EVAL_EVERY}"
+ --num_workers "${NUM_WORKERS}"
+ --seed "${seed}"
+ --device "${DEVICE}"
+ )
+ if [[ -n "${LR}" ]]; then
+ cmd+=(--lr "${LR}")
+ fi
+ if [[ "$(fmt_float "${EMA}")" != "0" ]]; then
+ cmd+=(--ema "${EMA}")
+ fi
+ if [[ -n "${MAX_TRAIN_BATCHES}" ]]; then
+ cmd+=(--max_train_batches "${MAX_TRAIN_BATCHES}")
+ fi
+ if [[ -n "${MAX_EVAL_BATCHES}" ]]; then
+ cmd+=(--max_eval_batches "${MAX_EVAL_BATCHES}")
+ fi
+ "${cmd[@]}"
+}
+
+for seed in ${SEEDS}; do
+ for view in ${VIEWS}; do
+ run_cell "${view}" "${seed}"
+ done
+done
+
+if [[ "${COLLECT}" == "1" ]]; then
+ python3 -m rrog.cli results --epochs "${EPOCHS}" | tee "summaries/ogb_graphprop_act_${TASK}_e${EPOCHS}.md"
+fi