diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-06-21 15:33:22 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-06-21 15:33:22 -0500 |
| commit | e42f575050efeeccb736385b43bed84e1129edb0 (patch) | |
| tree | 8ed04b42218cf4c90c0b9c29b40db149f1355f4a | |
Initial RRoG GNN runner
| -rw-r--r-- | .gitignore | 10 | ||||
| -rw-r--r-- | README.md | 100 | ||||
| -rw-r--r-- | diag/__init__.py | 0 | ||||
| -rw-r--r-- | diag/train_cycle.py | 188 | ||||
| -rw-r--r-- | diag/train_rec.py | 491 | ||||
| -rw-r--r-- | requirements.txt | 5 | ||||
| -rw-r--r-- | rrog/__init__.py | 2 | ||||
| -rw-r--r-- | rrog/backbones.py | 72 | ||||
| -rw-r--r-- | rrog/benchmarks.py | 44 | ||||
| -rw-r--r-- | rrog/cli.py | 176 | ||||
| -rw-r--r-- | rrog/collect_results.py | 239 | ||||
| -rw-r--r-- | rrog/collect_zinc.py | 137 | ||||
| -rw-r--r-- | rrog/registry.py | 57 | ||||
| -rwxr-xr-x | rrog/run_ogb_hiv_remaining.sh | 49 | ||||
| -rwxr-xr-x | rrog/run_zinc_gine.sh | 36 | ||||
| -rwxr-xr-x | rrog/run_zinc_gine_after_pid.sh | 14 | ||||
| -rw-r--r-- | rrog/runspecs.py | 188 | ||||
| -rw-r--r-- | rrog/train_ogb_graphprop.py | 685 | ||||
| -rwxr-xr-x | scripts/collect_results.sh | 10 | ||||
| -rwxr-xr-x | scripts/run_ogb_mol_all_tasks.sh | 17 | ||||
| -rwxr-xr-x | scripts/run_ogb_mol_task_full.sh | 54 | ||||
| -rwxr-xr-x | scripts/run_smoke.sh | 19 | ||||
| -rwxr-xr-x | scripts/run_two_a6000.sh | 32 | ||||
| -rwxr-xr-x | scripts/run_zinc_cycle56_full.sh | 54 | ||||
| -rwxr-xr-x | scripts/setup_and_run_two_a6000.sh | 15 | ||||
| -rwxr-xr-x | scripts/setup_env.sh | 35 |
26 files changed, 2729 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..307c159 --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +__pycache__/ +*.py[cod] +.venv/ +data/ +runs/ +logs/ +summaries/ +*.pt +*.pth +.DS_Store diff --git a/README.md b/README.md new file mode 100644 index 0000000..f07c90f --- /dev/null +++ b/README.md @@ -0,0 +1,100 @@ +# RRoG-GNN Runner + +This repo runs the current RRoG/TRM-on-GNN experiment grid. + +Core rule: + +```text +view/graph aggregation happens once; recursive compute is edge-free hidden-state refinement. +``` + +The main reported table is: + +```text +Task x Backbone -> classic baseline +Task x Backbone x fixed-RRoG -> delta against the matching classic row +``` + +`classic` is the non-RRoG baseline for every backbone: `T=0`, `n_sup=1`. + +## One-command Run On 2x A6000 + +On a clean machine with two visible GPUs: + +```bash +git clone git@github.com:YurenHao0426/rrog-gnn-runner.git +cd rrog-gnn-runner +./scripts/setup_and_run_two_a6000.sh +``` + +Defaults: + +- GPU0: `zinc-cycle56` over 17 backbones, `classic + fixed-rrog` +- GPU1: `ogbg-molhiv` over 17 backbones, `classic + fixed-rrog` +- Results: `runs/*.json` +- Logs: `logs/*.log` +- Summaries: `summaries/*.md` + +If the environment already has compatible `torch`, `torch_geometric`, and `ogb`: + +```bash +SKIP_SETUP=1 ./scripts/setup_and_run_two_a6000.sh +``` + +To override CUDA wheel index during setup: + +```bash +TORCH_INDEX_URL=https://download.pytorch.org/whl/cu121 ./scripts/setup_env.sh +``` + +## Common Commands + +Smoke test: + +```bash +./scripts/setup_env.sh +DEVICE=cuda:0 ./scripts/run_smoke.sh +``` + +Run the paired ZINC matrix only: + +```bash +DEVICE=cuda:0 EPOCHS=200 ./scripts/run_zinc_cycle56_full.sh +``` + +Run one OGB molecular task: + +```bash +TASK=ogbg-molhiv DEVICE=cuda:1 EPOCHS=100 ./scripts/run_ogb_mol_task_full.sh +``` + +Run all selected OGB molecular tasks serially on one GPU: + +```bash +DEVICE=cuda:1 ./scripts/run_ogb_mol_all_tasks.sh +``` + +Collect summaries: + +```bash +./scripts/collect_results.sh +``` + +## Backbones + +The implemented 2D view/backbone list is shared across ZINC and OGB: + +```text +gin, gine, gcn, graphsage, gatv2, graphconv, transformer, pna, +gen, film, resgated, tag, sgc, cheb, arma, mf, appnp +``` + +For ZINC `gine`, there are no bond features, so GINE uses a learned constant edge token. +For OGB molecular tasks, GINE and edge-aware backbones use OGB bond encodings. + +## Notes + +- Runs are resumable at the cell level: scripts skip existing expected JSON files. +- ZINC cycle-count cache is generated under `data/cycle_cache`. +- OGB datasets are downloaded under `data/ogb`. +- Override data/runs locations with `RROG_DATA_DIR` and `RROG_RUNS_DIR`. diff --git a/diag/__init__.py b/diag/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/diag/__init__.py diff --git a/diag/train_cycle.py b/diag/train_cycle.py new file mode 100644 index 0000000..598e349 --- /dev/null +++ b/diag/train_cycle.py @@ -0,0 +1,188 @@ +"""GNN-native ring-counting on real molecules (ZINC): regress [#5-cycles, #6-cycles]. + +k-cycle counts (k>=3) are provably NOT computable by 1-WL/MPNN (Chen et al. 2020) -> a REAL +H2 ceiling on REAL graphs. Training-based diagnosis (partition instrument is vacuous on +feature-rich graphs): + GIN(L) 1-WL baseline -> should FAIL to count + GCN(L) sub-1-WL reference + GIN+RNI random feats = NOISE -> PTRM-style crude symmetry break (eval-averaged) + GIN+RWSE random-walk return probs-> structured >1-WL positive control +Reads: GIN high error + RWSE fixes it = real ceiling exists; RNI also fixes = crude noise +breaks it (bridge cashed); only RWSE = bridge needs STRUCTURED stochasticity (GRAM>PTRM). +Targets z-scored for training; per-target MAE reported in RAW ring units. + +Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_cycle.py --conv gin --feat none +""" +import argparse, json, os, time +import numpy as np +import torch +import torch.nn as nn +import networkx as nx +from torch_geometric.datasets import ZINC +from torch_geometric.data import Data +from torch_geometric.loader import DataLoader +from torch_geometric.utils import to_networkx +from torch_geometric.nn import GINConv, GCNConv, global_add_pool + +PROJECT_ROOT = os.environ.get( + 'RROG_ROOT', + os.path.abspath(os.path.join(os.path.dirname(__file__), '..')), +) +DATA_ROOT = os.environ.get('RROG_DATA_DIR', os.path.join(PROJECT_ROOT, 'data')) +OUT = os.environ.get('RROG_RUNS_DIR', os.path.join(PROJECT_ROOT, 'runs')) +ROOT = os.path.join(DATA_ROOT, 'zinc') +CACHE = os.path.join(DATA_ROOT, 'cycle_cache') +RWSE_K = 16 + + +def rwse(edge_index, n, K=RWSE_K): + A = np.zeros((n, n), dtype=np.float64) + ei = edge_index.numpy() + A[ei[0], ei[1]] = 1.0 + A = np.maximum(A, A.T) + deg = A.sum(1) + P = A / np.where(deg > 0, deg, 1.0)[:, None] + out = np.zeros((n, K), dtype=np.float32) + M = np.eye(n) + for k in range(K): + M = M @ P + out[:, k] = np.diag(M) + return torch.from_numpy(out) + + +def c56(data): + G = to_networkx(data, to_undirected=True) + c = {5: 0, 6: 0} + for cyc in nx.simple_cycles(G, length_bound=6): + L = len(cyc) + if L in c: + c[L] += 1 + return [float(c[5]), float(c[6])] + + +def prepare(split): + os.makedirs(CACHE, exist_ok=True) + fp = os.path.join(CACHE, f"{split}.pt") + if os.path.exists(fp): + return torch.load(fp, weights_only=False) + ds = ZINC(ROOT, subset=True, split=split) + out = [] + for g in ds: + out.append({'x': g.x.view(-1).long(), 'edge_index': g.edge_index, + 'rwse': rwse(g.edge_index, g.num_nodes), + 'y': torch.tensor(c56(g), dtype=torch.float)}) + torch.save(out, fp) + return out + + +class Net(nn.Module): + def __init__(self, n_atom, hidden, layers, conv='gin', rni=0, use_rwse=False): + super().__init__() + self.emb = nn.Embedding(n_atom, hidden) + self.rni, self.use_rwse = rni, use_rwse + din = hidden + rni + (RWSE_K if use_rwse else 0) + self.lin_in = nn.Linear(din, hidden) + self.convs, self.bns = nn.ModuleList(), nn.ModuleList() + for _ in range(layers): + if conv == 'gin': + self.convs.append(GINConv(nn.Sequential( + nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True)) + else: + self.convs.append(GCNConv(hidden, hidden)) + self.bns.append(nn.BatchNorm1d(hidden)) + self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 2)) + + def forward(self, x, edge_index, batch, rwse=None): + h = self.emb(x) + parts = [h] + if self.use_rwse: + parts.append(rwse) + if self.rni: + parts.append(torch.randn(h.size(0), self.rni, device=h.device)) + h = self.lin_in(torch.cat(parts, dim=1)) + for conv, bn in zip(self.convs, self.bns): + h = bn(conv(h, edge_index)).relu() + return self.head(global_add_pool(h, batch)) + + +def to_loader(recs, bs, shuffle, drop_last=False): + data = [Data(x=r['x'], edge_index=r['edge_index'], rwse=r['rwse'], + y=r['y'].view(1, 2), num_nodes=r['x'].numel()) for r in recs] + return DataLoader(data, batch_size=bs, shuffle=shuffle, drop_last=drop_last) + + +@torch.no_grad() +def eval_mae(model, loader, dev, ymu, ysd, samples=1): + model.eval(); abs_err = torch.zeros(2); n = 0 + for b in loader: + b = b.to(dev) + ps = torch.stack([model(b.x, b.edge_index, b.batch, b.rwse) for _ in range(samples)]).mean(0) + pr = ps * ysd.to(dev) + ymu.to(dev) # un-standardize -> raw ring units + yr = b.y * ysd.to(dev) + ymu.to(dev) + abs_err += (pr - yr).abs().sum(0).cpu(); n += b.num_graphs + return (abs_err / n).tolist() + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--conv', choices=['gin', 'gcn'], default='gin') + ap.add_argument('--feat', choices=['none', 'rni', 'rwse'], default='none') + ap.add_argument('--layers', type=int, default=5) + ap.add_argument('--hidden', type=int, default=128) + ap.add_argument('--rni_dim', type=int, default=16) + ap.add_argument('--epochs', type=int, default=200) + ap.add_argument('--lr', type=float, default=1e-3) + ap.add_argument('--bs', type=int, default=128) + ap.add_argument('--seed', type=int, default=0) + args = ap.parse_args() + torch.manual_seed(args.seed); np.random.seed(args.seed) + dev = 'cuda' if torch.cuda.is_available() else 'cpu' + os.makedirs(OUT, exist_ok=True) + + tr, va, te = prepare('train'), prepare('val'), prepare('test') + n_atom = int(max(r['x'].max() for r in tr + va + te)) + 1 + Ytr = torch.stack([r['y'] for r in tr]) + ymu, ysd = Ytr.mean(0), Ytr.std(0) + 1e-8 + for recs in (tr, va, te): + for r in recs: + r['y'] = (r['y'] - ymu) / ysd + + rni = args.rni_dim if args.feat == 'rni' else 0 + use_rwse = args.feat == 'rwse' + samples = 8 if rni else 1 + model = Net(n_atom, args.hidden, args.layers, args.conv, rni, use_rwse).to(dev) + opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5) + sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs) + lossf = nn.L1Loss() + trl = to_loader(tr, args.bs, True, drop_last=True) + trl_e, val, tel = to_loader(tr, 256, False), to_loader(va, 256, False), to_loader(te, 256, False) + + t0 = time.time(); best_val = 9e9; best = {} + for ep in range(args.epochs): + model.train() + for b in trl: + b = b.to(dev); opt.zero_grad() + loss = lossf(model(b.x, b.edge_index, b.batch, b.rwse), b.y) + loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step() + sched.step() + if (ep + 1) % 20 == 0 or ep == args.epochs - 1: + vm = eval_mae(model, val, dev, ymu, ysd, samples) + if sum(vm) < best_val: + best_val = sum(vm) + best = {'ep': ep + 1, 'train_mae': eval_mae(model, trl_e, dev, ymu, ysd, samples), + 'val_mae': vm, 'test_mae': eval_mae(model, tel, dev, ymu, ysd, samples)} + print(f"ep{ep+1} val_mae(c5,c6)={[round(x,3) for x in vm]}", flush=True) + + tag = f"{args.conv}_{args.feat}_L{args.layers}_s{args.seed}" + rep = {'dataset': 'ZINC-cycle56', 'tag': tag, **vars(args), + 'y_std_raw': ysd.tolist(), 'sec': round(time.time() - t0, 1), 'dev': dev, **best} + tm = best.get('test_mae'); trm = best.get('train_mae') + print(f"[{tag}] train_mae(c5,c6)={[round(x,3) for x in trm]} test_mae={[round(x,3) for x in tm]} " + f"(raw rings; std={ [round(x,2) for x in ysd.tolist()] }) @ep{best.get('ep')} ({rep['sec']}s)") + with open(os.path.join(OUT, f"cyc_{tag}.json"), 'w') as f: + json.dump(rep, f, indent=2) + print(" wrote", os.path.join(OUT, f"cyc_{tag}.json")) + + +if __name__ == "__main__": + main() diff --git a/diag/train_rec.py b/diag/train_rec.py new file mode 100644 index 0000000..9db7eb1 --- /dev/null +++ b/diag/train_rec.py @@ -0,0 +1,491 @@ +"""Step-2: RRoG/TRM-on-GNN for ZINC ring-counting. + +The graph is encoded once with a GIN encoder. A shared edge-free node-wise compute block then +refines hidden state over n_sup*T recurrent steps (TRM-style: carry latent detached between +deep-supervision steps). --grad_mode controls the LAST supervision step's recursion: + full : backprop through all T inner recursions (TRM) + 1step : backprop only the last inner recursion, first T-1 detached (HRM 1-step-gradient) + +Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_rec.py --grad_mode full --sigma 0 --K 1 +""" +import argparse, json, os, time +import numpy as np +import torch +import torch.nn as nn +from torch_geometric.loader import DataLoader +from torch_geometric.data import Batch, Data +from torch_geometric.nn import ( + APPNP, + ARMAConv, + ChebConv, + FiLMConv, + GATv2Conv, + GCNConv, + GENConv, + GINEConv, + GINConv, + GraphConv, + MFConv, + PNAConv, + ResGatedGraphConv, + SAGEConv, + SGConv, + TAGConv, + TransformerConv, + global_add_pool, +) +from torch_geometric.utils import degree +from diag.train_cycle import prepare + +PROJECT_ROOT = os.environ.get( + 'RROG_ROOT', + os.path.abspath(os.path.join(os.path.dirname(__file__), '..')), +) +OUT = os.environ.get('RROG_RUNS_DIR', os.path.join(PROJECT_ROOT, 'runs')) +SUPPORTED_VIEWS = [ + 'gin', 'gine', 'gcn', 'graphsage', 'gatv2', 'graphconv', 'transformer', 'pna', + 'gen', 'film', 'resgated', 'tag', 'sgc', 'cheb', 'arma', 'mf', 'appnp', +] + + +def data_list(recs): + return [Data(x=r['x'], edge_index=r['edge_index'], y=r['y'].view(1, 2), + num_nodes=r['x'].numel()) for r in recs] + + +def loader(recs, bs, shuffle, drop_last=False): + data = recs if recs and isinstance(recs[0], Data) else data_list(recs) + return DataLoader(data, batch_size=bs, shuffle=shuffle, drop_last=drop_last) + + +def degree_histogram(data): + max_degree = 0 + degs = [] + for graph in data: + deg = degree(graph.edge_index[1], num_nodes=graph.num_nodes, dtype=torch.long) + degs.append(deg) + if deg.numel(): + max_degree = max(max_degree, int(deg.max().item())) + hist = torch.zeros(max_degree + 1, dtype=torch.long) + for deg in degs: + hist += torch.bincount(deg, minlength=hist.numel()) + return hist + + +def make_view_layer(view, hidden, deg): + if view == 'gin': + return GINConv(nn.Sequential( + nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True) + if view == 'gine': + return GINEConv(nn.Sequential( + nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), + train_eps=True, edge_dim=hidden) + if view == 'gcn': + return GCNConv(hidden, hidden) + if view == 'graphsage': + return SAGEConv(hidden, hidden) + if view == 'gatv2': + return GATv2Conv(hidden, hidden, heads=4, concat=False) + if view == 'graphconv': + return GraphConv(hidden, hidden) + if view == 'transformer': + return TransformerConv(hidden, hidden, heads=4, concat=False) + if view == 'pna': + if deg is None: + raise ValueError('PNA view requires a training-set degree histogram') + return PNAConv( + hidden, hidden, + aggregators=['mean', 'min', 'max', 'std'], + scalers=['identity', 'amplification', 'attenuation'], + deg=deg, + ) + if view == 'gen': + return GENConv(hidden, hidden) + if view == 'film': + return FiLMConv(hidden, hidden) + if view == 'resgated': + return ResGatedGraphConv(hidden, hidden) + if view == 'tag': + return TAGConv(hidden, hidden, K=3) + if view == 'sgc': + return SGConv(hidden, hidden, K=2, cached=False) + if view == 'cheb': + return ChebConv(hidden, hidden, K=3) + if view == 'arma': + return ARMAConv(hidden, hidden, num_stacks=1, num_layers=2) + if view == 'mf': + return MFConv(hidden, hidden) + if view == 'appnp': + return APPNP(K=5, alpha=0.1) + raise ValueError(f'unsupported view: {view}') + + +class RecGIN(nn.Module): + def __init__(self, n_atom, hidden=128, T=3, n_sup=3, sigma=0.0, inner=2, + grad_mode='full', agg_layers=5, compute_layers=None, view='gin', deg=None): + super().__init__() + self.view = view + self.agg_layers = agg_layers + self.compute_layers = compute_layers or inner + self.emb = nn.Embedding(n_atom, hidden) + self.edge_emb = nn.Embedding(1, hidden) if view == 'gine' else None + self.agg_convs = nn.ModuleList() + for _ in range(agg_layers): + self.agg_convs.append(make_view_layer(view, hidden, deg)) + self.agg_bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(agg_layers)]) + core = [] + d = hidden + for _ in range(self.compute_layers - 1): + core += [nn.Linear(d, hidden), nn.GELU()] + d = hidden + core.append(nn.Linear(d, hidden)) + self.core_norm = nn.LayerNorm(hidden) + self.core = nn.Sequential(*core) + nn.init.zeros_(self.core[-1].weight) + nn.init.zeros_(self.core[-1].bias) + self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 2)) + self.qhead = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 1)) + with torch.no_grad(): + self.qhead[-1].weight.zero_() + self.qhead[-1].bias.fill_(-5.0) + self.T, self.n_sup, self.sigma, self.grad_mode = T, n_sup, sigma, grad_mode + + def aggregate(self, x, ei): + h = self.emb(x) + for conv, bn in zip(self.agg_convs, self.agg_bns): + if self.view == 'gine': + edge_attr = self.edge_emb(torch.zeros(ei.size(1), dtype=torch.long, device=ei.device)) + h = bn(conv(h, ei, edge_attr)).relu() + else: + h = bn(conv(h, ei)).relu() + return h + + def core_step(self, combined, state): + """Shared TRM compute core. Deliberately edge-free.""" + return state + self.core(self.core_norm(combined)) + + def _z_step(self, y, z, ctx, noise): + z = self.core_step(ctx + y + z, z) + if noise and self.sigma > 0: + z = z + self.sigma * torch.randn_like(z) + return z + + def _y_step(self, y, z, noise): + y = self.core_step(y + z, y) + if noise and self.sigma > 0: + y = y + self.sigma * torch.randn_like(y) + return y + + def recurse(self, y, z, ctx, noise, one_step=False): + if self.T == 0: + return y, z + if one_step: # HRM 1-step gradient + with torch.no_grad(): + for _ in range(self.T - 1): + z = self._z_step(y, z, ctx, noise) + z = z.detach() + z = self._z_step(y, z, ctx, noise) # only last inner carries grad + y = self._y_step(y, z, noise) + return y, z + for _ in range(self.T): # TRM full recursion + z = self._z_step(y, z, ctx, noise) + y = self._y_step(y, z, noise) + return y, z + + def predict(self, y, batch): + pooled = global_add_pool(y, batch) + return self.head(pooled), self.qhead(pooled).view(-1) + + def forward_trace(self, x, ei, batch, steps, noise=False): + ctx = self.aggregate(x, ei) + y = ctx + z = torch.zeros_like(ctx) + preds, q_logits = [], [] + for s in range(steps): + y, z = self.recurse(y, z, ctx, noise, one_step=(self.grad_mode == '1step')) + pred, q = self.predict(y, batch) + preds.append(pred) + q_logits.append(q) + if s < steps - 1: + y, z = y.detach(), z.detach() + return preds, q_logits + + def forward(self, x, ei, batch, noise=False): + ctx = self.aggregate(x, ei) + y = ctx + z = torch.zeros_like(ctx) + preds = [] + for s in range(self.n_sup): + if s < self.n_sup - 1: + with torch.no_grad(): + y, z = self.recurse(y, z, ctx, noise) + y, z = y.detach(), z.detach() + else: + y, z = self.recurse(y, z, ctx, noise, one_step=(self.grad_mode == '1step')) + pred, _ = self.predict(y, batch) + preds.append(pred) + _, q = self.predict(y, batch) + return preds, q + + +@torch.no_grad() +def evaluate(model, ld, dev, ymu, ysd, K=1, select='none'): + model.eval() + ysd_d, ymu_d = ysd.to(dev), ymu.to(dev) + ae = torch.zeros(2); ae_or = torch.zeros(2); n = 0 + for b in ld: + b = b.to(dev) + if K == 1: + preds, _ = model(b.x, b.edge_index, b.batch, noise=model.sigma > 0) + chosen = oracle = preds[-1] + else: + P, Q = [], [] + for _ in range(K): + preds, q = model(b.x, b.edge_index, b.batch, noise=True) + P.append(preds[-1]); Q.append(q) + P = torch.stack(P); Q = torch.stack(Q) + ar = torch.arange(P.size(1), device=dev) + chosen = P[Q.argmax(0), ar] if select == 'bestq' else P.mean(0) + oracle = P[(P - b.y.unsqueeze(0)).abs().sum(-1).argmin(0), ar] + ae += ((chosen * ysd_d + ymu_d) - (b.y * ysd_d + ymu_d)).abs().sum(0).cpu() + ae_or += ((oracle * ysd_d + ymu_d) - (b.y * ysd_d + ymu_d)).abs().sum(0).cpu() + n += b.num_graphs + return (ae / n).tolist(), (ae_or / n).tolist() + + +@torch.no_grad() +def evaluate_trace(model, ld, dev, ymu, ysd, steps, adaptive=False): + model.eval() + ysd_d, ymu_d = ysd.to(dev), ymu.to(dev) + ae = torch.zeros(2) + n = 0 + step_sum = 0.0 + for b in ld: + b = b.to(dev) + preds, q_logits = model.forward_trace(b.x, b.edge_index, b.batch, steps, noise=False) + P = torch.stack(preds, dim=0) + if adaptive: + Q = torch.stack(q_logits, dim=0) + halted = Q > 0 + any_halt = halted.any(dim=0) + first_halt = halted.to(torch.int64).argmax(dim=0) + fallback = torch.full_like(first_halt, steps - 1) + idx = torch.where(any_halt, first_halt, fallback) + chosen = P[idx, torch.arange(P.size(1), device=dev)] + step_sum += (idx.to(torch.float32) + 1).sum().item() + else: + chosen = P[-1] + step_sum += steps * b.num_graphs + ae += ((chosen * ysd_d + ymu_d) - (b.y * ysd_d + ymu_d)).abs().sum(0).cpu() + n += b.num_graphs + return (ae / n).tolist(), step_sum / max(n, 1) + + +def _split_nodes(t, ptr): + return [t[ptr[i].item():ptr[i + 1].item()].detach() for i in range(ptr.numel() - 1)] + + +def act_train_step(model, state, replacement_batch, opt, dev, args): + replacement = replacement_batch.to_data_list() + batch_size = len(replacement) + if state is None: + state = { + 'graphs': [None for _ in range(batch_size)], + 'y': [None for _ in range(batch_size)], + 'z': [None for _ in range(batch_size)], + 'steps': torch.zeros(batch_size, dtype=torch.long, device=dev), + 'halted': torch.ones(batch_size, dtype=torch.bool, device=dev), + } + + halted_cpu = state['halted'].detach().cpu().tolist() + for i, halted in enumerate(halted_cpu): + if halted: + state['graphs'][i] = replacement[i] + + b = Batch.from_data_list(state['graphs']).to(dev) + ctx = model.aggregate(b.x, b.edge_index) + ptr = b.ptr + y_parts, z_parts = [], [] + for i in range(batch_size): + start, end = ptr[i].item(), ptr[i + 1].item() + if halted_cpu[i] or state['y'][i] is None: + y_parts.append(ctx[start:end]) + z_parts.append(torch.zeros_like(ctx[start:end])) + else: + y_parts.append(state['y'][i].to(dev)) + z_parts.append(state['z'][i].to(dev)) + y = torch.cat(y_parts, dim=0) + z = torch.cat(z_parts, dim=0) + + opt.zero_grad() + y, z = model.recurse(y, z, ctx, noise=False, one_step=(model.grad_mode == '1step')) + pred, q = model.predict(y, b.batch) + per_graph_err = (pred - b.y).abs().mean(1) + pred_loss = per_graph_err.mean() + with torch.no_grad(): + if args.halt_target == 'binary': + halt_target = (per_graph_err <= args.halt_norm_threshold).to(q.dtype) + else: + halt_target = torch.sigmoid((args.halt_norm_threshold - per_graph_err) / args.halt_temp) + q_loss = nn.functional.binary_cross_entropy_with_logits(q, halt_target) + loss = pred_loss + 0.5 * args.lam_q * q_loss + y_det, z_det = y.detach(), z.detach() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + opt.step() + + state['y'] = _split_nodes(y_det, ptr) + state['z'] = _split_nodes(z_det, ptr) + with torch.no_grad(): + was_halted = state['halted'] + steps = torch.where(was_halted, torch.zeros_like(state['steps']), state['steps']) + 1 + halted = (steps >= args.halt_max_steps) | (q.detach() > 0) + if args.halt_exploration_prob > 0 and args.halt_max_steps > 1: + explore = torch.rand_like(q) < args.halt_exploration_prob + min_steps = torch.where( + explore, + torch.randint(2, args.halt_max_steps + 1, steps.shape, device=dev), + torch.zeros_like(steps), + ) + halted = halted & (steps >= min_steps) + state['steps'] = steps + state['halted'] = halted + + return state, { + 'loss': float(loss.detach().cpu()), + 'pred_loss': float(pred_loss.detach().cpu()), + 'q_loss': float(q_loss.detach().cpu()), + 'halted_frac': float(state['halted'].to(torch.float32).mean().detach().cpu()), + 'steps': float(state['steps'].to(torch.float32).mean().detach().cpu()), + } + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--grad_mode', choices=['full', '1step'], default='full') + ap.add_argument('--sigma', type=float, default=0.0) + ap.add_argument('--K', type=int, default=1) + ap.add_argument('--select', choices=['none', 'bestq'], default='bestq') + ap.add_argument('--T', type=int, default=3) + ap.add_argument('--n_sup', type=int, default=3) + ap.add_argument('--hidden', type=int, default=128) + ap.add_argument('--agg_layers', type=int, default=5) + ap.add_argument('--compute_layers', type=int, default=2) + ap.add_argument('--view', choices=SUPPORTED_VIEWS, default='gin') + ap.add_argument('--epochs', type=int, default=200) + ap.add_argument('--lr', type=float, default=1e-3) + ap.add_argument('--bs', type=int, default=128) + ap.add_argument('--lam_q', type=float, default=1.0) + ap.add_argument('--act', action='store_true', + help='train all recurrent depths up to halt_max_steps and train qhead as a halt head') + ap.add_argument('--halt_max_steps', type=int, default=8) + ap.add_argument('--halt_norm_threshold', type=float, default=0.30) + ap.add_argument('--halt_temp', type=float, default=0.10) + ap.add_argument('--halt_target', choices=['soft', 'binary'], default='soft') + ap.add_argument('--halt_exploration_prob', type=float, default=0.1) + ap.add_argument('--loss_mode', choices=['last', 'trace'], default='trace') + ap.add_argument('--seed', type=int, default=0) + ap.add_argument('--device', default='auto') + args = ap.parse_args() + torch.manual_seed(args.seed); np.random.seed(args.seed) + dev = 'cuda' if args.device == 'auto' and torch.cuda.is_available() else ( + 'cpu' if args.device == 'auto' else args.device) + os.makedirs(OUT, exist_ok=True) + + tr, va, te = prepare('train'), prepare('val'), prepare('test') + n_atom = int(max(r['x'].max() for r in tr + va + te)) + 1 + Ytr = torch.stack([r['y'] for r in tr]); ymu, ysd = Ytr.mean(0), Ytr.std(0) + 1e-8 + for recs in (tr, va, te): + for r in recs: + r['y'] = (r['y'] - ymu) / ysd + train_data = data_list(tr) + trl = loader(train_data, args.bs, True, drop_last=True) + val, tel = loader(va, 256, False), loader(te, 256, False) + + deg = degree_histogram(train_data) if args.view == 'pna' else None + model = RecGIN(n_atom, args.hidden, args.T, args.n_sup, args.sigma, grad_mode=args.grad_mode, + agg_layers=args.agg_layers, compute_layers=args.compute_layers, + view=args.view, deg=deg).to(dev) + opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5) + sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs) + l1 = nn.L1Loss() + act_steps = max(1, args.halt_max_steps) + + t0 = time.time(); best_val = 9e9; best = {}; best_state = None; act_state = None + for ep in range(args.epochs): + model.train() + act_metrics = [] + for b in trl: + if args.act: + act_state, metrics = act_train_step(model, act_state, b, opt, dev, args) + act_metrics.append(metrics) + else: + b = b.to(dev); opt.zero_grad() + if args.loss_mode == 'trace': + preds, q_logits = model.forward_trace( + b.x, b.edge_index, b.batch, args.n_sup, noise=model.sigma > 0) + q = q_logits[-1] + else: + preds, q = model(b.x, b.edge_index, b.batch, noise=model.sigma > 0) + loss = sum(l1(p, b.y) for p in preds) / len(preds) + with torch.no_grad(): + tq = -(preds[-1] - b.y).abs().mean(1) + loss = loss + args.lam_q * nn.functional.mse_loss(q, tq) + loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step() + sched.step() + if (ep + 1) % 20 == 0 or ep == args.epochs - 1: + if args.act: + vm, _ = evaluate_trace(model, val, dev, ymu, ysd, act_steps, adaptive=False) + else: + vm, _ = evaluate(model, val, dev, ymu, ysd, args.K, args.select) + if sum(vm) < best_val: + best_val = sum(vm) + if args.act: + tem, fixed_steps = evaluate_trace(model, tel, dev, ymu, ysd, act_steps, adaptive=False) + tea, adaptive_steps = evaluate_trace(model, tel, dev, ymu, ysd, act_steps, adaptive=True) + best = {'ep': ep + 1, 'val_mae': vm, 'test_mae': tem, + 'test_mae_adaptive': tea, 'fixed_steps': fixed_steps, + 'adaptive_steps': adaptive_steps} + else: + tem, teo = evaluate(model, tel, dev, ymu, ysd, args.K, args.select) + best = {'ep': ep + 1, 'val_mae': vm, 'test_mae': tem, 'test_mae_oracle': teo} + best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} + if args.act and act_metrics: + hm = sum(m['halted_frac'] for m in act_metrics) / len(act_metrics) + sm = sum(m['steps'] for m in act_metrics) / len(act_metrics) + print(f"ep{ep+1} val_mae={[round(x,3) for x in vm]} halt={hm:.2f} train_steps={sm:.2f}", flush=True) + else: + print(f"ep{ep+1} val_mae={[round(x,3) for x in vm]}", flush=True) + + act_tag = f"_actfull{act_steps}_{args.halt_target}{args.halt_norm_threshold:g}_e{args.epochs}" if args.act else "" + loss_tag = f"_{args.loss_mode}" if (not args.act and args.loss_mode != 'last') else "" + view_tag = f"_{args.view}" if args.view != 'gin' else "" + tag = f"rec_rrog{view_tag}_{args.grad_mode}_sig{args.sigma}_K{args.K}_{args.select}_T{args.T}_ns{args.n_sup}{loss_tag}{act_tag}_s{args.seed}" + rep = {'dataset': 'ZINC-cycle56', 'tag': tag, **vars(args), 'sec': round(time.time() - t0, 1), + 'dev': dev, 'arch': 'rrog_once_agg_node_compute', 'y_std_raw': ysd.tolist(), **best} + if args.act: + print(f"[{tag}] test_mae={[round(x,3) for x in best.get('test_mae')]} " + f"adaptive={[round(x,3) for x in best.get('test_mae_adaptive')]} " + f"steps={best.get('adaptive_steps'):.2f}/{best.get('fixed_steps'):.2f} " + f"@ep{best.get('ep')} ({rep['sec']}s)") + else: + print(f"[{tag}] test_mae={[round(x,3) for x in best.get('test_mae')]} " + f"oracle@K={[round(x,3) for x in best.get('test_mae_oracle')]} @ep{best.get('ep')} ({rep['sec']}s)") + with open(os.path.join(OUT, f"{tag}.json"), 'w') as f: + json.dump(rep, f, indent=2) + torch.save({'state': best_state or model.state_dict(), + 'cfg': {'n_atom': n_atom, 'hidden': args.hidden, 'T': args.T, 'n_sup': args.n_sup, + 'sigma': args.sigma, 'grad_mode': args.grad_mode, + 'agg_layers': args.agg_layers, 'compute_layers': args.compute_layers, + 'view': args.view, + 'loss_mode': args.loss_mode, + 'act': args.act, 'act_impl': 'persistent_recycle' if args.act else 'none', + 'halt_max_steps': act_steps, + 'halt_exploration_prob': args.halt_exploration_prob, + 'arch': 'rrog_once_agg_node_compute'}, + 'ymu': ymu, 'ysd': ysd}, os.path.join(OUT, f"ckpt_{tag}.pt")) + print(" wrote", os.path.join(OUT, f"ckpt_{tag}.pt")) + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..c0f2bc7 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +numpy +networkx +ogb +torch-geometric +tqdm diff --git a/rrog/__init__.py b/rrog/__init__.py new file mode 100644 index 0000000..c09281d --- /dev/null +++ b/rrog/__init__.py @@ -0,0 +1,2 @@ +"""Experiment registry for RRoG/TRM-on-GNN sweeps.""" + diff --git a/rrog/backbones.py b/rrog/backbones.py new file mode 100644 index 0000000..81bfb9b --- /dev/null +++ b/rrog/backbones.py @@ -0,0 +1,72 @@ +from rrog.registry import ComputeSpec, ModifierSpec, ViewSpec, by_name + + +VIEWS = [ + ViewSpec("gin", "message-passing", "2d", 1, "implemented", + "Plain GINConv message passing."), + ViewSpec("gine", "message-passing", "2d", 2, "implemented", + "Edge-aware GIN variant. ZINC uses a learned constant edge token; OGB uses bond features."), + ViewSpec("gcn", "message-passing", "2d", 3, "implemented"), + ViewSpec("graphsage", "message-passing", "2d", 4, "implemented"), + ViewSpec("gatv2", "attention-mpnn", "2d", 5, "implemented"), + ViewSpec("graphconv", "message-passing", "2d", 6, "implemented"), + ViewSpec("transformer", "attention-mpnn", "2d", 7, "implemented"), + ViewSpec("pna", "message-passing", "2d", 8, "implemented", + "Requires degree histogram from the train split."), + ViewSpec("gen", "message-passing", "2d", 9, "implemented"), + ViewSpec("film", "message-passing", "2d", 10, "implemented"), + ViewSpec("resgated", "message-passing", "2d", 11, "implemented"), + ViewSpec("tag", "higher-order-hop", "2d", 12, "implemented"), + ViewSpec("sgc", "propagation", "2d", 13, "implemented"), + ViewSpec("cheb", "spectral", "2d", 14, "implemented"), + ViewSpec("arma", "spectral", "2d", 15, "implemented"), + ViewSpec("mf", "message-passing", "2d", 16, "implemented"), + ViewSpec("appnp", "propagation", "2d", 17, "implemented"), + ViewSpec("mixhop", "higher-order-hop", "2d", 18), + ViewSpec("gps", "hybrid-local-global", "2d", 19), + ViewSpec("graphormer", "global-attention", "2d", 20), + ViewSpec("san", "spectral-attention", "2d", 21), + ViewSpec("mpnn", "message-passing", "2d", 22), + ViewSpec("schnet", "continuous-filter", "3d", 23), + ViewSpec("dimenetpp", "angle-aware", "3d", 24), + ViewSpec("painn", "equivariant", "3d", 25), + ViewSpec("gemnet", "equivariant", "3d", 26), + ViewSpec("egnn", "equivariant", "3d", 27), + ViewSpec("equiformer", "equivariant", "3d", 28), + ViewSpec("mace", "equivariant", "3d", 29), +] + + +COMPUTES = [ + ComputeSpec("classic", "baseline", 1, "implemented", "Standard one-forward GNN baseline; no RRoG compute."), + ComputeSpec("view-only", "none", 2, "implemented", "RRoG view module only; no recursive compute."), + ComputeSpec("fixed-rrog", "recursive", 3, "implemented", "Fixed-depth edge-free y/z compute."), + ComputeSpec("rrog-act", "recursive-act", 4, "implemented", "Persistent full ACT recycling for graph batches."), + ComputeSpec("node-mlp", "recursive", 4), + ComputeSpec("gru-rrog", "recursive-gated", 5), + ComputeSpec("set-attn-core", "edge-free-attention", 6), + ComputeSpec("perceiver-core", "latent-attention", 7), + ComputeSpec("global-token-mixer", "token-mixer", 8), + ComputeSpec("equivariant-core", "3d-equivariant", 9), +] + + +MODIFIERS = [ + ModifierSpec("none", "none", 1, "implemented"), + ModifierSpec("dfa-gnn", "backward", 2, "planned", + "Non-BP/direct-feedback-style training; start on node classification."), + ModifierSpec("kaft", "backward", 3, "planned", + "User project in ../graph-grape; low priority until main table is stable."), + ModifierSpec("deep-supervision", "training", 4), + ModifierSpec("sam", "optimizer", 5), + ModifierSpec("lap-pe", "feature", 6), + ModifierSpec("rwse", "feature", 7), + ModifierSpec("virtual-node", "feature", 8), + ModifierSpec("dropedge", "regularization", 9), + ModifierSpec("flag", "augmentation", 10), +] + + +VIEW_BY_NAME = by_name(VIEWS) +COMPUTE_BY_NAME = by_name(COMPUTES) +MODIFIER_BY_NAME = by_name(MODIFIERS) diff --git a/rrog/benchmarks.py b/rrog/benchmarks.py new file mode 100644 index 0000000..dcd356b --- /dev/null +++ b/rrog/benchmarks.py @@ -0,0 +1,44 @@ +from rrog.registry import BenchmarkSpec, by_name + + +BENCHMARKS = [ + BenchmarkSpec("zinc-cycle56", "A", "molecule-2d", "graph-regression", "raw-mae", 1, + "implemented", "ZINC subset with #5/#6 cycle-count targets."), + BenchmarkSpec("zinc", "A", "molecule-2d", "graph-regression", "mae", 2), + BenchmarkSpec("ogbg-molhiv", "A", "molecule-2d", "graph-classification", "rocauc", 3, + "implemented", "OGB graph property prediction."), + BenchmarkSpec("ogbg-molpcba", "A", "molecule-2d", "graph-multilabel", "ap", 4, + "implemented", "OGB graph property prediction."), + BenchmarkSpec("ogbg-molbbbp", "A", "molecule-2d", "graph-classification", "rocauc", 5, + "implemented", "OGB graph property prediction."), + BenchmarkSpec("ogbg-molbace", "A", "molecule-2d", "graph-classification", "rocauc", 6, + "implemented", "OGB graph property prediction."), + BenchmarkSpec("ogbg-moltox21", "A", "molecule-2d", "graph-multilabel", "rocauc", 7, + "implemented", "OGB graph property prediction."), + BenchmarkSpec("ogbg-molclintox", "A", "molecule-2d", "graph-multilabel", "rocauc", 8, + "implemented", "OGB graph property prediction."), + BenchmarkSpec("ogbg-molsider", "A", "molecule-2d", "graph-multilabel", "rocauc", 9, + "implemented", "OGB graph property prediction."), + BenchmarkSpec("ogbg-molesol", "A", "molecule-2d", "graph-regression", "rmse", 10, + "implemented", "OGB ESOL graph property prediction."), + BenchmarkSpec("ogbg-molfreesolv", "A", "molecule-2d", "graph-regression", "rmse", 11, + "implemented", "OGB FreeSolv graph property prediction."), + BenchmarkSpec("ogbg-mollipo", "A", "molecule-2d", "graph-regression", "rmse", 12, + "implemented", "OGB Lipophilicity graph property prediction."), + BenchmarkSpec("pcqm4mv2", "A", "molecule-2d", "graph-regression", "mae", 13), + BenchmarkSpec("qm9", "A", "molecule-3d", "graph-regression", "mae", 14), + BenchmarkSpec("peptides-func", "B", "long-range", "graph-multilabel", "ap", 15), + BenchmarkSpec("peptides-struct", "B", "long-range", "graph-regression", "mae", 16), + BenchmarkSpec("pcqm-contact", "B", "long-range", "link-prediction", "mrr", 17), + BenchmarkSpec("pascalvoc-sp", "B", "superpixel", "node-classification", "f1", 18), + BenchmarkSpec("coco-sp", "B", "superpixel", "node-classification", "f1", 19), + BenchmarkSpec("ogbn-arxiv", "B", "citation", "node-classification", "accuracy", 20), + BenchmarkSpec("ogbn-products", "B", "commerce", "node-classification", "accuracy", 21), + BenchmarkSpec("rmd17", "C", "molecule-3d", "energy-force", "mae", 22), + BenchmarkSpec("oc20-s2ef", "C", "catalyst-3d", "energy-force", "mae", 23), + BenchmarkSpec("oc22", "C", "catalyst-3d", "energy-force", "mae", 24), + BenchmarkSpec("matbench-discovery", "C", "materials", "stability", "discovery-metrics", 25), + BenchmarkSpec("tgb-subset", "C", "temporal", "temporal-link-node", "dataset-specific", 26), +] + +BENCHMARK_BY_NAME = by_name(BENCHMARKS) diff --git a/rrog/cli.py b/rrog/cli.py new file mode 100644 index 0000000..9c826e3 --- /dev/null +++ b/rrog/cli.py @@ -0,0 +1,176 @@ +import argparse +import json +import os +import subprocess + +from rrog.backbones import COMPUTES, MODIFIERS, VIEWS +from rrog.benchmarks import BENCHMARKS +from rrog.collect_results import print_tables +from rrog.collect_zinc import print_zinc +from rrog.runspecs import find_run_spec + + +def _rows(items): + return [ + { + "name": item.name, + "tier": getattr(item, "tier", None), + "family": getattr(item, "family", None), + "domain": getattr(item, "domain", None), + "task_type": getattr(item, "task_type", None), + "metric": getattr(item, "metric", None), + "priority": item.priority, + "status": item.status, + "notes": item.notes, + } + for item in items + ] + + +def _print_table(items): + for row in _rows(items): + cols = [row["name"], row.get("tier") or row.get("family") or "", row["status"], row["notes"]] + print("\t".join(str(c) for c in cols)) + + +def list_axis(axis: str, as_json: bool): + mapping = { + "benchmarks": BENCHMARKS, + "views": VIEWS, + "computes": COMPUTES, + "modifiers": MODIFIERS, + } + items = mapping[axis] + if as_json: + print(json.dumps(_rows(items), indent=2)) + else: + _print_table(items) + + +def build_command(args) -> list[str]: + spec = find_run_spec(args.task, args.view, args.compute, args.modifier) + run_args = dict(spec.default_args) + for key in [ + "epochs", "hidden", "bs", "seed", "T", "n_sup", "halt_max_steps", "halt_target", + "halt_min_steps", "halt_loss_threshold", "q_warmup_epochs", + "eval_every", "max_train_batches", "max_eval_batches", "num_workers", "ema", "lr", "lam_q", + "device", + ]: + value = getattr(args, key) + if value is not None: + run_args[key] = value + run_args["compute"] = args.compute + return spec.command_builder(run_args) + + +def print_matrix(args): + tasks = [b for b in BENCHMARKS if args.tier == "all" or b.tier == args.tier] + tasks = sorted(tasks, key=lambda x: x.priority)[:args.limit_tasks] + + if args.kind == "main": + views = sorted(VIEWS, key=lambda x: x.priority)[:args.limit_views] + computes = [c for c in COMPUTES if c.name in ["classic", "view-only", "fixed-rrog", "rrog-act"]] + for task in tasks: + for view in views: + for compute in computes: + try: + find_run_spec(task.name, view.name, compute.name) + status = "implemented" + except KeyError: + status = "planned" + print(f"{task.name}\t{view.name}\t{compute.name}\tnone\t{status}") + return + + if args.kind == "modifier": + task_names = {"zinc-cycle56", "ogbg-molhiv", "peptides-struct", "peptides-func", "ogbn-arxiv", "qm9"} + mods = [m for m in sorted(MODIFIERS, key=lambda x: x.priority) if m.name != "none"] + for task in tasks: + if task.name not in task_names: + continue + for mod in mods: + print(f"{task.name}\tgin\tview-only\t{mod.name}\tplanned") + return + + raise ValueError(args.kind) + + +def main(): + ap = argparse.ArgumentParser() + sub = ap.add_subparsers(dest="cmd", required=True) + + lp = sub.add_parser("list") + lp.add_argument("axis", choices=["benchmarks", "views", "computes", "modifiers"]) + lp.add_argument("--json", action="store_true") + + rp = sub.add_parser("run") + rp.add_argument("--task", default="zinc-cycle56") + rp.add_argument("--view", default="gin") + rp.add_argument("--compute", default="rrog-act") + rp.add_argument("--modifier", default="none") + rp.add_argument("--epochs", type=int) + rp.add_argument("--hidden", type=int) + rp.add_argument("--bs", type=int) + rp.add_argument("--seed", type=int) + rp.add_argument("--T", type=int) + rp.add_argument("--n_sup", type=int) + rp.add_argument("--halt_max_steps", type=int) + rp.add_argument("--halt_target", choices=["soft", "binary", "exact", "loss"]) + rp.add_argument("--halt_min_steps", type=int) + rp.add_argument("--halt_loss_threshold", type=float) + rp.add_argument("--q_warmup_epochs", type=int) + rp.add_argument("--eval_every", type=int) + rp.add_argument("--max_train_batches", type=int) + rp.add_argument("--max_eval_batches", type=int) + rp.add_argument("--num_workers", type=int) + rp.add_argument("--ema", type=float) + rp.add_argument("--lr", type=float) + rp.add_argument("--lam_q", type=float) + rp.add_argument("--device") + rp.add_argument("--dry-run", action="store_true") + + mp = sub.add_parser("matrix") + mp.add_argument("--kind", choices=["main", "modifier"], default="main") + mp.add_argument("--tier", choices=["A", "B", "C", "all"], default="A") + mp.add_argument("--limit-tasks", type=int, default=20) + mp.add_argument("--limit-views", type=int, default=6) + + op = sub.add_parser("results") + op.add_argument("--runs-dir", default="runs") + op.add_argument("--glob", default="*.json") + op.add_argument("--min-epochs", type=int, default=10) + op.add_argument("--epochs", type=int) + op.add_argument("--digits", type=int, default=4) + + zp = sub.add_parser("zinc-results") + zp.add_argument("--runs-dir", default="runs") + zp.add_argument("--epochs", type=int) + zp.add_argument("--min-epochs", type=int, default=10) + zp.add_argument("--digits", type=int, default=4) + + args = ap.parse_args() + if args.cmd == "list": + list_axis(args.axis, args.json) + return + if args.cmd == "matrix": + print_matrix(args) + return + if args.cmd == "results": + print_tables(args) + return + if args.cmd == "zinc-results": + print_zinc(args) + return + + cmd = build_command(args) + env = dict(os.environ) + env["PYTHONPATH"] = os.getcwd() + os.pathsep + env.get("PYTHONPATH", "") + print(" ".join(cmd), flush=True) + if not args.dry_run: + raise SystemExit(subprocess.call(cmd, env=env)) + + +if __name__ == "__main__": + try: + main() + except BrokenPipeError: + raise SystemExit(0) diff --git a/rrog/collect_results.py b/rrog/collect_results.py new file mode 100644 index 0000000..125c0ab --- /dev/null +++ b/rrog/collect_results.py @@ -0,0 +1,239 @@ +import argparse +import json +import math +from collections import defaultdict +from pathlib import Path + + +HIGHER_BETTER = { + "accuracy", + "ap", + "auc", + "f1", + "mrr", + "rocauc", +} + +LOWER_BETTER = { + "mae", + "raw-mae", + "rmse", +} + + +def _metric_direction(metric: str) -> int: + metric = metric.lower() + if metric in HIGHER_BETTER: + return 1 + if metric in LOWER_BETTER: + return -1 + if "mae" in metric or "rmse" in metric: + return -1 + return 1 + + +def _read(path: Path) -> dict | None: + try: + with path.open() as f: + rep = json.load(f) + except (OSError, json.JSONDecodeError): + return None + required = {"dataset", "view", "compute", "seed", "metric", "val"} + if not required.issubset(rep): + return None + return rep + + +def _score(rep: dict, split: str) -> float | None: + metric = rep.get("metric") + if not metric: + return None + value = rep.get(split, {}).get(metric) + if value is None: + return None + return float(value) + + +def _rank_key(rep: dict) -> tuple[int, int, int]: + return ( + int(rep.get("epochs", 0)), + int(rep.get("hidden", 0)), + int(rep.get("ep", 0) or 0), + ) + + +def _mean(xs: list[float]) -> float: + return sum(xs) / len(xs) + + +def _std(xs: list[float]) -> float: + if len(xs) < 2: + return 0.0 + mu = _mean(xs) + return math.sqrt(sum((x - mu) ** 2 for x in xs) / (len(xs) - 1)) + + +def _fmt(value: float | None, digits: int) -> str: + if value is None: + return "" + return f"{value:.{digits}f}" + + +def _is_classic_baseline(rep: dict) -> bool: + return rep.get("compute") == "classic" and int(rep.get("T", -1)) == 0 and int(rep.get("n_sup", -1)) == 1 + + +def _compute_label(rep: dict) -> str: + label = str(rep["compute"]) + ema = float(rep.get("ema", 0.0) or 0.0) + if ema > 0: + label += f"+ema{ema:g}" + return label + + +def _choose_runs(paths: list[Path], min_epochs: int, epochs: int | None) -> list[dict]: + candidates: dict[tuple[str, str, str, int], dict] = {} + for path in paths: + rep = _read(path) + if rep is None: + continue + if epochs is not None and int(rep.get("epochs", -1)) != epochs: + continue + if int(rep.get("epochs", 0)) < min_epochs: + continue + val = _score(rep, "val") + test = _score(rep, "test") + if val is None or test is None: + continue + key = (str(rep["dataset"]), str(rep["view"]), _compute_label(rep), int(rep["seed"])) + old = candidates.get(key) + if old is None or _rank_key(rep) > _rank_key(old): + candidates[key] = rep + return list(candidates.values()) + + +def _group_by_cell(runs: list[dict]) -> dict[tuple[str, str, str], list[dict]]: + grouped: dict[tuple[str, str, str], list[dict]] = defaultdict(list) + for rep in runs: + grouped[(str(rep["dataset"]), str(rep["view"]), _compute_label(rep))].append(rep) + return dict(grouped) + + +def _summarize_cell(reps: list[dict], split: str) -> tuple[float, float, int]: + scores = [_score(rep, split) for rep in reps] + xs = [x for x in scores if x is not None] + return _mean(xs), _std(xs), len(xs) + + +def _markdown_table(headers: list[str], rows: list[list[str]]) -> str: + out = ["| " + " | ".join(headers) + " |", "| " + " | ".join(["---"] * len(headers)) + " |"] + out.extend("| " + " | ".join(row) + " |" for row in rows) + return "\n".join(out) + + +def print_tables(args) -> None: + paths = sorted(Path(args.runs_dir).glob(args.glob)) + runs = _choose_runs(paths, args.min_epochs, args.epochs) + grouped = _group_by_cell(runs) + + classic_cells = { + (task, view): reps + for (task, view, compute), reps in grouped.items() + if compute == "classic" + for reps in [[rep for rep in reps if _is_classic_baseline(rep)]] + if reps + } + + baseline_rows = [] + for (task, view), reps in sorted(classic_cells.items()): + metric = str(reps[0]["metric"]) + val_mu, val_sd, n = _summarize_cell(reps, "val") + test_mu, test_sd, _ = _summarize_cell(reps, "test") + baseline_rows.append([ + task, + view, + metric, + str(n), + f"{_fmt(val_mu, args.digits)} +/- {_fmt(val_sd, args.digits)}", + f"{_fmt(test_mu, args.digits)} +/- {_fmt(test_sd, args.digits)}", + ]) + + delta_rows = [] + for (task, view, compute), reps in sorted(grouped.items()): + if compute == "classic": + continue + base = classic_cells.get((task, view)) + if not base: + continue + metric = str(reps[0]["metric"]) + direction = _metric_direction(metric) + base_by_seed = {int(rep["seed"]): rep for rep in base} + paired = [] + for rep in reps: + seed = int(rep["seed"]) + if seed in base_by_seed: + paired.append((rep, base_by_seed[seed])) + if not paired: + base_test_mu, _, _ = _summarize_cell(base, "test") + base_val_mu, _, _ = _summarize_cell(base, "val") + paired = [(rep, None) for rep in reps] + else: + base_test_mu = None + base_val_mu = None + + val_scores, test_scores, val_deltas, test_deltas = [], [], [], [] + adaptive_steps = [] + for rep, base_rep in paired: + val = _score(rep, "val") + test = _score(rep, "test") + if val is None or test is None: + continue + if base_rep is None: + base_val = base_val_mu + base_test = base_test_mu + else: + base_val = _score(base_rep, "val") + base_test = _score(base_rep, "test") + if base_val is None or base_test is None: + continue + val_scores.append(val) + test_scores.append(test) + val_deltas.append(direction * (val - base_val)) + test_deltas.append(direction * (test - base_test)) + if rep.get("adaptive_steps") is not None: + adaptive_steps.append(float(rep["adaptive_steps"])) + + if not test_scores: + continue + delta_rows.append([ + task, + view, + compute, + metric, + str(len(test_scores)), + f"{_fmt(_mean(val_scores), args.digits)} ({_fmt(_mean(val_deltas), args.digits)})", + f"{_fmt(_mean(test_scores), args.digits)} ({_fmt(_mean(test_deltas), args.digits)})", + _fmt(_mean(adaptive_steps), 2) if adaptive_steps else "", + ]) + + print("\nClassic baseline: task x backbone") + print(_markdown_table(["task", "backbone", "metric", "n", "val", "test"], baseline_rows)) + print("\nDelta vs matching classic") + print(_markdown_table([ + "task", "backbone", "compute", "metric", "n", "val score (delta)", "test score (delta)", "steps" + ], delta_rows)) + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--runs-dir", default="runs") + ap.add_argument("--glob", default="*.json") + ap.add_argument("--min-epochs", type=int, default=10) + ap.add_argument("--epochs", type=int) + ap.add_argument("--digits", type=int, default=4) + args = ap.parse_args() + print_tables(args) + + +if __name__ == "__main__": + main() diff --git a/rrog/collect_zinc.py b/rrog/collect_zinc.py new file mode 100644 index 0000000..49ae1e9 --- /dev/null +++ b/rrog/collect_zinc.py @@ -0,0 +1,137 @@ +import argparse +import json +import math +from collections import defaultdict +from pathlib import Path + + +def _read(path: Path) -> dict | None: + try: + with path.open() as f: + rep = json.load(f) + except (OSError, json.JSONDecodeError): + return None + if rep.get("dataset") != "ZINC-cycle56": + return None + if rep.get("K") != 1 or rep.get("select") != "none" or float(rep.get("sigma", 0.0)) != 0.0: + return None + if "test_mae" not in rep or "val_mae" not in rep: + return None + return rep + + +def _compute(rep: dict) -> str: + if rep.get("act"): + return f"rrog-act-T{rep.get('T')}-ns{rep.get('n_sup')}" + if int(rep.get("T", -1)) == 0 and int(rep.get("n_sup", -1)) == 1: + return "classic" + label = f"fixed-rrog-T{rep.get('T')}-ns{rep.get('n_sup')}" + if rep.get("loss_mode") == "trace": + label += "+trace" + return label + + +def _view(rep: dict) -> str: + return str(rep.get("view", "gin")) + + +def _score(rep: dict, split: str) -> float: + return float(sum(rep[f"{split}_mae"])) + + +def _mean(xs: list[float]) -> float: + return sum(xs) / len(xs) + + +def _std(xs: list[float]) -> float: + if len(xs) < 2: + return 0.0 + mu = _mean(xs) + return math.sqrt(sum((x - mu) ** 2 for x in xs) / (len(xs) - 1)) + + +def _fmt(x: float, digits: int) -> str: + return f"{x:.{digits}f}" + + +def _markdown(headers: list[str], rows: list[list[str]]) -> str: + lines = ["| " + " | ".join(headers) + " |", "| " + " | ".join(["---"] * len(headers)) + " |"] + lines.extend("| " + " | ".join(row) + " |" for row in rows) + return "\n".join(lines) + + +def print_zinc(args) -> None: + by_cell: dict[tuple[str, str, int], dict] = {} + for path in sorted(Path(args.runs_dir).glob("rec_rrog*_sig0.0_K1_none_T*_s*.json")): + rep = _read(path) + if rep is None: + continue + if args.epochs is not None and int(rep.get("epochs", -1)) != args.epochs: + continue + if int(rep.get("epochs", 0)) < args.min_epochs: + continue + key = (_view(rep), _compute(rep), int(rep["seed"])) + old = by_cell.get(key) + if old is None or int(rep.get("epochs", 0)) > int(old.get("epochs", 0)): + by_cell[key] = rep + + grouped: dict[tuple[str, str], list[dict]] = defaultdict(list) + for rep in by_cell.values(): + grouped[(_view(rep), _compute(rep))].append(rep) + + classic_by_view = { + view: reps + for (view, compute), reps in grouped.items() + if compute == "classic" + } + + base_rows = [] + for view, classic in sorted(classic_by_view.items()): + vals = [_score(rep, "val") for rep in classic] + tests = [_score(rep, "test") for rep in classic] + base_rows.append([ + "zinc-cycle56", + view, + str(len(classic)), + f"{_fmt(_mean(vals), args.digits)} +/- {_fmt(_std(vals), args.digits)}", + f"{_fmt(_mean(tests), args.digits)} +/- {_fmt(_std(tests), args.digits)}", + ]) + + delta_rows = [] + for (view, compute), reps in sorted(grouped.items()): + if compute == "classic": + continue + base_by_seed = {int(rep["seed"]): rep for rep in classic_by_view.get(view, [])} + paired = [(rep, base_by_seed[int(rep["seed"])]) for rep in reps if int(rep["seed"]) in base_by_seed] + if not paired: + continue + vals = [_score(rep, "val") for rep, _ in paired] + tests = [_score(rep, "test") for rep, _ in paired] + val_deltas = [_score(base, "val") - _score(rep, "val") for rep, base in paired] + test_deltas = [_score(base, "test") - _score(rep, "test") for rep, base in paired] + delta_rows.append([ + "zinc-cycle56", + view, + compute, + str(len(paired)), + f"{_fmt(_mean(vals), args.digits)} ({_fmt(_mean(val_deltas), args.digits)})", + f"{_fmt(_mean(tests), args.digits)} ({_fmt(_mean(test_deltas), args.digits)})", + ]) + + print("\nZINC-cycle56 classic baseline") + print(_markdown(["task", "backbone", "n", "val MAE-sum", "test MAE-sum"], base_rows)) + print("\nZINC-cycle56 delta vs matching classic") + print(_markdown(["task", "backbone", "compute", "n", "val score (improvement)", "test score (improvement)"], delta_rows)) + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--runs-dir", default="runs") + ap.add_argument("--epochs", type=int) + ap.add_argument("--min-epochs", type=int, default=10) + ap.add_argument("--digits", type=int, default=4) + print_zinc(ap.parse_args()) + + +if __name__ == "__main__": + main() diff --git a/rrog/registry.py b/rrog/registry.py new file mode 100644 index 0000000..7f6b6fe --- /dev/null +++ b/rrog/registry.py @@ -0,0 +1,57 @@ +from dataclasses import dataclass, field +from typing import Callable + + +@dataclass(frozen=True) +class BenchmarkSpec: + name: str + tier: str + domain: str + task_type: str + metric: str + priority: int + status: str = "planned" + notes: str = "" + + +@dataclass(frozen=True) +class ViewSpec: + name: str + family: str + graph_type: str + priority: int + status: str = "planned" + notes: str = "" + + +@dataclass(frozen=True) +class ComputeSpec: + name: str + family: str + priority: int + status: str = "planned" + notes: str = "" + + +@dataclass(frozen=True) +class ModifierSpec: + name: str + family: str + priority: int + status: str = "planned" + notes: str = "" + + +@dataclass(frozen=True) +class RunSpec: + task: str + view: str + compute: str + modifier: str = "none" + default_args: dict[str, object] = field(default_factory=dict) + command_builder: Callable[[dict[str, object]], list[str]] | None = None + + +def by_name(items): + return {item.name: item for item in items} + diff --git a/rrog/run_ogb_hiv_remaining.sh b/rrog/run_ogb_hiv_remaining.sh new file mode 100755 index 0000000..067e736 --- /dev/null +++ b/rrog/run_ogb_hiv_remaining.sh @@ -0,0 +1,49 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "${ROOT_DIR}" +export PYTHONPATH="${ROOT_DIR}:${PYTHONPATH:-}" + +DEVICE="${DEVICE:-cuda:3}" +EPOCHS="${EPOCHS:-100}" +SEED="${SEED:-0}" + +run_if_missing() { + local view="$1" + local compute="$2" + local t="$3" + local ns="$4" + local out="runs/ogbg-molhiv_${view}_${compute}_T${t}_ns${ns}_h128_e${EPOCHS}_s${SEED}.json" + if [[ -f "${out}" ]]; then + echo "[skip] ${out}" + return + fi + echo "[run] view=${view} compute=${compute} T=${t} ns=${ns} device=${DEVICE}" + python3 -m rrog.cli run \ + --task ogbg-molhiv \ + --view "${view}" \ + --compute "${compute}" \ + --epochs "${EPOCHS}" \ + --T "${t}" \ + --n_sup "${ns}" \ + --seed "${SEED}" \ + --device "${DEVICE}" +} + +# Complete remaining OGB-HIV backbone x {classic, fixed-RRoG} cells. +# Existing json files are skipped, so the queue can be restarted safely. +run_if_missing sgc fixed-rrog 3 3 +run_if_missing cheb classic 0 1 +run_if_missing cheb fixed-rrog 3 3 +run_if_missing arma classic 0 1 +run_if_missing arma fixed-rrog 3 3 +run_if_missing mf classic 0 1 +run_if_missing mf fixed-rrog 3 3 +run_if_missing appnp classic 0 1 +run_if_missing appnp fixed-rrog 3 3 +run_if_missing pna fixed-rrog 3 3 +run_if_missing gine classic 0 1 +run_if_missing gine fixed-rrog 3 3 + +python3 -m rrog.cli results --epochs "${EPOCHS}" diff --git a/rrog/run_zinc_gine.sh b/rrog/run_zinc_gine.sh new file mode 100755 index 0000000..1ca36e1 --- /dev/null +++ b/rrog/run_zinc_gine.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "${ROOT_DIR}" +export PYTHONPATH="${ROOT_DIR}:${PYTHONPATH:-}" + +DEVICE="${DEVICE:-cuda:3}" +EPOCHS="${EPOCHS:-200}" +SEED="${SEED:-0}" + +run_if_missing() { + local compute="$1" + local t="$2" + local ns="$3" + local out="runs/rec_rrog_gine_full_sig0.0_K1_none_T${t}_ns${ns}_trace_s${SEED}.json" + if [[ -f "${out}" ]]; then + echo "[skip] ${out}" + return + fi + echo "[run] zinc-cycle56 view=gine compute=${compute} T=${t} ns=${ns} device=${DEVICE}" + python3 -m rrog.cli run \ + --task zinc-cycle56 \ + --view gine \ + --compute "${compute}" \ + --epochs "${EPOCHS}" \ + --T "${t}" \ + --n_sup "${ns}" \ + --seed "${SEED}" \ + --device "${DEVICE}" +} + +run_if_missing classic 0 1 +run_if_missing fixed-rrog 1 3 + +python3 -m rrog.cli zinc-results --epochs "${EPOCHS}" diff --git a/rrog/run_zinc_gine_after_pid.sh b/rrog/run_zinc_gine_after_pid.sh new file mode 100755 index 0000000..a4d24ae --- /dev/null +++ b/rrog/run_zinc_gine_after_pid.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash +set -euo pipefail + +WAIT_PID="${1:?usage: run_zinc_gine_after_pid.sh <pid>}" +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "${ROOT_DIR}" + +echo "[wait] OGB queue pid=${WAIT_PID}" +while kill -0 "${WAIT_PID}" 2>/dev/null; do + sleep 60 +done + +echo "[start] OGB queue exited; launching ZINC gine" +DEVICE="${DEVICE:-cuda:1}" EPOCHS="${EPOCHS:-200}" SEED="${SEED:-0}" ./rrog/run_zinc_gine.sh diff --git a/rrog/runspecs.py b/rrog/runspecs.py new file mode 100644 index 0000000..285348f --- /dev/null +++ b/rrog/runspecs.py @@ -0,0 +1,188 @@ +from rrog.registry import RunSpec + + +ZINC_VIEWS = [ + "gin", "gine", "gcn", "graphsage", "gatv2", "graphconv", "transformer", "pna", + "gen", "film", "resgated", "tag", "sgc", "cheb", "arma", "mf", "appnp", +] +OGB_MOL_VIEWS = [ + "gin", "gine", "gcn", "graphsage", "gatv2", "graphconv", "transformer", "pna", + "gen", "film", "resgated", "tag", "sgc", "cheb", "arma", "mf", "appnp", +] +OGB_MOL_TASKS = [ + "ogbg-molhiv", + "ogbg-molpcba", + "ogbg-molbbbp", + "ogbg-molbace", + "ogbg-moltox21", + "ogbg-molclintox", + "ogbg-molsider", + "ogbg-molesol", + "ogbg-molfreesolv", + "ogbg-mollipo", +] + + +def _zinc_cycle56_gin_command(args: dict[str, object]) -> list[str]: + compute = str(args.get("compute", "rrog-act")) + view = str(args.get("view", "gin")) + epochs = int(args.get("epochs", 200)) + hidden = int(args.get("hidden", 64)) + batch_size = int(args.get("bs", 256)) + seed = int(args.get("seed", 0)) + t = int(args.get("T", 1)) + n_sup = int(args.get("n_sup", 3)) + halt_max_steps = int(args.get("halt_max_steps", 8)) + halt_target = str(args.get("halt_target", "binary")) + + cmd = [ + "python3", "diag/train_rec.py", + "--grad_mode", "full", + "--T", str(t), + "--n_sup", str(n_sup), + "--hidden", str(hidden), + "--bs", str(batch_size), + "--epochs", str(epochs), + "--agg_layers", str(args.get("agg_layers", 5)), + "--compute_layers", str(args.get("compute_layers", 2)), + "--view", view, + "--sigma", "0", + "--K", "1", + "--select", "none", + "--seed", str(seed), + ] + + if compute in ["classic", "view-only"]: + cmd[cmd.index("--T") + 1] = "0" + elif compute == "fixed-rrog": + pass + elif compute == "rrog-act": + cmd.extend([ + "--act", + "--halt_max_steps", str(halt_max_steps), + "--halt_target", halt_target, + "--halt_norm_threshold", str(args.get("halt_norm_threshold", 0.3)), + ]) + else: + raise ValueError(f"unsupported compute for zinc-cycle56/{view}: {compute}") + if args.get("device") is not None: + cmd.extend(["--device", str(args["device"])]) + return cmd + + +def _ogb_graphprop_gin_command(args: dict[str, object]) -> list[str]: + task = str(args["task"]) + view = str(args.get("view", "gin")) + compute = str(args.get("compute", "rrog-act")) + cmd = [ + "python3", "rrog/train_ogb_graphprop.py", + "--dataset", task, + "--view", view, + "--compute", compute, + "--T", str(args.get("T", 1)), + "--n_sup", str(args.get("n_sup", 3)), + "--hidden", str(args.get("hidden", 128)), + "--bs", str(args.get("bs", 128)), + "--epochs", str(args.get("epochs", 100)), + "--eval_every", str(args.get("eval_every", 10)), + "--agg_layers", str(args.get("agg_layers", 5)), + "--compute_layers", str(args.get("compute_layers", 2)), + "--seed", str(args.get("seed", 0)), + ] + for key in ["lr", "lam_q"]: + if args.get(key) is not None: + cmd.extend([f"--{key}", str(args[key])]) + if compute == "rrog-act": + cmd.extend([ + "--halt_max_steps", str(args.get("halt_max_steps", 8)), + "--halt_min_steps", str(args.get("halt_min_steps", 2)), + "--halt_target", str(args.get("halt_target", "loss")), + "--halt_loss_threshold", str(args.get("halt_loss_threshold", 0.2)), + "--halt_exploration_prob", str(args.get("halt_exploration_prob", 0.1)), + ]) + if args.get("q_warmup_epochs") is not None: + cmd.extend(["--q_warmup_epochs", str(args["q_warmup_epochs"])]) + if float(args.get("ema", 0.0) or 0.0) > 0: + cmd.extend(["--ema", str(args["ema"])]) + if args.get("device") is not None: + cmd.extend(["--device", str(args["device"])]) + for key in ["max_train_batches", "max_eval_batches", "num_workers"]: + if args.get(key) is not None: + cmd.extend([f"--{key}", str(args[key])]) + return cmd + + +RUN_SPECS = [ +] + +for _view in ZINC_VIEWS: + RUN_SPECS.extend([ + RunSpec( + task="zinc-cycle56", + view=_view, + compute="classic", + default_args={"compute": "classic", "view": _view, "T": 0, "n_sup": 1, "epochs": 200}, + command_builder=_zinc_cycle56_gin_command, + ), + RunSpec( + task="zinc-cycle56", + view=_view, + compute="view-only", + default_args={"compute": "view-only", "view": _view, "T": 0, "n_sup": 3, "epochs": 200}, + command_builder=_zinc_cycle56_gin_command, + ), + RunSpec( + task="zinc-cycle56", + view=_view, + compute="fixed-rrog", + default_args={"compute": "fixed-rrog", "view": _view, "T": 3, "n_sup": 3, "epochs": 200}, + command_builder=_zinc_cycle56_gin_command, + ), + RunSpec( + task="zinc-cycle56", + view=_view, + compute="rrog-act", + default_args={"compute": "rrog-act", "view": _view, "T": 1, "n_sup": 3, "epochs": 200}, + command_builder=_zinc_cycle56_gin_command, + ), + ]) + +for _task in OGB_MOL_TASKS: + for _view in OGB_MOL_VIEWS: + RUN_SPECS.extend([ + RunSpec( + task=_task, + view=_view, + compute="classic", + default_args={"task": _task, "view": _view, "compute": "classic", "T": 0, "n_sup": 1, "epochs": 100}, + command_builder=_ogb_graphprop_gin_command, + ), + RunSpec( + task=_task, + view=_view, + compute="view-only", + default_args={"task": _task, "view": _view, "compute": "view-only", "T": 0, "n_sup": 3, "epochs": 100}, + command_builder=_ogb_graphprop_gin_command, + ), + RunSpec( + task=_task, + view=_view, + compute="fixed-rrog", + default_args={"task": _task, "view": _view, "compute": "fixed-rrog", "T": 3, "n_sup": 3, "epochs": 100}, + command_builder=_ogb_graphprop_gin_command, + ), + RunSpec( + task=_task, + view=_view, + compute="rrog-act", + default_args={"task": _task, "view": _view, "compute": "rrog-act", "T": 1, "n_sup": 3, "epochs": 100}, + command_builder=_ogb_graphprop_gin_command, + ), + ]) + + +def find_run_spec(task: str, view: str, compute: str, modifier: str = "none") -> RunSpec: + for spec in RUN_SPECS: + if (spec.task, spec.view, spec.compute, spec.modifier) == (task, view, compute, modifier): + return spec + raise KeyError(f"no implemented run spec for task={task} view={view} compute={compute} modifier={modifier}") diff --git a/rrog/train_ogb_graphprop.py b/rrog/train_ogb_graphprop.py new file mode 100644 index 0000000..387ef3c --- /dev/null +++ b/rrog/train_ogb_graphprop.py @@ -0,0 +1,685 @@ +import argparse +from contextlib import contextmanager +import json +import os +import time + +import numpy as np +import torch +import torch.nn as nn +from ogb.graphproppred import Evaluator, PygGraphPropPredDataset +from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder +from torch_geometric.data import Batch +from torch_geometric.loader import DataLoader +from torch_geometric.nn import ( + APPNP, + ARMAConv, + ChebConv, + FiLMConv, + GATv2Conv, + GCNConv, + GENConv, + GINEConv, + GraphConv, + MFConv, + PNAConv, + ResGatedGraphConv, + SAGEConv, + SGConv, + TAGConv, + TransformerConv, + global_add_pool, +) +from torch_geometric.utils import degree + + +PROJECT_ROOT = os.environ.get( + "RROG_ROOT", + os.path.abspath(os.path.join(os.path.dirname(__file__), "..")), +) +DATA_ROOT = os.environ.get("RROG_DATA_DIR", os.path.join(PROJECT_ROOT, "data")) +ROOT = os.path.join(DATA_ROOT, "ogb") +OUT = os.environ.get("RROG_RUNS_DIR", os.path.join(PROJECT_ROOT, "runs")) +SUPPORTED_VIEWS = [ + "gin", "gine", "gcn", "graphsage", "gatv2", "graphconv", "transformer", "pna", + "gen", "film", "resgated", "tag", "sgc", "cheb", "arma", "mf", "appnp", +] +SUPPORTED_MOL_DATASETS = [ + "ogbg-molhiv", + "ogbg-molpcba", + "ogbg-molbbbp", + "ogbg-molbace", + "ogbg-moltox21", + "ogbg-molclintox", + "ogbg-molsider", + "ogbg-molesol", + "ogbg-molfreesolv", + "ogbg-mollipo", +] +HIGHER_BETTER = {"rocauc", "ap", "auc", "accuracy", "acc", "f1"} +LOWER_BETTER = {"rmse", "mae"} + +_TORCH_LOAD = torch.load + + +def _torch_load_ogb_compat(*args, **kwargs): + kwargs.setdefault("weights_only", False) + return _TORCH_LOAD(*args, **kwargs) + + +torch.load = _torch_load_ogb_compat + + +def clone_state_dict(model): + return {k: v.detach().clone() for k, v in model.state_dict().items()} + + +@torch.no_grad() +def update_ema_state(ema_state, model, decay): + if ema_state is None: + return + for key, value in model.state_dict().items(): + if torch.is_floating_point(value): + ema_state[key].mul_(decay).add_(value.detach(), alpha=1.0 - decay) + else: + ema_state[key].copy_(value.detach()) + + +@contextmanager +def using_ema_state(model, ema_state): + if ema_state is None: + yield + return + raw_state = clone_state_dict(model) + model.load_state_dict(ema_state, strict=True) + try: + yield + finally: + model.load_state_dict(raw_state, strict=True) + + +def metric_direction(metric: str) -> int: + metric = metric.lower() + if metric in LOWER_BETTER or "rmse" in metric or "mae" in metric: + return -1 + return 1 + + +def is_regression_metric(metric: str) -> bool: + return metric_direction(metric) < 0 + + +def is_better(score: float, best: float | None, metric: str) -> bool: + if best is None: + return True + direction = metric_direction(metric) + return score > best if direction > 0 else score < best + + +def jsonable(obj): + if isinstance(obj, dict): + return {str(k): jsonable(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [jsonable(v) for v in obj] + if isinstance(obj, (np.integer, np.floating)): + return obj.item() + if isinstance(obj, torch.Tensor): + if obj.ndim == 0: + return obj.detach().cpu().item() + return obj.detach().cpu().tolist() + return obj + + +def degree_histogram(dataset) -> torch.Tensor: + max_degree = 0 + degs = [] + for graph in dataset: + deg = degree(graph.edge_index[1], num_nodes=graph.num_nodes, dtype=torch.long) + degs.append(deg) + if deg.numel(): + max_degree = max(max_degree, int(deg.max().item())) + hist = torch.zeros(max_degree + 1, dtype=torch.long) + for deg in degs: + hist += torch.bincount(deg, minlength=hist.numel()) + return hist + + +def make_view_layer(view: str, hidden: int, deg: torch.Tensor | None): + if view in {"gin", "gine"}: + mlp = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)) + return GINEConv(mlp, train_eps=True) + if view == "gcn": + return GCNConv(hidden, hidden) + if view == "graphsage": + return SAGEConv(hidden, hidden) + if view == "gatv2": + return GATv2Conv(hidden, hidden, heads=4, concat=False, edge_dim=hidden) + if view == "graphconv": + return GraphConv(hidden, hidden) + if view == "transformer": + return TransformerConv(hidden, hidden, heads=4, concat=False, edge_dim=hidden) + if view == "pna": + if deg is None: + raise ValueError("PNA view requires a training-set degree histogram") + return PNAConv( + hidden, hidden, + aggregators=["mean", "min", "max", "std"], + scalers=["identity", "amplification", "attenuation"], + deg=deg, + edge_dim=hidden, + ) + if view == "gen": + return GENConv(hidden, hidden, edge_dim=hidden) + if view == "film": + return FiLMConv(hidden, hidden) + if view == "resgated": + return ResGatedGraphConv(hidden, hidden, edge_dim=hidden) + if view == "tag": + return TAGConv(hidden, hidden, K=3) + if view == "sgc": + return SGConv(hidden, hidden, K=2, cached=False) + if view == "cheb": + return ChebConv(hidden, hidden, K=3) + if view == "arma": + return ARMAConv(hidden, hidden, num_stacks=1, num_layers=2) + if view == "mf": + return MFConv(hidden, hidden) + if view == "appnp": + return APPNP(K=5, alpha=0.1) + raise ValueError(f"unsupported OGB view: {view}") + + +EDGE_ATTR_VIEWS = {"gin", "gine", "gatv2", "transformer", "pna", "gen", "resgated"} + + +class OGBRRoG(nn.Module): + def __init__( + self, hidden, num_tasks, view="gin", T=1, n_sup=3, agg_layers=5, + compute_layers=2, grad_mode="full", deg=None, + ): + super().__init__() + self.view = view + self.atom_encoder = AtomEncoder(hidden) + self.bond_encoder = BondEncoder(hidden) + self.agg_convs = nn.ModuleList() + self.agg_bns = nn.ModuleList() + for _ in range(agg_layers): + self.agg_convs.append(make_view_layer(view, hidden, deg)) + self.agg_bns.append(nn.BatchNorm1d(hidden)) + + core = [] + d = hidden + for _ in range(compute_layers - 1): + core += [nn.Linear(d, hidden), nn.GELU()] + d = hidden + core.append(nn.Linear(d, hidden)) + self.core_norm = nn.LayerNorm(hidden) + self.core = nn.Sequential(*core) + nn.init.zeros_(self.core[-1].weight) + nn.init.zeros_(self.core[-1].bias) + + self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, num_tasks)) + self.qhead = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 1)) + with torch.no_grad(): + self.qhead[-1].weight.zero_() + self.qhead[-1].bias.fill_(-5.0) + + self.T = T + self.n_sup = n_sup + self.grad_mode = grad_mode + self.agg_layers = agg_layers + self.compute_layers = compute_layers + self.hidden = hidden + self.num_tasks = num_tasks + + def aggregate(self, x, edge_index, edge_attr): + h = self.atom_encoder(x) + e = self.bond_encoder(edge_attr) + for conv, bn in zip(self.agg_convs, self.agg_bns): + if self.view in {"gin", "gine", "pna", "gen", "resgated"}: + h = bn(conv(h, edge_index, e)).relu() + elif self.view in {"gatv2", "transformer"}: + h = bn(conv(h, edge_index, edge_attr=e)).relu() + else: + h = bn(conv(h, edge_index)).relu() + return h + + def core_step(self, combined, state): + return state + self.core(self.core_norm(combined)) + + def _z_step(self, y, z, ctx): + return self.core_step(ctx + y + z, z) + + def _y_step(self, y, z): + return self.core_step(y + z, y) + + def recurse(self, y, z, ctx, one_step=False): + if self.T == 0: + return y, z + if one_step: + with torch.no_grad(): + for _ in range(self.T - 1): + z = self._z_step(y, z, ctx) + z = z.detach() + z = self._z_step(y, z, ctx) + y = self._y_step(y, z) + return y, z + for _ in range(self.T): + z = self._z_step(y, z, ctx) + y = self._y_step(y, z) + return y, z + + def predict(self, y, batch): + pooled = global_add_pool(y, batch) + return self.head(pooled), self.qhead(pooled).view(-1) + + def forward_trace(self, data, steps): + ctx = self.aggregate(data.x, data.edge_index, data.edge_attr) + y = ctx + z = torch.zeros_like(ctx) + preds, q_logits = [], [] + for s in range(steps): + y, z = self.recurse(y, z, ctx, one_step=(self.grad_mode == "1step")) + pred, q = self.predict(y, data.batch) + preds.append(pred) + q_logits.append(q) + if s < steps - 1: + y, z = y.detach(), z.detach() + return preds, q_logits + + def forward(self, data): + steps = self.n_sup + preds, q_logits = self.forward_trace(data, steps) + return preds, q_logits[-1] + + +def supervised_loss(logits, y, metric): + per_graph, has_label = per_graph_supervised_loss(logits, y, metric) + if not has_label.any(): + return logits.sum() * 0.0 + return per_graph[has_label].mean() + + +def per_graph_supervised_loss(logits, y, metric): + y = y.to(torch.float32) + mask = ~torch.isnan(y) + target = torch.where(mask, y, torch.zeros_like(y)) + if is_regression_metric(metric): + losses = (logits - target).pow(2) + else: + losses = nn.functional.binary_cross_entropy_with_logits(logits, target, reduction="none") + losses = torch.where(mask, losses, torch.zeros_like(losses)) + denom = mask.sum(dim=1).clamp_min(1) + return losses.sum(dim=1) / denom, mask.any(dim=1) + + +@torch.no_grad() +def halt_targets(logits, y): + y = y.to(torch.float32) + mask = ~torch.isnan(y) + pred = (logits > 0).to(y.dtype) + correct_or_missing = (~mask) | (pred == y) + has_label = mask.any(dim=1) + return (correct_or_missing.all(dim=1) & has_label).to(logits.dtype) + + +@torch.no_grad() +def evaluate(model, loader, evaluator, dev, steps=None, adaptive=False, halt_min_steps=1, max_batches=0): + model.eval() + ys, ps = [], [] + step_sum = 0.0 + n = 0 + for i, batch in enumerate(loader): + if max_batches and i >= max_batches: + break + batch = batch.to(dev) + if steps is None: + preds, _ = model(batch) + pred = preds[-1] + used_steps = float(model.n_sup) + else: + preds, qs = model.forward_trace(batch, steps) + stack = torch.stack(preds, dim=0) + if adaptive: + q = torch.stack(qs, dim=0) + step_ids = torch.arange(1, steps + 1, device=dev).view(-1, 1) + halted = (q > 0) & (step_ids >= halt_min_steps) + any_halt = halted.any(dim=0) + first_halt = halted.to(torch.int64).argmax(dim=0) + fallback = torch.full_like(first_halt, steps - 1) + idx = torch.where(any_halt, first_halt, fallback) + pred = stack[idx, torch.arange(stack.size(1), device=dev)] + used_steps = float((idx.to(torch.float32) + 1).sum().item()) + else: + pred = stack[-1] + used_steps = float(steps * batch.num_graphs) + ys.append(batch.y.detach().cpu()) + ps.append(pred.detach().cpu()) + step_sum += used_steps + n += batch.num_graphs + y_true = torch.cat(ys, dim=0) + y_pred = torch.cat(ps, dim=0) + result = evaluator.eval({"y_true": y_true, "y_pred": y_pred}) + return result, step_sum / max(n, 1) + + +def _split_nodes(t, ptr): + return [t[ptr[i].item():ptr[i + 1].item()].detach() for i in range(ptr.numel() - 1)] + + +def act_train_step(model, state, replacement_batch, opt, dev, args, metric): + replacement = replacement_batch.to_data_list() + batch_size = len(replacement) + if state is None: + state = { + "graphs": [None for _ in range(batch_size)], + "y": [None for _ in range(batch_size)], + "z": [None for _ in range(batch_size)], + "steps": torch.zeros(batch_size, dtype=torch.long, device=dev), + "halted": torch.ones(batch_size, dtype=torch.bool, device=dev), + } + + halted_cpu = state["halted"].detach().cpu().tolist() + for i, halted in enumerate(halted_cpu): + if halted: + state["graphs"][i] = replacement[i] + + batch = Batch.from_data_list(state["graphs"]).to(dev) + ctx = model.aggregate(batch.x, batch.edge_index, batch.edge_attr) + ptr = batch.ptr + y_parts, z_parts = [], [] + for i in range(batch_size): + start, end = ptr[i].item(), ptr[i + 1].item() + if halted_cpu[i] or state["y"][i] is None: + y_parts.append(ctx[start:end]) + z_parts.append(torch.zeros_like(ctx[start:end])) + else: + y_parts.append(state["y"][i].to(dev)) + z_parts.append(state["z"][i].to(dev)) + y = torch.cat(y_parts, dim=0) + z = torch.cat(z_parts, dim=0) + + opt.zero_grad() + y, z = model.recurse(y, z, ctx, one_step=(model.grad_mode == "1step")) + logits, q = model.predict(y, batch.batch) + pred_loss = supervised_loss(logits, batch.y, metric) + if args.halt_target == "exact" and not is_regression_metric(metric): + target = halt_targets(logits.detach(), batch.y) + elif args.halt_target == "loss": + per_graph_loss, has_label = per_graph_supervised_loss(logits.detach(), batch.y, metric) + target = ((per_graph_loss <= args.halt_loss_threshold) & has_label).to(logits.dtype) + else: + raise ValueError(args.halt_target) + q_loss = nn.functional.binary_cross_entropy_with_logits(q, target) + loss = pred_loss + 0.5 * args.lam_q * q_loss + y_det, z_det = y.detach(), z.detach() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + opt.step() + + state["y"] = _split_nodes(y_det, ptr) + state["z"] = _split_nodes(z_det, ptr) + with torch.no_grad(): + was_halted = state["halted"] + steps = torch.where(was_halted, torch.zeros_like(state["steps"]), state["steps"]) + 1 + halted = (steps >= args.halt_max_steps) | ((q.detach() > 0) & (steps >= args.halt_min_steps)) + if args.halt_exploration_prob > 0 and args.halt_max_steps > 1: + explore = torch.rand_like(q) < args.halt_exploration_prob + min_steps = torch.where( + explore, + torch.randint(2, args.halt_max_steps + 1, steps.shape, device=dev), + torch.zeros_like(steps), + ) + halted = halted & (steps >= min_steps) + state["steps"] = steps + state["halted"] = halted + + return state, { + "loss": float(loss.detach().cpu()), + "pred_loss": float(pred_loss.detach().cpu()), + "q_loss": float(q_loss.detach().cpu()), + "halted_frac": float(state["halted"].to(torch.float32).mean().detach().cpu()), + "steps": float(state["steps"].to(torch.float32).mean().detach().cpu()), + } + + +def _halt_target(args, logits, y, metric): + if args.halt_target == "exact" and not is_regression_metric(metric): + return halt_targets(logits.detach(), y) + if args.halt_target == "loss": + per_graph_loss, has_label = per_graph_supervised_loss(logits.detach(), y, metric) + return ((per_graph_loss <= args.halt_loss_threshold) & has_label).to(logits.dtype) + raise ValueError(args.halt_target) + + +def act_trace_train_step(model, batch, opt, dev, args, epoch, metric): + batch = batch.to(dev) + steps = max(1, args.halt_max_steps) + opt.zero_grad() + preds, qs = model.forward_trace(batch, steps) + pred_loss = sum(supervised_loss(pred, batch.y, metric) for pred in preds) / len(preds) + if epoch <= args.q_warmup_epochs: + q_loss = pred_loss.detach() * 0.0 + loss = pred_loss + else: + q_losses = [] + for step_idx, (pred, q) in enumerate(zip(preds, qs), start=1): + target = _halt_target(args, pred, batch.y, metric) + if step_idx < args.halt_min_steps: + target = torch.zeros_like(target) + q_losses.append(nn.functional.binary_cross_entropy_with_logits(q, target)) + q_loss = sum(q_losses) / len(q_losses) + loss = pred_loss + 0.5 * args.lam_q * q_loss + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + opt.step() + + with torch.no_grad(): + q_stack = torch.stack(qs, dim=0) + step_ids = torch.arange(1, steps + 1, device=dev).view(-1, 1) + halted = (q_stack > 0) & (step_ids >= args.halt_min_steps) + any_halt = halted.any(dim=0) + first_halt = halted.to(torch.int64).argmax(dim=0) + fallback = torch.full_like(first_halt, steps - 1) + idx = torch.where(any_halt, first_halt, fallback) + + return { + "loss": float(loss.detach().cpu()), + "pred_loss": float(pred_loss.detach().cpu()), + "q_loss": float(q_loss.detach().cpu()), + "halted_frac": float(any_halt.to(torch.float32).mean().detach().cpu()), + "steps": float((idx.to(torch.float32) + 1).mean().detach().cpu()), + } + + +def train_epoch(model, loader, opt, dev, args, act_state, ema_state, epoch, metric): + model.train() + metrics = [] + for i, batch in enumerate(loader): + if args.max_train_batches and i >= args.max_train_batches: + break + if args.compute == "rrog-act": + m = act_trace_train_step(model, batch, opt, dev, args, epoch, metric) + update_ema_state(ema_state, model, args.ema) + metrics.append(m) + continue + batch = batch.to(dev) + opt.zero_grad() + preds, _ = model(batch) + loss = sum(supervised_loss(pred, batch.y, metric) for pred in preds) / len(preds) + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + opt.step() + update_ema_state(ema_state, model, args.ema) + metrics.append({"loss": float(loss.detach().cpu()), "halted_frac": 0.0, "steps": float(model.n_sup)}) + return act_state, metrics + + +def metric_value(result, evaluator): + return float(result[evaluator.eval_metric]) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--dataset", choices=SUPPORTED_MOL_DATASETS, default="ogbg-molhiv") + ap.add_argument("--view", choices=SUPPORTED_VIEWS, default="gin") + ap.add_argument("--compute", choices=["classic", "view-only", "fixed-rrog", "rrog-act"], default="rrog-act") + ap.add_argument("--T", type=int, default=1) + ap.add_argument("--n_sup", type=int, default=3) + ap.add_argument("--hidden", type=int, default=128) + ap.add_argument("--agg_layers", type=int, default=5) + ap.add_argument("--compute_layers", type=int, default=2) + ap.add_argument("--epochs", type=int, default=100) + ap.add_argument("--eval_every", type=int, default=10) + ap.add_argument("--lr", type=float, default=1e-3) + ap.add_argument("--bs", type=int, default=128) + ap.add_argument("--lam_q", type=float, default=1.0) + ap.add_argument("--halt_max_steps", type=int, default=8) + ap.add_argument("--halt_min_steps", type=int, default=1) + ap.add_argument("--halt_target", choices=["exact", "loss"], default="loss") + ap.add_argument("--halt_loss_threshold", type=float, default=0.2) + ap.add_argument("--halt_exploration_prob", type=float, default=0.1) + ap.add_argument("--q_warmup_epochs", type=int, default=0) + ap.add_argument("--ema", type=float, default=0.0) + ap.add_argument("--seed", type=int, default=0) + ap.add_argument("--num_workers", type=int, default=0) + ap.add_argument("--max_train_batches", type=int, default=0) + ap.add_argument("--max_eval_batches", type=int, default=0) + ap.add_argument("--device", default="auto") + args = ap.parse_args() + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + if args.compute == "classic": + args.n_sup = 1 + if args.device == "auto": + dev = "cuda" if torch.cuda.is_available() else "cpu" + else: + dev = args.device + os.makedirs(OUT, exist_ok=True) + + dataset = PygGraphPropPredDataset(name=args.dataset, root=ROOT) + split_idx = dataset.get_idx_split() + evaluator = Evaluator(args.dataset) + metric = evaluator.eval_metric + num_tasks = dataset.num_tasks + + train_dataset = dataset[split_idx["train"]] + valid_dataset = dataset[split_idx["valid"]] + test_dataset = dataset[split_idx["test"]] + train_loader = DataLoader(train_dataset, batch_size=args.bs, shuffle=True, + drop_last=True, num_workers=args.num_workers) + valid_loader = DataLoader(valid_dataset, batch_size=256, shuffle=False, + num_workers=args.num_workers) + test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, + num_workers=args.num_workers) + + T = 0 if args.compute in ["classic", "view-only"] else args.T + deg = degree_histogram(train_dataset) if args.view == "pna" else None + model = OGBRRoG(args.hidden, num_tasks, view=args.view, T=T, n_sup=args.n_sup, + agg_layers=args.agg_layers, compute_layers=args.compute_layers, deg=deg).to(dev) + opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5) + sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, max(args.epochs, 1)) + ema_state = clone_state_dict(model) if args.ema > 0 else None + + t0 = time.time() + best_val = None + best = {} + best_state = None + act_state = None + steps = max(1, args.halt_max_steps) + + for ep in range(args.epochs): + act_state, train_metrics = train_epoch( + model, train_loader, opt, dev, args, act_state, ema_state, ep + 1, metric) + sched.step() + if (ep + 1) % args.eval_every == 0 or ep == args.epochs - 1: + with using_ema_state(model, ema_state): + if args.compute == "rrog-act": + val_result, fixed_val_steps = evaluate(model, valid_loader, evaluator, dev, steps=steps, + adaptive=False, halt_min_steps=args.halt_min_steps, + max_batches=args.max_eval_batches) + val_adapt, adaptive_val_steps = evaluate(model, valid_loader, evaluator, dev, steps=steps, + adaptive=True, halt_min_steps=args.halt_min_steps, + max_batches=args.max_eval_batches) + else: + val_result, _ = evaluate(model, valid_loader, evaluator, dev, + max_batches=args.max_eval_batches) + val_score = metric_value(val_result, evaluator) + if is_better(val_score, best_val, metric): + best_val = val_score + if args.compute == "rrog-act": + test_fixed, fixed_steps = evaluate(model, test_loader, evaluator, dev, steps=steps, + adaptive=False, halt_min_steps=args.halt_min_steps, + max_batches=args.max_eval_batches) + test_adapt, adaptive_steps = evaluate(model, test_loader, evaluator, dev, steps=steps, + adaptive=True, halt_min_steps=args.halt_min_steps, + max_batches=args.max_eval_batches) + best = { + "ep": ep + 1, + "val": val_result, + "val_adaptive": val_adapt, + "test": test_fixed, + "test_adaptive": test_adapt, + "fixed_val_steps": fixed_val_steps, + "adaptive_val_steps": adaptive_val_steps, + "fixed_steps": fixed_steps, + "adaptive_steps": adaptive_steps, + } + else: + test_result, _ = evaluate(model, test_loader, evaluator, dev, + max_batches=args.max_eval_batches) + best = {"ep": ep + 1, "val": val_result, "test": test_result} + best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} + halted = sum(m.get("halted_frac", 0.0) for m in train_metrics) / max(len(train_metrics), 1) + train_steps = sum(m.get("steps", 0.0) for m in train_metrics) / max(len(train_metrics), 1) + msg = f"ep{ep+1} val_{evaluator.eval_metric}={val_score:.5f}" + if args.compute == "rrog-act": + msg += ( + f" val_adapt_{evaluator.eval_metric}={metric_value(val_adapt, evaluator):.5f}" + f" adapt_steps={adaptive_val_steps:.2f}" + ) + msg += f" halt={halted:.2f} train_steps={train_steps:.2f}" + print(msg, flush=True) + + ema_tag = f"_ema{args.ema:g}" if args.ema > 0 else "" + tag = ( + f"{args.dataset}_{args.view}_{args.compute}_T{T}_ns{args.n_sup}_" + f"h{args.hidden}_e{args.epochs}{ema_tag}_s{args.seed}" + ) + rep = { + "dataset": args.dataset, + "tag": tag, + **vars(args), + "metric": evaluator.eval_metric, + "sec": round(time.time() - t0, 1), + "dev": dev, + **best, + } + with open(os.path.join(OUT, f"{tag}.json"), "w") as f: + json.dump(jsonable(rep), f, indent=2) + torch.save({ + "state": best_state or model.state_dict(), + "cfg": { + "dataset": args.dataset, + "hidden": args.hidden, + "num_tasks": num_tasks, + "T": T, + "n_sup": args.n_sup, + "agg_layers": args.agg_layers, + "compute_layers": args.compute_layers, + "compute": args.compute, + "halt_max_steps": steps, + "halt_min_steps": args.halt_min_steps, + "halt_target": args.halt_target, + "halt_loss_threshold": args.halt_loss_threshold, + "view": args.view, + }, + }, os.path.join(OUT, f"ckpt_{tag}.pt")) + print(f"[{tag}] best_ep={best.get('ep')} val={best.get('val')} test={best.get('test')} " + f"adaptive={best.get('test_adaptive')} steps={best.get('adaptive_steps')}", flush=True) + print(" wrote", os.path.join(OUT, f"{tag}.json")) + + +if __name__ == "__main__": + main() diff --git a/scripts/collect_results.sh b/scripts/collect_results.sh new file mode 100755 index 0000000..360f05b --- /dev/null +++ b/scripts/collect_results.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "${ROOT_DIR}" +export PYTHONPATH="${ROOT_DIR}:${PYTHONPATH:-}" + +mkdir -p summaries +python3 -m rrog.cli zinc-results --epochs "${ZINC_EPOCHS:-200}" | tee summaries/zinc_cycle56.md +python3 -m rrog.cli results --epochs "${OGB_EPOCHS:-100}" | tee summaries/ogb_graphprop.md diff --git a/scripts/run_ogb_mol_all_tasks.sh b/scripts/run_ogb_mol_all_tasks.sh new file mode 100755 index 0000000..b191d79 --- /dev/null +++ b/scripts/run_ogb_mol_all_tasks.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "${ROOT_DIR}" + +DEVICE="${DEVICE:-cuda:1}" +EPOCHS="${EPOCHS:-100}" +SEED="${SEED:-0}" +TASKS="${TASKS:-ogbg-molhiv ogbg-molbbbp ogbg-molbace ogbg-moltox21 ogbg-molclintox ogbg-molsider ogbg-molesol ogbg-molfreesolv ogbg-mollipo}" + +mkdir -p logs +for task in ${TASKS}; do + echo "[task] ${task}" + TASK="${task}" DEVICE="${DEVICE}" EPOCHS="${EPOCHS}" SEED="${SEED}" \ + ./scripts/run_ogb_mol_task_full.sh 2>&1 | tee "logs/${task}_${SEED}.log" +done diff --git a/scripts/run_ogb_mol_task_full.sh b/scripts/run_ogb_mol_task_full.sh new file mode 100755 index 0000000..b25bff3 --- /dev/null +++ b/scripts/run_ogb_mol_task_full.sh @@ -0,0 +1,54 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "${ROOT_DIR}" +export PYTHONPATH="${ROOT_DIR}:${PYTHONPATH:-}" + +TASK="${TASK:-ogbg-molhiv}" +DEVICE="${DEVICE:-cuda:1}" +EPOCHS="${EPOCHS:-100}" +SEED="${SEED:-0}" +HIDDEN="${HIDDEN:-128}" +VIEWS="${VIEWS:-gin gine gcn graphsage gatv2 graphconv transformer pna gen film resgated tag sgc cheb arma mf appnp}" + +mkdir -p runs logs + +result_path() { + local view="$1" + local compute="$2" + local t="$3" + local ns="$4" + echo "runs/${TASK}_${view}_${compute}_T${t}_ns${ns}_h${HIDDEN}_e${EPOCHS}_s${SEED}.json" +} + +run_cell() { + local view="$1" + local compute="$2" + local t="$3" + local ns="$4" + local out + out="$(result_path "${view}" "${compute}" "${t}" "${ns}")" + if [[ -f "${out}" ]]; then + echo "[skip] ${out}" + return + fi + echo "[run] ${TASK} view=${view} compute=${compute} T=${t} ns=${ns} device=${DEVICE}" + python3 -m rrog.cli run \ + --task "${TASK}" \ + --view "${view}" \ + --compute "${compute}" \ + --epochs "${EPOCHS}" \ + --hidden "${HIDDEN}" \ + --T "${t}" \ + --n_sup "${ns}" \ + --seed "${SEED}" \ + --device "${DEVICE}" +} + +for view in ${VIEWS}; do + run_cell "${view}" classic 0 1 + run_cell "${view}" fixed-rrog 3 3 +done + +python3 -m rrog.cli results --epochs "${EPOCHS}" diff --git a/scripts/run_smoke.sh b/scripts/run_smoke.sh new file mode 100755 index 0000000..6365cec --- /dev/null +++ b/scripts/run_smoke.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "${ROOT_DIR}" +export PYTHONPATH="${ROOT_DIR}:${PYTHONPATH:-}" + +DEVICE="${DEVICE:-cuda:0}" +mkdir -p runs logs + +python3 -m rrog.cli run \ + --task ogbg-molhiv --view gin --compute classic \ + --epochs 1 --hidden 32 --bs 64 --seed 991 --device "${DEVICE}" \ + --max_train_batches 2 --max_eval_batches 2 + +python3 -m rrog.cli run \ + --task ogbg-molhiv --view gin --compute fixed-rrog \ + --epochs 1 --hidden 32 --bs 64 --T 1 --n_sup 2 --seed 992 --device "${DEVICE}" \ + --max_train_batches 2 --max_eval_batches 2 diff --git a/scripts/run_two_a6000.sh b/scripts/run_two_a6000.sh new file mode 100755 index 0000000..8d9851f --- /dev/null +++ b/scripts/run_two_a6000.sh @@ -0,0 +1,32 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "${ROOT_DIR}" +export PYTHONPATH="${ROOT_DIR}:${PYTHONPATH:-}" + +ZINC_DEVICE="${ZINC_DEVICE:-cuda:0}" +OGB_DEVICE="${OGB_DEVICE:-cuda:1}" +OGB_TASK="${OGB_TASK:-ogbg-molhiv}" +ZINC_EPOCHS="${ZINC_EPOCHS:-200}" +OGB_EPOCHS="${OGB_EPOCHS:-100}" +SEED="${SEED:-0}" + +mkdir -p runs logs + +echo "[launch] ZINC-cycle56 on ${ZINC_DEVICE}" +DEVICE="${ZINC_DEVICE}" EPOCHS="${ZINC_EPOCHS}" SEED="${SEED}" \ + ./scripts/run_zinc_cycle56_full.sh > "logs/zinc_cycle56_${SEED}.log" 2>&1 & +zinc_pid=$! + +echo "[launch] ${OGB_TASK} on ${OGB_DEVICE}" +TASK="${OGB_TASK}" DEVICE="${OGB_DEVICE}" EPOCHS="${OGB_EPOCHS}" SEED="${SEED}" \ + ./scripts/run_ogb_mol_task_full.sh > "logs/${OGB_TASK}_${SEED}.log" 2>&1 & +ogb_pid=$! + +echo "[pids] zinc=${zinc_pid} ogb=${ogb_pid}" +wait "${zinc_pid}" +wait "${ogb_pid}" + +echo "[done] collecting summaries" +./scripts/collect_results.sh diff --git a/scripts/run_zinc_cycle56_full.sh b/scripts/run_zinc_cycle56_full.sh new file mode 100755 index 0000000..151a51e --- /dev/null +++ b/scripts/run_zinc_cycle56_full.sh @@ -0,0 +1,54 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "${ROOT_DIR}" +export PYTHONPATH="${ROOT_DIR}:${PYTHONPATH:-}" + +DEVICE="${DEVICE:-cuda:0}" +EPOCHS="${EPOCHS:-200}" +SEED="${SEED:-0}" +VIEWS="${VIEWS:-gin gine gcn graphsage gatv2 graphconv transformer pna gen film resgated tag sgc cheb arma mf appnp}" + +mkdir -p runs logs + +result_path() { + local view="$1" + local t="$2" + local ns="$3" + local view_tag="" + if [[ "${view}" != "gin" ]]; then + view_tag="_${view}" + fi + echo "runs/rec_rrog${view_tag}_full_sig0.0_K1_none_T${t}_ns${ns}_trace_s${SEED}.json" +} + +run_cell() { + local view="$1" + local compute="$2" + local t="$3" + local ns="$4" + local out + out="$(result_path "${view}" "${t}" "${ns}")" + if [[ -f "${out}" ]]; then + echo "[skip] ${out}" + return + fi + echo "[run] zinc-cycle56 view=${view} compute=${compute} T=${t} ns=${ns} device=${DEVICE}" + python3 -m rrog.cli run \ + --task zinc-cycle56 \ + --view "${view}" \ + --compute "${compute}" \ + --epochs "${EPOCHS}" \ + --T "${t}" \ + --n_sup "${ns}" \ + --seed "${SEED}" \ + --device "${DEVICE}" +} + +for view in ${VIEWS}; do + run_cell "${view}" classic 0 1 + run_cell "${view}" fixed-rrog 1 3 +done + +python3 -m rrog.cli zinc-results --epochs "${EPOCHS}" diff --git a/scripts/setup_and_run_two_a6000.sh b/scripts/setup_and_run_two_a6000.sh new file mode 100755 index 0000000..ec4e3da --- /dev/null +++ b/scripts/setup_and_run_two_a6000.sh @@ -0,0 +1,15 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "${ROOT_DIR}" + +if [[ "${SKIP_SETUP:-0}" != "1" ]]; then + ./scripts/setup_env.sh +fi + +if [[ -d "${VENV_DIR:-.venv}" ]]; then + source "${VENV_DIR:-.venv}/bin/activate" +fi + +./scripts/run_two_a6000.sh diff --git a/scripts/setup_env.sh b/scripts/setup_env.sh new file mode 100755 index 0000000..66a94c8 --- /dev/null +++ b/scripts/setup_env.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "${ROOT_DIR}" + +PYTHON_BIN="${PYTHON_BIN:-python3}" +VENV_DIR="${VENV_DIR:-.venv}" +TORCH_INDEX_URL="${TORCH_INDEX_URL:-https://download.pytorch.org/whl/cu124}" + +if [[ ! -d "${VENV_DIR}" ]]; then + "${PYTHON_BIN}" -m venv "${VENV_DIR}" +fi + +source "${VENV_DIR}/bin/activate" +python -m pip install --upgrade pip wheel setuptools + +if ! python - <<'PY' >/dev/null 2>&1 +import torch +assert torch.cuda.is_available() or True +PY +then + python -m pip install torch --index-url "${TORCH_INDEX_URL}" +fi + +python -m pip install -r requirements.txt + +python - <<'PY' +import torch +import torch_geometric +import ogb +print("torch", torch.__version__, "cuda_available", torch.cuda.is_available()) +print("torch_geometric", torch_geometric.__version__) +print("ogb", ogb.__version__) +PY |
