summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore10
-rw-r--r--README.md100
-rw-r--r--diag/__init__.py0
-rw-r--r--diag/train_cycle.py188
-rw-r--r--diag/train_rec.py491
-rw-r--r--requirements.txt5
-rw-r--r--rrog/__init__.py2
-rw-r--r--rrog/backbones.py72
-rw-r--r--rrog/benchmarks.py44
-rw-r--r--rrog/cli.py176
-rw-r--r--rrog/collect_results.py239
-rw-r--r--rrog/collect_zinc.py137
-rw-r--r--rrog/registry.py57
-rwxr-xr-xrrog/run_ogb_hiv_remaining.sh49
-rwxr-xr-xrrog/run_zinc_gine.sh36
-rwxr-xr-xrrog/run_zinc_gine_after_pid.sh14
-rw-r--r--rrog/runspecs.py188
-rw-r--r--rrog/train_ogb_graphprop.py685
-rwxr-xr-xscripts/collect_results.sh10
-rwxr-xr-xscripts/run_ogb_mol_all_tasks.sh17
-rwxr-xr-xscripts/run_ogb_mol_task_full.sh54
-rwxr-xr-xscripts/run_smoke.sh19
-rwxr-xr-xscripts/run_two_a6000.sh32
-rwxr-xr-xscripts/run_zinc_cycle56_full.sh54
-rwxr-xr-xscripts/setup_and_run_two_a6000.sh15
-rwxr-xr-xscripts/setup_env.sh35
26 files changed, 2729 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..307c159
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,10 @@
+__pycache__/
+*.py[cod]
+.venv/
+data/
+runs/
+logs/
+summaries/
+*.pt
+*.pth
+.DS_Store
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..f07c90f
--- /dev/null
+++ b/README.md
@@ -0,0 +1,100 @@
+# RRoG-GNN Runner
+
+This repo runs the current RRoG/TRM-on-GNN experiment grid.
+
+Core rule:
+
+```text
+view/graph aggregation happens once; recursive compute is edge-free hidden-state refinement.
+```
+
+The main reported table is:
+
+```text
+Task x Backbone -> classic baseline
+Task x Backbone x fixed-RRoG -> delta against the matching classic row
+```
+
+`classic` is the non-RRoG baseline for every backbone: `T=0`, `n_sup=1`.
+
+## One-command Run On 2x A6000
+
+On a clean machine with two visible GPUs:
+
+```bash
+git clone git@github.com:YurenHao0426/rrog-gnn-runner.git
+cd rrog-gnn-runner
+./scripts/setup_and_run_two_a6000.sh
+```
+
+Defaults:
+
+- GPU0: `zinc-cycle56` over 17 backbones, `classic + fixed-rrog`
+- GPU1: `ogbg-molhiv` over 17 backbones, `classic + fixed-rrog`
+- Results: `runs/*.json`
+- Logs: `logs/*.log`
+- Summaries: `summaries/*.md`
+
+If the environment already has compatible `torch`, `torch_geometric`, and `ogb`:
+
+```bash
+SKIP_SETUP=1 ./scripts/setup_and_run_two_a6000.sh
+```
+
+To override CUDA wheel index during setup:
+
+```bash
+TORCH_INDEX_URL=https://download.pytorch.org/whl/cu121 ./scripts/setup_env.sh
+```
+
+## Common Commands
+
+Smoke test:
+
+```bash
+./scripts/setup_env.sh
+DEVICE=cuda:0 ./scripts/run_smoke.sh
+```
+
+Run the paired ZINC matrix only:
+
+```bash
+DEVICE=cuda:0 EPOCHS=200 ./scripts/run_zinc_cycle56_full.sh
+```
+
+Run one OGB molecular task:
+
+```bash
+TASK=ogbg-molhiv DEVICE=cuda:1 EPOCHS=100 ./scripts/run_ogb_mol_task_full.sh
+```
+
+Run all selected OGB molecular tasks serially on one GPU:
+
+```bash
+DEVICE=cuda:1 ./scripts/run_ogb_mol_all_tasks.sh
+```
+
+Collect summaries:
+
+```bash
+./scripts/collect_results.sh
+```
+
+## Backbones
+
+The implemented 2D view/backbone list is shared across ZINC and OGB:
+
+```text
+gin, gine, gcn, graphsage, gatv2, graphconv, transformer, pna,
+gen, film, resgated, tag, sgc, cheb, arma, mf, appnp
+```
+
+For ZINC `gine`, there are no bond features, so GINE uses a learned constant edge token.
+For OGB molecular tasks, GINE and edge-aware backbones use OGB bond encodings.
+
+## Notes
+
+- Runs are resumable at the cell level: scripts skip existing expected JSON files.
+- ZINC cycle-count cache is generated under `data/cycle_cache`.
+- OGB datasets are downloaded under `data/ogb`.
+- Override data/runs locations with `RROG_DATA_DIR` and `RROG_RUNS_DIR`.
diff --git a/diag/__init__.py b/diag/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/diag/__init__.py
diff --git a/diag/train_cycle.py b/diag/train_cycle.py
new file mode 100644
index 0000000..598e349
--- /dev/null
+++ b/diag/train_cycle.py
@@ -0,0 +1,188 @@
+"""GNN-native ring-counting on real molecules (ZINC): regress [#5-cycles, #6-cycles].
+
+k-cycle counts (k>=3) are provably NOT computable by 1-WL/MPNN (Chen et al. 2020) -> a REAL
+H2 ceiling on REAL graphs. Training-based diagnosis (partition instrument is vacuous on
+feature-rich graphs):
+ GIN(L) 1-WL baseline -> should FAIL to count
+ GCN(L) sub-1-WL reference
+ GIN+RNI random feats = NOISE -> PTRM-style crude symmetry break (eval-averaged)
+ GIN+RWSE random-walk return probs-> structured >1-WL positive control
+Reads: GIN high error + RWSE fixes it = real ceiling exists; RNI also fixes = crude noise
+breaks it (bridge cashed); only RWSE = bridge needs STRUCTURED stochasticity (GRAM>PTRM).
+Targets z-scored for training; per-target MAE reported in RAW ring units.
+
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_cycle.py --conv gin --feat none
+"""
+import argparse, json, os, time
+import numpy as np
+import torch
+import torch.nn as nn
+import networkx as nx
+from torch_geometric.datasets import ZINC
+from torch_geometric.data import Data
+from torch_geometric.loader import DataLoader
+from torch_geometric.utils import to_networkx
+from torch_geometric.nn import GINConv, GCNConv, global_add_pool
+
+PROJECT_ROOT = os.environ.get(
+ 'RROG_ROOT',
+ os.path.abspath(os.path.join(os.path.dirname(__file__), '..')),
+)
+DATA_ROOT = os.environ.get('RROG_DATA_DIR', os.path.join(PROJECT_ROOT, 'data'))
+OUT = os.environ.get('RROG_RUNS_DIR', os.path.join(PROJECT_ROOT, 'runs'))
+ROOT = os.path.join(DATA_ROOT, 'zinc')
+CACHE = os.path.join(DATA_ROOT, 'cycle_cache')
+RWSE_K = 16
+
+
+def rwse(edge_index, n, K=RWSE_K):
+ A = np.zeros((n, n), dtype=np.float64)
+ ei = edge_index.numpy()
+ A[ei[0], ei[1]] = 1.0
+ A = np.maximum(A, A.T)
+ deg = A.sum(1)
+ P = A / np.where(deg > 0, deg, 1.0)[:, None]
+ out = np.zeros((n, K), dtype=np.float32)
+ M = np.eye(n)
+ for k in range(K):
+ M = M @ P
+ out[:, k] = np.diag(M)
+ return torch.from_numpy(out)
+
+
+def c56(data):
+ G = to_networkx(data, to_undirected=True)
+ c = {5: 0, 6: 0}
+ for cyc in nx.simple_cycles(G, length_bound=6):
+ L = len(cyc)
+ if L in c:
+ c[L] += 1
+ return [float(c[5]), float(c[6])]
+
+
+def prepare(split):
+ os.makedirs(CACHE, exist_ok=True)
+ fp = os.path.join(CACHE, f"{split}.pt")
+ if os.path.exists(fp):
+ return torch.load(fp, weights_only=False)
+ ds = ZINC(ROOT, subset=True, split=split)
+ out = []
+ for g in ds:
+ out.append({'x': g.x.view(-1).long(), 'edge_index': g.edge_index,
+ 'rwse': rwse(g.edge_index, g.num_nodes),
+ 'y': torch.tensor(c56(g), dtype=torch.float)})
+ torch.save(out, fp)
+ return out
+
+
+class Net(nn.Module):
+ def __init__(self, n_atom, hidden, layers, conv='gin', rni=0, use_rwse=False):
+ super().__init__()
+ self.emb = nn.Embedding(n_atom, hidden)
+ self.rni, self.use_rwse = rni, use_rwse
+ din = hidden + rni + (RWSE_K if use_rwse else 0)
+ self.lin_in = nn.Linear(din, hidden)
+ self.convs, self.bns = nn.ModuleList(), nn.ModuleList()
+ for _ in range(layers):
+ if conv == 'gin':
+ self.convs.append(GINConv(nn.Sequential(
+ nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True))
+ else:
+ self.convs.append(GCNConv(hidden, hidden))
+ self.bns.append(nn.BatchNorm1d(hidden))
+ self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 2))
+
+ def forward(self, x, edge_index, batch, rwse=None):
+ h = self.emb(x)
+ parts = [h]
+ if self.use_rwse:
+ parts.append(rwse)
+ if self.rni:
+ parts.append(torch.randn(h.size(0), self.rni, device=h.device))
+ h = self.lin_in(torch.cat(parts, dim=1))
+ for conv, bn in zip(self.convs, self.bns):
+ h = bn(conv(h, edge_index)).relu()
+ return self.head(global_add_pool(h, batch))
+
+
+def to_loader(recs, bs, shuffle, drop_last=False):
+ data = [Data(x=r['x'], edge_index=r['edge_index'], rwse=r['rwse'],
+ y=r['y'].view(1, 2), num_nodes=r['x'].numel()) for r in recs]
+ return DataLoader(data, batch_size=bs, shuffle=shuffle, drop_last=drop_last)
+
+
+@torch.no_grad()
+def eval_mae(model, loader, dev, ymu, ysd, samples=1):
+ model.eval(); abs_err = torch.zeros(2); n = 0
+ for b in loader:
+ b = b.to(dev)
+ ps = torch.stack([model(b.x, b.edge_index, b.batch, b.rwse) for _ in range(samples)]).mean(0)
+ pr = ps * ysd.to(dev) + ymu.to(dev) # un-standardize -> raw ring units
+ yr = b.y * ysd.to(dev) + ymu.to(dev)
+ abs_err += (pr - yr).abs().sum(0).cpu(); n += b.num_graphs
+ return (abs_err / n).tolist()
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--conv', choices=['gin', 'gcn'], default='gin')
+ ap.add_argument('--feat', choices=['none', 'rni', 'rwse'], default='none')
+ ap.add_argument('--layers', type=int, default=5)
+ ap.add_argument('--hidden', type=int, default=128)
+ ap.add_argument('--rni_dim', type=int, default=16)
+ ap.add_argument('--epochs', type=int, default=200)
+ ap.add_argument('--lr', type=float, default=1e-3)
+ ap.add_argument('--bs', type=int, default=128)
+ ap.add_argument('--seed', type=int, default=0)
+ args = ap.parse_args()
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ os.makedirs(OUT, exist_ok=True)
+
+ tr, va, te = prepare('train'), prepare('val'), prepare('test')
+ n_atom = int(max(r['x'].max() for r in tr + va + te)) + 1
+ Ytr = torch.stack([r['y'] for r in tr])
+ ymu, ysd = Ytr.mean(0), Ytr.std(0) + 1e-8
+ for recs in (tr, va, te):
+ for r in recs:
+ r['y'] = (r['y'] - ymu) / ysd
+
+ rni = args.rni_dim if args.feat == 'rni' else 0
+ use_rwse = args.feat == 'rwse'
+ samples = 8 if rni else 1
+ model = Net(n_atom, args.hidden, args.layers, args.conv, rni, use_rwse).to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs)
+ lossf = nn.L1Loss()
+ trl = to_loader(tr, args.bs, True, drop_last=True)
+ trl_e, val, tel = to_loader(tr, 256, False), to_loader(va, 256, False), to_loader(te, 256, False)
+
+ t0 = time.time(); best_val = 9e9; best = {}
+ for ep in range(args.epochs):
+ model.train()
+ for b in trl:
+ b = b.to(dev); opt.zero_grad()
+ loss = lossf(model(b.x, b.edge_index, b.batch, b.rwse), b.y)
+ loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
+ sched.step()
+ if (ep + 1) % 20 == 0 or ep == args.epochs - 1:
+ vm = eval_mae(model, val, dev, ymu, ysd, samples)
+ if sum(vm) < best_val:
+ best_val = sum(vm)
+ best = {'ep': ep + 1, 'train_mae': eval_mae(model, trl_e, dev, ymu, ysd, samples),
+ 'val_mae': vm, 'test_mae': eval_mae(model, tel, dev, ymu, ysd, samples)}
+ print(f"ep{ep+1} val_mae(c5,c6)={[round(x,3) for x in vm]}", flush=True)
+
+ tag = f"{args.conv}_{args.feat}_L{args.layers}_s{args.seed}"
+ rep = {'dataset': 'ZINC-cycle56', 'tag': tag, **vars(args),
+ 'y_std_raw': ysd.tolist(), 'sec': round(time.time() - t0, 1), 'dev': dev, **best}
+ tm = best.get('test_mae'); trm = best.get('train_mae')
+ print(f"[{tag}] train_mae(c5,c6)={[round(x,3) for x in trm]} test_mae={[round(x,3) for x in tm]} "
+ f"(raw rings; std={ [round(x,2) for x in ysd.tolist()] }) @ep{best.get('ep')} ({rep['sec']}s)")
+ with open(os.path.join(OUT, f"cyc_{tag}.json"), 'w') as f:
+ json.dump(rep, f, indent=2)
+ print(" wrote", os.path.join(OUT, f"cyc_{tag}.json"))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diag/train_rec.py b/diag/train_rec.py
new file mode 100644
index 0000000..9db7eb1
--- /dev/null
+++ b/diag/train_rec.py
@@ -0,0 +1,491 @@
+"""Step-2: RRoG/TRM-on-GNN for ZINC ring-counting.
+
+The graph is encoded once with a GIN encoder. A shared edge-free node-wise compute block then
+refines hidden state over n_sup*T recurrent steps (TRM-style: carry latent detached between
+deep-supervision steps). --grad_mode controls the LAST supervision step's recursion:
+ full : backprop through all T inner recursions (TRM)
+ 1step : backprop only the last inner recursion, first T-1 detached (HRM 1-step-gradient)
+
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_rec.py --grad_mode full --sigma 0 --K 1
+"""
+import argparse, json, os, time
+import numpy as np
+import torch
+import torch.nn as nn
+from torch_geometric.loader import DataLoader
+from torch_geometric.data import Batch, Data
+from torch_geometric.nn import (
+ APPNP,
+ ARMAConv,
+ ChebConv,
+ FiLMConv,
+ GATv2Conv,
+ GCNConv,
+ GENConv,
+ GINEConv,
+ GINConv,
+ GraphConv,
+ MFConv,
+ PNAConv,
+ ResGatedGraphConv,
+ SAGEConv,
+ SGConv,
+ TAGConv,
+ TransformerConv,
+ global_add_pool,
+)
+from torch_geometric.utils import degree
+from diag.train_cycle import prepare
+
+PROJECT_ROOT = os.environ.get(
+ 'RROG_ROOT',
+ os.path.abspath(os.path.join(os.path.dirname(__file__), '..')),
+)
+OUT = os.environ.get('RROG_RUNS_DIR', os.path.join(PROJECT_ROOT, 'runs'))
+SUPPORTED_VIEWS = [
+ 'gin', 'gine', 'gcn', 'graphsage', 'gatv2', 'graphconv', 'transformer', 'pna',
+ 'gen', 'film', 'resgated', 'tag', 'sgc', 'cheb', 'arma', 'mf', 'appnp',
+]
+
+
+def data_list(recs):
+ return [Data(x=r['x'], edge_index=r['edge_index'], y=r['y'].view(1, 2),
+ num_nodes=r['x'].numel()) for r in recs]
+
+
+def loader(recs, bs, shuffle, drop_last=False):
+ data = recs if recs and isinstance(recs[0], Data) else data_list(recs)
+ return DataLoader(data, batch_size=bs, shuffle=shuffle, drop_last=drop_last)
+
+
+def degree_histogram(data):
+ max_degree = 0
+ degs = []
+ for graph in data:
+ deg = degree(graph.edge_index[1], num_nodes=graph.num_nodes, dtype=torch.long)
+ degs.append(deg)
+ if deg.numel():
+ max_degree = max(max_degree, int(deg.max().item()))
+ hist = torch.zeros(max_degree + 1, dtype=torch.long)
+ for deg in degs:
+ hist += torch.bincount(deg, minlength=hist.numel())
+ return hist
+
+
+def make_view_layer(view, hidden, deg):
+ if view == 'gin':
+ return GINConv(nn.Sequential(
+ nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True)
+ if view == 'gine':
+ return GINEConv(nn.Sequential(
+ nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)),
+ train_eps=True, edge_dim=hidden)
+ if view == 'gcn':
+ return GCNConv(hidden, hidden)
+ if view == 'graphsage':
+ return SAGEConv(hidden, hidden)
+ if view == 'gatv2':
+ return GATv2Conv(hidden, hidden, heads=4, concat=False)
+ if view == 'graphconv':
+ return GraphConv(hidden, hidden)
+ if view == 'transformer':
+ return TransformerConv(hidden, hidden, heads=4, concat=False)
+ if view == 'pna':
+ if deg is None:
+ raise ValueError('PNA view requires a training-set degree histogram')
+ return PNAConv(
+ hidden, hidden,
+ aggregators=['mean', 'min', 'max', 'std'],
+ scalers=['identity', 'amplification', 'attenuation'],
+ deg=deg,
+ )
+ if view == 'gen':
+ return GENConv(hidden, hidden)
+ if view == 'film':
+ return FiLMConv(hidden, hidden)
+ if view == 'resgated':
+ return ResGatedGraphConv(hidden, hidden)
+ if view == 'tag':
+ return TAGConv(hidden, hidden, K=3)
+ if view == 'sgc':
+ return SGConv(hidden, hidden, K=2, cached=False)
+ if view == 'cheb':
+ return ChebConv(hidden, hidden, K=3)
+ if view == 'arma':
+ return ARMAConv(hidden, hidden, num_stacks=1, num_layers=2)
+ if view == 'mf':
+ return MFConv(hidden, hidden)
+ if view == 'appnp':
+ return APPNP(K=5, alpha=0.1)
+ raise ValueError(f'unsupported view: {view}')
+
+
+class RecGIN(nn.Module):
+ def __init__(self, n_atom, hidden=128, T=3, n_sup=3, sigma=0.0, inner=2,
+ grad_mode='full', agg_layers=5, compute_layers=None, view='gin', deg=None):
+ super().__init__()
+ self.view = view
+ self.agg_layers = agg_layers
+ self.compute_layers = compute_layers or inner
+ self.emb = nn.Embedding(n_atom, hidden)
+ self.edge_emb = nn.Embedding(1, hidden) if view == 'gine' else None
+ self.agg_convs = nn.ModuleList()
+ for _ in range(agg_layers):
+ self.agg_convs.append(make_view_layer(view, hidden, deg))
+ self.agg_bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(agg_layers)])
+ core = []
+ d = hidden
+ for _ in range(self.compute_layers - 1):
+ core += [nn.Linear(d, hidden), nn.GELU()]
+ d = hidden
+ core.append(nn.Linear(d, hidden))
+ self.core_norm = nn.LayerNorm(hidden)
+ self.core = nn.Sequential(*core)
+ nn.init.zeros_(self.core[-1].weight)
+ nn.init.zeros_(self.core[-1].bias)
+ self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 2))
+ self.qhead = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 1))
+ with torch.no_grad():
+ self.qhead[-1].weight.zero_()
+ self.qhead[-1].bias.fill_(-5.0)
+ self.T, self.n_sup, self.sigma, self.grad_mode = T, n_sup, sigma, grad_mode
+
+ def aggregate(self, x, ei):
+ h = self.emb(x)
+ for conv, bn in zip(self.agg_convs, self.agg_bns):
+ if self.view == 'gine':
+ edge_attr = self.edge_emb(torch.zeros(ei.size(1), dtype=torch.long, device=ei.device))
+ h = bn(conv(h, ei, edge_attr)).relu()
+ else:
+ h = bn(conv(h, ei)).relu()
+ return h
+
+ def core_step(self, combined, state):
+ """Shared TRM compute core. Deliberately edge-free."""
+ return state + self.core(self.core_norm(combined))
+
+ def _z_step(self, y, z, ctx, noise):
+ z = self.core_step(ctx + y + z, z)
+ if noise and self.sigma > 0:
+ z = z + self.sigma * torch.randn_like(z)
+ return z
+
+ def _y_step(self, y, z, noise):
+ y = self.core_step(y + z, y)
+ if noise and self.sigma > 0:
+ y = y + self.sigma * torch.randn_like(y)
+ return y
+
+ def recurse(self, y, z, ctx, noise, one_step=False):
+ if self.T == 0:
+ return y, z
+ if one_step: # HRM 1-step gradient
+ with torch.no_grad():
+ for _ in range(self.T - 1):
+ z = self._z_step(y, z, ctx, noise)
+ z = z.detach()
+ z = self._z_step(y, z, ctx, noise) # only last inner carries grad
+ y = self._y_step(y, z, noise)
+ return y, z
+ for _ in range(self.T): # TRM full recursion
+ z = self._z_step(y, z, ctx, noise)
+ y = self._y_step(y, z, noise)
+ return y, z
+
+ def predict(self, y, batch):
+ pooled = global_add_pool(y, batch)
+ return self.head(pooled), self.qhead(pooled).view(-1)
+
+ def forward_trace(self, x, ei, batch, steps, noise=False):
+ ctx = self.aggregate(x, ei)
+ y = ctx
+ z = torch.zeros_like(ctx)
+ preds, q_logits = [], []
+ for s in range(steps):
+ y, z = self.recurse(y, z, ctx, noise, one_step=(self.grad_mode == '1step'))
+ pred, q = self.predict(y, batch)
+ preds.append(pred)
+ q_logits.append(q)
+ if s < steps - 1:
+ y, z = y.detach(), z.detach()
+ return preds, q_logits
+
+ def forward(self, x, ei, batch, noise=False):
+ ctx = self.aggregate(x, ei)
+ y = ctx
+ z = torch.zeros_like(ctx)
+ preds = []
+ for s in range(self.n_sup):
+ if s < self.n_sup - 1:
+ with torch.no_grad():
+ y, z = self.recurse(y, z, ctx, noise)
+ y, z = y.detach(), z.detach()
+ else:
+ y, z = self.recurse(y, z, ctx, noise, one_step=(self.grad_mode == '1step'))
+ pred, _ = self.predict(y, batch)
+ preds.append(pred)
+ _, q = self.predict(y, batch)
+ return preds, q
+
+
+@torch.no_grad()
+def evaluate(model, ld, dev, ymu, ysd, K=1, select='none'):
+ model.eval()
+ ysd_d, ymu_d = ysd.to(dev), ymu.to(dev)
+ ae = torch.zeros(2); ae_or = torch.zeros(2); n = 0
+ for b in ld:
+ b = b.to(dev)
+ if K == 1:
+ preds, _ = model(b.x, b.edge_index, b.batch, noise=model.sigma > 0)
+ chosen = oracle = preds[-1]
+ else:
+ P, Q = [], []
+ for _ in range(K):
+ preds, q = model(b.x, b.edge_index, b.batch, noise=True)
+ P.append(preds[-1]); Q.append(q)
+ P = torch.stack(P); Q = torch.stack(Q)
+ ar = torch.arange(P.size(1), device=dev)
+ chosen = P[Q.argmax(0), ar] if select == 'bestq' else P.mean(0)
+ oracle = P[(P - b.y.unsqueeze(0)).abs().sum(-1).argmin(0), ar]
+ ae += ((chosen * ysd_d + ymu_d) - (b.y * ysd_d + ymu_d)).abs().sum(0).cpu()
+ ae_or += ((oracle * ysd_d + ymu_d) - (b.y * ysd_d + ymu_d)).abs().sum(0).cpu()
+ n += b.num_graphs
+ return (ae / n).tolist(), (ae_or / n).tolist()
+
+
+@torch.no_grad()
+def evaluate_trace(model, ld, dev, ymu, ysd, steps, adaptive=False):
+ model.eval()
+ ysd_d, ymu_d = ysd.to(dev), ymu.to(dev)
+ ae = torch.zeros(2)
+ n = 0
+ step_sum = 0.0
+ for b in ld:
+ b = b.to(dev)
+ preds, q_logits = model.forward_trace(b.x, b.edge_index, b.batch, steps, noise=False)
+ P = torch.stack(preds, dim=0)
+ if adaptive:
+ Q = torch.stack(q_logits, dim=0)
+ halted = Q > 0
+ any_halt = halted.any(dim=0)
+ first_halt = halted.to(torch.int64).argmax(dim=0)
+ fallback = torch.full_like(first_halt, steps - 1)
+ idx = torch.where(any_halt, first_halt, fallback)
+ chosen = P[idx, torch.arange(P.size(1), device=dev)]
+ step_sum += (idx.to(torch.float32) + 1).sum().item()
+ else:
+ chosen = P[-1]
+ step_sum += steps * b.num_graphs
+ ae += ((chosen * ysd_d + ymu_d) - (b.y * ysd_d + ymu_d)).abs().sum(0).cpu()
+ n += b.num_graphs
+ return (ae / n).tolist(), step_sum / max(n, 1)
+
+
+def _split_nodes(t, ptr):
+ return [t[ptr[i].item():ptr[i + 1].item()].detach() for i in range(ptr.numel() - 1)]
+
+
+def act_train_step(model, state, replacement_batch, opt, dev, args):
+ replacement = replacement_batch.to_data_list()
+ batch_size = len(replacement)
+ if state is None:
+ state = {
+ 'graphs': [None for _ in range(batch_size)],
+ 'y': [None for _ in range(batch_size)],
+ 'z': [None for _ in range(batch_size)],
+ 'steps': torch.zeros(batch_size, dtype=torch.long, device=dev),
+ 'halted': torch.ones(batch_size, dtype=torch.bool, device=dev),
+ }
+
+ halted_cpu = state['halted'].detach().cpu().tolist()
+ for i, halted in enumerate(halted_cpu):
+ if halted:
+ state['graphs'][i] = replacement[i]
+
+ b = Batch.from_data_list(state['graphs']).to(dev)
+ ctx = model.aggregate(b.x, b.edge_index)
+ ptr = b.ptr
+ y_parts, z_parts = [], []
+ for i in range(batch_size):
+ start, end = ptr[i].item(), ptr[i + 1].item()
+ if halted_cpu[i] or state['y'][i] is None:
+ y_parts.append(ctx[start:end])
+ z_parts.append(torch.zeros_like(ctx[start:end]))
+ else:
+ y_parts.append(state['y'][i].to(dev))
+ z_parts.append(state['z'][i].to(dev))
+ y = torch.cat(y_parts, dim=0)
+ z = torch.cat(z_parts, dim=0)
+
+ opt.zero_grad()
+ y, z = model.recurse(y, z, ctx, noise=False, one_step=(model.grad_mode == '1step'))
+ pred, q = model.predict(y, b.batch)
+ per_graph_err = (pred - b.y).abs().mean(1)
+ pred_loss = per_graph_err.mean()
+ with torch.no_grad():
+ if args.halt_target == 'binary':
+ halt_target = (per_graph_err <= args.halt_norm_threshold).to(q.dtype)
+ else:
+ halt_target = torch.sigmoid((args.halt_norm_threshold - per_graph_err) / args.halt_temp)
+ q_loss = nn.functional.binary_cross_entropy_with_logits(q, halt_target)
+ loss = pred_loss + 0.5 * args.lam_q * q_loss
+ y_det, z_det = y.detach(), z.detach()
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
+ opt.step()
+
+ state['y'] = _split_nodes(y_det, ptr)
+ state['z'] = _split_nodes(z_det, ptr)
+ with torch.no_grad():
+ was_halted = state['halted']
+ steps = torch.where(was_halted, torch.zeros_like(state['steps']), state['steps']) + 1
+ halted = (steps >= args.halt_max_steps) | (q.detach() > 0)
+ if args.halt_exploration_prob > 0 and args.halt_max_steps > 1:
+ explore = torch.rand_like(q) < args.halt_exploration_prob
+ min_steps = torch.where(
+ explore,
+ torch.randint(2, args.halt_max_steps + 1, steps.shape, device=dev),
+ torch.zeros_like(steps),
+ )
+ halted = halted & (steps >= min_steps)
+ state['steps'] = steps
+ state['halted'] = halted
+
+ return state, {
+ 'loss': float(loss.detach().cpu()),
+ 'pred_loss': float(pred_loss.detach().cpu()),
+ 'q_loss': float(q_loss.detach().cpu()),
+ 'halted_frac': float(state['halted'].to(torch.float32).mean().detach().cpu()),
+ 'steps': float(state['steps'].to(torch.float32).mean().detach().cpu()),
+ }
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--grad_mode', choices=['full', '1step'], default='full')
+ ap.add_argument('--sigma', type=float, default=0.0)
+ ap.add_argument('--K', type=int, default=1)
+ ap.add_argument('--select', choices=['none', 'bestq'], default='bestq')
+ ap.add_argument('--T', type=int, default=3)
+ ap.add_argument('--n_sup', type=int, default=3)
+ ap.add_argument('--hidden', type=int, default=128)
+ ap.add_argument('--agg_layers', type=int, default=5)
+ ap.add_argument('--compute_layers', type=int, default=2)
+ ap.add_argument('--view', choices=SUPPORTED_VIEWS, default='gin')
+ ap.add_argument('--epochs', type=int, default=200)
+ ap.add_argument('--lr', type=float, default=1e-3)
+ ap.add_argument('--bs', type=int, default=128)
+ ap.add_argument('--lam_q', type=float, default=1.0)
+ ap.add_argument('--act', action='store_true',
+ help='train all recurrent depths up to halt_max_steps and train qhead as a halt head')
+ ap.add_argument('--halt_max_steps', type=int, default=8)
+ ap.add_argument('--halt_norm_threshold', type=float, default=0.30)
+ ap.add_argument('--halt_temp', type=float, default=0.10)
+ ap.add_argument('--halt_target', choices=['soft', 'binary'], default='soft')
+ ap.add_argument('--halt_exploration_prob', type=float, default=0.1)
+ ap.add_argument('--loss_mode', choices=['last', 'trace'], default='trace')
+ ap.add_argument('--seed', type=int, default=0)
+ ap.add_argument('--device', default='auto')
+ args = ap.parse_args()
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ dev = 'cuda' if args.device == 'auto' and torch.cuda.is_available() else (
+ 'cpu' if args.device == 'auto' else args.device)
+ os.makedirs(OUT, exist_ok=True)
+
+ tr, va, te = prepare('train'), prepare('val'), prepare('test')
+ n_atom = int(max(r['x'].max() for r in tr + va + te)) + 1
+ Ytr = torch.stack([r['y'] for r in tr]); ymu, ysd = Ytr.mean(0), Ytr.std(0) + 1e-8
+ for recs in (tr, va, te):
+ for r in recs:
+ r['y'] = (r['y'] - ymu) / ysd
+ train_data = data_list(tr)
+ trl = loader(train_data, args.bs, True, drop_last=True)
+ val, tel = loader(va, 256, False), loader(te, 256, False)
+
+ deg = degree_histogram(train_data) if args.view == 'pna' else None
+ model = RecGIN(n_atom, args.hidden, args.T, args.n_sup, args.sigma, grad_mode=args.grad_mode,
+ agg_layers=args.agg_layers, compute_layers=args.compute_layers,
+ view=args.view, deg=deg).to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs)
+ l1 = nn.L1Loss()
+ act_steps = max(1, args.halt_max_steps)
+
+ t0 = time.time(); best_val = 9e9; best = {}; best_state = None; act_state = None
+ for ep in range(args.epochs):
+ model.train()
+ act_metrics = []
+ for b in trl:
+ if args.act:
+ act_state, metrics = act_train_step(model, act_state, b, opt, dev, args)
+ act_metrics.append(metrics)
+ else:
+ b = b.to(dev); opt.zero_grad()
+ if args.loss_mode == 'trace':
+ preds, q_logits = model.forward_trace(
+ b.x, b.edge_index, b.batch, args.n_sup, noise=model.sigma > 0)
+ q = q_logits[-1]
+ else:
+ preds, q = model(b.x, b.edge_index, b.batch, noise=model.sigma > 0)
+ loss = sum(l1(p, b.y) for p in preds) / len(preds)
+ with torch.no_grad():
+ tq = -(preds[-1] - b.y).abs().mean(1)
+ loss = loss + args.lam_q * nn.functional.mse_loss(q, tq)
+ loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
+ sched.step()
+ if (ep + 1) % 20 == 0 or ep == args.epochs - 1:
+ if args.act:
+ vm, _ = evaluate_trace(model, val, dev, ymu, ysd, act_steps, adaptive=False)
+ else:
+ vm, _ = evaluate(model, val, dev, ymu, ysd, args.K, args.select)
+ if sum(vm) < best_val:
+ best_val = sum(vm)
+ if args.act:
+ tem, fixed_steps = evaluate_trace(model, tel, dev, ymu, ysd, act_steps, adaptive=False)
+ tea, adaptive_steps = evaluate_trace(model, tel, dev, ymu, ysd, act_steps, adaptive=True)
+ best = {'ep': ep + 1, 'val_mae': vm, 'test_mae': tem,
+ 'test_mae_adaptive': tea, 'fixed_steps': fixed_steps,
+ 'adaptive_steps': adaptive_steps}
+ else:
+ tem, teo = evaluate(model, tel, dev, ymu, ysd, args.K, args.select)
+ best = {'ep': ep + 1, 'val_mae': vm, 'test_mae': tem, 'test_mae_oracle': teo}
+ best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
+ if args.act and act_metrics:
+ hm = sum(m['halted_frac'] for m in act_metrics) / len(act_metrics)
+ sm = sum(m['steps'] for m in act_metrics) / len(act_metrics)
+ print(f"ep{ep+1} val_mae={[round(x,3) for x in vm]} halt={hm:.2f} train_steps={sm:.2f}", flush=True)
+ else:
+ print(f"ep{ep+1} val_mae={[round(x,3) for x in vm]}", flush=True)
+
+ act_tag = f"_actfull{act_steps}_{args.halt_target}{args.halt_norm_threshold:g}_e{args.epochs}" if args.act else ""
+ loss_tag = f"_{args.loss_mode}" if (not args.act and args.loss_mode != 'last') else ""
+ view_tag = f"_{args.view}" if args.view != 'gin' else ""
+ tag = f"rec_rrog{view_tag}_{args.grad_mode}_sig{args.sigma}_K{args.K}_{args.select}_T{args.T}_ns{args.n_sup}{loss_tag}{act_tag}_s{args.seed}"
+ rep = {'dataset': 'ZINC-cycle56', 'tag': tag, **vars(args), 'sec': round(time.time() - t0, 1),
+ 'dev': dev, 'arch': 'rrog_once_agg_node_compute', 'y_std_raw': ysd.tolist(), **best}
+ if args.act:
+ print(f"[{tag}] test_mae={[round(x,3) for x in best.get('test_mae')]} "
+ f"adaptive={[round(x,3) for x in best.get('test_mae_adaptive')]} "
+ f"steps={best.get('adaptive_steps'):.2f}/{best.get('fixed_steps'):.2f} "
+ f"@ep{best.get('ep')} ({rep['sec']}s)")
+ else:
+ print(f"[{tag}] test_mae={[round(x,3) for x in best.get('test_mae')]} "
+ f"oracle@K={[round(x,3) for x in best.get('test_mae_oracle')]} @ep{best.get('ep')} ({rep['sec']}s)")
+ with open(os.path.join(OUT, f"{tag}.json"), 'w') as f:
+ json.dump(rep, f, indent=2)
+ torch.save({'state': best_state or model.state_dict(),
+ 'cfg': {'n_atom': n_atom, 'hidden': args.hidden, 'T': args.T, 'n_sup': args.n_sup,
+ 'sigma': args.sigma, 'grad_mode': args.grad_mode,
+ 'agg_layers': args.agg_layers, 'compute_layers': args.compute_layers,
+ 'view': args.view,
+ 'loss_mode': args.loss_mode,
+ 'act': args.act, 'act_impl': 'persistent_recycle' if args.act else 'none',
+ 'halt_max_steps': act_steps,
+ 'halt_exploration_prob': args.halt_exploration_prob,
+ 'arch': 'rrog_once_agg_node_compute'},
+ 'ymu': ymu, 'ysd': ysd}, os.path.join(OUT, f"ckpt_{tag}.pt"))
+ print(" wrote", os.path.join(OUT, f"ckpt_{tag}.pt"))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..c0f2bc7
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,5 @@
+numpy
+networkx
+ogb
+torch-geometric
+tqdm
diff --git a/rrog/__init__.py b/rrog/__init__.py
new file mode 100644
index 0000000..c09281d
--- /dev/null
+++ b/rrog/__init__.py
@@ -0,0 +1,2 @@
+"""Experiment registry for RRoG/TRM-on-GNN sweeps."""
+
diff --git a/rrog/backbones.py b/rrog/backbones.py
new file mode 100644
index 0000000..81bfb9b
--- /dev/null
+++ b/rrog/backbones.py
@@ -0,0 +1,72 @@
+from rrog.registry import ComputeSpec, ModifierSpec, ViewSpec, by_name
+
+
+VIEWS = [
+ ViewSpec("gin", "message-passing", "2d", 1, "implemented",
+ "Plain GINConv message passing."),
+ ViewSpec("gine", "message-passing", "2d", 2, "implemented",
+ "Edge-aware GIN variant. ZINC uses a learned constant edge token; OGB uses bond features."),
+ ViewSpec("gcn", "message-passing", "2d", 3, "implemented"),
+ ViewSpec("graphsage", "message-passing", "2d", 4, "implemented"),
+ ViewSpec("gatv2", "attention-mpnn", "2d", 5, "implemented"),
+ ViewSpec("graphconv", "message-passing", "2d", 6, "implemented"),
+ ViewSpec("transformer", "attention-mpnn", "2d", 7, "implemented"),
+ ViewSpec("pna", "message-passing", "2d", 8, "implemented",
+ "Requires degree histogram from the train split."),
+ ViewSpec("gen", "message-passing", "2d", 9, "implemented"),
+ ViewSpec("film", "message-passing", "2d", 10, "implemented"),
+ ViewSpec("resgated", "message-passing", "2d", 11, "implemented"),
+ ViewSpec("tag", "higher-order-hop", "2d", 12, "implemented"),
+ ViewSpec("sgc", "propagation", "2d", 13, "implemented"),
+ ViewSpec("cheb", "spectral", "2d", 14, "implemented"),
+ ViewSpec("arma", "spectral", "2d", 15, "implemented"),
+ ViewSpec("mf", "message-passing", "2d", 16, "implemented"),
+ ViewSpec("appnp", "propagation", "2d", 17, "implemented"),
+ ViewSpec("mixhop", "higher-order-hop", "2d", 18),
+ ViewSpec("gps", "hybrid-local-global", "2d", 19),
+ ViewSpec("graphormer", "global-attention", "2d", 20),
+ ViewSpec("san", "spectral-attention", "2d", 21),
+ ViewSpec("mpnn", "message-passing", "2d", 22),
+ ViewSpec("schnet", "continuous-filter", "3d", 23),
+ ViewSpec("dimenetpp", "angle-aware", "3d", 24),
+ ViewSpec("painn", "equivariant", "3d", 25),
+ ViewSpec("gemnet", "equivariant", "3d", 26),
+ ViewSpec("egnn", "equivariant", "3d", 27),
+ ViewSpec("equiformer", "equivariant", "3d", 28),
+ ViewSpec("mace", "equivariant", "3d", 29),
+]
+
+
+COMPUTES = [
+ ComputeSpec("classic", "baseline", 1, "implemented", "Standard one-forward GNN baseline; no RRoG compute."),
+ ComputeSpec("view-only", "none", 2, "implemented", "RRoG view module only; no recursive compute."),
+ ComputeSpec("fixed-rrog", "recursive", 3, "implemented", "Fixed-depth edge-free y/z compute."),
+ ComputeSpec("rrog-act", "recursive-act", 4, "implemented", "Persistent full ACT recycling for graph batches."),
+ ComputeSpec("node-mlp", "recursive", 4),
+ ComputeSpec("gru-rrog", "recursive-gated", 5),
+ ComputeSpec("set-attn-core", "edge-free-attention", 6),
+ ComputeSpec("perceiver-core", "latent-attention", 7),
+ ComputeSpec("global-token-mixer", "token-mixer", 8),
+ ComputeSpec("equivariant-core", "3d-equivariant", 9),
+]
+
+
+MODIFIERS = [
+ ModifierSpec("none", "none", 1, "implemented"),
+ ModifierSpec("dfa-gnn", "backward", 2, "planned",
+ "Non-BP/direct-feedback-style training; start on node classification."),
+ ModifierSpec("kaft", "backward", 3, "planned",
+ "User project in ../graph-grape; low priority until main table is stable."),
+ ModifierSpec("deep-supervision", "training", 4),
+ ModifierSpec("sam", "optimizer", 5),
+ ModifierSpec("lap-pe", "feature", 6),
+ ModifierSpec("rwse", "feature", 7),
+ ModifierSpec("virtual-node", "feature", 8),
+ ModifierSpec("dropedge", "regularization", 9),
+ ModifierSpec("flag", "augmentation", 10),
+]
+
+
+VIEW_BY_NAME = by_name(VIEWS)
+COMPUTE_BY_NAME = by_name(COMPUTES)
+MODIFIER_BY_NAME = by_name(MODIFIERS)
diff --git a/rrog/benchmarks.py b/rrog/benchmarks.py
new file mode 100644
index 0000000..dcd356b
--- /dev/null
+++ b/rrog/benchmarks.py
@@ -0,0 +1,44 @@
+from rrog.registry import BenchmarkSpec, by_name
+
+
+BENCHMARKS = [
+ BenchmarkSpec("zinc-cycle56", "A", "molecule-2d", "graph-regression", "raw-mae", 1,
+ "implemented", "ZINC subset with #5/#6 cycle-count targets."),
+ BenchmarkSpec("zinc", "A", "molecule-2d", "graph-regression", "mae", 2),
+ BenchmarkSpec("ogbg-molhiv", "A", "molecule-2d", "graph-classification", "rocauc", 3,
+ "implemented", "OGB graph property prediction."),
+ BenchmarkSpec("ogbg-molpcba", "A", "molecule-2d", "graph-multilabel", "ap", 4,
+ "implemented", "OGB graph property prediction."),
+ BenchmarkSpec("ogbg-molbbbp", "A", "molecule-2d", "graph-classification", "rocauc", 5,
+ "implemented", "OGB graph property prediction."),
+ BenchmarkSpec("ogbg-molbace", "A", "molecule-2d", "graph-classification", "rocauc", 6,
+ "implemented", "OGB graph property prediction."),
+ BenchmarkSpec("ogbg-moltox21", "A", "molecule-2d", "graph-multilabel", "rocauc", 7,
+ "implemented", "OGB graph property prediction."),
+ BenchmarkSpec("ogbg-molclintox", "A", "molecule-2d", "graph-multilabel", "rocauc", 8,
+ "implemented", "OGB graph property prediction."),
+ BenchmarkSpec("ogbg-molsider", "A", "molecule-2d", "graph-multilabel", "rocauc", 9,
+ "implemented", "OGB graph property prediction."),
+ BenchmarkSpec("ogbg-molesol", "A", "molecule-2d", "graph-regression", "rmse", 10,
+ "implemented", "OGB ESOL graph property prediction."),
+ BenchmarkSpec("ogbg-molfreesolv", "A", "molecule-2d", "graph-regression", "rmse", 11,
+ "implemented", "OGB FreeSolv graph property prediction."),
+ BenchmarkSpec("ogbg-mollipo", "A", "molecule-2d", "graph-regression", "rmse", 12,
+ "implemented", "OGB Lipophilicity graph property prediction."),
+ BenchmarkSpec("pcqm4mv2", "A", "molecule-2d", "graph-regression", "mae", 13),
+ BenchmarkSpec("qm9", "A", "molecule-3d", "graph-regression", "mae", 14),
+ BenchmarkSpec("peptides-func", "B", "long-range", "graph-multilabel", "ap", 15),
+ BenchmarkSpec("peptides-struct", "B", "long-range", "graph-regression", "mae", 16),
+ BenchmarkSpec("pcqm-contact", "B", "long-range", "link-prediction", "mrr", 17),
+ BenchmarkSpec("pascalvoc-sp", "B", "superpixel", "node-classification", "f1", 18),
+ BenchmarkSpec("coco-sp", "B", "superpixel", "node-classification", "f1", 19),
+ BenchmarkSpec("ogbn-arxiv", "B", "citation", "node-classification", "accuracy", 20),
+ BenchmarkSpec("ogbn-products", "B", "commerce", "node-classification", "accuracy", 21),
+ BenchmarkSpec("rmd17", "C", "molecule-3d", "energy-force", "mae", 22),
+ BenchmarkSpec("oc20-s2ef", "C", "catalyst-3d", "energy-force", "mae", 23),
+ BenchmarkSpec("oc22", "C", "catalyst-3d", "energy-force", "mae", 24),
+ BenchmarkSpec("matbench-discovery", "C", "materials", "stability", "discovery-metrics", 25),
+ BenchmarkSpec("tgb-subset", "C", "temporal", "temporal-link-node", "dataset-specific", 26),
+]
+
+BENCHMARK_BY_NAME = by_name(BENCHMARKS)
diff --git a/rrog/cli.py b/rrog/cli.py
new file mode 100644
index 0000000..9c826e3
--- /dev/null
+++ b/rrog/cli.py
@@ -0,0 +1,176 @@
+import argparse
+import json
+import os
+import subprocess
+
+from rrog.backbones import COMPUTES, MODIFIERS, VIEWS
+from rrog.benchmarks import BENCHMARKS
+from rrog.collect_results import print_tables
+from rrog.collect_zinc import print_zinc
+from rrog.runspecs import find_run_spec
+
+
+def _rows(items):
+ return [
+ {
+ "name": item.name,
+ "tier": getattr(item, "tier", None),
+ "family": getattr(item, "family", None),
+ "domain": getattr(item, "domain", None),
+ "task_type": getattr(item, "task_type", None),
+ "metric": getattr(item, "metric", None),
+ "priority": item.priority,
+ "status": item.status,
+ "notes": item.notes,
+ }
+ for item in items
+ ]
+
+
+def _print_table(items):
+ for row in _rows(items):
+ cols = [row["name"], row.get("tier") or row.get("family") or "", row["status"], row["notes"]]
+ print("\t".join(str(c) for c in cols))
+
+
+def list_axis(axis: str, as_json: bool):
+ mapping = {
+ "benchmarks": BENCHMARKS,
+ "views": VIEWS,
+ "computes": COMPUTES,
+ "modifiers": MODIFIERS,
+ }
+ items = mapping[axis]
+ if as_json:
+ print(json.dumps(_rows(items), indent=2))
+ else:
+ _print_table(items)
+
+
+def build_command(args) -> list[str]:
+ spec = find_run_spec(args.task, args.view, args.compute, args.modifier)
+ run_args = dict(spec.default_args)
+ for key in [
+ "epochs", "hidden", "bs", "seed", "T", "n_sup", "halt_max_steps", "halt_target",
+ "halt_min_steps", "halt_loss_threshold", "q_warmup_epochs",
+ "eval_every", "max_train_batches", "max_eval_batches", "num_workers", "ema", "lr", "lam_q",
+ "device",
+ ]:
+ value = getattr(args, key)
+ if value is not None:
+ run_args[key] = value
+ run_args["compute"] = args.compute
+ return spec.command_builder(run_args)
+
+
+def print_matrix(args):
+ tasks = [b for b in BENCHMARKS if args.tier == "all" or b.tier == args.tier]
+ tasks = sorted(tasks, key=lambda x: x.priority)[:args.limit_tasks]
+
+ if args.kind == "main":
+ views = sorted(VIEWS, key=lambda x: x.priority)[:args.limit_views]
+ computes = [c for c in COMPUTES if c.name in ["classic", "view-only", "fixed-rrog", "rrog-act"]]
+ for task in tasks:
+ for view in views:
+ for compute in computes:
+ try:
+ find_run_spec(task.name, view.name, compute.name)
+ status = "implemented"
+ except KeyError:
+ status = "planned"
+ print(f"{task.name}\t{view.name}\t{compute.name}\tnone\t{status}")
+ return
+
+ if args.kind == "modifier":
+ task_names = {"zinc-cycle56", "ogbg-molhiv", "peptides-struct", "peptides-func", "ogbn-arxiv", "qm9"}
+ mods = [m for m in sorted(MODIFIERS, key=lambda x: x.priority) if m.name != "none"]
+ for task in tasks:
+ if task.name not in task_names:
+ continue
+ for mod in mods:
+ print(f"{task.name}\tgin\tview-only\t{mod.name}\tplanned")
+ return
+
+ raise ValueError(args.kind)
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ sub = ap.add_subparsers(dest="cmd", required=True)
+
+ lp = sub.add_parser("list")
+ lp.add_argument("axis", choices=["benchmarks", "views", "computes", "modifiers"])
+ lp.add_argument("--json", action="store_true")
+
+ rp = sub.add_parser("run")
+ rp.add_argument("--task", default="zinc-cycle56")
+ rp.add_argument("--view", default="gin")
+ rp.add_argument("--compute", default="rrog-act")
+ rp.add_argument("--modifier", default="none")
+ rp.add_argument("--epochs", type=int)
+ rp.add_argument("--hidden", type=int)
+ rp.add_argument("--bs", type=int)
+ rp.add_argument("--seed", type=int)
+ rp.add_argument("--T", type=int)
+ rp.add_argument("--n_sup", type=int)
+ rp.add_argument("--halt_max_steps", type=int)
+ rp.add_argument("--halt_target", choices=["soft", "binary", "exact", "loss"])
+ rp.add_argument("--halt_min_steps", type=int)
+ rp.add_argument("--halt_loss_threshold", type=float)
+ rp.add_argument("--q_warmup_epochs", type=int)
+ rp.add_argument("--eval_every", type=int)
+ rp.add_argument("--max_train_batches", type=int)
+ rp.add_argument("--max_eval_batches", type=int)
+ rp.add_argument("--num_workers", type=int)
+ rp.add_argument("--ema", type=float)
+ rp.add_argument("--lr", type=float)
+ rp.add_argument("--lam_q", type=float)
+ rp.add_argument("--device")
+ rp.add_argument("--dry-run", action="store_true")
+
+ mp = sub.add_parser("matrix")
+ mp.add_argument("--kind", choices=["main", "modifier"], default="main")
+ mp.add_argument("--tier", choices=["A", "B", "C", "all"], default="A")
+ mp.add_argument("--limit-tasks", type=int, default=20)
+ mp.add_argument("--limit-views", type=int, default=6)
+
+ op = sub.add_parser("results")
+ op.add_argument("--runs-dir", default="runs")
+ op.add_argument("--glob", default="*.json")
+ op.add_argument("--min-epochs", type=int, default=10)
+ op.add_argument("--epochs", type=int)
+ op.add_argument("--digits", type=int, default=4)
+
+ zp = sub.add_parser("zinc-results")
+ zp.add_argument("--runs-dir", default="runs")
+ zp.add_argument("--epochs", type=int)
+ zp.add_argument("--min-epochs", type=int, default=10)
+ zp.add_argument("--digits", type=int, default=4)
+
+ args = ap.parse_args()
+ if args.cmd == "list":
+ list_axis(args.axis, args.json)
+ return
+ if args.cmd == "matrix":
+ print_matrix(args)
+ return
+ if args.cmd == "results":
+ print_tables(args)
+ return
+ if args.cmd == "zinc-results":
+ print_zinc(args)
+ return
+
+ cmd = build_command(args)
+ env = dict(os.environ)
+ env["PYTHONPATH"] = os.getcwd() + os.pathsep + env.get("PYTHONPATH", "")
+ print(" ".join(cmd), flush=True)
+ if not args.dry_run:
+ raise SystemExit(subprocess.call(cmd, env=env))
+
+
+if __name__ == "__main__":
+ try:
+ main()
+ except BrokenPipeError:
+ raise SystemExit(0)
diff --git a/rrog/collect_results.py b/rrog/collect_results.py
new file mode 100644
index 0000000..125c0ab
--- /dev/null
+++ b/rrog/collect_results.py
@@ -0,0 +1,239 @@
+import argparse
+import json
+import math
+from collections import defaultdict
+from pathlib import Path
+
+
+HIGHER_BETTER = {
+ "accuracy",
+ "ap",
+ "auc",
+ "f1",
+ "mrr",
+ "rocauc",
+}
+
+LOWER_BETTER = {
+ "mae",
+ "raw-mae",
+ "rmse",
+}
+
+
+def _metric_direction(metric: str) -> int:
+ metric = metric.lower()
+ if metric in HIGHER_BETTER:
+ return 1
+ if metric in LOWER_BETTER:
+ return -1
+ if "mae" in metric or "rmse" in metric:
+ return -1
+ return 1
+
+
+def _read(path: Path) -> dict | None:
+ try:
+ with path.open() as f:
+ rep = json.load(f)
+ except (OSError, json.JSONDecodeError):
+ return None
+ required = {"dataset", "view", "compute", "seed", "metric", "val"}
+ if not required.issubset(rep):
+ return None
+ return rep
+
+
+def _score(rep: dict, split: str) -> float | None:
+ metric = rep.get("metric")
+ if not metric:
+ return None
+ value = rep.get(split, {}).get(metric)
+ if value is None:
+ return None
+ return float(value)
+
+
+def _rank_key(rep: dict) -> tuple[int, int, int]:
+ return (
+ int(rep.get("epochs", 0)),
+ int(rep.get("hidden", 0)),
+ int(rep.get("ep", 0) or 0),
+ )
+
+
+def _mean(xs: list[float]) -> float:
+ return sum(xs) / len(xs)
+
+
+def _std(xs: list[float]) -> float:
+ if len(xs) < 2:
+ return 0.0
+ mu = _mean(xs)
+ return math.sqrt(sum((x - mu) ** 2 for x in xs) / (len(xs) - 1))
+
+
+def _fmt(value: float | None, digits: int) -> str:
+ if value is None:
+ return ""
+ return f"{value:.{digits}f}"
+
+
+def _is_classic_baseline(rep: dict) -> bool:
+ return rep.get("compute") == "classic" and int(rep.get("T", -1)) == 0 and int(rep.get("n_sup", -1)) == 1
+
+
+def _compute_label(rep: dict) -> str:
+ label = str(rep["compute"])
+ ema = float(rep.get("ema", 0.0) or 0.0)
+ if ema > 0:
+ label += f"+ema{ema:g}"
+ return label
+
+
+def _choose_runs(paths: list[Path], min_epochs: int, epochs: int | None) -> list[dict]:
+ candidates: dict[tuple[str, str, str, int], dict] = {}
+ for path in paths:
+ rep = _read(path)
+ if rep is None:
+ continue
+ if epochs is not None and int(rep.get("epochs", -1)) != epochs:
+ continue
+ if int(rep.get("epochs", 0)) < min_epochs:
+ continue
+ val = _score(rep, "val")
+ test = _score(rep, "test")
+ if val is None or test is None:
+ continue
+ key = (str(rep["dataset"]), str(rep["view"]), _compute_label(rep), int(rep["seed"]))
+ old = candidates.get(key)
+ if old is None or _rank_key(rep) > _rank_key(old):
+ candidates[key] = rep
+ return list(candidates.values())
+
+
+def _group_by_cell(runs: list[dict]) -> dict[tuple[str, str, str], list[dict]]:
+ grouped: dict[tuple[str, str, str], list[dict]] = defaultdict(list)
+ for rep in runs:
+ grouped[(str(rep["dataset"]), str(rep["view"]), _compute_label(rep))].append(rep)
+ return dict(grouped)
+
+
+def _summarize_cell(reps: list[dict], split: str) -> tuple[float, float, int]:
+ scores = [_score(rep, split) for rep in reps]
+ xs = [x for x in scores if x is not None]
+ return _mean(xs), _std(xs), len(xs)
+
+
+def _markdown_table(headers: list[str], rows: list[list[str]]) -> str:
+ out = ["| " + " | ".join(headers) + " |", "| " + " | ".join(["---"] * len(headers)) + " |"]
+ out.extend("| " + " | ".join(row) + " |" for row in rows)
+ return "\n".join(out)
+
+
+def print_tables(args) -> None:
+ paths = sorted(Path(args.runs_dir).glob(args.glob))
+ runs = _choose_runs(paths, args.min_epochs, args.epochs)
+ grouped = _group_by_cell(runs)
+
+ classic_cells = {
+ (task, view): reps
+ for (task, view, compute), reps in grouped.items()
+ if compute == "classic"
+ for reps in [[rep for rep in reps if _is_classic_baseline(rep)]]
+ if reps
+ }
+
+ baseline_rows = []
+ for (task, view), reps in sorted(classic_cells.items()):
+ metric = str(reps[0]["metric"])
+ val_mu, val_sd, n = _summarize_cell(reps, "val")
+ test_mu, test_sd, _ = _summarize_cell(reps, "test")
+ baseline_rows.append([
+ task,
+ view,
+ metric,
+ str(n),
+ f"{_fmt(val_mu, args.digits)} +/- {_fmt(val_sd, args.digits)}",
+ f"{_fmt(test_mu, args.digits)} +/- {_fmt(test_sd, args.digits)}",
+ ])
+
+ delta_rows = []
+ for (task, view, compute), reps in sorted(grouped.items()):
+ if compute == "classic":
+ continue
+ base = classic_cells.get((task, view))
+ if not base:
+ continue
+ metric = str(reps[0]["metric"])
+ direction = _metric_direction(metric)
+ base_by_seed = {int(rep["seed"]): rep for rep in base}
+ paired = []
+ for rep in reps:
+ seed = int(rep["seed"])
+ if seed in base_by_seed:
+ paired.append((rep, base_by_seed[seed]))
+ if not paired:
+ base_test_mu, _, _ = _summarize_cell(base, "test")
+ base_val_mu, _, _ = _summarize_cell(base, "val")
+ paired = [(rep, None) for rep in reps]
+ else:
+ base_test_mu = None
+ base_val_mu = None
+
+ val_scores, test_scores, val_deltas, test_deltas = [], [], [], []
+ adaptive_steps = []
+ for rep, base_rep in paired:
+ val = _score(rep, "val")
+ test = _score(rep, "test")
+ if val is None or test is None:
+ continue
+ if base_rep is None:
+ base_val = base_val_mu
+ base_test = base_test_mu
+ else:
+ base_val = _score(base_rep, "val")
+ base_test = _score(base_rep, "test")
+ if base_val is None or base_test is None:
+ continue
+ val_scores.append(val)
+ test_scores.append(test)
+ val_deltas.append(direction * (val - base_val))
+ test_deltas.append(direction * (test - base_test))
+ if rep.get("adaptive_steps") is not None:
+ adaptive_steps.append(float(rep["adaptive_steps"]))
+
+ if not test_scores:
+ continue
+ delta_rows.append([
+ task,
+ view,
+ compute,
+ metric,
+ str(len(test_scores)),
+ f"{_fmt(_mean(val_scores), args.digits)} ({_fmt(_mean(val_deltas), args.digits)})",
+ f"{_fmt(_mean(test_scores), args.digits)} ({_fmt(_mean(test_deltas), args.digits)})",
+ _fmt(_mean(adaptive_steps), 2) if adaptive_steps else "",
+ ])
+
+ print("\nClassic baseline: task x backbone")
+ print(_markdown_table(["task", "backbone", "metric", "n", "val", "test"], baseline_rows))
+ print("\nDelta vs matching classic")
+ print(_markdown_table([
+ "task", "backbone", "compute", "metric", "n", "val score (delta)", "test score (delta)", "steps"
+ ], delta_rows))
+
+
+def main() -> None:
+ ap = argparse.ArgumentParser()
+ ap.add_argument("--runs-dir", default="runs")
+ ap.add_argument("--glob", default="*.json")
+ ap.add_argument("--min-epochs", type=int, default=10)
+ ap.add_argument("--epochs", type=int)
+ ap.add_argument("--digits", type=int, default=4)
+ args = ap.parse_args()
+ print_tables(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/rrog/collect_zinc.py b/rrog/collect_zinc.py
new file mode 100644
index 0000000..49ae1e9
--- /dev/null
+++ b/rrog/collect_zinc.py
@@ -0,0 +1,137 @@
+import argparse
+import json
+import math
+from collections import defaultdict
+from pathlib import Path
+
+
+def _read(path: Path) -> dict | None:
+ try:
+ with path.open() as f:
+ rep = json.load(f)
+ except (OSError, json.JSONDecodeError):
+ return None
+ if rep.get("dataset") != "ZINC-cycle56":
+ return None
+ if rep.get("K") != 1 or rep.get("select") != "none" or float(rep.get("sigma", 0.0)) != 0.0:
+ return None
+ if "test_mae" not in rep or "val_mae" not in rep:
+ return None
+ return rep
+
+
+def _compute(rep: dict) -> str:
+ if rep.get("act"):
+ return f"rrog-act-T{rep.get('T')}-ns{rep.get('n_sup')}"
+ if int(rep.get("T", -1)) == 0 and int(rep.get("n_sup", -1)) == 1:
+ return "classic"
+ label = f"fixed-rrog-T{rep.get('T')}-ns{rep.get('n_sup')}"
+ if rep.get("loss_mode") == "trace":
+ label += "+trace"
+ return label
+
+
+def _view(rep: dict) -> str:
+ return str(rep.get("view", "gin"))
+
+
+def _score(rep: dict, split: str) -> float:
+ return float(sum(rep[f"{split}_mae"]))
+
+
+def _mean(xs: list[float]) -> float:
+ return sum(xs) / len(xs)
+
+
+def _std(xs: list[float]) -> float:
+ if len(xs) < 2:
+ return 0.0
+ mu = _mean(xs)
+ return math.sqrt(sum((x - mu) ** 2 for x in xs) / (len(xs) - 1))
+
+
+def _fmt(x: float, digits: int) -> str:
+ return f"{x:.{digits}f}"
+
+
+def _markdown(headers: list[str], rows: list[list[str]]) -> str:
+ lines = ["| " + " | ".join(headers) + " |", "| " + " | ".join(["---"] * len(headers)) + " |"]
+ lines.extend("| " + " | ".join(row) + " |" for row in rows)
+ return "\n".join(lines)
+
+
+def print_zinc(args) -> None:
+ by_cell: dict[tuple[str, str, int], dict] = {}
+ for path in sorted(Path(args.runs_dir).glob("rec_rrog*_sig0.0_K1_none_T*_s*.json")):
+ rep = _read(path)
+ if rep is None:
+ continue
+ if args.epochs is not None and int(rep.get("epochs", -1)) != args.epochs:
+ continue
+ if int(rep.get("epochs", 0)) < args.min_epochs:
+ continue
+ key = (_view(rep), _compute(rep), int(rep["seed"]))
+ old = by_cell.get(key)
+ if old is None or int(rep.get("epochs", 0)) > int(old.get("epochs", 0)):
+ by_cell[key] = rep
+
+ grouped: dict[tuple[str, str], list[dict]] = defaultdict(list)
+ for rep in by_cell.values():
+ grouped[(_view(rep), _compute(rep))].append(rep)
+
+ classic_by_view = {
+ view: reps
+ for (view, compute), reps in grouped.items()
+ if compute == "classic"
+ }
+
+ base_rows = []
+ for view, classic in sorted(classic_by_view.items()):
+ vals = [_score(rep, "val") for rep in classic]
+ tests = [_score(rep, "test") for rep in classic]
+ base_rows.append([
+ "zinc-cycle56",
+ view,
+ str(len(classic)),
+ f"{_fmt(_mean(vals), args.digits)} +/- {_fmt(_std(vals), args.digits)}",
+ f"{_fmt(_mean(tests), args.digits)} +/- {_fmt(_std(tests), args.digits)}",
+ ])
+
+ delta_rows = []
+ for (view, compute), reps in sorted(grouped.items()):
+ if compute == "classic":
+ continue
+ base_by_seed = {int(rep["seed"]): rep for rep in classic_by_view.get(view, [])}
+ paired = [(rep, base_by_seed[int(rep["seed"])]) for rep in reps if int(rep["seed"]) in base_by_seed]
+ if not paired:
+ continue
+ vals = [_score(rep, "val") for rep, _ in paired]
+ tests = [_score(rep, "test") for rep, _ in paired]
+ val_deltas = [_score(base, "val") - _score(rep, "val") for rep, base in paired]
+ test_deltas = [_score(base, "test") - _score(rep, "test") for rep, base in paired]
+ delta_rows.append([
+ "zinc-cycle56",
+ view,
+ compute,
+ str(len(paired)),
+ f"{_fmt(_mean(vals), args.digits)} ({_fmt(_mean(val_deltas), args.digits)})",
+ f"{_fmt(_mean(tests), args.digits)} ({_fmt(_mean(test_deltas), args.digits)})",
+ ])
+
+ print("\nZINC-cycle56 classic baseline")
+ print(_markdown(["task", "backbone", "n", "val MAE-sum", "test MAE-sum"], base_rows))
+ print("\nZINC-cycle56 delta vs matching classic")
+ print(_markdown(["task", "backbone", "compute", "n", "val score (improvement)", "test score (improvement)"], delta_rows))
+
+
+def main() -> None:
+ ap = argparse.ArgumentParser()
+ ap.add_argument("--runs-dir", default="runs")
+ ap.add_argument("--epochs", type=int)
+ ap.add_argument("--min-epochs", type=int, default=10)
+ ap.add_argument("--digits", type=int, default=4)
+ print_zinc(ap.parse_args())
+
+
+if __name__ == "__main__":
+ main()
diff --git a/rrog/registry.py b/rrog/registry.py
new file mode 100644
index 0000000..7f6b6fe
--- /dev/null
+++ b/rrog/registry.py
@@ -0,0 +1,57 @@
+from dataclasses import dataclass, field
+from typing import Callable
+
+
+@dataclass(frozen=True)
+class BenchmarkSpec:
+ name: str
+ tier: str
+ domain: str
+ task_type: str
+ metric: str
+ priority: int
+ status: str = "planned"
+ notes: str = ""
+
+
+@dataclass(frozen=True)
+class ViewSpec:
+ name: str
+ family: str
+ graph_type: str
+ priority: int
+ status: str = "planned"
+ notes: str = ""
+
+
+@dataclass(frozen=True)
+class ComputeSpec:
+ name: str
+ family: str
+ priority: int
+ status: str = "planned"
+ notes: str = ""
+
+
+@dataclass(frozen=True)
+class ModifierSpec:
+ name: str
+ family: str
+ priority: int
+ status: str = "planned"
+ notes: str = ""
+
+
+@dataclass(frozen=True)
+class RunSpec:
+ task: str
+ view: str
+ compute: str
+ modifier: str = "none"
+ default_args: dict[str, object] = field(default_factory=dict)
+ command_builder: Callable[[dict[str, object]], list[str]] | None = None
+
+
+def by_name(items):
+ return {item.name: item for item in items}
+
diff --git a/rrog/run_ogb_hiv_remaining.sh b/rrog/run_ogb_hiv_remaining.sh
new file mode 100755
index 0000000..067e736
--- /dev/null
+++ b/rrog/run_ogb_hiv_remaining.sh
@@ -0,0 +1,49 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "${ROOT_DIR}"
+export PYTHONPATH="${ROOT_DIR}:${PYTHONPATH:-}"
+
+DEVICE="${DEVICE:-cuda:3}"
+EPOCHS="${EPOCHS:-100}"
+SEED="${SEED:-0}"
+
+run_if_missing() {
+ local view="$1"
+ local compute="$2"
+ local t="$3"
+ local ns="$4"
+ local out="runs/ogbg-molhiv_${view}_${compute}_T${t}_ns${ns}_h128_e${EPOCHS}_s${SEED}.json"
+ if [[ -f "${out}" ]]; then
+ echo "[skip] ${out}"
+ return
+ fi
+ echo "[run] view=${view} compute=${compute} T=${t} ns=${ns} device=${DEVICE}"
+ python3 -m rrog.cli run \
+ --task ogbg-molhiv \
+ --view "${view}" \
+ --compute "${compute}" \
+ --epochs "${EPOCHS}" \
+ --T "${t}" \
+ --n_sup "${ns}" \
+ --seed "${SEED}" \
+ --device "${DEVICE}"
+}
+
+# Complete remaining OGB-HIV backbone x {classic, fixed-RRoG} cells.
+# Existing json files are skipped, so the queue can be restarted safely.
+run_if_missing sgc fixed-rrog 3 3
+run_if_missing cheb classic 0 1
+run_if_missing cheb fixed-rrog 3 3
+run_if_missing arma classic 0 1
+run_if_missing arma fixed-rrog 3 3
+run_if_missing mf classic 0 1
+run_if_missing mf fixed-rrog 3 3
+run_if_missing appnp classic 0 1
+run_if_missing appnp fixed-rrog 3 3
+run_if_missing pna fixed-rrog 3 3
+run_if_missing gine classic 0 1
+run_if_missing gine fixed-rrog 3 3
+
+python3 -m rrog.cli results --epochs "${EPOCHS}"
diff --git a/rrog/run_zinc_gine.sh b/rrog/run_zinc_gine.sh
new file mode 100755
index 0000000..1ca36e1
--- /dev/null
+++ b/rrog/run_zinc_gine.sh
@@ -0,0 +1,36 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "${ROOT_DIR}"
+export PYTHONPATH="${ROOT_DIR}:${PYTHONPATH:-}"
+
+DEVICE="${DEVICE:-cuda:3}"
+EPOCHS="${EPOCHS:-200}"
+SEED="${SEED:-0}"
+
+run_if_missing() {
+ local compute="$1"
+ local t="$2"
+ local ns="$3"
+ local out="runs/rec_rrog_gine_full_sig0.0_K1_none_T${t}_ns${ns}_trace_s${SEED}.json"
+ if [[ -f "${out}" ]]; then
+ echo "[skip] ${out}"
+ return
+ fi
+ echo "[run] zinc-cycle56 view=gine compute=${compute} T=${t} ns=${ns} device=${DEVICE}"
+ python3 -m rrog.cli run \
+ --task zinc-cycle56 \
+ --view gine \
+ --compute "${compute}" \
+ --epochs "${EPOCHS}" \
+ --T "${t}" \
+ --n_sup "${ns}" \
+ --seed "${SEED}" \
+ --device "${DEVICE}"
+}
+
+run_if_missing classic 0 1
+run_if_missing fixed-rrog 1 3
+
+python3 -m rrog.cli zinc-results --epochs "${EPOCHS}"
diff --git a/rrog/run_zinc_gine_after_pid.sh b/rrog/run_zinc_gine_after_pid.sh
new file mode 100755
index 0000000..a4d24ae
--- /dev/null
+++ b/rrog/run_zinc_gine_after_pid.sh
@@ -0,0 +1,14 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+WAIT_PID="${1:?usage: run_zinc_gine_after_pid.sh <pid>}"
+ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "${ROOT_DIR}"
+
+echo "[wait] OGB queue pid=${WAIT_PID}"
+while kill -0 "${WAIT_PID}" 2>/dev/null; do
+ sleep 60
+done
+
+echo "[start] OGB queue exited; launching ZINC gine"
+DEVICE="${DEVICE:-cuda:1}" EPOCHS="${EPOCHS:-200}" SEED="${SEED:-0}" ./rrog/run_zinc_gine.sh
diff --git a/rrog/runspecs.py b/rrog/runspecs.py
new file mode 100644
index 0000000..285348f
--- /dev/null
+++ b/rrog/runspecs.py
@@ -0,0 +1,188 @@
+from rrog.registry import RunSpec
+
+
+ZINC_VIEWS = [
+ "gin", "gine", "gcn", "graphsage", "gatv2", "graphconv", "transformer", "pna",
+ "gen", "film", "resgated", "tag", "sgc", "cheb", "arma", "mf", "appnp",
+]
+OGB_MOL_VIEWS = [
+ "gin", "gine", "gcn", "graphsage", "gatv2", "graphconv", "transformer", "pna",
+ "gen", "film", "resgated", "tag", "sgc", "cheb", "arma", "mf", "appnp",
+]
+OGB_MOL_TASKS = [
+ "ogbg-molhiv",
+ "ogbg-molpcba",
+ "ogbg-molbbbp",
+ "ogbg-molbace",
+ "ogbg-moltox21",
+ "ogbg-molclintox",
+ "ogbg-molsider",
+ "ogbg-molesol",
+ "ogbg-molfreesolv",
+ "ogbg-mollipo",
+]
+
+
+def _zinc_cycle56_gin_command(args: dict[str, object]) -> list[str]:
+ compute = str(args.get("compute", "rrog-act"))
+ view = str(args.get("view", "gin"))
+ epochs = int(args.get("epochs", 200))
+ hidden = int(args.get("hidden", 64))
+ batch_size = int(args.get("bs", 256))
+ seed = int(args.get("seed", 0))
+ t = int(args.get("T", 1))
+ n_sup = int(args.get("n_sup", 3))
+ halt_max_steps = int(args.get("halt_max_steps", 8))
+ halt_target = str(args.get("halt_target", "binary"))
+
+ cmd = [
+ "python3", "diag/train_rec.py",
+ "--grad_mode", "full",
+ "--T", str(t),
+ "--n_sup", str(n_sup),
+ "--hidden", str(hidden),
+ "--bs", str(batch_size),
+ "--epochs", str(epochs),
+ "--agg_layers", str(args.get("agg_layers", 5)),
+ "--compute_layers", str(args.get("compute_layers", 2)),
+ "--view", view,
+ "--sigma", "0",
+ "--K", "1",
+ "--select", "none",
+ "--seed", str(seed),
+ ]
+
+ if compute in ["classic", "view-only"]:
+ cmd[cmd.index("--T") + 1] = "0"
+ elif compute == "fixed-rrog":
+ pass
+ elif compute == "rrog-act":
+ cmd.extend([
+ "--act",
+ "--halt_max_steps", str(halt_max_steps),
+ "--halt_target", halt_target,
+ "--halt_norm_threshold", str(args.get("halt_norm_threshold", 0.3)),
+ ])
+ else:
+ raise ValueError(f"unsupported compute for zinc-cycle56/{view}: {compute}")
+ if args.get("device") is not None:
+ cmd.extend(["--device", str(args["device"])])
+ return cmd
+
+
+def _ogb_graphprop_gin_command(args: dict[str, object]) -> list[str]:
+ task = str(args["task"])
+ view = str(args.get("view", "gin"))
+ compute = str(args.get("compute", "rrog-act"))
+ cmd = [
+ "python3", "rrog/train_ogb_graphprop.py",
+ "--dataset", task,
+ "--view", view,
+ "--compute", compute,
+ "--T", str(args.get("T", 1)),
+ "--n_sup", str(args.get("n_sup", 3)),
+ "--hidden", str(args.get("hidden", 128)),
+ "--bs", str(args.get("bs", 128)),
+ "--epochs", str(args.get("epochs", 100)),
+ "--eval_every", str(args.get("eval_every", 10)),
+ "--agg_layers", str(args.get("agg_layers", 5)),
+ "--compute_layers", str(args.get("compute_layers", 2)),
+ "--seed", str(args.get("seed", 0)),
+ ]
+ for key in ["lr", "lam_q"]:
+ if args.get(key) is not None:
+ cmd.extend([f"--{key}", str(args[key])])
+ if compute == "rrog-act":
+ cmd.extend([
+ "--halt_max_steps", str(args.get("halt_max_steps", 8)),
+ "--halt_min_steps", str(args.get("halt_min_steps", 2)),
+ "--halt_target", str(args.get("halt_target", "loss")),
+ "--halt_loss_threshold", str(args.get("halt_loss_threshold", 0.2)),
+ "--halt_exploration_prob", str(args.get("halt_exploration_prob", 0.1)),
+ ])
+ if args.get("q_warmup_epochs") is not None:
+ cmd.extend(["--q_warmup_epochs", str(args["q_warmup_epochs"])])
+ if float(args.get("ema", 0.0) or 0.0) > 0:
+ cmd.extend(["--ema", str(args["ema"])])
+ if args.get("device") is not None:
+ cmd.extend(["--device", str(args["device"])])
+ for key in ["max_train_batches", "max_eval_batches", "num_workers"]:
+ if args.get(key) is not None:
+ cmd.extend([f"--{key}", str(args[key])])
+ return cmd
+
+
+RUN_SPECS = [
+]
+
+for _view in ZINC_VIEWS:
+ RUN_SPECS.extend([
+ RunSpec(
+ task="zinc-cycle56",
+ view=_view,
+ compute="classic",
+ default_args={"compute": "classic", "view": _view, "T": 0, "n_sup": 1, "epochs": 200},
+ command_builder=_zinc_cycle56_gin_command,
+ ),
+ RunSpec(
+ task="zinc-cycle56",
+ view=_view,
+ compute="view-only",
+ default_args={"compute": "view-only", "view": _view, "T": 0, "n_sup": 3, "epochs": 200},
+ command_builder=_zinc_cycle56_gin_command,
+ ),
+ RunSpec(
+ task="zinc-cycle56",
+ view=_view,
+ compute="fixed-rrog",
+ default_args={"compute": "fixed-rrog", "view": _view, "T": 3, "n_sup": 3, "epochs": 200},
+ command_builder=_zinc_cycle56_gin_command,
+ ),
+ RunSpec(
+ task="zinc-cycle56",
+ view=_view,
+ compute="rrog-act",
+ default_args={"compute": "rrog-act", "view": _view, "T": 1, "n_sup": 3, "epochs": 200},
+ command_builder=_zinc_cycle56_gin_command,
+ ),
+ ])
+
+for _task in OGB_MOL_TASKS:
+ for _view in OGB_MOL_VIEWS:
+ RUN_SPECS.extend([
+ RunSpec(
+ task=_task,
+ view=_view,
+ compute="classic",
+ default_args={"task": _task, "view": _view, "compute": "classic", "T": 0, "n_sup": 1, "epochs": 100},
+ command_builder=_ogb_graphprop_gin_command,
+ ),
+ RunSpec(
+ task=_task,
+ view=_view,
+ compute="view-only",
+ default_args={"task": _task, "view": _view, "compute": "view-only", "T": 0, "n_sup": 3, "epochs": 100},
+ command_builder=_ogb_graphprop_gin_command,
+ ),
+ RunSpec(
+ task=_task,
+ view=_view,
+ compute="fixed-rrog",
+ default_args={"task": _task, "view": _view, "compute": "fixed-rrog", "T": 3, "n_sup": 3, "epochs": 100},
+ command_builder=_ogb_graphprop_gin_command,
+ ),
+ RunSpec(
+ task=_task,
+ view=_view,
+ compute="rrog-act",
+ default_args={"task": _task, "view": _view, "compute": "rrog-act", "T": 1, "n_sup": 3, "epochs": 100},
+ command_builder=_ogb_graphprop_gin_command,
+ ),
+ ])
+
+
+def find_run_spec(task: str, view: str, compute: str, modifier: str = "none") -> RunSpec:
+ for spec in RUN_SPECS:
+ if (spec.task, spec.view, spec.compute, spec.modifier) == (task, view, compute, modifier):
+ return spec
+ raise KeyError(f"no implemented run spec for task={task} view={view} compute={compute} modifier={modifier}")
diff --git a/rrog/train_ogb_graphprop.py b/rrog/train_ogb_graphprop.py
new file mode 100644
index 0000000..387ef3c
--- /dev/null
+++ b/rrog/train_ogb_graphprop.py
@@ -0,0 +1,685 @@
+import argparse
+from contextlib import contextmanager
+import json
+import os
+import time
+
+import numpy as np
+import torch
+import torch.nn as nn
+from ogb.graphproppred import Evaluator, PygGraphPropPredDataset
+from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
+from torch_geometric.data import Batch
+from torch_geometric.loader import DataLoader
+from torch_geometric.nn import (
+ APPNP,
+ ARMAConv,
+ ChebConv,
+ FiLMConv,
+ GATv2Conv,
+ GCNConv,
+ GENConv,
+ GINEConv,
+ GraphConv,
+ MFConv,
+ PNAConv,
+ ResGatedGraphConv,
+ SAGEConv,
+ SGConv,
+ TAGConv,
+ TransformerConv,
+ global_add_pool,
+)
+from torch_geometric.utils import degree
+
+
+PROJECT_ROOT = os.environ.get(
+ "RROG_ROOT",
+ os.path.abspath(os.path.join(os.path.dirname(__file__), "..")),
+)
+DATA_ROOT = os.environ.get("RROG_DATA_DIR", os.path.join(PROJECT_ROOT, "data"))
+ROOT = os.path.join(DATA_ROOT, "ogb")
+OUT = os.environ.get("RROG_RUNS_DIR", os.path.join(PROJECT_ROOT, "runs"))
+SUPPORTED_VIEWS = [
+ "gin", "gine", "gcn", "graphsage", "gatv2", "graphconv", "transformer", "pna",
+ "gen", "film", "resgated", "tag", "sgc", "cheb", "arma", "mf", "appnp",
+]
+SUPPORTED_MOL_DATASETS = [
+ "ogbg-molhiv",
+ "ogbg-molpcba",
+ "ogbg-molbbbp",
+ "ogbg-molbace",
+ "ogbg-moltox21",
+ "ogbg-molclintox",
+ "ogbg-molsider",
+ "ogbg-molesol",
+ "ogbg-molfreesolv",
+ "ogbg-mollipo",
+]
+HIGHER_BETTER = {"rocauc", "ap", "auc", "accuracy", "acc", "f1"}
+LOWER_BETTER = {"rmse", "mae"}
+
+_TORCH_LOAD = torch.load
+
+
+def _torch_load_ogb_compat(*args, **kwargs):
+ kwargs.setdefault("weights_only", False)
+ return _TORCH_LOAD(*args, **kwargs)
+
+
+torch.load = _torch_load_ogb_compat
+
+
+def clone_state_dict(model):
+ return {k: v.detach().clone() for k, v in model.state_dict().items()}
+
+
+@torch.no_grad()
+def update_ema_state(ema_state, model, decay):
+ if ema_state is None:
+ return
+ for key, value in model.state_dict().items():
+ if torch.is_floating_point(value):
+ ema_state[key].mul_(decay).add_(value.detach(), alpha=1.0 - decay)
+ else:
+ ema_state[key].copy_(value.detach())
+
+
+@contextmanager
+def using_ema_state(model, ema_state):
+ if ema_state is None:
+ yield
+ return
+ raw_state = clone_state_dict(model)
+ model.load_state_dict(ema_state, strict=True)
+ try:
+ yield
+ finally:
+ model.load_state_dict(raw_state, strict=True)
+
+
+def metric_direction(metric: str) -> int:
+ metric = metric.lower()
+ if metric in LOWER_BETTER or "rmse" in metric or "mae" in metric:
+ return -1
+ return 1
+
+
+def is_regression_metric(metric: str) -> bool:
+ return metric_direction(metric) < 0
+
+
+def is_better(score: float, best: float | None, metric: str) -> bool:
+ if best is None:
+ return True
+ direction = metric_direction(metric)
+ return score > best if direction > 0 else score < best
+
+
+def jsonable(obj):
+ if isinstance(obj, dict):
+ return {str(k): jsonable(v) for k, v in obj.items()}
+ if isinstance(obj, (list, tuple)):
+ return [jsonable(v) for v in obj]
+ if isinstance(obj, (np.integer, np.floating)):
+ return obj.item()
+ if isinstance(obj, torch.Tensor):
+ if obj.ndim == 0:
+ return obj.detach().cpu().item()
+ return obj.detach().cpu().tolist()
+ return obj
+
+
+def degree_histogram(dataset) -> torch.Tensor:
+ max_degree = 0
+ degs = []
+ for graph in dataset:
+ deg = degree(graph.edge_index[1], num_nodes=graph.num_nodes, dtype=torch.long)
+ degs.append(deg)
+ if deg.numel():
+ max_degree = max(max_degree, int(deg.max().item()))
+ hist = torch.zeros(max_degree + 1, dtype=torch.long)
+ for deg in degs:
+ hist += torch.bincount(deg, minlength=hist.numel())
+ return hist
+
+
+def make_view_layer(view: str, hidden: int, deg: torch.Tensor | None):
+ if view in {"gin", "gine"}:
+ mlp = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden))
+ return GINEConv(mlp, train_eps=True)
+ if view == "gcn":
+ return GCNConv(hidden, hidden)
+ if view == "graphsage":
+ return SAGEConv(hidden, hidden)
+ if view == "gatv2":
+ return GATv2Conv(hidden, hidden, heads=4, concat=False, edge_dim=hidden)
+ if view == "graphconv":
+ return GraphConv(hidden, hidden)
+ if view == "transformer":
+ return TransformerConv(hidden, hidden, heads=4, concat=False, edge_dim=hidden)
+ if view == "pna":
+ if deg is None:
+ raise ValueError("PNA view requires a training-set degree histogram")
+ return PNAConv(
+ hidden, hidden,
+ aggregators=["mean", "min", "max", "std"],
+ scalers=["identity", "amplification", "attenuation"],
+ deg=deg,
+ edge_dim=hidden,
+ )
+ if view == "gen":
+ return GENConv(hidden, hidden, edge_dim=hidden)
+ if view == "film":
+ return FiLMConv(hidden, hidden)
+ if view == "resgated":
+ return ResGatedGraphConv(hidden, hidden, edge_dim=hidden)
+ if view == "tag":
+ return TAGConv(hidden, hidden, K=3)
+ if view == "sgc":
+ return SGConv(hidden, hidden, K=2, cached=False)
+ if view == "cheb":
+ return ChebConv(hidden, hidden, K=3)
+ if view == "arma":
+ return ARMAConv(hidden, hidden, num_stacks=1, num_layers=2)
+ if view == "mf":
+ return MFConv(hidden, hidden)
+ if view == "appnp":
+ return APPNP(K=5, alpha=0.1)
+ raise ValueError(f"unsupported OGB view: {view}")
+
+
+EDGE_ATTR_VIEWS = {"gin", "gine", "gatv2", "transformer", "pna", "gen", "resgated"}
+
+
+class OGBRRoG(nn.Module):
+ def __init__(
+ self, hidden, num_tasks, view="gin", T=1, n_sup=3, agg_layers=5,
+ compute_layers=2, grad_mode="full", deg=None,
+ ):
+ super().__init__()
+ self.view = view
+ self.atom_encoder = AtomEncoder(hidden)
+ self.bond_encoder = BondEncoder(hidden)
+ self.agg_convs = nn.ModuleList()
+ self.agg_bns = nn.ModuleList()
+ for _ in range(agg_layers):
+ self.agg_convs.append(make_view_layer(view, hidden, deg))
+ self.agg_bns.append(nn.BatchNorm1d(hidden))
+
+ core = []
+ d = hidden
+ for _ in range(compute_layers - 1):
+ core += [nn.Linear(d, hidden), nn.GELU()]
+ d = hidden
+ core.append(nn.Linear(d, hidden))
+ self.core_norm = nn.LayerNorm(hidden)
+ self.core = nn.Sequential(*core)
+ nn.init.zeros_(self.core[-1].weight)
+ nn.init.zeros_(self.core[-1].bias)
+
+ self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, num_tasks))
+ self.qhead = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 1))
+ with torch.no_grad():
+ self.qhead[-1].weight.zero_()
+ self.qhead[-1].bias.fill_(-5.0)
+
+ self.T = T
+ self.n_sup = n_sup
+ self.grad_mode = grad_mode
+ self.agg_layers = agg_layers
+ self.compute_layers = compute_layers
+ self.hidden = hidden
+ self.num_tasks = num_tasks
+
+ def aggregate(self, x, edge_index, edge_attr):
+ h = self.atom_encoder(x)
+ e = self.bond_encoder(edge_attr)
+ for conv, bn in zip(self.agg_convs, self.agg_bns):
+ if self.view in {"gin", "gine", "pna", "gen", "resgated"}:
+ h = bn(conv(h, edge_index, e)).relu()
+ elif self.view in {"gatv2", "transformer"}:
+ h = bn(conv(h, edge_index, edge_attr=e)).relu()
+ else:
+ h = bn(conv(h, edge_index)).relu()
+ return h
+
+ def core_step(self, combined, state):
+ return state + self.core(self.core_norm(combined))
+
+ def _z_step(self, y, z, ctx):
+ return self.core_step(ctx + y + z, z)
+
+ def _y_step(self, y, z):
+ return self.core_step(y + z, y)
+
+ def recurse(self, y, z, ctx, one_step=False):
+ if self.T == 0:
+ return y, z
+ if one_step:
+ with torch.no_grad():
+ for _ in range(self.T - 1):
+ z = self._z_step(y, z, ctx)
+ z = z.detach()
+ z = self._z_step(y, z, ctx)
+ y = self._y_step(y, z)
+ return y, z
+ for _ in range(self.T):
+ z = self._z_step(y, z, ctx)
+ y = self._y_step(y, z)
+ return y, z
+
+ def predict(self, y, batch):
+ pooled = global_add_pool(y, batch)
+ return self.head(pooled), self.qhead(pooled).view(-1)
+
+ def forward_trace(self, data, steps):
+ ctx = self.aggregate(data.x, data.edge_index, data.edge_attr)
+ y = ctx
+ z = torch.zeros_like(ctx)
+ preds, q_logits = [], []
+ for s in range(steps):
+ y, z = self.recurse(y, z, ctx, one_step=(self.grad_mode == "1step"))
+ pred, q = self.predict(y, data.batch)
+ preds.append(pred)
+ q_logits.append(q)
+ if s < steps - 1:
+ y, z = y.detach(), z.detach()
+ return preds, q_logits
+
+ def forward(self, data):
+ steps = self.n_sup
+ preds, q_logits = self.forward_trace(data, steps)
+ return preds, q_logits[-1]
+
+
+def supervised_loss(logits, y, metric):
+ per_graph, has_label = per_graph_supervised_loss(logits, y, metric)
+ if not has_label.any():
+ return logits.sum() * 0.0
+ return per_graph[has_label].mean()
+
+
+def per_graph_supervised_loss(logits, y, metric):
+ y = y.to(torch.float32)
+ mask = ~torch.isnan(y)
+ target = torch.where(mask, y, torch.zeros_like(y))
+ if is_regression_metric(metric):
+ losses = (logits - target).pow(2)
+ else:
+ losses = nn.functional.binary_cross_entropy_with_logits(logits, target, reduction="none")
+ losses = torch.where(mask, losses, torch.zeros_like(losses))
+ denom = mask.sum(dim=1).clamp_min(1)
+ return losses.sum(dim=1) / denom, mask.any(dim=1)
+
+
+@torch.no_grad()
+def halt_targets(logits, y):
+ y = y.to(torch.float32)
+ mask = ~torch.isnan(y)
+ pred = (logits > 0).to(y.dtype)
+ correct_or_missing = (~mask) | (pred == y)
+ has_label = mask.any(dim=1)
+ return (correct_or_missing.all(dim=1) & has_label).to(logits.dtype)
+
+
+@torch.no_grad()
+def evaluate(model, loader, evaluator, dev, steps=None, adaptive=False, halt_min_steps=1, max_batches=0):
+ model.eval()
+ ys, ps = [], []
+ step_sum = 0.0
+ n = 0
+ for i, batch in enumerate(loader):
+ if max_batches and i >= max_batches:
+ break
+ batch = batch.to(dev)
+ if steps is None:
+ preds, _ = model(batch)
+ pred = preds[-1]
+ used_steps = float(model.n_sup)
+ else:
+ preds, qs = model.forward_trace(batch, steps)
+ stack = torch.stack(preds, dim=0)
+ if adaptive:
+ q = torch.stack(qs, dim=0)
+ step_ids = torch.arange(1, steps + 1, device=dev).view(-1, 1)
+ halted = (q > 0) & (step_ids >= halt_min_steps)
+ any_halt = halted.any(dim=0)
+ first_halt = halted.to(torch.int64).argmax(dim=0)
+ fallback = torch.full_like(first_halt, steps - 1)
+ idx = torch.where(any_halt, first_halt, fallback)
+ pred = stack[idx, torch.arange(stack.size(1), device=dev)]
+ used_steps = float((idx.to(torch.float32) + 1).sum().item())
+ else:
+ pred = stack[-1]
+ used_steps = float(steps * batch.num_graphs)
+ ys.append(batch.y.detach().cpu())
+ ps.append(pred.detach().cpu())
+ step_sum += used_steps
+ n += batch.num_graphs
+ y_true = torch.cat(ys, dim=0)
+ y_pred = torch.cat(ps, dim=0)
+ result = evaluator.eval({"y_true": y_true, "y_pred": y_pred})
+ return result, step_sum / max(n, 1)
+
+
+def _split_nodes(t, ptr):
+ return [t[ptr[i].item():ptr[i + 1].item()].detach() for i in range(ptr.numel() - 1)]
+
+
+def act_train_step(model, state, replacement_batch, opt, dev, args, metric):
+ replacement = replacement_batch.to_data_list()
+ batch_size = len(replacement)
+ if state is None:
+ state = {
+ "graphs": [None for _ in range(batch_size)],
+ "y": [None for _ in range(batch_size)],
+ "z": [None for _ in range(batch_size)],
+ "steps": torch.zeros(batch_size, dtype=torch.long, device=dev),
+ "halted": torch.ones(batch_size, dtype=torch.bool, device=dev),
+ }
+
+ halted_cpu = state["halted"].detach().cpu().tolist()
+ for i, halted in enumerate(halted_cpu):
+ if halted:
+ state["graphs"][i] = replacement[i]
+
+ batch = Batch.from_data_list(state["graphs"]).to(dev)
+ ctx = model.aggregate(batch.x, batch.edge_index, batch.edge_attr)
+ ptr = batch.ptr
+ y_parts, z_parts = [], []
+ for i in range(batch_size):
+ start, end = ptr[i].item(), ptr[i + 1].item()
+ if halted_cpu[i] or state["y"][i] is None:
+ y_parts.append(ctx[start:end])
+ z_parts.append(torch.zeros_like(ctx[start:end]))
+ else:
+ y_parts.append(state["y"][i].to(dev))
+ z_parts.append(state["z"][i].to(dev))
+ y = torch.cat(y_parts, dim=0)
+ z = torch.cat(z_parts, dim=0)
+
+ opt.zero_grad()
+ y, z = model.recurse(y, z, ctx, one_step=(model.grad_mode == "1step"))
+ logits, q = model.predict(y, batch.batch)
+ pred_loss = supervised_loss(logits, batch.y, metric)
+ if args.halt_target == "exact" and not is_regression_metric(metric):
+ target = halt_targets(logits.detach(), batch.y)
+ elif args.halt_target == "loss":
+ per_graph_loss, has_label = per_graph_supervised_loss(logits.detach(), batch.y, metric)
+ target = ((per_graph_loss <= args.halt_loss_threshold) & has_label).to(logits.dtype)
+ else:
+ raise ValueError(args.halt_target)
+ q_loss = nn.functional.binary_cross_entropy_with_logits(q, target)
+ loss = pred_loss + 0.5 * args.lam_q * q_loss
+ y_det, z_det = y.detach(), z.detach()
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
+ opt.step()
+
+ state["y"] = _split_nodes(y_det, ptr)
+ state["z"] = _split_nodes(z_det, ptr)
+ with torch.no_grad():
+ was_halted = state["halted"]
+ steps = torch.where(was_halted, torch.zeros_like(state["steps"]), state["steps"]) + 1
+ halted = (steps >= args.halt_max_steps) | ((q.detach() > 0) & (steps >= args.halt_min_steps))
+ if args.halt_exploration_prob > 0 and args.halt_max_steps > 1:
+ explore = torch.rand_like(q) < args.halt_exploration_prob
+ min_steps = torch.where(
+ explore,
+ torch.randint(2, args.halt_max_steps + 1, steps.shape, device=dev),
+ torch.zeros_like(steps),
+ )
+ halted = halted & (steps >= min_steps)
+ state["steps"] = steps
+ state["halted"] = halted
+
+ return state, {
+ "loss": float(loss.detach().cpu()),
+ "pred_loss": float(pred_loss.detach().cpu()),
+ "q_loss": float(q_loss.detach().cpu()),
+ "halted_frac": float(state["halted"].to(torch.float32).mean().detach().cpu()),
+ "steps": float(state["steps"].to(torch.float32).mean().detach().cpu()),
+ }
+
+
+def _halt_target(args, logits, y, metric):
+ if args.halt_target == "exact" and not is_regression_metric(metric):
+ return halt_targets(logits.detach(), y)
+ if args.halt_target == "loss":
+ per_graph_loss, has_label = per_graph_supervised_loss(logits.detach(), y, metric)
+ return ((per_graph_loss <= args.halt_loss_threshold) & has_label).to(logits.dtype)
+ raise ValueError(args.halt_target)
+
+
+def act_trace_train_step(model, batch, opt, dev, args, epoch, metric):
+ batch = batch.to(dev)
+ steps = max(1, args.halt_max_steps)
+ opt.zero_grad()
+ preds, qs = model.forward_trace(batch, steps)
+ pred_loss = sum(supervised_loss(pred, batch.y, metric) for pred in preds) / len(preds)
+ if epoch <= args.q_warmup_epochs:
+ q_loss = pred_loss.detach() * 0.0
+ loss = pred_loss
+ else:
+ q_losses = []
+ for step_idx, (pred, q) in enumerate(zip(preds, qs), start=1):
+ target = _halt_target(args, pred, batch.y, metric)
+ if step_idx < args.halt_min_steps:
+ target = torch.zeros_like(target)
+ q_losses.append(nn.functional.binary_cross_entropy_with_logits(q, target))
+ q_loss = sum(q_losses) / len(q_losses)
+ loss = pred_loss + 0.5 * args.lam_q * q_loss
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
+ opt.step()
+
+ with torch.no_grad():
+ q_stack = torch.stack(qs, dim=0)
+ step_ids = torch.arange(1, steps + 1, device=dev).view(-1, 1)
+ halted = (q_stack > 0) & (step_ids >= args.halt_min_steps)
+ any_halt = halted.any(dim=0)
+ first_halt = halted.to(torch.int64).argmax(dim=0)
+ fallback = torch.full_like(first_halt, steps - 1)
+ idx = torch.where(any_halt, first_halt, fallback)
+
+ return {
+ "loss": float(loss.detach().cpu()),
+ "pred_loss": float(pred_loss.detach().cpu()),
+ "q_loss": float(q_loss.detach().cpu()),
+ "halted_frac": float(any_halt.to(torch.float32).mean().detach().cpu()),
+ "steps": float((idx.to(torch.float32) + 1).mean().detach().cpu()),
+ }
+
+
+def train_epoch(model, loader, opt, dev, args, act_state, ema_state, epoch, metric):
+ model.train()
+ metrics = []
+ for i, batch in enumerate(loader):
+ if args.max_train_batches and i >= args.max_train_batches:
+ break
+ if args.compute == "rrog-act":
+ m = act_trace_train_step(model, batch, opt, dev, args, epoch, metric)
+ update_ema_state(ema_state, model, args.ema)
+ metrics.append(m)
+ continue
+ batch = batch.to(dev)
+ opt.zero_grad()
+ preds, _ = model(batch)
+ loss = sum(supervised_loss(pred, batch.y, metric) for pred in preds) / len(preds)
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
+ opt.step()
+ update_ema_state(ema_state, model, args.ema)
+ metrics.append({"loss": float(loss.detach().cpu()), "halted_frac": 0.0, "steps": float(model.n_sup)})
+ return act_state, metrics
+
+
+def metric_value(result, evaluator):
+ return float(result[evaluator.eval_metric])
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument("--dataset", choices=SUPPORTED_MOL_DATASETS, default="ogbg-molhiv")
+ ap.add_argument("--view", choices=SUPPORTED_VIEWS, default="gin")
+ ap.add_argument("--compute", choices=["classic", "view-only", "fixed-rrog", "rrog-act"], default="rrog-act")
+ ap.add_argument("--T", type=int, default=1)
+ ap.add_argument("--n_sup", type=int, default=3)
+ ap.add_argument("--hidden", type=int, default=128)
+ ap.add_argument("--agg_layers", type=int, default=5)
+ ap.add_argument("--compute_layers", type=int, default=2)
+ ap.add_argument("--epochs", type=int, default=100)
+ ap.add_argument("--eval_every", type=int, default=10)
+ ap.add_argument("--lr", type=float, default=1e-3)
+ ap.add_argument("--bs", type=int, default=128)
+ ap.add_argument("--lam_q", type=float, default=1.0)
+ ap.add_argument("--halt_max_steps", type=int, default=8)
+ ap.add_argument("--halt_min_steps", type=int, default=1)
+ ap.add_argument("--halt_target", choices=["exact", "loss"], default="loss")
+ ap.add_argument("--halt_loss_threshold", type=float, default=0.2)
+ ap.add_argument("--halt_exploration_prob", type=float, default=0.1)
+ ap.add_argument("--q_warmup_epochs", type=int, default=0)
+ ap.add_argument("--ema", type=float, default=0.0)
+ ap.add_argument("--seed", type=int, default=0)
+ ap.add_argument("--num_workers", type=int, default=0)
+ ap.add_argument("--max_train_batches", type=int, default=0)
+ ap.add_argument("--max_eval_batches", type=int, default=0)
+ ap.add_argument("--device", default="auto")
+ args = ap.parse_args()
+
+ torch.manual_seed(args.seed)
+ np.random.seed(args.seed)
+ if args.compute == "classic":
+ args.n_sup = 1
+ if args.device == "auto":
+ dev = "cuda" if torch.cuda.is_available() else "cpu"
+ else:
+ dev = args.device
+ os.makedirs(OUT, exist_ok=True)
+
+ dataset = PygGraphPropPredDataset(name=args.dataset, root=ROOT)
+ split_idx = dataset.get_idx_split()
+ evaluator = Evaluator(args.dataset)
+ metric = evaluator.eval_metric
+ num_tasks = dataset.num_tasks
+
+ train_dataset = dataset[split_idx["train"]]
+ valid_dataset = dataset[split_idx["valid"]]
+ test_dataset = dataset[split_idx["test"]]
+ train_loader = DataLoader(train_dataset, batch_size=args.bs, shuffle=True,
+ drop_last=True, num_workers=args.num_workers)
+ valid_loader = DataLoader(valid_dataset, batch_size=256, shuffle=False,
+ num_workers=args.num_workers)
+ test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False,
+ num_workers=args.num_workers)
+
+ T = 0 if args.compute in ["classic", "view-only"] else args.T
+ deg = degree_histogram(train_dataset) if args.view == "pna" else None
+ model = OGBRRoG(args.hidden, num_tasks, view=args.view, T=T, n_sup=args.n_sup,
+ agg_layers=args.agg_layers, compute_layers=args.compute_layers, deg=deg).to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, max(args.epochs, 1))
+ ema_state = clone_state_dict(model) if args.ema > 0 else None
+
+ t0 = time.time()
+ best_val = None
+ best = {}
+ best_state = None
+ act_state = None
+ steps = max(1, args.halt_max_steps)
+
+ for ep in range(args.epochs):
+ act_state, train_metrics = train_epoch(
+ model, train_loader, opt, dev, args, act_state, ema_state, ep + 1, metric)
+ sched.step()
+ if (ep + 1) % args.eval_every == 0 or ep == args.epochs - 1:
+ with using_ema_state(model, ema_state):
+ if args.compute == "rrog-act":
+ val_result, fixed_val_steps = evaluate(model, valid_loader, evaluator, dev, steps=steps,
+ adaptive=False, halt_min_steps=args.halt_min_steps,
+ max_batches=args.max_eval_batches)
+ val_adapt, adaptive_val_steps = evaluate(model, valid_loader, evaluator, dev, steps=steps,
+ adaptive=True, halt_min_steps=args.halt_min_steps,
+ max_batches=args.max_eval_batches)
+ else:
+ val_result, _ = evaluate(model, valid_loader, evaluator, dev,
+ max_batches=args.max_eval_batches)
+ val_score = metric_value(val_result, evaluator)
+ if is_better(val_score, best_val, metric):
+ best_val = val_score
+ if args.compute == "rrog-act":
+ test_fixed, fixed_steps = evaluate(model, test_loader, evaluator, dev, steps=steps,
+ adaptive=False, halt_min_steps=args.halt_min_steps,
+ max_batches=args.max_eval_batches)
+ test_adapt, adaptive_steps = evaluate(model, test_loader, evaluator, dev, steps=steps,
+ adaptive=True, halt_min_steps=args.halt_min_steps,
+ max_batches=args.max_eval_batches)
+ best = {
+ "ep": ep + 1,
+ "val": val_result,
+ "val_adaptive": val_adapt,
+ "test": test_fixed,
+ "test_adaptive": test_adapt,
+ "fixed_val_steps": fixed_val_steps,
+ "adaptive_val_steps": adaptive_val_steps,
+ "fixed_steps": fixed_steps,
+ "adaptive_steps": adaptive_steps,
+ }
+ else:
+ test_result, _ = evaluate(model, test_loader, evaluator, dev,
+ max_batches=args.max_eval_batches)
+ best = {"ep": ep + 1, "val": val_result, "test": test_result}
+ best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
+ halted = sum(m.get("halted_frac", 0.0) for m in train_metrics) / max(len(train_metrics), 1)
+ train_steps = sum(m.get("steps", 0.0) for m in train_metrics) / max(len(train_metrics), 1)
+ msg = f"ep{ep+1} val_{evaluator.eval_metric}={val_score:.5f}"
+ if args.compute == "rrog-act":
+ msg += (
+ f" val_adapt_{evaluator.eval_metric}={metric_value(val_adapt, evaluator):.5f}"
+ f" adapt_steps={adaptive_val_steps:.2f}"
+ )
+ msg += f" halt={halted:.2f} train_steps={train_steps:.2f}"
+ print(msg, flush=True)
+
+ ema_tag = f"_ema{args.ema:g}" if args.ema > 0 else ""
+ tag = (
+ f"{args.dataset}_{args.view}_{args.compute}_T{T}_ns{args.n_sup}_"
+ f"h{args.hidden}_e{args.epochs}{ema_tag}_s{args.seed}"
+ )
+ rep = {
+ "dataset": args.dataset,
+ "tag": tag,
+ **vars(args),
+ "metric": evaluator.eval_metric,
+ "sec": round(time.time() - t0, 1),
+ "dev": dev,
+ **best,
+ }
+ with open(os.path.join(OUT, f"{tag}.json"), "w") as f:
+ json.dump(jsonable(rep), f, indent=2)
+ torch.save({
+ "state": best_state or model.state_dict(),
+ "cfg": {
+ "dataset": args.dataset,
+ "hidden": args.hidden,
+ "num_tasks": num_tasks,
+ "T": T,
+ "n_sup": args.n_sup,
+ "agg_layers": args.agg_layers,
+ "compute_layers": args.compute_layers,
+ "compute": args.compute,
+ "halt_max_steps": steps,
+ "halt_min_steps": args.halt_min_steps,
+ "halt_target": args.halt_target,
+ "halt_loss_threshold": args.halt_loss_threshold,
+ "view": args.view,
+ },
+ }, os.path.join(OUT, f"ckpt_{tag}.pt"))
+ print(f"[{tag}] best_ep={best.get('ep')} val={best.get('val')} test={best.get('test')} "
+ f"adaptive={best.get('test_adaptive')} steps={best.get('adaptive_steps')}", flush=True)
+ print(" wrote", os.path.join(OUT, f"{tag}.json"))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/collect_results.sh b/scripts/collect_results.sh
new file mode 100755
index 0000000..360f05b
--- /dev/null
+++ b/scripts/collect_results.sh
@@ -0,0 +1,10 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "${ROOT_DIR}"
+export PYTHONPATH="${ROOT_DIR}:${PYTHONPATH:-}"
+
+mkdir -p summaries
+python3 -m rrog.cli zinc-results --epochs "${ZINC_EPOCHS:-200}" | tee summaries/zinc_cycle56.md
+python3 -m rrog.cli results --epochs "${OGB_EPOCHS:-100}" | tee summaries/ogb_graphprop.md
diff --git a/scripts/run_ogb_mol_all_tasks.sh b/scripts/run_ogb_mol_all_tasks.sh
new file mode 100755
index 0000000..b191d79
--- /dev/null
+++ b/scripts/run_ogb_mol_all_tasks.sh
@@ -0,0 +1,17 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "${ROOT_DIR}"
+
+DEVICE="${DEVICE:-cuda:1}"
+EPOCHS="${EPOCHS:-100}"
+SEED="${SEED:-0}"
+TASKS="${TASKS:-ogbg-molhiv ogbg-molbbbp ogbg-molbace ogbg-moltox21 ogbg-molclintox ogbg-molsider ogbg-molesol ogbg-molfreesolv ogbg-mollipo}"
+
+mkdir -p logs
+for task in ${TASKS}; do
+ echo "[task] ${task}"
+ TASK="${task}" DEVICE="${DEVICE}" EPOCHS="${EPOCHS}" SEED="${SEED}" \
+ ./scripts/run_ogb_mol_task_full.sh 2>&1 | tee "logs/${task}_${SEED}.log"
+done
diff --git a/scripts/run_ogb_mol_task_full.sh b/scripts/run_ogb_mol_task_full.sh
new file mode 100755
index 0000000..b25bff3
--- /dev/null
+++ b/scripts/run_ogb_mol_task_full.sh
@@ -0,0 +1,54 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "${ROOT_DIR}"
+export PYTHONPATH="${ROOT_DIR}:${PYTHONPATH:-}"
+
+TASK="${TASK:-ogbg-molhiv}"
+DEVICE="${DEVICE:-cuda:1}"
+EPOCHS="${EPOCHS:-100}"
+SEED="${SEED:-0}"
+HIDDEN="${HIDDEN:-128}"
+VIEWS="${VIEWS:-gin gine gcn graphsage gatv2 graphconv transformer pna gen film resgated tag sgc cheb arma mf appnp}"
+
+mkdir -p runs logs
+
+result_path() {
+ local view="$1"
+ local compute="$2"
+ local t="$3"
+ local ns="$4"
+ echo "runs/${TASK}_${view}_${compute}_T${t}_ns${ns}_h${HIDDEN}_e${EPOCHS}_s${SEED}.json"
+}
+
+run_cell() {
+ local view="$1"
+ local compute="$2"
+ local t="$3"
+ local ns="$4"
+ local out
+ out="$(result_path "${view}" "${compute}" "${t}" "${ns}")"
+ if [[ -f "${out}" ]]; then
+ echo "[skip] ${out}"
+ return
+ fi
+ echo "[run] ${TASK} view=${view} compute=${compute} T=${t} ns=${ns} device=${DEVICE}"
+ python3 -m rrog.cli run \
+ --task "${TASK}" \
+ --view "${view}" \
+ --compute "${compute}" \
+ --epochs "${EPOCHS}" \
+ --hidden "${HIDDEN}" \
+ --T "${t}" \
+ --n_sup "${ns}" \
+ --seed "${SEED}" \
+ --device "${DEVICE}"
+}
+
+for view in ${VIEWS}; do
+ run_cell "${view}" classic 0 1
+ run_cell "${view}" fixed-rrog 3 3
+done
+
+python3 -m rrog.cli results --epochs "${EPOCHS}"
diff --git a/scripts/run_smoke.sh b/scripts/run_smoke.sh
new file mode 100755
index 0000000..6365cec
--- /dev/null
+++ b/scripts/run_smoke.sh
@@ -0,0 +1,19 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "${ROOT_DIR}"
+export PYTHONPATH="${ROOT_DIR}:${PYTHONPATH:-}"
+
+DEVICE="${DEVICE:-cuda:0}"
+mkdir -p runs logs
+
+python3 -m rrog.cli run \
+ --task ogbg-molhiv --view gin --compute classic \
+ --epochs 1 --hidden 32 --bs 64 --seed 991 --device "${DEVICE}" \
+ --max_train_batches 2 --max_eval_batches 2
+
+python3 -m rrog.cli run \
+ --task ogbg-molhiv --view gin --compute fixed-rrog \
+ --epochs 1 --hidden 32 --bs 64 --T 1 --n_sup 2 --seed 992 --device "${DEVICE}" \
+ --max_train_batches 2 --max_eval_batches 2
diff --git a/scripts/run_two_a6000.sh b/scripts/run_two_a6000.sh
new file mode 100755
index 0000000..8d9851f
--- /dev/null
+++ b/scripts/run_two_a6000.sh
@@ -0,0 +1,32 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "${ROOT_DIR}"
+export PYTHONPATH="${ROOT_DIR}:${PYTHONPATH:-}"
+
+ZINC_DEVICE="${ZINC_DEVICE:-cuda:0}"
+OGB_DEVICE="${OGB_DEVICE:-cuda:1}"
+OGB_TASK="${OGB_TASK:-ogbg-molhiv}"
+ZINC_EPOCHS="${ZINC_EPOCHS:-200}"
+OGB_EPOCHS="${OGB_EPOCHS:-100}"
+SEED="${SEED:-0}"
+
+mkdir -p runs logs
+
+echo "[launch] ZINC-cycle56 on ${ZINC_DEVICE}"
+DEVICE="${ZINC_DEVICE}" EPOCHS="${ZINC_EPOCHS}" SEED="${SEED}" \
+ ./scripts/run_zinc_cycle56_full.sh > "logs/zinc_cycle56_${SEED}.log" 2>&1 &
+zinc_pid=$!
+
+echo "[launch] ${OGB_TASK} on ${OGB_DEVICE}"
+TASK="${OGB_TASK}" DEVICE="${OGB_DEVICE}" EPOCHS="${OGB_EPOCHS}" SEED="${SEED}" \
+ ./scripts/run_ogb_mol_task_full.sh > "logs/${OGB_TASK}_${SEED}.log" 2>&1 &
+ogb_pid=$!
+
+echo "[pids] zinc=${zinc_pid} ogb=${ogb_pid}"
+wait "${zinc_pid}"
+wait "${ogb_pid}"
+
+echo "[done] collecting summaries"
+./scripts/collect_results.sh
diff --git a/scripts/run_zinc_cycle56_full.sh b/scripts/run_zinc_cycle56_full.sh
new file mode 100755
index 0000000..151a51e
--- /dev/null
+++ b/scripts/run_zinc_cycle56_full.sh
@@ -0,0 +1,54 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "${ROOT_DIR}"
+export PYTHONPATH="${ROOT_DIR}:${PYTHONPATH:-}"
+
+DEVICE="${DEVICE:-cuda:0}"
+EPOCHS="${EPOCHS:-200}"
+SEED="${SEED:-0}"
+VIEWS="${VIEWS:-gin gine gcn graphsage gatv2 graphconv transformer pna gen film resgated tag sgc cheb arma mf appnp}"
+
+mkdir -p runs logs
+
+result_path() {
+ local view="$1"
+ local t="$2"
+ local ns="$3"
+ local view_tag=""
+ if [[ "${view}" != "gin" ]]; then
+ view_tag="_${view}"
+ fi
+ echo "runs/rec_rrog${view_tag}_full_sig0.0_K1_none_T${t}_ns${ns}_trace_s${SEED}.json"
+}
+
+run_cell() {
+ local view="$1"
+ local compute="$2"
+ local t="$3"
+ local ns="$4"
+ local out
+ out="$(result_path "${view}" "${t}" "${ns}")"
+ if [[ -f "${out}" ]]; then
+ echo "[skip] ${out}"
+ return
+ fi
+ echo "[run] zinc-cycle56 view=${view} compute=${compute} T=${t} ns=${ns} device=${DEVICE}"
+ python3 -m rrog.cli run \
+ --task zinc-cycle56 \
+ --view "${view}" \
+ --compute "${compute}" \
+ --epochs "${EPOCHS}" \
+ --T "${t}" \
+ --n_sup "${ns}" \
+ --seed "${SEED}" \
+ --device "${DEVICE}"
+}
+
+for view in ${VIEWS}; do
+ run_cell "${view}" classic 0 1
+ run_cell "${view}" fixed-rrog 1 3
+done
+
+python3 -m rrog.cli zinc-results --epochs "${EPOCHS}"
diff --git a/scripts/setup_and_run_two_a6000.sh b/scripts/setup_and_run_two_a6000.sh
new file mode 100755
index 0000000..ec4e3da
--- /dev/null
+++ b/scripts/setup_and_run_two_a6000.sh
@@ -0,0 +1,15 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "${ROOT_DIR}"
+
+if [[ "${SKIP_SETUP:-0}" != "1" ]]; then
+ ./scripts/setup_env.sh
+fi
+
+if [[ -d "${VENV_DIR:-.venv}" ]]; then
+ source "${VENV_DIR:-.venv}/bin/activate"
+fi
+
+./scripts/run_two_a6000.sh
diff --git a/scripts/setup_env.sh b/scripts/setup_env.sh
new file mode 100755
index 0000000..66a94c8
--- /dev/null
+++ b/scripts/setup_env.sh
@@ -0,0 +1,35 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "${ROOT_DIR}"
+
+PYTHON_BIN="${PYTHON_BIN:-python3}"
+VENV_DIR="${VENV_DIR:-.venv}"
+TORCH_INDEX_URL="${TORCH_INDEX_URL:-https://download.pytorch.org/whl/cu124}"
+
+if [[ ! -d "${VENV_DIR}" ]]; then
+ "${PYTHON_BIN}" -m venv "${VENV_DIR}"
+fi
+
+source "${VENV_DIR}/bin/activate"
+python -m pip install --upgrade pip wheel setuptools
+
+if ! python - <<'PY' >/dev/null 2>&1
+import torch
+assert torch.cuda.is_available() or True
+PY
+then
+ python -m pip install torch --index-url "${TORCH_INDEX_URL}"
+fi
+
+python -m pip install -r requirements.txt
+
+python - <<'PY'
+import torch
+import torch_geometric
+import ogb
+print("torch", torch.__version__, "cuda_available", torch.cuda.is_available())
+print("torch_geometric", torch_geometric.__version__)
+print("ogb", ogb.__version__)
+PY