import argparse from contextlib import contextmanager import json import os import time import numpy as np import torch import torch.nn as nn from ogb.graphproppred import Evaluator, PygGraphPropPredDataset from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder from torch_geometric.data import Batch from torch_geometric.loader import DataLoader from torch_geometric.nn import ( APPNP, ARMAConv, ChebConv, FiLMConv, GATv2Conv, GCNConv, GENConv, GINEConv, GraphConv, MFConv, PNAConv, ResGatedGraphConv, SAGEConv, SGConv, TAGConv, TransformerConv, global_add_pool, ) from torch_geometric.utils import degree PROJECT_ROOT = os.environ.get( "RROG_ROOT", os.path.abspath(os.path.join(os.path.dirname(__file__), "..")), ) DATA_ROOT = os.environ.get("RROG_DATA_DIR", os.path.join(PROJECT_ROOT, "data")) ROOT = os.path.join(DATA_ROOT, "ogb") OUT = os.environ.get("RROG_RUNS_DIR", os.path.join(PROJECT_ROOT, "runs")) SUPPORTED_VIEWS = [ "gin", "gine", "gcn", "graphsage", "gatv2", "graphconv", "transformer", "pna", "gen", "film", "resgated", "tag", "sgc", "cheb", "arma", "mf", "appnp", ] SUPPORTED_MOL_DATASETS = [ "ogbg-molhiv", "ogbg-molpcba", "ogbg-molbbbp", "ogbg-molbace", "ogbg-moltox21", "ogbg-molclintox", "ogbg-molsider", "ogbg-molesol", "ogbg-molfreesolv", "ogbg-mollipo", ] HIGHER_BETTER = {"rocauc", "ap", "auc", "accuracy", "acc", "f1"} LOWER_BETTER = {"rmse", "mae"} _TORCH_LOAD = torch.load def _torch_load_ogb_compat(*args, **kwargs): kwargs.setdefault("weights_only", False) return _TORCH_LOAD(*args, **kwargs) torch.load = _torch_load_ogb_compat def clone_state_dict(model): return {k: v.detach().clone() for k, v in model.state_dict().items()} @torch.no_grad() def update_ema_state(ema_state, model, decay): if ema_state is None: return for key, value in model.state_dict().items(): if torch.is_floating_point(value): ema_state[key].mul_(decay).add_(value.detach(), alpha=1.0 - decay) else: ema_state[key].copy_(value.detach()) @contextmanager def using_ema_state(model, ema_state): if ema_state is None: yield return raw_state = clone_state_dict(model) model.load_state_dict(ema_state, strict=True) try: yield finally: model.load_state_dict(raw_state, strict=True) def metric_direction(metric: str) -> int: metric = metric.lower() if metric in LOWER_BETTER or "rmse" in metric or "mae" in metric: return -1 return 1 def is_regression_metric(metric: str) -> bool: return metric_direction(metric) < 0 def is_better(score: float, best: float | None, metric: str) -> bool: if best is None: return True direction = metric_direction(metric) return score > best if direction > 0 else score < best def jsonable(obj): if isinstance(obj, dict): return {str(k): jsonable(v) for k, v in obj.items()} if isinstance(obj, (list, tuple)): return [jsonable(v) for v in obj] if isinstance(obj, (np.integer, np.floating)): return obj.item() if isinstance(obj, torch.Tensor): if obj.ndim == 0: return obj.detach().cpu().item() return obj.detach().cpu().tolist() return obj def degree_histogram(dataset) -> torch.Tensor: max_degree = 0 degs = [] for graph in dataset: deg = degree(graph.edge_index[1], num_nodes=graph.num_nodes, dtype=torch.long) degs.append(deg) if deg.numel(): max_degree = max(max_degree, int(deg.max().item())) hist = torch.zeros(max_degree + 1, dtype=torch.long) for deg in degs: hist += torch.bincount(deg, minlength=hist.numel()) return hist def make_view_layer(view: str, hidden: int, deg: torch.Tensor | None): if view in {"gin", "gine"}: mlp = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)) return GINEConv(mlp, train_eps=True) if view == "gcn": return GCNConv(hidden, hidden) if view == "graphsage": return SAGEConv(hidden, hidden) if view == "gatv2": return GATv2Conv(hidden, hidden, heads=4, concat=False, edge_dim=hidden) if view == "graphconv": return GraphConv(hidden, hidden) if view == "transformer": return TransformerConv(hidden, hidden, heads=4, concat=False, edge_dim=hidden) if view == "pna": if deg is None: raise ValueError("PNA view requires a training-set degree histogram") return PNAConv( hidden, hidden, aggregators=["mean", "min", "max", "std"], scalers=["identity", "amplification", "attenuation"], deg=deg, edge_dim=hidden, ) if view == "gen": return GENConv(hidden, hidden, edge_dim=hidden) if view == "film": return FiLMConv(hidden, hidden) if view == "resgated": return ResGatedGraphConv(hidden, hidden, edge_dim=hidden) if view == "tag": return TAGConv(hidden, hidden, K=3) if view == "sgc": return SGConv(hidden, hidden, K=2, cached=False) if view == "cheb": return ChebConv(hidden, hidden, K=3) if view == "arma": return ARMAConv(hidden, hidden, num_stacks=1, num_layers=2) if view == "mf": return MFConv(hidden, hidden) if view == "appnp": return APPNP(K=5, alpha=0.1) raise ValueError(f"unsupported OGB view: {view}") EDGE_ATTR_VIEWS = {"gin", "gine", "gatv2", "transformer", "pna", "gen", "resgated"} class OGBRRoG(nn.Module): def __init__( self, hidden, num_tasks, view="gin", T=1, n_sup=3, agg_layers=5, compute_layers=2, grad_mode="full", deg=None, ): super().__init__() self.view = view self.atom_encoder = AtomEncoder(hidden) self.bond_encoder = BondEncoder(hidden) self.agg_convs = nn.ModuleList() self.agg_bns = nn.ModuleList() for _ in range(agg_layers): self.agg_convs.append(make_view_layer(view, hidden, deg)) self.agg_bns.append(nn.BatchNorm1d(hidden)) core = [] d = hidden for _ in range(compute_layers - 1): core += [nn.Linear(d, hidden), nn.GELU()] d = hidden core.append(nn.Linear(d, hidden)) self.core_norm = nn.LayerNorm(hidden) self.core = nn.Sequential(*core) nn.init.zeros_(self.core[-1].weight) nn.init.zeros_(self.core[-1].bias) self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, num_tasks)) self.qhead = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 1)) with torch.no_grad(): self.qhead[-1].weight.zero_() self.qhead[-1].bias.fill_(-5.0) self.T = T self.n_sup = n_sup self.grad_mode = grad_mode self.agg_layers = agg_layers self.compute_layers = compute_layers self.hidden = hidden self.num_tasks = num_tasks def aggregate(self, x, edge_index, edge_attr): h = self.atom_encoder(x) e = self.bond_encoder(edge_attr) for conv, bn in zip(self.agg_convs, self.agg_bns): if self.view in {"gin", "gine", "pna", "gen", "resgated"}: h = bn(conv(h, edge_index, e)).relu() elif self.view in {"gatv2", "transformer"}: h = bn(conv(h, edge_index, edge_attr=e)).relu() else: h = bn(conv(h, edge_index)).relu() return h def core_step(self, combined, state): return state + self.core(self.core_norm(combined)) def _z_step(self, y, z, ctx): return self.core_step(ctx + y + z, z) def _y_step(self, y, z): return self.core_step(y + z, y) def recurse(self, y, z, ctx, one_step=False): if self.T == 0: return y, z if one_step: with torch.no_grad(): for _ in range(self.T - 1): z = self._z_step(y, z, ctx) z = z.detach() z = self._z_step(y, z, ctx) y = self._y_step(y, z) return y, z for _ in range(self.T): z = self._z_step(y, z, ctx) y = self._y_step(y, z) return y, z def predict(self, y, batch): pooled = global_add_pool(y, batch) return self.head(pooled), self.qhead(pooled).view(-1) def forward_trace(self, data, steps): ctx = self.aggregate(data.x, data.edge_index, data.edge_attr) y = ctx z = torch.zeros_like(ctx) preds, q_logits = [], [] for s in range(steps): y, z = self.recurse(y, z, ctx, one_step=(self.grad_mode == "1step")) pred, q = self.predict(y, data.batch) preds.append(pred) q_logits.append(q) if s < steps - 1: y, z = y.detach(), z.detach() return preds, q_logits def forward(self, data): steps = self.n_sup preds, q_logits = self.forward_trace(data, steps) return preds, q_logits[-1] def supervised_loss(logits, y, metric): per_graph, has_label = per_graph_supervised_loss(logits, y, metric) if not has_label.any(): return logits.sum() * 0.0 return per_graph[has_label].mean() def per_graph_supervised_loss(logits, y, metric): y = y.to(torch.float32) mask = ~torch.isnan(y) target = torch.where(mask, y, torch.zeros_like(y)) if is_regression_metric(metric): losses = (logits - target).pow(2) else: losses = nn.functional.binary_cross_entropy_with_logits(logits, target, reduction="none") losses = torch.where(mask, losses, torch.zeros_like(losses)) denom = mask.sum(dim=1).clamp_min(1) return losses.sum(dim=1) / denom, mask.any(dim=1) @torch.no_grad() def halt_targets(logits, y): y = y.to(torch.float32) mask = ~torch.isnan(y) pred = (logits > 0).to(y.dtype) correct_or_missing = (~mask) | (pred == y) has_label = mask.any(dim=1) return (correct_or_missing.all(dim=1) & has_label).to(logits.dtype) @torch.no_grad() def evaluate(model, loader, evaluator, dev, steps=None, adaptive=False, halt_min_steps=1, max_batches=0): model.eval() ys, ps = [], [] step_sum = 0.0 n = 0 for i, batch in enumerate(loader): if max_batches and i >= max_batches: break batch = batch.to(dev) if steps is None: preds, _ = model(batch) pred = preds[-1] used_steps = float(model.n_sup) else: preds, qs = model.forward_trace(batch, steps) stack = torch.stack(preds, dim=0) if adaptive: q = torch.stack(qs, dim=0) step_ids = torch.arange(1, steps + 1, device=dev).view(-1, 1) halted = (q > 0) & (step_ids >= halt_min_steps) any_halt = halted.any(dim=0) first_halt = halted.to(torch.int64).argmax(dim=0) fallback = torch.full_like(first_halt, steps - 1) idx = torch.where(any_halt, first_halt, fallback) pred = stack[idx, torch.arange(stack.size(1), device=dev)] used_steps = float((idx.to(torch.float32) + 1).sum().item()) else: pred = stack[-1] used_steps = float(steps * batch.num_graphs) ys.append(batch.y.detach().cpu()) ps.append(pred.detach().cpu()) step_sum += used_steps n += batch.num_graphs y_true = torch.cat(ys, dim=0) y_pred = torch.cat(ps, dim=0) result = evaluator.eval({"y_true": y_true, "y_pred": y_pred}) return result, step_sum / max(n, 1) def _split_nodes(t, ptr): return [t[ptr[i].item():ptr[i + 1].item()].detach() for i in range(ptr.numel() - 1)] def act_train_step(model, state, replacement_batch, opt, dev, args, metric): replacement = replacement_batch.to_data_list() batch_size = len(replacement) if state is None: state = { "graphs": [None for _ in range(batch_size)], "y": [None for _ in range(batch_size)], "z": [None for _ in range(batch_size)], "steps": torch.zeros(batch_size, dtype=torch.long, device=dev), "halted": torch.ones(batch_size, dtype=torch.bool, device=dev), } halted_cpu = state["halted"].detach().cpu().tolist() for i, halted in enumerate(halted_cpu): if halted: state["graphs"][i] = replacement[i] batch = Batch.from_data_list(state["graphs"]).to(dev) ctx = model.aggregate(batch.x, batch.edge_index, batch.edge_attr) ptr = batch.ptr y_parts, z_parts = [], [] for i in range(batch_size): start, end = ptr[i].item(), ptr[i + 1].item() if halted_cpu[i] or state["y"][i] is None: y_parts.append(ctx[start:end]) z_parts.append(torch.zeros_like(ctx[start:end])) else: y_parts.append(state["y"][i].to(dev)) z_parts.append(state["z"][i].to(dev)) y = torch.cat(y_parts, dim=0) z = torch.cat(z_parts, dim=0) opt.zero_grad() y, z = model.recurse(y, z, ctx, one_step=(model.grad_mode == "1step")) logits, q = model.predict(y, batch.batch) pred_loss = supervised_loss(logits, batch.y, metric) if args.halt_target == "exact" and not is_regression_metric(metric): target = halt_targets(logits.detach(), batch.y) elif args.halt_target == "loss": per_graph_loss, has_label = per_graph_supervised_loss(logits.detach(), batch.y, metric) target = ((per_graph_loss <= args.halt_loss_threshold) & has_label).to(logits.dtype) else: raise ValueError(args.halt_target) q_loss = nn.functional.binary_cross_entropy_with_logits(q, target) loss = pred_loss + 0.5 * args.lam_q * q_loss y_det, z_det = y.detach(), z.detach() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() state["y"] = _split_nodes(y_det, ptr) state["z"] = _split_nodes(z_det, ptr) with torch.no_grad(): was_halted = state["halted"] steps = torch.where(was_halted, torch.zeros_like(state["steps"]), state["steps"]) + 1 halted = (steps >= args.halt_max_steps) | ((q.detach() > 0) & (steps >= args.halt_min_steps)) if args.halt_exploration_prob > 0 and args.halt_max_steps > 1: explore = torch.rand_like(q) < args.halt_exploration_prob min_steps = torch.where( explore, torch.randint(2, args.halt_max_steps + 1, steps.shape, device=dev), torch.zeros_like(steps), ) halted = halted & (steps >= min_steps) state["steps"] = steps state["halted"] = halted return state, { "loss": float(loss.detach().cpu()), "pred_loss": float(pred_loss.detach().cpu()), "q_loss": float(q_loss.detach().cpu()), "halted_frac": float(state["halted"].to(torch.float32).mean().detach().cpu()), "steps": float(state["steps"].to(torch.float32).mean().detach().cpu()), } def _halt_target(args, logits, y, metric): if args.halt_target == "exact" and not is_regression_metric(metric): return halt_targets(logits.detach(), y) if args.halt_target == "loss": per_graph_loss, has_label = per_graph_supervised_loss(logits.detach(), y, metric) return ((per_graph_loss <= args.halt_loss_threshold) & has_label).to(logits.dtype) raise ValueError(args.halt_target) def act_trace_train_step(model, batch, opt, dev, args, epoch, metric): batch = batch.to(dev) steps = max(1, args.halt_max_steps) opt.zero_grad() preds, qs = model.forward_trace(batch, steps) pred_loss = sum(supervised_loss(pred, batch.y, metric) for pred in preds) / len(preds) if epoch <= args.q_warmup_epochs: q_loss = pred_loss.detach() * 0.0 loss = pred_loss else: q_losses = [] for step_idx, (pred, q) in enumerate(zip(preds, qs), start=1): target = _halt_target(args, pred, batch.y, metric) if step_idx < args.halt_min_steps: target = torch.zeros_like(target) q_losses.append(nn.functional.binary_cross_entropy_with_logits(q, target)) q_loss = sum(q_losses) / len(q_losses) loss = pred_loss + 0.5 * args.lam_q * q_loss loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() with torch.no_grad(): q_stack = torch.stack(qs, dim=0) step_ids = torch.arange(1, steps + 1, device=dev).view(-1, 1) halted = (q_stack > 0) & (step_ids >= args.halt_min_steps) any_halt = halted.any(dim=0) first_halt = halted.to(torch.int64).argmax(dim=0) fallback = torch.full_like(first_halt, steps - 1) idx = torch.where(any_halt, first_halt, fallback) return { "loss": float(loss.detach().cpu()), "pred_loss": float(pred_loss.detach().cpu()), "q_loss": float(q_loss.detach().cpu()), "halted_frac": float(any_halt.to(torch.float32).mean().detach().cpu()), "steps": float((idx.to(torch.float32) + 1).mean().detach().cpu()), } def train_epoch(model, loader, opt, dev, args, act_state, ema_state, epoch, metric): model.train() metrics = [] for i, batch in enumerate(loader): if args.max_train_batches and i >= args.max_train_batches: break if args.compute == "rrog-act": m = act_trace_train_step(model, batch, opt, dev, args, epoch, metric) update_ema_state(ema_state, model, args.ema) metrics.append(m) continue batch = batch.to(dev) opt.zero_grad() preds, _ = model(batch) loss = sum(supervised_loss(pred, batch.y, metric) for pred in preds) / len(preds) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() update_ema_state(ema_state, model, args.ema) metrics.append({"loss": float(loss.detach().cpu()), "halted_frac": 0.0, "steps": float(model.n_sup)}) return act_state, metrics def metric_value(result, evaluator): return float(result[evaluator.eval_metric]) def main(): ap = argparse.ArgumentParser() ap.add_argument("--dataset", choices=SUPPORTED_MOL_DATASETS, default="ogbg-molhiv") ap.add_argument("--view", choices=SUPPORTED_VIEWS, default="gin") ap.add_argument("--compute", choices=["classic", "view-only", "fixed-rrog", "rrog-act"], default="rrog-act") ap.add_argument("--T", type=int, default=1) ap.add_argument("--n_sup", type=int, default=3) ap.add_argument("--hidden", type=int, default=128) ap.add_argument("--agg_layers", type=int, default=5) ap.add_argument("--compute_layers", type=int, default=2) ap.add_argument("--epochs", type=int, default=100) ap.add_argument("--eval_every", type=int, default=10) ap.add_argument("--lr", type=float, default=1e-3) ap.add_argument("--bs", type=int, default=128) ap.add_argument("--lam_q", type=float, default=1.0) ap.add_argument("--halt_max_steps", type=int, default=8) ap.add_argument("--halt_min_steps", type=int, default=1) ap.add_argument("--halt_target", choices=["exact", "loss"], default="loss") ap.add_argument("--halt_loss_threshold", type=float, default=0.2) ap.add_argument("--halt_exploration_prob", type=float, default=0.1) ap.add_argument("--q_warmup_epochs", type=int, default=0) ap.add_argument("--ema", type=float, default=0.0) ap.add_argument("--seed", type=int, default=0) ap.add_argument("--num_workers", type=int, default=0) ap.add_argument("--max_train_batches", type=int, default=0) ap.add_argument("--max_eval_batches", type=int, default=0) ap.add_argument("--device", default="auto") args = ap.parse_args() torch.manual_seed(args.seed) np.random.seed(args.seed) if args.compute == "classic": args.n_sup = 1 if args.device == "auto": dev = "cuda" if torch.cuda.is_available() else "cpu" else: dev = args.device os.makedirs(OUT, exist_ok=True) dataset = PygGraphPropPredDataset(name=args.dataset, root=ROOT) split_idx = dataset.get_idx_split() evaluator = Evaluator(args.dataset) metric = evaluator.eval_metric num_tasks = dataset.num_tasks train_dataset = dataset[split_idx["train"]] valid_dataset = dataset[split_idx["valid"]] test_dataset = dataset[split_idx["test"]] train_loader = DataLoader(train_dataset, batch_size=args.bs, shuffle=True, drop_last=True, num_workers=args.num_workers) valid_loader = DataLoader(valid_dataset, batch_size=256, shuffle=False, num_workers=args.num_workers) test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=args.num_workers) T = 0 if args.compute in ["classic", "view-only"] else args.T deg = degree_histogram(train_dataset) if args.view == "pna" else None model = OGBRRoG(args.hidden, num_tasks, view=args.view, T=T, n_sup=args.n_sup, agg_layers=args.agg_layers, compute_layers=args.compute_layers, deg=deg).to(dev) opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, max(args.epochs, 1)) ema_state = clone_state_dict(model) if args.ema > 0 else None t0 = time.time() best_val = None best = {} best_state = None act_state = None steps = max(1, args.halt_max_steps) for ep in range(args.epochs): act_state, train_metrics = train_epoch( model, train_loader, opt, dev, args, act_state, ema_state, ep + 1, metric) sched.step() if (ep + 1) % args.eval_every == 0 or ep == args.epochs - 1: with using_ema_state(model, ema_state): if args.compute == "rrog-act": val_result, fixed_val_steps = evaluate(model, valid_loader, evaluator, dev, steps=steps, adaptive=False, halt_min_steps=args.halt_min_steps, max_batches=args.max_eval_batches) val_adapt, adaptive_val_steps = evaluate(model, valid_loader, evaluator, dev, steps=steps, adaptive=True, halt_min_steps=args.halt_min_steps, max_batches=args.max_eval_batches) else: val_result, _ = evaluate(model, valid_loader, evaluator, dev, max_batches=args.max_eval_batches) val_score = metric_value(val_result, evaluator) if is_better(val_score, best_val, metric): best_val = val_score if args.compute == "rrog-act": test_fixed, fixed_steps = evaluate(model, test_loader, evaluator, dev, steps=steps, adaptive=False, halt_min_steps=args.halt_min_steps, max_batches=args.max_eval_batches) test_adapt, adaptive_steps = evaluate(model, test_loader, evaluator, dev, steps=steps, adaptive=True, halt_min_steps=args.halt_min_steps, max_batches=args.max_eval_batches) best = { "ep": ep + 1, "val": val_result, "val_adaptive": val_adapt, "test": test_fixed, "test_adaptive": test_adapt, "fixed_val_steps": fixed_val_steps, "adaptive_val_steps": adaptive_val_steps, "fixed_steps": fixed_steps, "adaptive_steps": adaptive_steps, } else: test_result, _ = evaluate(model, test_loader, evaluator, dev, max_batches=args.max_eval_batches) best = {"ep": ep + 1, "val": val_result, "test": test_result} best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} halted = sum(m.get("halted_frac", 0.0) for m in train_metrics) / max(len(train_metrics), 1) train_steps = sum(m.get("steps", 0.0) for m in train_metrics) / max(len(train_metrics), 1) msg = f"ep{ep+1} val_{evaluator.eval_metric}={val_score:.5f}" if args.compute == "rrog-act": msg += ( f" val_adapt_{evaluator.eval_metric}={metric_value(val_adapt, evaluator):.5f}" f" adapt_steps={adaptive_val_steps:.2f}" ) msg += f" halt={halted:.2f} train_steps={train_steps:.2f}" print(msg, flush=True) ema_tag = f"_ema{args.ema:g}" if args.ema > 0 else "" tag = ( f"{args.dataset}_{args.view}_{args.compute}_T{T}_ns{args.n_sup}_" f"h{args.hidden}_e{args.epochs}{ema_tag}_s{args.seed}" ) rep = { "dataset": args.dataset, "tag": tag, **vars(args), "metric": evaluator.eval_metric, "sec": round(time.time() - t0, 1), "dev": dev, **best, } with open(os.path.join(OUT, f"{tag}.json"), "w") as f: json.dump(jsonable(rep), f, indent=2) torch.save({ "state": best_state or model.state_dict(), "cfg": { "dataset": args.dataset, "hidden": args.hidden, "num_tasks": num_tasks, "T": T, "n_sup": args.n_sup, "agg_layers": args.agg_layers, "compute_layers": args.compute_layers, "compute": args.compute, "halt_max_steps": steps, "halt_min_steps": args.halt_min_steps, "halt_target": args.halt_target, "halt_loss_threshold": args.halt_loss_threshold, "view": args.view, }, }, os.path.join(OUT, f"ckpt_{tag}.pt")) print(f"[{tag}] best_ep={best.get('ep')} val={best.get('val')} test={best.get('test')} " f"adaptive={best.get('test_adaptive')} steps={best.get('adaptive_steps')}", flush=True) print(" wrote", os.path.join(OUT, f"{tag}.json")) if __name__ == "__main__": main()