summaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
Diffstat (limited to 'scripts')
-rw-r--r--scripts/analyze_ogb_hiv_log.py312
-rwxr-xr-xscripts/collect_results.sh10
-rwxr-xr-xscripts/run_ogb_act_task.sh119
-rwxr-xr-xscripts/run_ogb_act_two_gpu.sh91
-rwxr-xr-xscripts/run_ogb_mol_all_tasks.sh17
-rwxr-xr-xscripts/run_ogb_mol_task_full.sh56
-rwxr-xr-xscripts/run_smoke.sh19
-rwxr-xr-xscripts/run_two_a6000.sh32
-rwxr-xr-xscripts/run_zinc_cycle56_full.sh54
-rwxr-xr-xscripts/setup_and_run_two_a6000.sh15
-rwxr-xr-xscripts/setup_env.sh35
11 files changed, 760 insertions, 0 deletions
diff --git a/scripts/analyze_ogb_hiv_log.py b/scripts/analyze_ogb_hiv_log.py
new file mode 100644
index 0000000..21253fe
--- /dev/null
+++ b/scripts/analyze_ogb_hiv_log.py
@@ -0,0 +1,312 @@
+#!/usr/bin/env python3
+import argparse
+import csv
+import json
+import math
+import re
+from collections import defaultdict
+from pathlib import Path
+
+
+RUN_RE = re.compile(
+ r"^\[run\]\s+"
+ r"(?P<dataset>\S+)\s+view=(?P<view>\S+)\s+compute=(?P<compute>\S+)\s+"
+ r"T=(?P<T>\d+)\s+ns=(?P<ns>\d+)"
+)
+EP_RE = re.compile(
+ r"^ep(?P<ep>\d+)\s+val_(?P<metric>\w+)=(?P<val>[-+0-9.eE]+).*"
+ r"train_steps=(?P<steps>[-+0-9.eE]+)"
+)
+
+
+def _load_json(path: Path) -> dict | None:
+ try:
+ with path.open() as f:
+ return json.load(f)
+ except (OSError, json.JSONDecodeError):
+ return None
+
+
+def _score(rep: dict, split: str) -> float | None:
+ metric = rep.get("metric")
+ if not metric:
+ return None
+ value = rep.get(split, {}).get(metric)
+ return None if value is None else float(value)
+
+
+def parse_curves(path: Path):
+ curves = defaultdict(list)
+ current = None
+ with path.open(errors="replace") as f:
+ for line in f:
+ line = line.strip()
+ m = RUN_RE.match(line)
+ if m:
+ current = (
+ m.group("dataset"),
+ m.group("view"),
+ m.group("compute"),
+ int(m.group("T")),
+ int(m.group("ns")),
+ )
+ continue
+ m = EP_RE.match(line)
+ if current is not None and m:
+ curves[current].append({
+ "ep": int(m.group("ep")),
+ "metric": m.group("metric"),
+ "val": float(m.group("val")),
+ "train_steps": float(m.group("steps")),
+ })
+ return curves
+
+
+def load_runs(runs_dir: Path, dataset: str):
+ runs = {}
+ for path in sorted(runs_dir.glob(f"{dataset}_*.json")):
+ rep = _load_json(path)
+ if rep is None:
+ continue
+ if rep.get("dataset") != dataset:
+ continue
+ key = (
+ rep.get("view"),
+ rep.get("compute"),
+ int(rep.get("T", -1)),
+ int(rep.get("n_sup", -1)),
+ int(rep.get("seed", -1)),
+ )
+ runs[key] = rep
+ return runs
+
+
+def curve_stats(curve: list[dict]) -> dict:
+ if not curve:
+ return {
+ "curve_best_ep": None,
+ "curve_best_val": None,
+ "curve_final_ep": None,
+ "curve_final_val": None,
+ "final_gap": None,
+ "late_slope": None,
+ "range": None,
+ }
+ best = max(curve, key=lambda x: x["val"])
+ final = curve[-1]
+ late_slope = None
+ if len(curve) >= 2:
+ late_slope = final["val"] - curve[-2]["val"]
+ vals = [x["val"] for x in curve]
+ return {
+ "curve_best_ep": best["ep"],
+ "curve_best_val": best["val"],
+ "curve_final_ep": final["ep"],
+ "curve_final_val": final["val"],
+ "final_gap": best["val"] - final["val"],
+ "late_slope": late_slope,
+ "range": max(vals) - min(vals),
+ }
+
+
+def labels_for(row: dict, *, val_tol: float, test_tol: float, collapse_gap: float) -> str:
+ labels = []
+ if row["val_delta"] >= -val_tol and row["test_delta"] >= -test_tol:
+ labels.append("positive_or_tie")
+ elif row["val_delta"] >= -val_tol and row["test_delta"] < -test_tol:
+ labels.append("val_up_test_down")
+ elif row["val_delta"] < -val_tol and row["test_delta"] >= -test_tol:
+ labels.append("test_up_val_down")
+ else:
+ labels.append("val_and_test_down")
+
+ if row["val_delta"] < -0.02:
+ labels.append("fixed_val_lags_classic")
+ if row["cand_final_gap"] is not None and row["cand_final_gap"] > collapse_gap:
+ labels.append("late_collapse")
+ if (
+ row["cand_curve_best_ep"] is not None
+ and row["cand_curve_final_ep"] is not None
+ and row["cand_curve_best_ep"] >= row["cand_curve_final_ep"] - 10
+ and (row["cand_final_gap"] is None or row["cand_final_gap"] <= 0.01)
+ ):
+ labels.append("maybe_undertrained")
+ return ",".join(labels)
+
+
+def fmt(x, digits=4):
+ if x is None:
+ return ""
+ if isinstance(x, int):
+ return str(x)
+ return f"{x:.{digits}f}"
+
+
+def markdown_table(headers, rows):
+ out = [
+ "| " + " | ".join(headers) + " |",
+ "| " + " | ".join(["---"] * len(headers)) + " |",
+ ]
+ out.extend("| " + " | ".join(row) + " |" for row in rows)
+ return "\n".join(out)
+
+
+def mean(xs):
+ return sum(xs) / len(xs) if xs else 0.0
+
+
+def corr(xs, ys):
+ if len(xs) != len(ys) or len(xs) < 2:
+ return 0.0
+ mx = mean(xs)
+ my = mean(ys)
+ den = math.sqrt(sum((x - mx) ** 2 for x in xs) * sum((y - my) ** 2 for y in ys))
+ if den == 0:
+ return 0.0
+ return sum((x - mx) * (y - my) for x, y in zip(xs, ys)) / den
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument("--dataset", default="ogbg-molhiv")
+ ap.add_argument("--runs-dir", default="runs")
+ ap.add_argument("--log", default="logs/ogbg-molhiv_0.log")
+ ap.add_argument("--out-dir", default="analysis")
+ ap.add_argument("--candidate-compute", default="fixed-rrog")
+ ap.add_argument("--candidate-T", type=int, default=3)
+ ap.add_argument("--candidate-n-sup", type=int, default=3)
+ ap.add_argument("--baseline-compute", default="classic")
+ ap.add_argument("--baseline-T", type=int, default=0)
+ ap.add_argument("--baseline-n-sup", type=int, default=1)
+ ap.add_argument("--seed", type=int, default=0)
+ ap.add_argument("--val-tol", type=float, default=0.0)
+ ap.add_argument("--test-tol", type=float, default=0.0)
+ ap.add_argument("--collapse-gap", type=float, default=0.03)
+ args = ap.parse_args()
+
+ runs_dir = Path(args.runs_dir)
+ curves = parse_curves(Path(args.log))
+ runs = load_runs(runs_dir, args.dataset)
+ out_dir = Path(args.out_dir)
+ out_dir.mkdir(parents=True, exist_ok=True)
+
+ rows = []
+ views = sorted({
+ view for (view, compute, _t, _ns, seed) in runs
+ if seed == args.seed and compute == args.baseline_compute
+ })
+ for view in views:
+ base_key = (view, args.baseline_compute, args.baseline_T, args.baseline_n_sup, args.seed)
+ cand_key = (view, args.candidate_compute, args.candidate_T, args.candidate_n_sup, args.seed)
+ base = runs.get(base_key)
+ cand = runs.get(cand_key)
+ if base is None or cand is None:
+ continue
+ metric = cand["metric"]
+ base_val = _score(base, "val")
+ base_test = _score(base, "test")
+ cand_val = _score(cand, "val")
+ cand_test = _score(cand, "test")
+ if None in {base_val, base_test, cand_val, cand_test}:
+ continue
+ base_curve = curve_stats(curves.get((args.dataset, view, args.baseline_compute, args.baseline_T, args.baseline_n_sup), []))
+ cand_curve = curve_stats(curves.get((args.dataset, view, args.candidate_compute, args.candidate_T, args.candidate_n_sup), []))
+ row = {
+ "view": view,
+ "metric": metric,
+ "base_ep": int(base.get("ep") or 0),
+ "base_val": base_val,
+ "base_test": base_test,
+ "cand_ep": int(cand.get("ep") or 0),
+ "cand_val": cand_val,
+ "cand_test": cand_test,
+ "val_delta": cand_val - base_val,
+ "test_delta": cand_test - base_test,
+ "base_curve_best_ep": base_curve["curve_best_ep"],
+ "base_final_gap": base_curve["final_gap"],
+ "cand_curve_best_ep": cand_curve["curve_best_ep"],
+ "cand_curve_final_ep": cand_curve["curve_final_ep"],
+ "cand_final_gap": cand_curve["final_gap"],
+ "cand_late_slope": cand_curve["late_slope"],
+ "cand_range": cand_curve["range"],
+ }
+ row["labels"] = labels_for(
+ row,
+ val_tol=args.val_tol,
+ test_tol=args.test_tol,
+ collapse_gap=args.collapse_gap,
+ )
+ rows.append(row)
+
+ csv_path = out_dir / f"{args.dataset}_{args.candidate_compute}_T{args.candidate_T}_ns{args.candidate_n_sup}_diagnostics.csv"
+ fieldnames = [
+ "view", "metric", "base_ep", "base_val", "base_test", "cand_ep", "cand_val", "cand_test",
+ "val_delta", "test_delta", "base_curve_best_ep", "base_final_gap",
+ "cand_curve_best_ep", "cand_curve_final_ep", "cand_final_gap", "cand_late_slope",
+ "cand_range", "labels",
+ ]
+ with csv_path.open("w", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
+ writer.writeheader()
+ writer.writerows(rows)
+
+ counts = defaultdict(int)
+ for row in rows:
+ for label in row["labels"].split(","):
+ counts[label] += 1
+
+ pos = [r for r in rows if r["val_delta"] >= 0 and r["test_delta"] >= 0]
+ val_up_test_down = [r for r in rows if r["val_delta"] >= 0 and r["test_delta"] < 0]
+ val_down_test_up = [r for r in rows if r["val_delta"] < 0 and r["test_delta"] >= 0]
+ both_down = [r for r in rows if r["val_delta"] < 0 and r["test_delta"] < 0]
+ late = [r for r in rows if "maybe_undertrained" in r["labels"]]
+ collapse = [r for r in rows if "late_collapse" in r["labels"]]
+ val_deltas = [r["val_delta"] for r in rows]
+ test_deltas = [r["test_delta"] for r in rows]
+
+ def row_line(r):
+ return [
+ r["view"],
+ fmt(r["base_val"]),
+ fmt(r["cand_val"]),
+ fmt(r["val_delta"]),
+ fmt(r["base_test"]),
+ fmt(r["cand_test"]),
+ fmt(r["test_delta"]),
+ str(r["cand_ep"]),
+ fmt(r["cand_final_gap"]),
+ r["labels"],
+ ]
+
+ md_path = csv_path.with_suffix(".md")
+ with md_path.open("w") as f:
+ f.write(f"# {args.dataset} {args.candidate_compute} T={args.candidate_T} ns={args.candidate_n_sup} diagnostics\n\n")
+ f.write(f"- rows: {len(rows)}\n")
+ f.write(f"- val+test positive: {len(pos)}\n")
+ f.write(f"- val up, test down: {len(val_up_test_down)}\n")
+ f.write(f"- val down, test up: {len(val_down_test_up)}\n")
+ f.write(f"- val+test down: {len(both_down)}\n")
+ f.write(f"- maybe undertrained: {len(late)}\n")
+ f.write(f"- late collapse: {len(collapse)}\n\n")
+ f.write(f"- mean val delta: {fmt(mean(val_deltas))}\n")
+ f.write(f"- mean test delta: {fmt(mean(test_deltas))}\n")
+ f.write(f"- val/test delta corr: {fmt(corr(val_deltas, test_deltas))}\n\n")
+ f.write("## Label Counts\n\n")
+ f.write(markdown_table(["label", "n"], [[k, str(v)] for k, v in sorted(counts.items())]))
+ f.write("\n\n## Per-Backbone Rows\n\n")
+ headers = [
+ "view", "base_val", "cand_val", "d_val", "base_test", "cand_test",
+ "d_test", "cand_ep", "cand_final_gap", "labels",
+ ]
+ f.write(markdown_table(headers, [row_line(r) for r in rows]))
+ f.write("\n")
+
+ print(f"wrote {csv_path}")
+ print(f"wrote {md_path}")
+ print(f"rows={len(rows)} val+test={len(pos)} val_up_test_down={len(val_up_test_down)} "
+ f"val_down_test_up={len(val_down_test_up)} both_down={len(both_down)} "
+ f"maybe_undertrained={len(late)} late_collapse={len(collapse)}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/collect_results.sh b/scripts/collect_results.sh
new file mode 100755
index 0000000..360f05b
--- /dev/null
+++ b/scripts/collect_results.sh
@@ -0,0 +1,10 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "${ROOT_DIR}"
+export PYTHONPATH="${ROOT_DIR}:${PYTHONPATH:-}"
+
+mkdir -p summaries
+python3 -m rrog.cli zinc-results --epochs "${ZINC_EPOCHS:-200}" | tee summaries/zinc_cycle56.md
+python3 -m rrog.cli results --epochs "${OGB_EPOCHS:-100}" | tee summaries/ogb_graphprop.md
diff --git a/scripts/run_ogb_act_task.sh b/scripts/run_ogb_act_task.sh
new file mode 100755
index 0000000..37eb1a6
--- /dev/null
+++ b/scripts/run_ogb_act_task.sh
@@ -0,0 +1,119 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "${ROOT_DIR}"
+export PYTHONPATH="${ROOT_DIR}:${PYTHONPATH:-}"
+
+TASK="${TASK:-ogbg-molhiv}"
+DEVICE="${DEVICE:-cuda:0}"
+EPOCHS="${EPOCHS:-100}"
+SEEDS="${SEEDS:-${SEED:-0}}"
+HIDDEN="${HIDDEN:-128}"
+BS="${BS:-128}"
+LR="${LR:-}"
+EVAL_EVERY="${EVAL_EVERY:-10}"
+NUM_WORKERS="${NUM_WORKERS:-0}"
+T="${T:-1}"
+N_SUP="${N_SUP:-3}"
+HALT_MAX="${HALT_MAX:-8}"
+HALT_MIN="${HALT_MIN:-2}"
+HALT_TARGET="${HALT_TARGET:-loss}"
+HALT_LOSS_THRESHOLD="${HALT_LOSS_THRESHOLD:-0.2}"
+HALT_EXPLORATION_PROB="${HALT_EXPLORATION_PROB:-0.1}"
+LAM_Q="${LAM_Q:-0.1}"
+Q_WARMUP="${Q_WARMUP:-0}"
+ACT_TRAIN_MODE="${ACT_TRAIN_MODE:-stream}"
+EMA="${EMA:-0}"
+MAX_TRAIN_BATCHES="${MAX_TRAIN_BATCHES:-}"
+MAX_EVAL_BATCHES="${MAX_EVAL_BATCHES:-}"
+COLLECT="${COLLECT:-1}"
+VIEWS="${VIEWS:-gin gine gcn graphsage gatv2 graphconv transformer pna gen film resgated tag sgc cheb arma mf appnp}"
+
+mkdir -p runs logs summaries
+
+fmt_float() {
+ python3 - "$1" <<'PY'
+import sys
+print(f"{float(sys.argv[1]):g}")
+PY
+}
+
+result_path() {
+ local view="$1"
+ local seed="$2"
+ local target_tag="${HALT_TARGET}"
+ local loss_tag
+ local lam_tag
+ local hex_tag
+ local ema_tag=""
+ loss_tag="$(fmt_float "${HALT_LOSS_THRESHOLD}")"
+ lam_tag="$(fmt_float "${LAM_Q}")"
+ hex_tag="$(fmt_float "${HALT_EXPLORATION_PROB}")"
+ if [[ "${HALT_TARGET}" == "loss" ]]; then
+ target_tag="loss${loss_tag}"
+ fi
+ if [[ "$(fmt_float "${EMA}")" != "0" ]]; then
+ ema_tag="_ema$(fmt_float "${EMA}")"
+ fi
+ echo "runs/${TASK}_${view}_rrog-act_T${T}_ns${N_SUP}_${ACT_TRAIN_MODE}_hm${HALT_MAX}_hmin${HALT_MIN}_${target_tag}_lq${lam_tag}_hex${hex_tag}_qw${Q_WARMUP}_h${HIDDEN}_e${EPOCHS}${ema_tag}_s${seed}.json"
+}
+
+run_cell() {
+ local view="$1"
+ local seed="$2"
+ local out
+ out="$(result_path "${view}" "${seed}")"
+ if [[ -f "${out}" ]]; then
+ echo "[skip] ${out}"
+ return
+ fi
+
+ echo "[run] ${TASK} view=${view} compute=rrog-act mode=${ACT_TRAIN_MODE} T=${T} ns=${N_SUP} seed=${seed} device=${DEVICE}"
+ cmd=(
+ python3 -m rrog.cli run
+ --task "${TASK}"
+ --view "${view}"
+ --compute rrog-act
+ --epochs "${EPOCHS}"
+ --hidden "${HIDDEN}"
+ --bs "${BS}"
+ --T "${T}"
+ --n_sup "${N_SUP}"
+ --halt_max_steps "${HALT_MAX}"
+ --halt_min_steps "${HALT_MIN}"
+ --halt_target "${HALT_TARGET}"
+ --halt_loss_threshold "${HALT_LOSS_THRESHOLD}"
+ --halt_exploration_prob "${HALT_EXPLORATION_PROB}"
+ --lam_q "${LAM_Q}"
+ --q_warmup_epochs "${Q_WARMUP}"
+ --act_train_mode "${ACT_TRAIN_MODE}"
+ --eval_every "${EVAL_EVERY}"
+ --num_workers "${NUM_WORKERS}"
+ --seed "${seed}"
+ --device "${DEVICE}"
+ )
+ if [[ -n "${LR}" ]]; then
+ cmd+=(--lr "${LR}")
+ fi
+ if [[ "$(fmt_float "${EMA}")" != "0" ]]; then
+ cmd+=(--ema "${EMA}")
+ fi
+ if [[ -n "${MAX_TRAIN_BATCHES}" ]]; then
+ cmd+=(--max_train_batches "${MAX_TRAIN_BATCHES}")
+ fi
+ if [[ -n "${MAX_EVAL_BATCHES}" ]]; then
+ cmd+=(--max_eval_batches "${MAX_EVAL_BATCHES}")
+ fi
+ "${cmd[@]}"
+}
+
+for seed in ${SEEDS}; do
+ for view in ${VIEWS}; do
+ run_cell "${view}" "${seed}"
+ done
+done
+
+if [[ "${COLLECT}" == "1" ]]; then
+ python3 -m rrog.cli results --epochs "${EPOCHS}" | tee "summaries/ogb_graphprop_act_${TASK}_e${EPOCHS}.md"
+fi
diff --git a/scripts/run_ogb_act_two_gpu.sh b/scripts/run_ogb_act_two_gpu.sh
new file mode 100755
index 0000000..f51262b
--- /dev/null
+++ b/scripts/run_ogb_act_two_gpu.sh
@@ -0,0 +1,91 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "${ROOT_DIR}"
+export PYTHONPATH="${ROOT_DIR}:${PYTHONPATH:-}"
+
+GPU0="${GPU0:-cuda:0}"
+GPU1="${GPU1:-cuda:1}"
+if [[ -z "${TASKS_GPU0+x}" ]]; then
+ TASKS_GPU0="ogbg-molhiv ogbg-molbbbp ogbg-molsider ogbg-molbace"
+fi
+if [[ -z "${TASKS_GPU1+x}" ]]; then
+ TASKS_GPU1="ogbg-molesol ogbg-mollipo ogbg-moltox21 ogbg-molclintox"
+fi
+EPOCHS="${EPOCHS:-100}"
+SEEDS="${SEEDS:-${SEED:-0}}"
+HALT_MAX="${HALT_MAX:-8}"
+HALT_MIN="${HALT_MIN:-2}"
+HALT_TARGET="${HALT_TARGET:-loss}"
+HALT_LOSS_THRESHOLD="${HALT_LOSS_THRESHOLD:-0.2}"
+HALT_EXPLORATION_PROB="${HALT_EXPLORATION_PROB:-0.1}"
+LAM_Q="${LAM_Q:-0.1}"
+Q_WARMUP="${Q_WARMUP:-0}"
+ACT_TRAIN_MODE="${ACT_TRAIN_MODE:-stream}"
+
+mkdir -p logs summaries
+
+fmt_float() {
+ python3 - "$1" <<'PY'
+import sys
+print(f"{float(sys.argv[1]):g}")
+PY
+}
+
+target_log_tag() {
+ local target_tag="${HALT_TARGET}"
+ if [[ "${HALT_TARGET}" == "loss" ]]; then
+ target_tag="loss$(fmt_float "${HALT_LOSS_THRESHOLD}")"
+ fi
+ echo "${ACT_TRAIN_MODE}_hm${HALT_MAX}_hmin${HALT_MIN}_${target_tag}_lq$(fmt_float "${LAM_Q}")_hex$(fmt_float "${HALT_EXPLORATION_PROB}")_qw${Q_WARMUP}_e${EPOCHS}_s${SEEDS// /-}"
+}
+
+run_queue() {
+ local device="$1"
+ shift
+ local tasks=("$@")
+ local task
+ local tag
+ tag="$(target_log_tag)"
+ for task in "${tasks[@]}"; do
+ if [[ -z "${task}" ]]; then
+ continue
+ fi
+ echo "[task] ${task} on ${device}"
+ TASK="${task}" DEVICE="${device}" EPOCHS="${EPOCHS}" SEEDS="${SEEDS}" \
+ HALT_MAX="${HALT_MAX}" HALT_MIN="${HALT_MIN}" HALT_TARGET="${HALT_TARGET}" \
+ HALT_LOSS_THRESHOLD="${HALT_LOSS_THRESHOLD}" HALT_EXPLORATION_PROB="${HALT_EXPLORATION_PROB}" \
+ LAM_Q="${LAM_Q}" Q_WARMUP="${Q_WARMUP}" \
+ ACT_TRAIN_MODE="${ACT_TRAIN_MODE}" COLLECT=0 \
+ ./scripts/run_ogb_act_task.sh 2>&1 | tee "logs/${task}_act_${tag}.log"
+ done
+}
+
+tasks0=()
+tasks1=()
+if [[ -n "${TASKS_GPU0}" ]]; then
+ read -r -a tasks0 <<< "${TASKS_GPU0}"
+fi
+if [[ -n "${TASKS_GPU1}" ]]; then
+ read -r -a tasks1 <<< "${TASKS_GPU1}"
+fi
+
+pids=()
+if (( ${#tasks0[@]} > 0 )); then
+ echo "[launch] ${GPU0}: ${tasks0[*]}"
+ run_queue "${GPU0}" "${tasks0[@]}" &
+ pids+=("$!")
+fi
+if (( ${#tasks1[@]} > 0 )); then
+ echo "[launch] ${GPU1}: ${tasks1[*]}"
+ run_queue "${GPU1}" "${tasks1[@]}" &
+ pids+=("$!")
+fi
+
+for pid in "${pids[@]}"; do
+ wait "${pid}"
+done
+
+echo "[done] collecting summaries"
+OGB_EPOCHS="${EPOCHS}" ./scripts/collect_results.sh
diff --git a/scripts/run_ogb_mol_all_tasks.sh b/scripts/run_ogb_mol_all_tasks.sh
new file mode 100755
index 0000000..b191d79
--- /dev/null
+++ b/scripts/run_ogb_mol_all_tasks.sh
@@ -0,0 +1,17 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "${ROOT_DIR}"
+
+DEVICE="${DEVICE:-cuda:1}"
+EPOCHS="${EPOCHS:-100}"
+SEED="${SEED:-0}"
+TASKS="${TASKS:-ogbg-molhiv ogbg-molbbbp ogbg-molbace ogbg-moltox21 ogbg-molclintox ogbg-molsider ogbg-molesol ogbg-molfreesolv ogbg-mollipo}"
+
+mkdir -p logs
+for task in ${TASKS}; do
+ echo "[task] ${task}"
+ TASK="${task}" DEVICE="${DEVICE}" EPOCHS="${EPOCHS}" SEED="${SEED}" \
+ ./scripts/run_ogb_mol_task_full.sh 2>&1 | tee "logs/${task}_${SEED}.log"
+done
diff --git a/scripts/run_ogb_mol_task_full.sh b/scripts/run_ogb_mol_task_full.sh
new file mode 100755
index 0000000..71ddba1
--- /dev/null
+++ b/scripts/run_ogb_mol_task_full.sh
@@ -0,0 +1,56 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "${ROOT_DIR}"
+export PYTHONPATH="${ROOT_DIR}:${PYTHONPATH:-}"
+
+TASK="${TASK:-ogbg-molhiv}"
+DEVICE="${DEVICE:-cuda:1}"
+EPOCHS="${EPOCHS:-100}"
+SEED="${SEED:-0}"
+HIDDEN="${HIDDEN:-128}"
+FIXED_T="${FIXED_T:-3}"
+FIXED_NS="${FIXED_NS:-3}"
+VIEWS="${VIEWS:-gin gine gcn graphsage gatv2 graphconv transformer pna gen film resgated tag sgc cheb arma mf appnp}"
+
+mkdir -p runs logs
+
+result_path() {
+ local view="$1"
+ local compute="$2"
+ local t="$3"
+ local ns="$4"
+ echo "runs/${TASK}_${view}_${compute}_T${t}_ns${ns}_h${HIDDEN}_e${EPOCHS}_s${SEED}.json"
+}
+
+run_cell() {
+ local view="$1"
+ local compute="$2"
+ local t="$3"
+ local ns="$4"
+ local out
+ out="$(result_path "${view}" "${compute}" "${t}" "${ns}")"
+ if [[ -f "${out}" ]]; then
+ echo "[skip] ${out}"
+ return
+ fi
+ echo "[run] ${TASK} view=${view} compute=${compute} T=${t} ns=${ns} device=${DEVICE}"
+ python3 -m rrog.cli run \
+ --task "${TASK}" \
+ --view "${view}" \
+ --compute "${compute}" \
+ --epochs "${EPOCHS}" \
+ --hidden "${HIDDEN}" \
+ --T "${t}" \
+ --n_sup "${ns}" \
+ --seed "${SEED}" \
+ --device "${DEVICE}"
+}
+
+for view in ${VIEWS}; do
+ run_cell "${view}" classic 0 1
+ run_cell "${view}" fixed-rrog "${FIXED_T}" "${FIXED_NS}"
+done
+
+python3 -m rrog.cli results --epochs "${EPOCHS}"
diff --git a/scripts/run_smoke.sh b/scripts/run_smoke.sh
new file mode 100755
index 0000000..6365cec
--- /dev/null
+++ b/scripts/run_smoke.sh
@@ -0,0 +1,19 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "${ROOT_DIR}"
+export PYTHONPATH="${ROOT_DIR}:${PYTHONPATH:-}"
+
+DEVICE="${DEVICE:-cuda:0}"
+mkdir -p runs logs
+
+python3 -m rrog.cli run \
+ --task ogbg-molhiv --view gin --compute classic \
+ --epochs 1 --hidden 32 --bs 64 --seed 991 --device "${DEVICE}" \
+ --max_train_batches 2 --max_eval_batches 2
+
+python3 -m rrog.cli run \
+ --task ogbg-molhiv --view gin --compute fixed-rrog \
+ --epochs 1 --hidden 32 --bs 64 --T 1 --n_sup 2 --seed 992 --device "${DEVICE}" \
+ --max_train_batches 2 --max_eval_batches 2
diff --git a/scripts/run_two_a6000.sh b/scripts/run_two_a6000.sh
new file mode 100755
index 0000000..8d9851f
--- /dev/null
+++ b/scripts/run_two_a6000.sh
@@ -0,0 +1,32 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "${ROOT_DIR}"
+export PYTHONPATH="${ROOT_DIR}:${PYTHONPATH:-}"
+
+ZINC_DEVICE="${ZINC_DEVICE:-cuda:0}"
+OGB_DEVICE="${OGB_DEVICE:-cuda:1}"
+OGB_TASK="${OGB_TASK:-ogbg-molhiv}"
+ZINC_EPOCHS="${ZINC_EPOCHS:-200}"
+OGB_EPOCHS="${OGB_EPOCHS:-100}"
+SEED="${SEED:-0}"
+
+mkdir -p runs logs
+
+echo "[launch] ZINC-cycle56 on ${ZINC_DEVICE}"
+DEVICE="${ZINC_DEVICE}" EPOCHS="${ZINC_EPOCHS}" SEED="${SEED}" \
+ ./scripts/run_zinc_cycle56_full.sh > "logs/zinc_cycle56_${SEED}.log" 2>&1 &
+zinc_pid=$!
+
+echo "[launch] ${OGB_TASK} on ${OGB_DEVICE}"
+TASK="${OGB_TASK}" DEVICE="${OGB_DEVICE}" EPOCHS="${OGB_EPOCHS}" SEED="${SEED}" \
+ ./scripts/run_ogb_mol_task_full.sh > "logs/${OGB_TASK}_${SEED}.log" 2>&1 &
+ogb_pid=$!
+
+echo "[pids] zinc=${zinc_pid} ogb=${ogb_pid}"
+wait "${zinc_pid}"
+wait "${ogb_pid}"
+
+echo "[done] collecting summaries"
+./scripts/collect_results.sh
diff --git a/scripts/run_zinc_cycle56_full.sh b/scripts/run_zinc_cycle56_full.sh
new file mode 100755
index 0000000..151a51e
--- /dev/null
+++ b/scripts/run_zinc_cycle56_full.sh
@@ -0,0 +1,54 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "${ROOT_DIR}"
+export PYTHONPATH="${ROOT_DIR}:${PYTHONPATH:-}"
+
+DEVICE="${DEVICE:-cuda:0}"
+EPOCHS="${EPOCHS:-200}"
+SEED="${SEED:-0}"
+VIEWS="${VIEWS:-gin gine gcn graphsage gatv2 graphconv transformer pna gen film resgated tag sgc cheb arma mf appnp}"
+
+mkdir -p runs logs
+
+result_path() {
+ local view="$1"
+ local t="$2"
+ local ns="$3"
+ local view_tag=""
+ if [[ "${view}" != "gin" ]]; then
+ view_tag="_${view}"
+ fi
+ echo "runs/rec_rrog${view_tag}_full_sig0.0_K1_none_T${t}_ns${ns}_trace_s${SEED}.json"
+}
+
+run_cell() {
+ local view="$1"
+ local compute="$2"
+ local t="$3"
+ local ns="$4"
+ local out
+ out="$(result_path "${view}" "${t}" "${ns}")"
+ if [[ -f "${out}" ]]; then
+ echo "[skip] ${out}"
+ return
+ fi
+ echo "[run] zinc-cycle56 view=${view} compute=${compute} T=${t} ns=${ns} device=${DEVICE}"
+ python3 -m rrog.cli run \
+ --task zinc-cycle56 \
+ --view "${view}" \
+ --compute "${compute}" \
+ --epochs "${EPOCHS}" \
+ --T "${t}" \
+ --n_sup "${ns}" \
+ --seed "${SEED}" \
+ --device "${DEVICE}"
+}
+
+for view in ${VIEWS}; do
+ run_cell "${view}" classic 0 1
+ run_cell "${view}" fixed-rrog 1 3
+done
+
+python3 -m rrog.cli zinc-results --epochs "${EPOCHS}"
diff --git a/scripts/setup_and_run_two_a6000.sh b/scripts/setup_and_run_two_a6000.sh
new file mode 100755
index 0000000..ec4e3da
--- /dev/null
+++ b/scripts/setup_and_run_two_a6000.sh
@@ -0,0 +1,15 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "${ROOT_DIR}"
+
+if [[ "${SKIP_SETUP:-0}" != "1" ]]; then
+ ./scripts/setup_env.sh
+fi
+
+if [[ -d "${VENV_DIR:-.venv}" ]]; then
+ source "${VENV_DIR:-.venv}/bin/activate"
+fi
+
+./scripts/run_two_a6000.sh
diff --git a/scripts/setup_env.sh b/scripts/setup_env.sh
new file mode 100755
index 0000000..66a94c8
--- /dev/null
+++ b/scripts/setup_env.sh
@@ -0,0 +1,35 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "${ROOT_DIR}"
+
+PYTHON_BIN="${PYTHON_BIN:-python3}"
+VENV_DIR="${VENV_DIR:-.venv}"
+TORCH_INDEX_URL="${TORCH_INDEX_URL:-https://download.pytorch.org/whl/cu124}"
+
+if [[ ! -d "${VENV_DIR}" ]]; then
+ "${PYTHON_BIN}" -m venv "${VENV_DIR}"
+fi
+
+source "${VENV_DIR}/bin/activate"
+python -m pip install --upgrade pip wheel setuptools
+
+if ! python - <<'PY' >/dev/null 2>&1
+import torch
+assert torch.cuda.is_available() or True
+PY
+then
+ python -m pip install torch --index-url "${TORCH_INDEX_URL}"
+fi
+
+python -m pip install -r requirements.txt
+
+python - <<'PY'
+import torch
+import torch_geometric
+import ogb
+print("torch", torch.__version__, "cuda_available", torch.cuda.is_available())
+print("torch_geometric", torch_geometric.__version__)
+print("ogb", ogb.__version__)
+PY