diff options
Diffstat (limited to 'scripts/run_ogb_act_task.sh')
| -rwxr-xr-x | scripts/run_ogb_act_task.sh | 119 |
1 files changed, 119 insertions, 0 deletions
diff --git a/scripts/run_ogb_act_task.sh b/scripts/run_ogb_act_task.sh new file mode 100755 index 0000000..37eb1a6 --- /dev/null +++ b/scripts/run_ogb_act_task.sh @@ -0,0 +1,119 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "${ROOT_DIR}" +export PYTHONPATH="${ROOT_DIR}:${PYTHONPATH:-}" + +TASK="${TASK:-ogbg-molhiv}" +DEVICE="${DEVICE:-cuda:0}" +EPOCHS="${EPOCHS:-100}" +SEEDS="${SEEDS:-${SEED:-0}}" +HIDDEN="${HIDDEN:-128}" +BS="${BS:-128}" +LR="${LR:-}" +EVAL_EVERY="${EVAL_EVERY:-10}" +NUM_WORKERS="${NUM_WORKERS:-0}" +T="${T:-1}" +N_SUP="${N_SUP:-3}" +HALT_MAX="${HALT_MAX:-8}" +HALT_MIN="${HALT_MIN:-2}" +HALT_TARGET="${HALT_TARGET:-loss}" +HALT_LOSS_THRESHOLD="${HALT_LOSS_THRESHOLD:-0.2}" +HALT_EXPLORATION_PROB="${HALT_EXPLORATION_PROB:-0.1}" +LAM_Q="${LAM_Q:-0.1}" +Q_WARMUP="${Q_WARMUP:-0}" +ACT_TRAIN_MODE="${ACT_TRAIN_MODE:-stream}" +EMA="${EMA:-0}" +MAX_TRAIN_BATCHES="${MAX_TRAIN_BATCHES:-}" +MAX_EVAL_BATCHES="${MAX_EVAL_BATCHES:-}" +COLLECT="${COLLECT:-1}" +VIEWS="${VIEWS:-gin gine gcn graphsage gatv2 graphconv transformer pna gen film resgated tag sgc cheb arma mf appnp}" + +mkdir -p runs logs summaries + +fmt_float() { + python3 - "$1" <<'PY' +import sys +print(f"{float(sys.argv[1]):g}") +PY +} + +result_path() { + local view="$1" + local seed="$2" + local target_tag="${HALT_TARGET}" + local loss_tag + local lam_tag + local hex_tag + local ema_tag="" + loss_tag="$(fmt_float "${HALT_LOSS_THRESHOLD}")" + lam_tag="$(fmt_float "${LAM_Q}")" + hex_tag="$(fmt_float "${HALT_EXPLORATION_PROB}")" + if [[ "${HALT_TARGET}" == "loss" ]]; then + target_tag="loss${loss_tag}" + fi + if [[ "$(fmt_float "${EMA}")" != "0" ]]; then + ema_tag="_ema$(fmt_float "${EMA}")" + fi + echo "runs/${TASK}_${view}_rrog-act_T${T}_ns${N_SUP}_${ACT_TRAIN_MODE}_hm${HALT_MAX}_hmin${HALT_MIN}_${target_tag}_lq${lam_tag}_hex${hex_tag}_qw${Q_WARMUP}_h${HIDDEN}_e${EPOCHS}${ema_tag}_s${seed}.json" +} + +run_cell() { + local view="$1" + local seed="$2" + local out + out="$(result_path "${view}" "${seed}")" + if [[ -f "${out}" ]]; then + echo "[skip] ${out}" + return + fi + + echo "[run] ${TASK} view=${view} compute=rrog-act mode=${ACT_TRAIN_MODE} T=${T} ns=${N_SUP} seed=${seed} device=${DEVICE}" + cmd=( + python3 -m rrog.cli run + --task "${TASK}" + --view "${view}" + --compute rrog-act + --epochs "${EPOCHS}" + --hidden "${HIDDEN}" + --bs "${BS}" + --T "${T}" + --n_sup "${N_SUP}" + --halt_max_steps "${HALT_MAX}" + --halt_min_steps "${HALT_MIN}" + --halt_target "${HALT_TARGET}" + --halt_loss_threshold "${HALT_LOSS_THRESHOLD}" + --halt_exploration_prob "${HALT_EXPLORATION_PROB}" + --lam_q "${LAM_Q}" + --q_warmup_epochs "${Q_WARMUP}" + --act_train_mode "${ACT_TRAIN_MODE}" + --eval_every "${EVAL_EVERY}" + --num_workers "${NUM_WORKERS}" + --seed "${seed}" + --device "${DEVICE}" + ) + if [[ -n "${LR}" ]]; then + cmd+=(--lr "${LR}") + fi + if [[ "$(fmt_float "${EMA}")" != "0" ]]; then + cmd+=(--ema "${EMA}") + fi + if [[ -n "${MAX_TRAIN_BATCHES}" ]]; then + cmd+=(--max_train_batches "${MAX_TRAIN_BATCHES}") + fi + if [[ -n "${MAX_EVAL_BATCHES}" ]]; then + cmd+=(--max_eval_batches "${MAX_EVAL_BATCHES}") + fi + "${cmd[@]}" +} + +for seed in ${SEEDS}; do + for view in ${VIEWS}; do + run_cell "${view}" "${seed}" + done +done + +if [[ "${COLLECT}" == "1" ]]; then + python3 -m rrog.cli results --epochs "${EPOCHS}" | tee "summaries/ogb_graphprop_act_${TASK}_e${EPOCHS}.md" +fi |
