blob: 37eb1a6081c609bfb0d056535e24503cef144cea (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
|
#!/usr/bin/env bash
set -euo pipefail
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
cd "${ROOT_DIR}"
export PYTHONPATH="${ROOT_DIR}:${PYTHONPATH:-}"
TASK="${TASK:-ogbg-molhiv}"
DEVICE="${DEVICE:-cuda:0}"
EPOCHS="${EPOCHS:-100}"
SEEDS="${SEEDS:-${SEED:-0}}"
HIDDEN="${HIDDEN:-128}"
BS="${BS:-128}"
LR="${LR:-}"
EVAL_EVERY="${EVAL_EVERY:-10}"
NUM_WORKERS="${NUM_WORKERS:-0}"
T="${T:-1}"
N_SUP="${N_SUP:-3}"
HALT_MAX="${HALT_MAX:-8}"
HALT_MIN="${HALT_MIN:-2}"
HALT_TARGET="${HALT_TARGET:-loss}"
HALT_LOSS_THRESHOLD="${HALT_LOSS_THRESHOLD:-0.2}"
HALT_EXPLORATION_PROB="${HALT_EXPLORATION_PROB:-0.1}"
LAM_Q="${LAM_Q:-0.1}"
Q_WARMUP="${Q_WARMUP:-0}"
ACT_TRAIN_MODE="${ACT_TRAIN_MODE:-stream}"
EMA="${EMA:-0}"
MAX_TRAIN_BATCHES="${MAX_TRAIN_BATCHES:-}"
MAX_EVAL_BATCHES="${MAX_EVAL_BATCHES:-}"
COLLECT="${COLLECT:-1}"
VIEWS="${VIEWS:-gin gine gcn graphsage gatv2 graphconv transformer pna gen film resgated tag sgc cheb arma mf appnp}"
mkdir -p runs logs summaries
fmt_float() {
python3 - "$1" <<'PY'
import sys
print(f"{float(sys.argv[1]):g}")
PY
}
result_path() {
local view="$1"
local seed="$2"
local target_tag="${HALT_TARGET}"
local loss_tag
local lam_tag
local hex_tag
local ema_tag=""
loss_tag="$(fmt_float "${HALT_LOSS_THRESHOLD}")"
lam_tag="$(fmt_float "${LAM_Q}")"
hex_tag="$(fmt_float "${HALT_EXPLORATION_PROB}")"
if [[ "${HALT_TARGET}" == "loss" ]]; then
target_tag="loss${loss_tag}"
fi
if [[ "$(fmt_float "${EMA}")" != "0" ]]; then
ema_tag="_ema$(fmt_float "${EMA}")"
fi
echo "runs/${TASK}_${view}_rrog-act_T${T}_ns${N_SUP}_${ACT_TRAIN_MODE}_hm${HALT_MAX}_hmin${HALT_MIN}_${target_tag}_lq${lam_tag}_hex${hex_tag}_qw${Q_WARMUP}_h${HIDDEN}_e${EPOCHS}${ema_tag}_s${seed}.json"
}
run_cell() {
local view="$1"
local seed="$2"
local out
out="$(result_path "${view}" "${seed}")"
if [[ -f "${out}" ]]; then
echo "[skip] ${out}"
return
fi
echo "[run] ${TASK} view=${view} compute=rrog-act mode=${ACT_TRAIN_MODE} T=${T} ns=${N_SUP} seed=${seed} device=${DEVICE}"
cmd=(
python3 -m rrog.cli run
--task "${TASK}"
--view "${view}"
--compute rrog-act
--epochs "${EPOCHS}"
--hidden "${HIDDEN}"
--bs "${BS}"
--T "${T}"
--n_sup "${N_SUP}"
--halt_max_steps "${HALT_MAX}"
--halt_min_steps "${HALT_MIN}"
--halt_target "${HALT_TARGET}"
--halt_loss_threshold "${HALT_LOSS_THRESHOLD}"
--halt_exploration_prob "${HALT_EXPLORATION_PROB}"
--lam_q "${LAM_Q}"
--q_warmup_epochs "${Q_WARMUP}"
--act_train_mode "${ACT_TRAIN_MODE}"
--eval_every "${EVAL_EVERY}"
--num_workers "${NUM_WORKERS}"
--seed "${seed}"
--device "${DEVICE}"
)
if [[ -n "${LR}" ]]; then
cmd+=(--lr "${LR}")
fi
if [[ "$(fmt_float "${EMA}")" != "0" ]]; then
cmd+=(--ema "${EMA}")
fi
if [[ -n "${MAX_TRAIN_BATCHES}" ]]; then
cmd+=(--max_train_batches "${MAX_TRAIN_BATCHES}")
fi
if [[ -n "${MAX_EVAL_BATCHES}" ]]; then
cmd+=(--max_eval_batches "${MAX_EVAL_BATCHES}")
fi
"${cmd[@]}"
}
for seed in ${SEEDS}; do
for view in ${VIEWS}; do
run_cell "${view}" "${seed}"
done
done
if [[ "${COLLECT}" == "1" ]]; then
python3 -m rrog.cli results --epochs "${EPOCHS}" | tee "summaries/ogb_graphprop_act_${TASK}_e${EPOCHS}.md"
fi
|