summaryrefslogtreecommitdiff
path: root/scripts/run_ogb_act_task.sh
blob: 37eb1a6081c609bfb0d056535e24503cef144cea (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#!/usr/bin/env bash
set -euo pipefail

ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
cd "${ROOT_DIR}"
export PYTHONPATH="${ROOT_DIR}:${PYTHONPATH:-}"

TASK="${TASK:-ogbg-molhiv}"
DEVICE="${DEVICE:-cuda:0}"
EPOCHS="${EPOCHS:-100}"
SEEDS="${SEEDS:-${SEED:-0}}"
HIDDEN="${HIDDEN:-128}"
BS="${BS:-128}"
LR="${LR:-}"
EVAL_EVERY="${EVAL_EVERY:-10}"
NUM_WORKERS="${NUM_WORKERS:-0}"
T="${T:-1}"
N_SUP="${N_SUP:-3}"
HALT_MAX="${HALT_MAX:-8}"
HALT_MIN="${HALT_MIN:-2}"
HALT_TARGET="${HALT_TARGET:-loss}"
HALT_LOSS_THRESHOLD="${HALT_LOSS_THRESHOLD:-0.2}"
HALT_EXPLORATION_PROB="${HALT_EXPLORATION_PROB:-0.1}"
LAM_Q="${LAM_Q:-0.1}"
Q_WARMUP="${Q_WARMUP:-0}"
ACT_TRAIN_MODE="${ACT_TRAIN_MODE:-stream}"
EMA="${EMA:-0}"
MAX_TRAIN_BATCHES="${MAX_TRAIN_BATCHES:-}"
MAX_EVAL_BATCHES="${MAX_EVAL_BATCHES:-}"
COLLECT="${COLLECT:-1}"
VIEWS="${VIEWS:-gin gine gcn graphsage gatv2 graphconv transformer pna gen film resgated tag sgc cheb arma mf appnp}"

mkdir -p runs logs summaries

fmt_float() {
  python3 - "$1" <<'PY'
import sys
print(f"{float(sys.argv[1]):g}")
PY
}

result_path() {
  local view="$1"
  local seed="$2"
  local target_tag="${HALT_TARGET}"
  local loss_tag
  local lam_tag
  local hex_tag
  local ema_tag=""
  loss_tag="$(fmt_float "${HALT_LOSS_THRESHOLD}")"
  lam_tag="$(fmt_float "${LAM_Q}")"
  hex_tag="$(fmt_float "${HALT_EXPLORATION_PROB}")"
  if [[ "${HALT_TARGET}" == "loss" ]]; then
    target_tag="loss${loss_tag}"
  fi
  if [[ "$(fmt_float "${EMA}")" != "0" ]]; then
    ema_tag="_ema$(fmt_float "${EMA}")"
  fi
  echo "runs/${TASK}_${view}_rrog-act_T${T}_ns${N_SUP}_${ACT_TRAIN_MODE}_hm${HALT_MAX}_hmin${HALT_MIN}_${target_tag}_lq${lam_tag}_hex${hex_tag}_qw${Q_WARMUP}_h${HIDDEN}_e${EPOCHS}${ema_tag}_s${seed}.json"
}

run_cell() {
  local view="$1"
  local seed="$2"
  local out
  out="$(result_path "${view}" "${seed}")"
  if [[ -f "${out}" ]]; then
    echo "[skip] ${out}"
    return
  fi

  echo "[run] ${TASK} view=${view} compute=rrog-act mode=${ACT_TRAIN_MODE} T=${T} ns=${N_SUP} seed=${seed} device=${DEVICE}"
  cmd=(
    python3 -m rrog.cli run
    --task "${TASK}"
    --view "${view}"
    --compute rrog-act
    --epochs "${EPOCHS}"
    --hidden "${HIDDEN}"
    --bs "${BS}"
    --T "${T}"
    --n_sup "${N_SUP}"
    --halt_max_steps "${HALT_MAX}"
    --halt_min_steps "${HALT_MIN}"
    --halt_target "${HALT_TARGET}"
    --halt_loss_threshold "${HALT_LOSS_THRESHOLD}"
    --halt_exploration_prob "${HALT_EXPLORATION_PROB}"
    --lam_q "${LAM_Q}"
    --q_warmup_epochs "${Q_WARMUP}"
    --act_train_mode "${ACT_TRAIN_MODE}"
    --eval_every "${EVAL_EVERY}"
    --num_workers "${NUM_WORKERS}"
    --seed "${seed}"
    --device "${DEVICE}"
  )
  if [[ -n "${LR}" ]]; then
    cmd+=(--lr "${LR}")
  fi
  if [[ "$(fmt_float "${EMA}")" != "0" ]]; then
    cmd+=(--ema "${EMA}")
  fi
  if [[ -n "${MAX_TRAIN_BATCHES}" ]]; then
    cmd+=(--max_train_batches "${MAX_TRAIN_BATCHES}")
  fi
  if [[ -n "${MAX_EVAL_BATCHES}" ]]; then
    cmd+=(--max_eval_batches "${MAX_EVAL_BATCHES}")
  fi
  "${cmd[@]}"
}

for seed in ${SEEDS}; do
  for view in ${VIEWS}; do
    run_cell "${view}" "${seed}"
  done
done

if [[ "${COLLECT}" == "1" ]]; then
  python3 -m rrog.cli results --epochs "${EPOCHS}" | tee "summaries/ogb_graphprop_act_${TASK}_e${EPOCHS}.md"
fi