summaryrefslogtreecommitdiff
path: root/rrog
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-06-21 15:33:22 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-06-21 15:33:22 -0500
commite42f575050efeeccb736385b43bed84e1129edb0 (patch)
tree8ed04b42218cf4c90c0b9c29b40db149f1355f4a /rrog
Initial RRoG GNN runner
Diffstat (limited to 'rrog')
-rw-r--r--rrog/__init__.py2
-rw-r--r--rrog/backbones.py72
-rw-r--r--rrog/benchmarks.py44
-rw-r--r--rrog/cli.py176
-rw-r--r--rrog/collect_results.py239
-rw-r--r--rrog/collect_zinc.py137
-rw-r--r--rrog/registry.py57
-rwxr-xr-xrrog/run_ogb_hiv_remaining.sh49
-rwxr-xr-xrrog/run_zinc_gine.sh36
-rwxr-xr-xrrog/run_zinc_gine_after_pid.sh14
-rw-r--r--rrog/runspecs.py188
-rw-r--r--rrog/train_ogb_graphprop.py685
12 files changed, 1699 insertions, 0 deletions
diff --git a/rrog/__init__.py b/rrog/__init__.py
new file mode 100644
index 0000000..c09281d
--- /dev/null
+++ b/rrog/__init__.py
@@ -0,0 +1,2 @@
+"""Experiment registry for RRoG/TRM-on-GNN sweeps."""
+
diff --git a/rrog/backbones.py b/rrog/backbones.py
new file mode 100644
index 0000000..81bfb9b
--- /dev/null
+++ b/rrog/backbones.py
@@ -0,0 +1,72 @@
+from rrog.registry import ComputeSpec, ModifierSpec, ViewSpec, by_name
+
+
+VIEWS = [
+ ViewSpec("gin", "message-passing", "2d", 1, "implemented",
+ "Plain GINConv message passing."),
+ ViewSpec("gine", "message-passing", "2d", 2, "implemented",
+ "Edge-aware GIN variant. ZINC uses a learned constant edge token; OGB uses bond features."),
+ ViewSpec("gcn", "message-passing", "2d", 3, "implemented"),
+ ViewSpec("graphsage", "message-passing", "2d", 4, "implemented"),
+ ViewSpec("gatv2", "attention-mpnn", "2d", 5, "implemented"),
+ ViewSpec("graphconv", "message-passing", "2d", 6, "implemented"),
+ ViewSpec("transformer", "attention-mpnn", "2d", 7, "implemented"),
+ ViewSpec("pna", "message-passing", "2d", 8, "implemented",
+ "Requires degree histogram from the train split."),
+ ViewSpec("gen", "message-passing", "2d", 9, "implemented"),
+ ViewSpec("film", "message-passing", "2d", 10, "implemented"),
+ ViewSpec("resgated", "message-passing", "2d", 11, "implemented"),
+ ViewSpec("tag", "higher-order-hop", "2d", 12, "implemented"),
+ ViewSpec("sgc", "propagation", "2d", 13, "implemented"),
+ ViewSpec("cheb", "spectral", "2d", 14, "implemented"),
+ ViewSpec("arma", "spectral", "2d", 15, "implemented"),
+ ViewSpec("mf", "message-passing", "2d", 16, "implemented"),
+ ViewSpec("appnp", "propagation", "2d", 17, "implemented"),
+ ViewSpec("mixhop", "higher-order-hop", "2d", 18),
+ ViewSpec("gps", "hybrid-local-global", "2d", 19),
+ ViewSpec("graphormer", "global-attention", "2d", 20),
+ ViewSpec("san", "spectral-attention", "2d", 21),
+ ViewSpec("mpnn", "message-passing", "2d", 22),
+ ViewSpec("schnet", "continuous-filter", "3d", 23),
+ ViewSpec("dimenetpp", "angle-aware", "3d", 24),
+ ViewSpec("painn", "equivariant", "3d", 25),
+ ViewSpec("gemnet", "equivariant", "3d", 26),
+ ViewSpec("egnn", "equivariant", "3d", 27),
+ ViewSpec("equiformer", "equivariant", "3d", 28),
+ ViewSpec("mace", "equivariant", "3d", 29),
+]
+
+
+COMPUTES = [
+ ComputeSpec("classic", "baseline", 1, "implemented", "Standard one-forward GNN baseline; no RRoG compute."),
+ ComputeSpec("view-only", "none", 2, "implemented", "RRoG view module only; no recursive compute."),
+ ComputeSpec("fixed-rrog", "recursive", 3, "implemented", "Fixed-depth edge-free y/z compute."),
+ ComputeSpec("rrog-act", "recursive-act", 4, "implemented", "Persistent full ACT recycling for graph batches."),
+ ComputeSpec("node-mlp", "recursive", 4),
+ ComputeSpec("gru-rrog", "recursive-gated", 5),
+ ComputeSpec("set-attn-core", "edge-free-attention", 6),
+ ComputeSpec("perceiver-core", "latent-attention", 7),
+ ComputeSpec("global-token-mixer", "token-mixer", 8),
+ ComputeSpec("equivariant-core", "3d-equivariant", 9),
+]
+
+
+MODIFIERS = [
+ ModifierSpec("none", "none", 1, "implemented"),
+ ModifierSpec("dfa-gnn", "backward", 2, "planned",
+ "Non-BP/direct-feedback-style training; start on node classification."),
+ ModifierSpec("kaft", "backward", 3, "planned",
+ "User project in ../graph-grape; low priority until main table is stable."),
+ ModifierSpec("deep-supervision", "training", 4),
+ ModifierSpec("sam", "optimizer", 5),
+ ModifierSpec("lap-pe", "feature", 6),
+ ModifierSpec("rwse", "feature", 7),
+ ModifierSpec("virtual-node", "feature", 8),
+ ModifierSpec("dropedge", "regularization", 9),
+ ModifierSpec("flag", "augmentation", 10),
+]
+
+
+VIEW_BY_NAME = by_name(VIEWS)
+COMPUTE_BY_NAME = by_name(COMPUTES)
+MODIFIER_BY_NAME = by_name(MODIFIERS)
diff --git a/rrog/benchmarks.py b/rrog/benchmarks.py
new file mode 100644
index 0000000..dcd356b
--- /dev/null
+++ b/rrog/benchmarks.py
@@ -0,0 +1,44 @@
+from rrog.registry import BenchmarkSpec, by_name
+
+
+BENCHMARKS = [
+ BenchmarkSpec("zinc-cycle56", "A", "molecule-2d", "graph-regression", "raw-mae", 1,
+ "implemented", "ZINC subset with #5/#6 cycle-count targets."),
+ BenchmarkSpec("zinc", "A", "molecule-2d", "graph-regression", "mae", 2),
+ BenchmarkSpec("ogbg-molhiv", "A", "molecule-2d", "graph-classification", "rocauc", 3,
+ "implemented", "OGB graph property prediction."),
+ BenchmarkSpec("ogbg-molpcba", "A", "molecule-2d", "graph-multilabel", "ap", 4,
+ "implemented", "OGB graph property prediction."),
+ BenchmarkSpec("ogbg-molbbbp", "A", "molecule-2d", "graph-classification", "rocauc", 5,
+ "implemented", "OGB graph property prediction."),
+ BenchmarkSpec("ogbg-molbace", "A", "molecule-2d", "graph-classification", "rocauc", 6,
+ "implemented", "OGB graph property prediction."),
+ BenchmarkSpec("ogbg-moltox21", "A", "molecule-2d", "graph-multilabel", "rocauc", 7,
+ "implemented", "OGB graph property prediction."),
+ BenchmarkSpec("ogbg-molclintox", "A", "molecule-2d", "graph-multilabel", "rocauc", 8,
+ "implemented", "OGB graph property prediction."),
+ BenchmarkSpec("ogbg-molsider", "A", "molecule-2d", "graph-multilabel", "rocauc", 9,
+ "implemented", "OGB graph property prediction."),
+ BenchmarkSpec("ogbg-molesol", "A", "molecule-2d", "graph-regression", "rmse", 10,
+ "implemented", "OGB ESOL graph property prediction."),
+ BenchmarkSpec("ogbg-molfreesolv", "A", "molecule-2d", "graph-regression", "rmse", 11,
+ "implemented", "OGB FreeSolv graph property prediction."),
+ BenchmarkSpec("ogbg-mollipo", "A", "molecule-2d", "graph-regression", "rmse", 12,
+ "implemented", "OGB Lipophilicity graph property prediction."),
+ BenchmarkSpec("pcqm4mv2", "A", "molecule-2d", "graph-regression", "mae", 13),
+ BenchmarkSpec("qm9", "A", "molecule-3d", "graph-regression", "mae", 14),
+ BenchmarkSpec("peptides-func", "B", "long-range", "graph-multilabel", "ap", 15),
+ BenchmarkSpec("peptides-struct", "B", "long-range", "graph-regression", "mae", 16),
+ BenchmarkSpec("pcqm-contact", "B", "long-range", "link-prediction", "mrr", 17),
+ BenchmarkSpec("pascalvoc-sp", "B", "superpixel", "node-classification", "f1", 18),
+ BenchmarkSpec("coco-sp", "B", "superpixel", "node-classification", "f1", 19),
+ BenchmarkSpec("ogbn-arxiv", "B", "citation", "node-classification", "accuracy", 20),
+ BenchmarkSpec("ogbn-products", "B", "commerce", "node-classification", "accuracy", 21),
+ BenchmarkSpec("rmd17", "C", "molecule-3d", "energy-force", "mae", 22),
+ BenchmarkSpec("oc20-s2ef", "C", "catalyst-3d", "energy-force", "mae", 23),
+ BenchmarkSpec("oc22", "C", "catalyst-3d", "energy-force", "mae", 24),
+ BenchmarkSpec("matbench-discovery", "C", "materials", "stability", "discovery-metrics", 25),
+ BenchmarkSpec("tgb-subset", "C", "temporal", "temporal-link-node", "dataset-specific", 26),
+]
+
+BENCHMARK_BY_NAME = by_name(BENCHMARKS)
diff --git a/rrog/cli.py b/rrog/cli.py
new file mode 100644
index 0000000..9c826e3
--- /dev/null
+++ b/rrog/cli.py
@@ -0,0 +1,176 @@
+import argparse
+import json
+import os
+import subprocess
+
+from rrog.backbones import COMPUTES, MODIFIERS, VIEWS
+from rrog.benchmarks import BENCHMARKS
+from rrog.collect_results import print_tables
+from rrog.collect_zinc import print_zinc
+from rrog.runspecs import find_run_spec
+
+
+def _rows(items):
+ return [
+ {
+ "name": item.name,
+ "tier": getattr(item, "tier", None),
+ "family": getattr(item, "family", None),
+ "domain": getattr(item, "domain", None),
+ "task_type": getattr(item, "task_type", None),
+ "metric": getattr(item, "metric", None),
+ "priority": item.priority,
+ "status": item.status,
+ "notes": item.notes,
+ }
+ for item in items
+ ]
+
+
+def _print_table(items):
+ for row in _rows(items):
+ cols = [row["name"], row.get("tier") or row.get("family") or "", row["status"], row["notes"]]
+ print("\t".join(str(c) for c in cols))
+
+
+def list_axis(axis: str, as_json: bool):
+ mapping = {
+ "benchmarks": BENCHMARKS,
+ "views": VIEWS,
+ "computes": COMPUTES,
+ "modifiers": MODIFIERS,
+ }
+ items = mapping[axis]
+ if as_json:
+ print(json.dumps(_rows(items), indent=2))
+ else:
+ _print_table(items)
+
+
+def build_command(args) -> list[str]:
+ spec = find_run_spec(args.task, args.view, args.compute, args.modifier)
+ run_args = dict(spec.default_args)
+ for key in [
+ "epochs", "hidden", "bs", "seed", "T", "n_sup", "halt_max_steps", "halt_target",
+ "halt_min_steps", "halt_loss_threshold", "q_warmup_epochs",
+ "eval_every", "max_train_batches", "max_eval_batches", "num_workers", "ema", "lr", "lam_q",
+ "device",
+ ]:
+ value = getattr(args, key)
+ if value is not None:
+ run_args[key] = value
+ run_args["compute"] = args.compute
+ return spec.command_builder(run_args)
+
+
+def print_matrix(args):
+ tasks = [b for b in BENCHMARKS if args.tier == "all" or b.tier == args.tier]
+ tasks = sorted(tasks, key=lambda x: x.priority)[:args.limit_tasks]
+
+ if args.kind == "main":
+ views = sorted(VIEWS, key=lambda x: x.priority)[:args.limit_views]
+ computes = [c for c in COMPUTES if c.name in ["classic", "view-only", "fixed-rrog", "rrog-act"]]
+ for task in tasks:
+ for view in views:
+ for compute in computes:
+ try:
+ find_run_spec(task.name, view.name, compute.name)
+ status = "implemented"
+ except KeyError:
+ status = "planned"
+ print(f"{task.name}\t{view.name}\t{compute.name}\tnone\t{status}")
+ return
+
+ if args.kind == "modifier":
+ task_names = {"zinc-cycle56", "ogbg-molhiv", "peptides-struct", "peptides-func", "ogbn-arxiv", "qm9"}
+ mods = [m for m in sorted(MODIFIERS, key=lambda x: x.priority) if m.name != "none"]
+ for task in tasks:
+ if task.name not in task_names:
+ continue
+ for mod in mods:
+ print(f"{task.name}\tgin\tview-only\t{mod.name}\tplanned")
+ return
+
+ raise ValueError(args.kind)
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ sub = ap.add_subparsers(dest="cmd", required=True)
+
+ lp = sub.add_parser("list")
+ lp.add_argument("axis", choices=["benchmarks", "views", "computes", "modifiers"])
+ lp.add_argument("--json", action="store_true")
+
+ rp = sub.add_parser("run")
+ rp.add_argument("--task", default="zinc-cycle56")
+ rp.add_argument("--view", default="gin")
+ rp.add_argument("--compute", default="rrog-act")
+ rp.add_argument("--modifier", default="none")
+ rp.add_argument("--epochs", type=int)
+ rp.add_argument("--hidden", type=int)
+ rp.add_argument("--bs", type=int)
+ rp.add_argument("--seed", type=int)
+ rp.add_argument("--T", type=int)
+ rp.add_argument("--n_sup", type=int)
+ rp.add_argument("--halt_max_steps", type=int)
+ rp.add_argument("--halt_target", choices=["soft", "binary", "exact", "loss"])
+ rp.add_argument("--halt_min_steps", type=int)
+ rp.add_argument("--halt_loss_threshold", type=float)
+ rp.add_argument("--q_warmup_epochs", type=int)
+ rp.add_argument("--eval_every", type=int)
+ rp.add_argument("--max_train_batches", type=int)
+ rp.add_argument("--max_eval_batches", type=int)
+ rp.add_argument("--num_workers", type=int)
+ rp.add_argument("--ema", type=float)
+ rp.add_argument("--lr", type=float)
+ rp.add_argument("--lam_q", type=float)
+ rp.add_argument("--device")
+ rp.add_argument("--dry-run", action="store_true")
+
+ mp = sub.add_parser("matrix")
+ mp.add_argument("--kind", choices=["main", "modifier"], default="main")
+ mp.add_argument("--tier", choices=["A", "B", "C", "all"], default="A")
+ mp.add_argument("--limit-tasks", type=int, default=20)
+ mp.add_argument("--limit-views", type=int, default=6)
+
+ op = sub.add_parser("results")
+ op.add_argument("--runs-dir", default="runs")
+ op.add_argument("--glob", default="*.json")
+ op.add_argument("--min-epochs", type=int, default=10)
+ op.add_argument("--epochs", type=int)
+ op.add_argument("--digits", type=int, default=4)
+
+ zp = sub.add_parser("zinc-results")
+ zp.add_argument("--runs-dir", default="runs")
+ zp.add_argument("--epochs", type=int)
+ zp.add_argument("--min-epochs", type=int, default=10)
+ zp.add_argument("--digits", type=int, default=4)
+
+ args = ap.parse_args()
+ if args.cmd == "list":
+ list_axis(args.axis, args.json)
+ return
+ if args.cmd == "matrix":
+ print_matrix(args)
+ return
+ if args.cmd == "results":
+ print_tables(args)
+ return
+ if args.cmd == "zinc-results":
+ print_zinc(args)
+ return
+
+ cmd = build_command(args)
+ env = dict(os.environ)
+ env["PYTHONPATH"] = os.getcwd() + os.pathsep + env.get("PYTHONPATH", "")
+ print(" ".join(cmd), flush=True)
+ if not args.dry_run:
+ raise SystemExit(subprocess.call(cmd, env=env))
+
+
+if __name__ == "__main__":
+ try:
+ main()
+ except BrokenPipeError:
+ raise SystemExit(0)
diff --git a/rrog/collect_results.py b/rrog/collect_results.py
new file mode 100644
index 0000000..125c0ab
--- /dev/null
+++ b/rrog/collect_results.py
@@ -0,0 +1,239 @@
+import argparse
+import json
+import math
+from collections import defaultdict
+from pathlib import Path
+
+
+HIGHER_BETTER = {
+ "accuracy",
+ "ap",
+ "auc",
+ "f1",
+ "mrr",
+ "rocauc",
+}
+
+LOWER_BETTER = {
+ "mae",
+ "raw-mae",
+ "rmse",
+}
+
+
+def _metric_direction(metric: str) -> int:
+ metric = metric.lower()
+ if metric in HIGHER_BETTER:
+ return 1
+ if metric in LOWER_BETTER:
+ return -1
+ if "mae" in metric or "rmse" in metric:
+ return -1
+ return 1
+
+
+def _read(path: Path) -> dict | None:
+ try:
+ with path.open() as f:
+ rep = json.load(f)
+ except (OSError, json.JSONDecodeError):
+ return None
+ required = {"dataset", "view", "compute", "seed", "metric", "val"}
+ if not required.issubset(rep):
+ return None
+ return rep
+
+
+def _score(rep: dict, split: str) -> float | None:
+ metric = rep.get("metric")
+ if not metric:
+ return None
+ value = rep.get(split, {}).get(metric)
+ if value is None:
+ return None
+ return float(value)
+
+
+def _rank_key(rep: dict) -> tuple[int, int, int]:
+ return (
+ int(rep.get("epochs", 0)),
+ int(rep.get("hidden", 0)),
+ int(rep.get("ep", 0) or 0),
+ )
+
+
+def _mean(xs: list[float]) -> float:
+ return sum(xs) / len(xs)
+
+
+def _std(xs: list[float]) -> float:
+ if len(xs) < 2:
+ return 0.0
+ mu = _mean(xs)
+ return math.sqrt(sum((x - mu) ** 2 for x in xs) / (len(xs) - 1))
+
+
+def _fmt(value: float | None, digits: int) -> str:
+ if value is None:
+ return ""
+ return f"{value:.{digits}f}"
+
+
+def _is_classic_baseline(rep: dict) -> bool:
+ return rep.get("compute") == "classic" and int(rep.get("T", -1)) == 0 and int(rep.get("n_sup", -1)) == 1
+
+
+def _compute_label(rep: dict) -> str:
+ label = str(rep["compute"])
+ ema = float(rep.get("ema", 0.0) or 0.0)
+ if ema > 0:
+ label += f"+ema{ema:g}"
+ return label
+
+
+def _choose_runs(paths: list[Path], min_epochs: int, epochs: int | None) -> list[dict]:
+ candidates: dict[tuple[str, str, str, int], dict] = {}
+ for path in paths:
+ rep = _read(path)
+ if rep is None:
+ continue
+ if epochs is not None and int(rep.get("epochs", -1)) != epochs:
+ continue
+ if int(rep.get("epochs", 0)) < min_epochs:
+ continue
+ val = _score(rep, "val")
+ test = _score(rep, "test")
+ if val is None or test is None:
+ continue
+ key = (str(rep["dataset"]), str(rep["view"]), _compute_label(rep), int(rep["seed"]))
+ old = candidates.get(key)
+ if old is None or _rank_key(rep) > _rank_key(old):
+ candidates[key] = rep
+ return list(candidates.values())
+
+
+def _group_by_cell(runs: list[dict]) -> dict[tuple[str, str, str], list[dict]]:
+ grouped: dict[tuple[str, str, str], list[dict]] = defaultdict(list)
+ for rep in runs:
+ grouped[(str(rep["dataset"]), str(rep["view"]), _compute_label(rep))].append(rep)
+ return dict(grouped)
+
+
+def _summarize_cell(reps: list[dict], split: str) -> tuple[float, float, int]:
+ scores = [_score(rep, split) for rep in reps]
+ xs = [x for x in scores if x is not None]
+ return _mean(xs), _std(xs), len(xs)
+
+
+def _markdown_table(headers: list[str], rows: list[list[str]]) -> str:
+ out = ["| " + " | ".join(headers) + " |", "| " + " | ".join(["---"] * len(headers)) + " |"]
+ out.extend("| " + " | ".join(row) + " |" for row in rows)
+ return "\n".join(out)
+
+
+def print_tables(args) -> None:
+ paths = sorted(Path(args.runs_dir).glob(args.glob))
+ runs = _choose_runs(paths, args.min_epochs, args.epochs)
+ grouped = _group_by_cell(runs)
+
+ classic_cells = {
+ (task, view): reps
+ for (task, view, compute), reps in grouped.items()
+ if compute == "classic"
+ for reps in [[rep for rep in reps if _is_classic_baseline(rep)]]
+ if reps
+ }
+
+ baseline_rows = []
+ for (task, view), reps in sorted(classic_cells.items()):
+ metric = str(reps[0]["metric"])
+ val_mu, val_sd, n = _summarize_cell(reps, "val")
+ test_mu, test_sd, _ = _summarize_cell(reps, "test")
+ baseline_rows.append([
+ task,
+ view,
+ metric,
+ str(n),
+ f"{_fmt(val_mu, args.digits)} +/- {_fmt(val_sd, args.digits)}",
+ f"{_fmt(test_mu, args.digits)} +/- {_fmt(test_sd, args.digits)}",
+ ])
+
+ delta_rows = []
+ for (task, view, compute), reps in sorted(grouped.items()):
+ if compute == "classic":
+ continue
+ base = classic_cells.get((task, view))
+ if not base:
+ continue
+ metric = str(reps[0]["metric"])
+ direction = _metric_direction(metric)
+ base_by_seed = {int(rep["seed"]): rep for rep in base}
+ paired = []
+ for rep in reps:
+ seed = int(rep["seed"])
+ if seed in base_by_seed:
+ paired.append((rep, base_by_seed[seed]))
+ if not paired:
+ base_test_mu, _, _ = _summarize_cell(base, "test")
+ base_val_mu, _, _ = _summarize_cell(base, "val")
+ paired = [(rep, None) for rep in reps]
+ else:
+ base_test_mu = None
+ base_val_mu = None
+
+ val_scores, test_scores, val_deltas, test_deltas = [], [], [], []
+ adaptive_steps = []
+ for rep, base_rep in paired:
+ val = _score(rep, "val")
+ test = _score(rep, "test")
+ if val is None or test is None:
+ continue
+ if base_rep is None:
+ base_val = base_val_mu
+ base_test = base_test_mu
+ else:
+ base_val = _score(base_rep, "val")
+ base_test = _score(base_rep, "test")
+ if base_val is None or base_test is None:
+ continue
+ val_scores.append(val)
+ test_scores.append(test)
+ val_deltas.append(direction * (val - base_val))
+ test_deltas.append(direction * (test - base_test))
+ if rep.get("adaptive_steps") is not None:
+ adaptive_steps.append(float(rep["adaptive_steps"]))
+
+ if not test_scores:
+ continue
+ delta_rows.append([
+ task,
+ view,
+ compute,
+ metric,
+ str(len(test_scores)),
+ f"{_fmt(_mean(val_scores), args.digits)} ({_fmt(_mean(val_deltas), args.digits)})",
+ f"{_fmt(_mean(test_scores), args.digits)} ({_fmt(_mean(test_deltas), args.digits)})",
+ _fmt(_mean(adaptive_steps), 2) if adaptive_steps else "",
+ ])
+
+ print("\nClassic baseline: task x backbone")
+ print(_markdown_table(["task", "backbone", "metric", "n", "val", "test"], baseline_rows))
+ print("\nDelta vs matching classic")
+ print(_markdown_table([
+ "task", "backbone", "compute", "metric", "n", "val score (delta)", "test score (delta)", "steps"
+ ], delta_rows))
+
+
+def main() -> None:
+ ap = argparse.ArgumentParser()
+ ap.add_argument("--runs-dir", default="runs")
+ ap.add_argument("--glob", default="*.json")
+ ap.add_argument("--min-epochs", type=int, default=10)
+ ap.add_argument("--epochs", type=int)
+ ap.add_argument("--digits", type=int, default=4)
+ args = ap.parse_args()
+ print_tables(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/rrog/collect_zinc.py b/rrog/collect_zinc.py
new file mode 100644
index 0000000..49ae1e9
--- /dev/null
+++ b/rrog/collect_zinc.py
@@ -0,0 +1,137 @@
+import argparse
+import json
+import math
+from collections import defaultdict
+from pathlib import Path
+
+
+def _read(path: Path) -> dict | None:
+ try:
+ with path.open() as f:
+ rep = json.load(f)
+ except (OSError, json.JSONDecodeError):
+ return None
+ if rep.get("dataset") != "ZINC-cycle56":
+ return None
+ if rep.get("K") != 1 or rep.get("select") != "none" or float(rep.get("sigma", 0.0)) != 0.0:
+ return None
+ if "test_mae" not in rep or "val_mae" not in rep:
+ return None
+ return rep
+
+
+def _compute(rep: dict) -> str:
+ if rep.get("act"):
+ return f"rrog-act-T{rep.get('T')}-ns{rep.get('n_sup')}"
+ if int(rep.get("T", -1)) == 0 and int(rep.get("n_sup", -1)) == 1:
+ return "classic"
+ label = f"fixed-rrog-T{rep.get('T')}-ns{rep.get('n_sup')}"
+ if rep.get("loss_mode") == "trace":
+ label += "+trace"
+ return label
+
+
+def _view(rep: dict) -> str:
+ return str(rep.get("view", "gin"))
+
+
+def _score(rep: dict, split: str) -> float:
+ return float(sum(rep[f"{split}_mae"]))
+
+
+def _mean(xs: list[float]) -> float:
+ return sum(xs) / len(xs)
+
+
+def _std(xs: list[float]) -> float:
+ if len(xs) < 2:
+ return 0.0
+ mu = _mean(xs)
+ return math.sqrt(sum((x - mu) ** 2 for x in xs) / (len(xs) - 1))
+
+
+def _fmt(x: float, digits: int) -> str:
+ return f"{x:.{digits}f}"
+
+
+def _markdown(headers: list[str], rows: list[list[str]]) -> str:
+ lines = ["| " + " | ".join(headers) + " |", "| " + " | ".join(["---"] * len(headers)) + " |"]
+ lines.extend("| " + " | ".join(row) + " |" for row in rows)
+ return "\n".join(lines)
+
+
+def print_zinc(args) -> None:
+ by_cell: dict[tuple[str, str, int], dict] = {}
+ for path in sorted(Path(args.runs_dir).glob("rec_rrog*_sig0.0_K1_none_T*_s*.json")):
+ rep = _read(path)
+ if rep is None:
+ continue
+ if args.epochs is not None and int(rep.get("epochs", -1)) != args.epochs:
+ continue
+ if int(rep.get("epochs", 0)) < args.min_epochs:
+ continue
+ key = (_view(rep), _compute(rep), int(rep["seed"]))
+ old = by_cell.get(key)
+ if old is None or int(rep.get("epochs", 0)) > int(old.get("epochs", 0)):
+ by_cell[key] = rep
+
+ grouped: dict[tuple[str, str], list[dict]] = defaultdict(list)
+ for rep in by_cell.values():
+ grouped[(_view(rep), _compute(rep))].append(rep)
+
+ classic_by_view = {
+ view: reps
+ for (view, compute), reps in grouped.items()
+ if compute == "classic"
+ }
+
+ base_rows = []
+ for view, classic in sorted(classic_by_view.items()):
+ vals = [_score(rep, "val") for rep in classic]
+ tests = [_score(rep, "test") for rep in classic]
+ base_rows.append([
+ "zinc-cycle56",
+ view,
+ str(len(classic)),
+ f"{_fmt(_mean(vals), args.digits)} +/- {_fmt(_std(vals), args.digits)}",
+ f"{_fmt(_mean(tests), args.digits)} +/- {_fmt(_std(tests), args.digits)}",
+ ])
+
+ delta_rows = []
+ for (view, compute), reps in sorted(grouped.items()):
+ if compute == "classic":
+ continue
+ base_by_seed = {int(rep["seed"]): rep for rep in classic_by_view.get(view, [])}
+ paired = [(rep, base_by_seed[int(rep["seed"])]) for rep in reps if int(rep["seed"]) in base_by_seed]
+ if not paired:
+ continue
+ vals = [_score(rep, "val") for rep, _ in paired]
+ tests = [_score(rep, "test") for rep, _ in paired]
+ val_deltas = [_score(base, "val") - _score(rep, "val") for rep, base in paired]
+ test_deltas = [_score(base, "test") - _score(rep, "test") for rep, base in paired]
+ delta_rows.append([
+ "zinc-cycle56",
+ view,
+ compute,
+ str(len(paired)),
+ f"{_fmt(_mean(vals), args.digits)} ({_fmt(_mean(val_deltas), args.digits)})",
+ f"{_fmt(_mean(tests), args.digits)} ({_fmt(_mean(test_deltas), args.digits)})",
+ ])
+
+ print("\nZINC-cycle56 classic baseline")
+ print(_markdown(["task", "backbone", "n", "val MAE-sum", "test MAE-sum"], base_rows))
+ print("\nZINC-cycle56 delta vs matching classic")
+ print(_markdown(["task", "backbone", "compute", "n", "val score (improvement)", "test score (improvement)"], delta_rows))
+
+
+def main() -> None:
+ ap = argparse.ArgumentParser()
+ ap.add_argument("--runs-dir", default="runs")
+ ap.add_argument("--epochs", type=int)
+ ap.add_argument("--min-epochs", type=int, default=10)
+ ap.add_argument("--digits", type=int, default=4)
+ print_zinc(ap.parse_args())
+
+
+if __name__ == "__main__":
+ main()
diff --git a/rrog/registry.py b/rrog/registry.py
new file mode 100644
index 0000000..7f6b6fe
--- /dev/null
+++ b/rrog/registry.py
@@ -0,0 +1,57 @@
+from dataclasses import dataclass, field
+from typing import Callable
+
+
+@dataclass(frozen=True)
+class BenchmarkSpec:
+ name: str
+ tier: str
+ domain: str
+ task_type: str
+ metric: str
+ priority: int
+ status: str = "planned"
+ notes: str = ""
+
+
+@dataclass(frozen=True)
+class ViewSpec:
+ name: str
+ family: str
+ graph_type: str
+ priority: int
+ status: str = "planned"
+ notes: str = ""
+
+
+@dataclass(frozen=True)
+class ComputeSpec:
+ name: str
+ family: str
+ priority: int
+ status: str = "planned"
+ notes: str = ""
+
+
+@dataclass(frozen=True)
+class ModifierSpec:
+ name: str
+ family: str
+ priority: int
+ status: str = "planned"
+ notes: str = ""
+
+
+@dataclass(frozen=True)
+class RunSpec:
+ task: str
+ view: str
+ compute: str
+ modifier: str = "none"
+ default_args: dict[str, object] = field(default_factory=dict)
+ command_builder: Callable[[dict[str, object]], list[str]] | None = None
+
+
+def by_name(items):
+ return {item.name: item for item in items}
+
diff --git a/rrog/run_ogb_hiv_remaining.sh b/rrog/run_ogb_hiv_remaining.sh
new file mode 100755
index 0000000..067e736
--- /dev/null
+++ b/rrog/run_ogb_hiv_remaining.sh
@@ -0,0 +1,49 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "${ROOT_DIR}"
+export PYTHONPATH="${ROOT_DIR}:${PYTHONPATH:-}"
+
+DEVICE="${DEVICE:-cuda:3}"
+EPOCHS="${EPOCHS:-100}"
+SEED="${SEED:-0}"
+
+run_if_missing() {
+ local view="$1"
+ local compute="$2"
+ local t="$3"
+ local ns="$4"
+ local out="runs/ogbg-molhiv_${view}_${compute}_T${t}_ns${ns}_h128_e${EPOCHS}_s${SEED}.json"
+ if [[ -f "${out}" ]]; then
+ echo "[skip] ${out}"
+ return
+ fi
+ echo "[run] view=${view} compute=${compute} T=${t} ns=${ns} device=${DEVICE}"
+ python3 -m rrog.cli run \
+ --task ogbg-molhiv \
+ --view "${view}" \
+ --compute "${compute}" \
+ --epochs "${EPOCHS}" \
+ --T "${t}" \
+ --n_sup "${ns}" \
+ --seed "${SEED}" \
+ --device "${DEVICE}"
+}
+
+# Complete remaining OGB-HIV backbone x {classic, fixed-RRoG} cells.
+# Existing json files are skipped, so the queue can be restarted safely.
+run_if_missing sgc fixed-rrog 3 3
+run_if_missing cheb classic 0 1
+run_if_missing cheb fixed-rrog 3 3
+run_if_missing arma classic 0 1
+run_if_missing arma fixed-rrog 3 3
+run_if_missing mf classic 0 1
+run_if_missing mf fixed-rrog 3 3
+run_if_missing appnp classic 0 1
+run_if_missing appnp fixed-rrog 3 3
+run_if_missing pna fixed-rrog 3 3
+run_if_missing gine classic 0 1
+run_if_missing gine fixed-rrog 3 3
+
+python3 -m rrog.cli results --epochs "${EPOCHS}"
diff --git a/rrog/run_zinc_gine.sh b/rrog/run_zinc_gine.sh
new file mode 100755
index 0000000..1ca36e1
--- /dev/null
+++ b/rrog/run_zinc_gine.sh
@@ -0,0 +1,36 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "${ROOT_DIR}"
+export PYTHONPATH="${ROOT_DIR}:${PYTHONPATH:-}"
+
+DEVICE="${DEVICE:-cuda:3}"
+EPOCHS="${EPOCHS:-200}"
+SEED="${SEED:-0}"
+
+run_if_missing() {
+ local compute="$1"
+ local t="$2"
+ local ns="$3"
+ local out="runs/rec_rrog_gine_full_sig0.0_K1_none_T${t}_ns${ns}_trace_s${SEED}.json"
+ if [[ -f "${out}" ]]; then
+ echo "[skip] ${out}"
+ return
+ fi
+ echo "[run] zinc-cycle56 view=gine compute=${compute} T=${t} ns=${ns} device=${DEVICE}"
+ python3 -m rrog.cli run \
+ --task zinc-cycle56 \
+ --view gine \
+ --compute "${compute}" \
+ --epochs "${EPOCHS}" \
+ --T "${t}" \
+ --n_sup "${ns}" \
+ --seed "${SEED}" \
+ --device "${DEVICE}"
+}
+
+run_if_missing classic 0 1
+run_if_missing fixed-rrog 1 3
+
+python3 -m rrog.cli zinc-results --epochs "${EPOCHS}"
diff --git a/rrog/run_zinc_gine_after_pid.sh b/rrog/run_zinc_gine_after_pid.sh
new file mode 100755
index 0000000..a4d24ae
--- /dev/null
+++ b/rrog/run_zinc_gine_after_pid.sh
@@ -0,0 +1,14 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+WAIT_PID="${1:?usage: run_zinc_gine_after_pid.sh <pid>}"
+ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "${ROOT_DIR}"
+
+echo "[wait] OGB queue pid=${WAIT_PID}"
+while kill -0 "${WAIT_PID}" 2>/dev/null; do
+ sleep 60
+done
+
+echo "[start] OGB queue exited; launching ZINC gine"
+DEVICE="${DEVICE:-cuda:1}" EPOCHS="${EPOCHS:-200}" SEED="${SEED:-0}" ./rrog/run_zinc_gine.sh
diff --git a/rrog/runspecs.py b/rrog/runspecs.py
new file mode 100644
index 0000000..285348f
--- /dev/null
+++ b/rrog/runspecs.py
@@ -0,0 +1,188 @@
+from rrog.registry import RunSpec
+
+
+ZINC_VIEWS = [
+ "gin", "gine", "gcn", "graphsage", "gatv2", "graphconv", "transformer", "pna",
+ "gen", "film", "resgated", "tag", "sgc", "cheb", "arma", "mf", "appnp",
+]
+OGB_MOL_VIEWS = [
+ "gin", "gine", "gcn", "graphsage", "gatv2", "graphconv", "transformer", "pna",
+ "gen", "film", "resgated", "tag", "sgc", "cheb", "arma", "mf", "appnp",
+]
+OGB_MOL_TASKS = [
+ "ogbg-molhiv",
+ "ogbg-molpcba",
+ "ogbg-molbbbp",
+ "ogbg-molbace",
+ "ogbg-moltox21",
+ "ogbg-molclintox",
+ "ogbg-molsider",
+ "ogbg-molesol",
+ "ogbg-molfreesolv",
+ "ogbg-mollipo",
+]
+
+
+def _zinc_cycle56_gin_command(args: dict[str, object]) -> list[str]:
+ compute = str(args.get("compute", "rrog-act"))
+ view = str(args.get("view", "gin"))
+ epochs = int(args.get("epochs", 200))
+ hidden = int(args.get("hidden", 64))
+ batch_size = int(args.get("bs", 256))
+ seed = int(args.get("seed", 0))
+ t = int(args.get("T", 1))
+ n_sup = int(args.get("n_sup", 3))
+ halt_max_steps = int(args.get("halt_max_steps", 8))
+ halt_target = str(args.get("halt_target", "binary"))
+
+ cmd = [
+ "python3", "diag/train_rec.py",
+ "--grad_mode", "full",
+ "--T", str(t),
+ "--n_sup", str(n_sup),
+ "--hidden", str(hidden),
+ "--bs", str(batch_size),
+ "--epochs", str(epochs),
+ "--agg_layers", str(args.get("agg_layers", 5)),
+ "--compute_layers", str(args.get("compute_layers", 2)),
+ "--view", view,
+ "--sigma", "0",
+ "--K", "1",
+ "--select", "none",
+ "--seed", str(seed),
+ ]
+
+ if compute in ["classic", "view-only"]:
+ cmd[cmd.index("--T") + 1] = "0"
+ elif compute == "fixed-rrog":
+ pass
+ elif compute == "rrog-act":
+ cmd.extend([
+ "--act",
+ "--halt_max_steps", str(halt_max_steps),
+ "--halt_target", halt_target,
+ "--halt_norm_threshold", str(args.get("halt_norm_threshold", 0.3)),
+ ])
+ else:
+ raise ValueError(f"unsupported compute for zinc-cycle56/{view}: {compute}")
+ if args.get("device") is not None:
+ cmd.extend(["--device", str(args["device"])])
+ return cmd
+
+
+def _ogb_graphprop_gin_command(args: dict[str, object]) -> list[str]:
+ task = str(args["task"])
+ view = str(args.get("view", "gin"))
+ compute = str(args.get("compute", "rrog-act"))
+ cmd = [
+ "python3", "rrog/train_ogb_graphprop.py",
+ "--dataset", task,
+ "--view", view,
+ "--compute", compute,
+ "--T", str(args.get("T", 1)),
+ "--n_sup", str(args.get("n_sup", 3)),
+ "--hidden", str(args.get("hidden", 128)),
+ "--bs", str(args.get("bs", 128)),
+ "--epochs", str(args.get("epochs", 100)),
+ "--eval_every", str(args.get("eval_every", 10)),
+ "--agg_layers", str(args.get("agg_layers", 5)),
+ "--compute_layers", str(args.get("compute_layers", 2)),
+ "--seed", str(args.get("seed", 0)),
+ ]
+ for key in ["lr", "lam_q"]:
+ if args.get(key) is not None:
+ cmd.extend([f"--{key}", str(args[key])])
+ if compute == "rrog-act":
+ cmd.extend([
+ "--halt_max_steps", str(args.get("halt_max_steps", 8)),
+ "--halt_min_steps", str(args.get("halt_min_steps", 2)),
+ "--halt_target", str(args.get("halt_target", "loss")),
+ "--halt_loss_threshold", str(args.get("halt_loss_threshold", 0.2)),
+ "--halt_exploration_prob", str(args.get("halt_exploration_prob", 0.1)),
+ ])
+ if args.get("q_warmup_epochs") is not None:
+ cmd.extend(["--q_warmup_epochs", str(args["q_warmup_epochs"])])
+ if float(args.get("ema", 0.0) or 0.0) > 0:
+ cmd.extend(["--ema", str(args["ema"])])
+ if args.get("device") is not None:
+ cmd.extend(["--device", str(args["device"])])
+ for key in ["max_train_batches", "max_eval_batches", "num_workers"]:
+ if args.get(key) is not None:
+ cmd.extend([f"--{key}", str(args[key])])
+ return cmd
+
+
+RUN_SPECS = [
+]
+
+for _view in ZINC_VIEWS:
+ RUN_SPECS.extend([
+ RunSpec(
+ task="zinc-cycle56",
+ view=_view,
+ compute="classic",
+ default_args={"compute": "classic", "view": _view, "T": 0, "n_sup": 1, "epochs": 200},
+ command_builder=_zinc_cycle56_gin_command,
+ ),
+ RunSpec(
+ task="zinc-cycle56",
+ view=_view,
+ compute="view-only",
+ default_args={"compute": "view-only", "view": _view, "T": 0, "n_sup": 3, "epochs": 200},
+ command_builder=_zinc_cycle56_gin_command,
+ ),
+ RunSpec(
+ task="zinc-cycle56",
+ view=_view,
+ compute="fixed-rrog",
+ default_args={"compute": "fixed-rrog", "view": _view, "T": 3, "n_sup": 3, "epochs": 200},
+ command_builder=_zinc_cycle56_gin_command,
+ ),
+ RunSpec(
+ task="zinc-cycle56",
+ view=_view,
+ compute="rrog-act",
+ default_args={"compute": "rrog-act", "view": _view, "T": 1, "n_sup": 3, "epochs": 200},
+ command_builder=_zinc_cycle56_gin_command,
+ ),
+ ])
+
+for _task in OGB_MOL_TASKS:
+ for _view in OGB_MOL_VIEWS:
+ RUN_SPECS.extend([
+ RunSpec(
+ task=_task,
+ view=_view,
+ compute="classic",
+ default_args={"task": _task, "view": _view, "compute": "classic", "T": 0, "n_sup": 1, "epochs": 100},
+ command_builder=_ogb_graphprop_gin_command,
+ ),
+ RunSpec(
+ task=_task,
+ view=_view,
+ compute="view-only",
+ default_args={"task": _task, "view": _view, "compute": "view-only", "T": 0, "n_sup": 3, "epochs": 100},
+ command_builder=_ogb_graphprop_gin_command,
+ ),
+ RunSpec(
+ task=_task,
+ view=_view,
+ compute="fixed-rrog",
+ default_args={"task": _task, "view": _view, "compute": "fixed-rrog", "T": 3, "n_sup": 3, "epochs": 100},
+ command_builder=_ogb_graphprop_gin_command,
+ ),
+ RunSpec(
+ task=_task,
+ view=_view,
+ compute="rrog-act",
+ default_args={"task": _task, "view": _view, "compute": "rrog-act", "T": 1, "n_sup": 3, "epochs": 100},
+ command_builder=_ogb_graphprop_gin_command,
+ ),
+ ])
+
+
+def find_run_spec(task: str, view: str, compute: str, modifier: str = "none") -> RunSpec:
+ for spec in RUN_SPECS:
+ if (spec.task, spec.view, spec.compute, spec.modifier) == (task, view, compute, modifier):
+ return spec
+ raise KeyError(f"no implemented run spec for task={task} view={view} compute={compute} modifier={modifier}")
diff --git a/rrog/train_ogb_graphprop.py b/rrog/train_ogb_graphprop.py
new file mode 100644
index 0000000..387ef3c
--- /dev/null
+++ b/rrog/train_ogb_graphprop.py
@@ -0,0 +1,685 @@
+import argparse
+from contextlib import contextmanager
+import json
+import os
+import time
+
+import numpy as np
+import torch
+import torch.nn as nn
+from ogb.graphproppred import Evaluator, PygGraphPropPredDataset
+from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
+from torch_geometric.data import Batch
+from torch_geometric.loader import DataLoader
+from torch_geometric.nn import (
+ APPNP,
+ ARMAConv,
+ ChebConv,
+ FiLMConv,
+ GATv2Conv,
+ GCNConv,
+ GENConv,
+ GINEConv,
+ GraphConv,
+ MFConv,
+ PNAConv,
+ ResGatedGraphConv,
+ SAGEConv,
+ SGConv,
+ TAGConv,
+ TransformerConv,
+ global_add_pool,
+)
+from torch_geometric.utils import degree
+
+
+PROJECT_ROOT = os.environ.get(
+ "RROG_ROOT",
+ os.path.abspath(os.path.join(os.path.dirname(__file__), "..")),
+)
+DATA_ROOT = os.environ.get("RROG_DATA_DIR", os.path.join(PROJECT_ROOT, "data"))
+ROOT = os.path.join(DATA_ROOT, "ogb")
+OUT = os.environ.get("RROG_RUNS_DIR", os.path.join(PROJECT_ROOT, "runs"))
+SUPPORTED_VIEWS = [
+ "gin", "gine", "gcn", "graphsage", "gatv2", "graphconv", "transformer", "pna",
+ "gen", "film", "resgated", "tag", "sgc", "cheb", "arma", "mf", "appnp",
+]
+SUPPORTED_MOL_DATASETS = [
+ "ogbg-molhiv",
+ "ogbg-molpcba",
+ "ogbg-molbbbp",
+ "ogbg-molbace",
+ "ogbg-moltox21",
+ "ogbg-molclintox",
+ "ogbg-molsider",
+ "ogbg-molesol",
+ "ogbg-molfreesolv",
+ "ogbg-mollipo",
+]
+HIGHER_BETTER = {"rocauc", "ap", "auc", "accuracy", "acc", "f1"}
+LOWER_BETTER = {"rmse", "mae"}
+
+_TORCH_LOAD = torch.load
+
+
+def _torch_load_ogb_compat(*args, **kwargs):
+ kwargs.setdefault("weights_only", False)
+ return _TORCH_LOAD(*args, **kwargs)
+
+
+torch.load = _torch_load_ogb_compat
+
+
+def clone_state_dict(model):
+ return {k: v.detach().clone() for k, v in model.state_dict().items()}
+
+
+@torch.no_grad()
+def update_ema_state(ema_state, model, decay):
+ if ema_state is None:
+ return
+ for key, value in model.state_dict().items():
+ if torch.is_floating_point(value):
+ ema_state[key].mul_(decay).add_(value.detach(), alpha=1.0 - decay)
+ else:
+ ema_state[key].copy_(value.detach())
+
+
+@contextmanager
+def using_ema_state(model, ema_state):
+ if ema_state is None:
+ yield
+ return
+ raw_state = clone_state_dict(model)
+ model.load_state_dict(ema_state, strict=True)
+ try:
+ yield
+ finally:
+ model.load_state_dict(raw_state, strict=True)
+
+
+def metric_direction(metric: str) -> int:
+ metric = metric.lower()
+ if metric in LOWER_BETTER or "rmse" in metric or "mae" in metric:
+ return -1
+ return 1
+
+
+def is_regression_metric(metric: str) -> bool:
+ return metric_direction(metric) < 0
+
+
+def is_better(score: float, best: float | None, metric: str) -> bool:
+ if best is None:
+ return True
+ direction = metric_direction(metric)
+ return score > best if direction > 0 else score < best
+
+
+def jsonable(obj):
+ if isinstance(obj, dict):
+ return {str(k): jsonable(v) for k, v in obj.items()}
+ if isinstance(obj, (list, tuple)):
+ return [jsonable(v) for v in obj]
+ if isinstance(obj, (np.integer, np.floating)):
+ return obj.item()
+ if isinstance(obj, torch.Tensor):
+ if obj.ndim == 0:
+ return obj.detach().cpu().item()
+ return obj.detach().cpu().tolist()
+ return obj
+
+
+def degree_histogram(dataset) -> torch.Tensor:
+ max_degree = 0
+ degs = []
+ for graph in dataset:
+ deg = degree(graph.edge_index[1], num_nodes=graph.num_nodes, dtype=torch.long)
+ degs.append(deg)
+ if deg.numel():
+ max_degree = max(max_degree, int(deg.max().item()))
+ hist = torch.zeros(max_degree + 1, dtype=torch.long)
+ for deg in degs:
+ hist += torch.bincount(deg, minlength=hist.numel())
+ return hist
+
+
+def make_view_layer(view: str, hidden: int, deg: torch.Tensor | None):
+ if view in {"gin", "gine"}:
+ mlp = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden))
+ return GINEConv(mlp, train_eps=True)
+ if view == "gcn":
+ return GCNConv(hidden, hidden)
+ if view == "graphsage":
+ return SAGEConv(hidden, hidden)
+ if view == "gatv2":
+ return GATv2Conv(hidden, hidden, heads=4, concat=False, edge_dim=hidden)
+ if view == "graphconv":
+ return GraphConv(hidden, hidden)
+ if view == "transformer":
+ return TransformerConv(hidden, hidden, heads=4, concat=False, edge_dim=hidden)
+ if view == "pna":
+ if deg is None:
+ raise ValueError("PNA view requires a training-set degree histogram")
+ return PNAConv(
+ hidden, hidden,
+ aggregators=["mean", "min", "max", "std"],
+ scalers=["identity", "amplification", "attenuation"],
+ deg=deg,
+ edge_dim=hidden,
+ )
+ if view == "gen":
+ return GENConv(hidden, hidden, edge_dim=hidden)
+ if view == "film":
+ return FiLMConv(hidden, hidden)
+ if view == "resgated":
+ return ResGatedGraphConv(hidden, hidden, edge_dim=hidden)
+ if view == "tag":
+ return TAGConv(hidden, hidden, K=3)
+ if view == "sgc":
+ return SGConv(hidden, hidden, K=2, cached=False)
+ if view == "cheb":
+ return ChebConv(hidden, hidden, K=3)
+ if view == "arma":
+ return ARMAConv(hidden, hidden, num_stacks=1, num_layers=2)
+ if view == "mf":
+ return MFConv(hidden, hidden)
+ if view == "appnp":
+ return APPNP(K=5, alpha=0.1)
+ raise ValueError(f"unsupported OGB view: {view}")
+
+
+EDGE_ATTR_VIEWS = {"gin", "gine", "gatv2", "transformer", "pna", "gen", "resgated"}
+
+
+class OGBRRoG(nn.Module):
+ def __init__(
+ self, hidden, num_tasks, view="gin", T=1, n_sup=3, agg_layers=5,
+ compute_layers=2, grad_mode="full", deg=None,
+ ):
+ super().__init__()
+ self.view = view
+ self.atom_encoder = AtomEncoder(hidden)
+ self.bond_encoder = BondEncoder(hidden)
+ self.agg_convs = nn.ModuleList()
+ self.agg_bns = nn.ModuleList()
+ for _ in range(agg_layers):
+ self.agg_convs.append(make_view_layer(view, hidden, deg))
+ self.agg_bns.append(nn.BatchNorm1d(hidden))
+
+ core = []
+ d = hidden
+ for _ in range(compute_layers - 1):
+ core += [nn.Linear(d, hidden), nn.GELU()]
+ d = hidden
+ core.append(nn.Linear(d, hidden))
+ self.core_norm = nn.LayerNorm(hidden)
+ self.core = nn.Sequential(*core)
+ nn.init.zeros_(self.core[-1].weight)
+ nn.init.zeros_(self.core[-1].bias)
+
+ self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, num_tasks))
+ self.qhead = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 1))
+ with torch.no_grad():
+ self.qhead[-1].weight.zero_()
+ self.qhead[-1].bias.fill_(-5.0)
+
+ self.T = T
+ self.n_sup = n_sup
+ self.grad_mode = grad_mode
+ self.agg_layers = agg_layers
+ self.compute_layers = compute_layers
+ self.hidden = hidden
+ self.num_tasks = num_tasks
+
+ def aggregate(self, x, edge_index, edge_attr):
+ h = self.atom_encoder(x)
+ e = self.bond_encoder(edge_attr)
+ for conv, bn in zip(self.agg_convs, self.agg_bns):
+ if self.view in {"gin", "gine", "pna", "gen", "resgated"}:
+ h = bn(conv(h, edge_index, e)).relu()
+ elif self.view in {"gatv2", "transformer"}:
+ h = bn(conv(h, edge_index, edge_attr=e)).relu()
+ else:
+ h = bn(conv(h, edge_index)).relu()
+ return h
+
+ def core_step(self, combined, state):
+ return state + self.core(self.core_norm(combined))
+
+ def _z_step(self, y, z, ctx):
+ return self.core_step(ctx + y + z, z)
+
+ def _y_step(self, y, z):
+ return self.core_step(y + z, y)
+
+ def recurse(self, y, z, ctx, one_step=False):
+ if self.T == 0:
+ return y, z
+ if one_step:
+ with torch.no_grad():
+ for _ in range(self.T - 1):
+ z = self._z_step(y, z, ctx)
+ z = z.detach()
+ z = self._z_step(y, z, ctx)
+ y = self._y_step(y, z)
+ return y, z
+ for _ in range(self.T):
+ z = self._z_step(y, z, ctx)
+ y = self._y_step(y, z)
+ return y, z
+
+ def predict(self, y, batch):
+ pooled = global_add_pool(y, batch)
+ return self.head(pooled), self.qhead(pooled).view(-1)
+
+ def forward_trace(self, data, steps):
+ ctx = self.aggregate(data.x, data.edge_index, data.edge_attr)
+ y = ctx
+ z = torch.zeros_like(ctx)
+ preds, q_logits = [], []
+ for s in range(steps):
+ y, z = self.recurse(y, z, ctx, one_step=(self.grad_mode == "1step"))
+ pred, q = self.predict(y, data.batch)
+ preds.append(pred)
+ q_logits.append(q)
+ if s < steps - 1:
+ y, z = y.detach(), z.detach()
+ return preds, q_logits
+
+ def forward(self, data):
+ steps = self.n_sup
+ preds, q_logits = self.forward_trace(data, steps)
+ return preds, q_logits[-1]
+
+
+def supervised_loss(logits, y, metric):
+ per_graph, has_label = per_graph_supervised_loss(logits, y, metric)
+ if not has_label.any():
+ return logits.sum() * 0.0
+ return per_graph[has_label].mean()
+
+
+def per_graph_supervised_loss(logits, y, metric):
+ y = y.to(torch.float32)
+ mask = ~torch.isnan(y)
+ target = torch.where(mask, y, torch.zeros_like(y))
+ if is_regression_metric(metric):
+ losses = (logits - target).pow(2)
+ else:
+ losses = nn.functional.binary_cross_entropy_with_logits(logits, target, reduction="none")
+ losses = torch.where(mask, losses, torch.zeros_like(losses))
+ denom = mask.sum(dim=1).clamp_min(1)
+ return losses.sum(dim=1) / denom, mask.any(dim=1)
+
+
+@torch.no_grad()
+def halt_targets(logits, y):
+ y = y.to(torch.float32)
+ mask = ~torch.isnan(y)
+ pred = (logits > 0).to(y.dtype)
+ correct_or_missing = (~mask) | (pred == y)
+ has_label = mask.any(dim=1)
+ return (correct_or_missing.all(dim=1) & has_label).to(logits.dtype)
+
+
+@torch.no_grad()
+def evaluate(model, loader, evaluator, dev, steps=None, adaptive=False, halt_min_steps=1, max_batches=0):
+ model.eval()
+ ys, ps = [], []
+ step_sum = 0.0
+ n = 0
+ for i, batch in enumerate(loader):
+ if max_batches and i >= max_batches:
+ break
+ batch = batch.to(dev)
+ if steps is None:
+ preds, _ = model(batch)
+ pred = preds[-1]
+ used_steps = float(model.n_sup)
+ else:
+ preds, qs = model.forward_trace(batch, steps)
+ stack = torch.stack(preds, dim=0)
+ if adaptive:
+ q = torch.stack(qs, dim=0)
+ step_ids = torch.arange(1, steps + 1, device=dev).view(-1, 1)
+ halted = (q > 0) & (step_ids >= halt_min_steps)
+ any_halt = halted.any(dim=0)
+ first_halt = halted.to(torch.int64).argmax(dim=0)
+ fallback = torch.full_like(first_halt, steps - 1)
+ idx = torch.where(any_halt, first_halt, fallback)
+ pred = stack[idx, torch.arange(stack.size(1), device=dev)]
+ used_steps = float((idx.to(torch.float32) + 1).sum().item())
+ else:
+ pred = stack[-1]
+ used_steps = float(steps * batch.num_graphs)
+ ys.append(batch.y.detach().cpu())
+ ps.append(pred.detach().cpu())
+ step_sum += used_steps
+ n += batch.num_graphs
+ y_true = torch.cat(ys, dim=0)
+ y_pred = torch.cat(ps, dim=0)
+ result = evaluator.eval({"y_true": y_true, "y_pred": y_pred})
+ return result, step_sum / max(n, 1)
+
+
+def _split_nodes(t, ptr):
+ return [t[ptr[i].item():ptr[i + 1].item()].detach() for i in range(ptr.numel() - 1)]
+
+
+def act_train_step(model, state, replacement_batch, opt, dev, args, metric):
+ replacement = replacement_batch.to_data_list()
+ batch_size = len(replacement)
+ if state is None:
+ state = {
+ "graphs": [None for _ in range(batch_size)],
+ "y": [None for _ in range(batch_size)],
+ "z": [None for _ in range(batch_size)],
+ "steps": torch.zeros(batch_size, dtype=torch.long, device=dev),
+ "halted": torch.ones(batch_size, dtype=torch.bool, device=dev),
+ }
+
+ halted_cpu = state["halted"].detach().cpu().tolist()
+ for i, halted in enumerate(halted_cpu):
+ if halted:
+ state["graphs"][i] = replacement[i]
+
+ batch = Batch.from_data_list(state["graphs"]).to(dev)
+ ctx = model.aggregate(batch.x, batch.edge_index, batch.edge_attr)
+ ptr = batch.ptr
+ y_parts, z_parts = [], []
+ for i in range(batch_size):
+ start, end = ptr[i].item(), ptr[i + 1].item()
+ if halted_cpu[i] or state["y"][i] is None:
+ y_parts.append(ctx[start:end])
+ z_parts.append(torch.zeros_like(ctx[start:end]))
+ else:
+ y_parts.append(state["y"][i].to(dev))
+ z_parts.append(state["z"][i].to(dev))
+ y = torch.cat(y_parts, dim=0)
+ z = torch.cat(z_parts, dim=0)
+
+ opt.zero_grad()
+ y, z = model.recurse(y, z, ctx, one_step=(model.grad_mode == "1step"))
+ logits, q = model.predict(y, batch.batch)
+ pred_loss = supervised_loss(logits, batch.y, metric)
+ if args.halt_target == "exact" and not is_regression_metric(metric):
+ target = halt_targets(logits.detach(), batch.y)
+ elif args.halt_target == "loss":
+ per_graph_loss, has_label = per_graph_supervised_loss(logits.detach(), batch.y, metric)
+ target = ((per_graph_loss <= args.halt_loss_threshold) & has_label).to(logits.dtype)
+ else:
+ raise ValueError(args.halt_target)
+ q_loss = nn.functional.binary_cross_entropy_with_logits(q, target)
+ loss = pred_loss + 0.5 * args.lam_q * q_loss
+ y_det, z_det = y.detach(), z.detach()
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
+ opt.step()
+
+ state["y"] = _split_nodes(y_det, ptr)
+ state["z"] = _split_nodes(z_det, ptr)
+ with torch.no_grad():
+ was_halted = state["halted"]
+ steps = torch.where(was_halted, torch.zeros_like(state["steps"]), state["steps"]) + 1
+ halted = (steps >= args.halt_max_steps) | ((q.detach() > 0) & (steps >= args.halt_min_steps))
+ if args.halt_exploration_prob > 0 and args.halt_max_steps > 1:
+ explore = torch.rand_like(q) < args.halt_exploration_prob
+ min_steps = torch.where(
+ explore,
+ torch.randint(2, args.halt_max_steps + 1, steps.shape, device=dev),
+ torch.zeros_like(steps),
+ )
+ halted = halted & (steps >= min_steps)
+ state["steps"] = steps
+ state["halted"] = halted
+
+ return state, {
+ "loss": float(loss.detach().cpu()),
+ "pred_loss": float(pred_loss.detach().cpu()),
+ "q_loss": float(q_loss.detach().cpu()),
+ "halted_frac": float(state["halted"].to(torch.float32).mean().detach().cpu()),
+ "steps": float(state["steps"].to(torch.float32).mean().detach().cpu()),
+ }
+
+
+def _halt_target(args, logits, y, metric):
+ if args.halt_target == "exact" and not is_regression_metric(metric):
+ return halt_targets(logits.detach(), y)
+ if args.halt_target == "loss":
+ per_graph_loss, has_label = per_graph_supervised_loss(logits.detach(), y, metric)
+ return ((per_graph_loss <= args.halt_loss_threshold) & has_label).to(logits.dtype)
+ raise ValueError(args.halt_target)
+
+
+def act_trace_train_step(model, batch, opt, dev, args, epoch, metric):
+ batch = batch.to(dev)
+ steps = max(1, args.halt_max_steps)
+ opt.zero_grad()
+ preds, qs = model.forward_trace(batch, steps)
+ pred_loss = sum(supervised_loss(pred, batch.y, metric) for pred in preds) / len(preds)
+ if epoch <= args.q_warmup_epochs:
+ q_loss = pred_loss.detach() * 0.0
+ loss = pred_loss
+ else:
+ q_losses = []
+ for step_idx, (pred, q) in enumerate(zip(preds, qs), start=1):
+ target = _halt_target(args, pred, batch.y, metric)
+ if step_idx < args.halt_min_steps:
+ target = torch.zeros_like(target)
+ q_losses.append(nn.functional.binary_cross_entropy_with_logits(q, target))
+ q_loss = sum(q_losses) / len(q_losses)
+ loss = pred_loss + 0.5 * args.lam_q * q_loss
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
+ opt.step()
+
+ with torch.no_grad():
+ q_stack = torch.stack(qs, dim=0)
+ step_ids = torch.arange(1, steps + 1, device=dev).view(-1, 1)
+ halted = (q_stack > 0) & (step_ids >= args.halt_min_steps)
+ any_halt = halted.any(dim=0)
+ first_halt = halted.to(torch.int64).argmax(dim=0)
+ fallback = torch.full_like(first_halt, steps - 1)
+ idx = torch.where(any_halt, first_halt, fallback)
+
+ return {
+ "loss": float(loss.detach().cpu()),
+ "pred_loss": float(pred_loss.detach().cpu()),
+ "q_loss": float(q_loss.detach().cpu()),
+ "halted_frac": float(any_halt.to(torch.float32).mean().detach().cpu()),
+ "steps": float((idx.to(torch.float32) + 1).mean().detach().cpu()),
+ }
+
+
+def train_epoch(model, loader, opt, dev, args, act_state, ema_state, epoch, metric):
+ model.train()
+ metrics = []
+ for i, batch in enumerate(loader):
+ if args.max_train_batches and i >= args.max_train_batches:
+ break
+ if args.compute == "rrog-act":
+ m = act_trace_train_step(model, batch, opt, dev, args, epoch, metric)
+ update_ema_state(ema_state, model, args.ema)
+ metrics.append(m)
+ continue
+ batch = batch.to(dev)
+ opt.zero_grad()
+ preds, _ = model(batch)
+ loss = sum(supervised_loss(pred, batch.y, metric) for pred in preds) / len(preds)
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
+ opt.step()
+ update_ema_state(ema_state, model, args.ema)
+ metrics.append({"loss": float(loss.detach().cpu()), "halted_frac": 0.0, "steps": float(model.n_sup)})
+ return act_state, metrics
+
+
+def metric_value(result, evaluator):
+ return float(result[evaluator.eval_metric])
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument("--dataset", choices=SUPPORTED_MOL_DATASETS, default="ogbg-molhiv")
+ ap.add_argument("--view", choices=SUPPORTED_VIEWS, default="gin")
+ ap.add_argument("--compute", choices=["classic", "view-only", "fixed-rrog", "rrog-act"], default="rrog-act")
+ ap.add_argument("--T", type=int, default=1)
+ ap.add_argument("--n_sup", type=int, default=3)
+ ap.add_argument("--hidden", type=int, default=128)
+ ap.add_argument("--agg_layers", type=int, default=5)
+ ap.add_argument("--compute_layers", type=int, default=2)
+ ap.add_argument("--epochs", type=int, default=100)
+ ap.add_argument("--eval_every", type=int, default=10)
+ ap.add_argument("--lr", type=float, default=1e-3)
+ ap.add_argument("--bs", type=int, default=128)
+ ap.add_argument("--lam_q", type=float, default=1.0)
+ ap.add_argument("--halt_max_steps", type=int, default=8)
+ ap.add_argument("--halt_min_steps", type=int, default=1)
+ ap.add_argument("--halt_target", choices=["exact", "loss"], default="loss")
+ ap.add_argument("--halt_loss_threshold", type=float, default=0.2)
+ ap.add_argument("--halt_exploration_prob", type=float, default=0.1)
+ ap.add_argument("--q_warmup_epochs", type=int, default=0)
+ ap.add_argument("--ema", type=float, default=0.0)
+ ap.add_argument("--seed", type=int, default=0)
+ ap.add_argument("--num_workers", type=int, default=0)
+ ap.add_argument("--max_train_batches", type=int, default=0)
+ ap.add_argument("--max_eval_batches", type=int, default=0)
+ ap.add_argument("--device", default="auto")
+ args = ap.parse_args()
+
+ torch.manual_seed(args.seed)
+ np.random.seed(args.seed)
+ if args.compute == "classic":
+ args.n_sup = 1
+ if args.device == "auto":
+ dev = "cuda" if torch.cuda.is_available() else "cpu"
+ else:
+ dev = args.device
+ os.makedirs(OUT, exist_ok=True)
+
+ dataset = PygGraphPropPredDataset(name=args.dataset, root=ROOT)
+ split_idx = dataset.get_idx_split()
+ evaluator = Evaluator(args.dataset)
+ metric = evaluator.eval_metric
+ num_tasks = dataset.num_tasks
+
+ train_dataset = dataset[split_idx["train"]]
+ valid_dataset = dataset[split_idx["valid"]]
+ test_dataset = dataset[split_idx["test"]]
+ train_loader = DataLoader(train_dataset, batch_size=args.bs, shuffle=True,
+ drop_last=True, num_workers=args.num_workers)
+ valid_loader = DataLoader(valid_dataset, batch_size=256, shuffle=False,
+ num_workers=args.num_workers)
+ test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False,
+ num_workers=args.num_workers)
+
+ T = 0 if args.compute in ["classic", "view-only"] else args.T
+ deg = degree_histogram(train_dataset) if args.view == "pna" else None
+ model = OGBRRoG(args.hidden, num_tasks, view=args.view, T=T, n_sup=args.n_sup,
+ agg_layers=args.agg_layers, compute_layers=args.compute_layers, deg=deg).to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, max(args.epochs, 1))
+ ema_state = clone_state_dict(model) if args.ema > 0 else None
+
+ t0 = time.time()
+ best_val = None
+ best = {}
+ best_state = None
+ act_state = None
+ steps = max(1, args.halt_max_steps)
+
+ for ep in range(args.epochs):
+ act_state, train_metrics = train_epoch(
+ model, train_loader, opt, dev, args, act_state, ema_state, ep + 1, metric)
+ sched.step()
+ if (ep + 1) % args.eval_every == 0 or ep == args.epochs - 1:
+ with using_ema_state(model, ema_state):
+ if args.compute == "rrog-act":
+ val_result, fixed_val_steps = evaluate(model, valid_loader, evaluator, dev, steps=steps,
+ adaptive=False, halt_min_steps=args.halt_min_steps,
+ max_batches=args.max_eval_batches)
+ val_adapt, adaptive_val_steps = evaluate(model, valid_loader, evaluator, dev, steps=steps,
+ adaptive=True, halt_min_steps=args.halt_min_steps,
+ max_batches=args.max_eval_batches)
+ else:
+ val_result, _ = evaluate(model, valid_loader, evaluator, dev,
+ max_batches=args.max_eval_batches)
+ val_score = metric_value(val_result, evaluator)
+ if is_better(val_score, best_val, metric):
+ best_val = val_score
+ if args.compute == "rrog-act":
+ test_fixed, fixed_steps = evaluate(model, test_loader, evaluator, dev, steps=steps,
+ adaptive=False, halt_min_steps=args.halt_min_steps,
+ max_batches=args.max_eval_batches)
+ test_adapt, adaptive_steps = evaluate(model, test_loader, evaluator, dev, steps=steps,
+ adaptive=True, halt_min_steps=args.halt_min_steps,
+ max_batches=args.max_eval_batches)
+ best = {
+ "ep": ep + 1,
+ "val": val_result,
+ "val_adaptive": val_adapt,
+ "test": test_fixed,
+ "test_adaptive": test_adapt,
+ "fixed_val_steps": fixed_val_steps,
+ "adaptive_val_steps": adaptive_val_steps,
+ "fixed_steps": fixed_steps,
+ "adaptive_steps": adaptive_steps,
+ }
+ else:
+ test_result, _ = evaluate(model, test_loader, evaluator, dev,
+ max_batches=args.max_eval_batches)
+ best = {"ep": ep + 1, "val": val_result, "test": test_result}
+ best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
+ halted = sum(m.get("halted_frac", 0.0) for m in train_metrics) / max(len(train_metrics), 1)
+ train_steps = sum(m.get("steps", 0.0) for m in train_metrics) / max(len(train_metrics), 1)
+ msg = f"ep{ep+1} val_{evaluator.eval_metric}={val_score:.5f}"
+ if args.compute == "rrog-act":
+ msg += (
+ f" val_adapt_{evaluator.eval_metric}={metric_value(val_adapt, evaluator):.5f}"
+ f" adapt_steps={adaptive_val_steps:.2f}"
+ )
+ msg += f" halt={halted:.2f} train_steps={train_steps:.2f}"
+ print(msg, flush=True)
+
+ ema_tag = f"_ema{args.ema:g}" if args.ema > 0 else ""
+ tag = (
+ f"{args.dataset}_{args.view}_{args.compute}_T{T}_ns{args.n_sup}_"
+ f"h{args.hidden}_e{args.epochs}{ema_tag}_s{args.seed}"
+ )
+ rep = {
+ "dataset": args.dataset,
+ "tag": tag,
+ **vars(args),
+ "metric": evaluator.eval_metric,
+ "sec": round(time.time() - t0, 1),
+ "dev": dev,
+ **best,
+ }
+ with open(os.path.join(OUT, f"{tag}.json"), "w") as f:
+ json.dump(jsonable(rep), f, indent=2)
+ torch.save({
+ "state": best_state or model.state_dict(),
+ "cfg": {
+ "dataset": args.dataset,
+ "hidden": args.hidden,
+ "num_tasks": num_tasks,
+ "T": T,
+ "n_sup": args.n_sup,
+ "agg_layers": args.agg_layers,
+ "compute_layers": args.compute_layers,
+ "compute": args.compute,
+ "halt_max_steps": steps,
+ "halt_min_steps": args.halt_min_steps,
+ "halt_target": args.halt_target,
+ "halt_loss_threshold": args.halt_loss_threshold,
+ "view": args.view,
+ },
+ }, os.path.join(OUT, f"ckpt_{tag}.pt"))
+ print(f"[{tag}] best_ep={best.get('ep')} val={best.get('val')} test={best.get('test')} "
+ f"adaptive={best.get('test_adaptive')} steps={best.get('adaptive_steps')}", flush=True)
+ print(" wrote", os.path.join(OUT, f"{tag}.json"))
+
+
+if __name__ == "__main__":
+ main()