summaryrefslogtreecommitdiff
path: root/diag
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-06-21 15:33:22 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-06-21 15:33:22 -0500
commite42f575050efeeccb736385b43bed84e1129edb0 (patch)
tree8ed04b42218cf4c90c0b9c29b40db149f1355f4a /diag
Initial RRoG GNN runner
Diffstat (limited to 'diag')
-rw-r--r--diag/__init__.py0
-rw-r--r--diag/train_cycle.py188
-rw-r--r--diag/train_rec.py491
3 files changed, 679 insertions, 0 deletions
diff --git a/diag/__init__.py b/diag/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/diag/__init__.py
diff --git a/diag/train_cycle.py b/diag/train_cycle.py
new file mode 100644
index 0000000..598e349
--- /dev/null
+++ b/diag/train_cycle.py
@@ -0,0 +1,188 @@
+"""GNN-native ring-counting on real molecules (ZINC): regress [#5-cycles, #6-cycles].
+
+k-cycle counts (k>=3) are provably NOT computable by 1-WL/MPNN (Chen et al. 2020) -> a REAL
+H2 ceiling on REAL graphs. Training-based diagnosis (partition instrument is vacuous on
+feature-rich graphs):
+ GIN(L) 1-WL baseline -> should FAIL to count
+ GCN(L) sub-1-WL reference
+ GIN+RNI random feats = NOISE -> PTRM-style crude symmetry break (eval-averaged)
+ GIN+RWSE random-walk return probs-> structured >1-WL positive control
+Reads: GIN high error + RWSE fixes it = real ceiling exists; RNI also fixes = crude noise
+breaks it (bridge cashed); only RWSE = bridge needs STRUCTURED stochasticity (GRAM>PTRM).
+Targets z-scored for training; per-target MAE reported in RAW ring units.
+
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_cycle.py --conv gin --feat none
+"""
+import argparse, json, os, time
+import numpy as np
+import torch
+import torch.nn as nn
+import networkx as nx
+from torch_geometric.datasets import ZINC
+from torch_geometric.data import Data
+from torch_geometric.loader import DataLoader
+from torch_geometric.utils import to_networkx
+from torch_geometric.nn import GINConv, GCNConv, global_add_pool
+
+PROJECT_ROOT = os.environ.get(
+ 'RROG_ROOT',
+ os.path.abspath(os.path.join(os.path.dirname(__file__), '..')),
+)
+DATA_ROOT = os.environ.get('RROG_DATA_DIR', os.path.join(PROJECT_ROOT, 'data'))
+OUT = os.environ.get('RROG_RUNS_DIR', os.path.join(PROJECT_ROOT, 'runs'))
+ROOT = os.path.join(DATA_ROOT, 'zinc')
+CACHE = os.path.join(DATA_ROOT, 'cycle_cache')
+RWSE_K = 16
+
+
+def rwse(edge_index, n, K=RWSE_K):
+ A = np.zeros((n, n), dtype=np.float64)
+ ei = edge_index.numpy()
+ A[ei[0], ei[1]] = 1.0
+ A = np.maximum(A, A.T)
+ deg = A.sum(1)
+ P = A / np.where(deg > 0, deg, 1.0)[:, None]
+ out = np.zeros((n, K), dtype=np.float32)
+ M = np.eye(n)
+ for k in range(K):
+ M = M @ P
+ out[:, k] = np.diag(M)
+ return torch.from_numpy(out)
+
+
+def c56(data):
+ G = to_networkx(data, to_undirected=True)
+ c = {5: 0, 6: 0}
+ for cyc in nx.simple_cycles(G, length_bound=6):
+ L = len(cyc)
+ if L in c:
+ c[L] += 1
+ return [float(c[5]), float(c[6])]
+
+
+def prepare(split):
+ os.makedirs(CACHE, exist_ok=True)
+ fp = os.path.join(CACHE, f"{split}.pt")
+ if os.path.exists(fp):
+ return torch.load(fp, weights_only=False)
+ ds = ZINC(ROOT, subset=True, split=split)
+ out = []
+ for g in ds:
+ out.append({'x': g.x.view(-1).long(), 'edge_index': g.edge_index,
+ 'rwse': rwse(g.edge_index, g.num_nodes),
+ 'y': torch.tensor(c56(g), dtype=torch.float)})
+ torch.save(out, fp)
+ return out
+
+
+class Net(nn.Module):
+ def __init__(self, n_atom, hidden, layers, conv='gin', rni=0, use_rwse=False):
+ super().__init__()
+ self.emb = nn.Embedding(n_atom, hidden)
+ self.rni, self.use_rwse = rni, use_rwse
+ din = hidden + rni + (RWSE_K if use_rwse else 0)
+ self.lin_in = nn.Linear(din, hidden)
+ self.convs, self.bns = nn.ModuleList(), nn.ModuleList()
+ for _ in range(layers):
+ if conv == 'gin':
+ self.convs.append(GINConv(nn.Sequential(
+ nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True))
+ else:
+ self.convs.append(GCNConv(hidden, hidden))
+ self.bns.append(nn.BatchNorm1d(hidden))
+ self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 2))
+
+ def forward(self, x, edge_index, batch, rwse=None):
+ h = self.emb(x)
+ parts = [h]
+ if self.use_rwse:
+ parts.append(rwse)
+ if self.rni:
+ parts.append(torch.randn(h.size(0), self.rni, device=h.device))
+ h = self.lin_in(torch.cat(parts, dim=1))
+ for conv, bn in zip(self.convs, self.bns):
+ h = bn(conv(h, edge_index)).relu()
+ return self.head(global_add_pool(h, batch))
+
+
+def to_loader(recs, bs, shuffle, drop_last=False):
+ data = [Data(x=r['x'], edge_index=r['edge_index'], rwse=r['rwse'],
+ y=r['y'].view(1, 2), num_nodes=r['x'].numel()) for r in recs]
+ return DataLoader(data, batch_size=bs, shuffle=shuffle, drop_last=drop_last)
+
+
+@torch.no_grad()
+def eval_mae(model, loader, dev, ymu, ysd, samples=1):
+ model.eval(); abs_err = torch.zeros(2); n = 0
+ for b in loader:
+ b = b.to(dev)
+ ps = torch.stack([model(b.x, b.edge_index, b.batch, b.rwse) for _ in range(samples)]).mean(0)
+ pr = ps * ysd.to(dev) + ymu.to(dev) # un-standardize -> raw ring units
+ yr = b.y * ysd.to(dev) + ymu.to(dev)
+ abs_err += (pr - yr).abs().sum(0).cpu(); n += b.num_graphs
+ return (abs_err / n).tolist()
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--conv', choices=['gin', 'gcn'], default='gin')
+ ap.add_argument('--feat', choices=['none', 'rni', 'rwse'], default='none')
+ ap.add_argument('--layers', type=int, default=5)
+ ap.add_argument('--hidden', type=int, default=128)
+ ap.add_argument('--rni_dim', type=int, default=16)
+ ap.add_argument('--epochs', type=int, default=200)
+ ap.add_argument('--lr', type=float, default=1e-3)
+ ap.add_argument('--bs', type=int, default=128)
+ ap.add_argument('--seed', type=int, default=0)
+ args = ap.parse_args()
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ os.makedirs(OUT, exist_ok=True)
+
+ tr, va, te = prepare('train'), prepare('val'), prepare('test')
+ n_atom = int(max(r['x'].max() for r in tr + va + te)) + 1
+ Ytr = torch.stack([r['y'] for r in tr])
+ ymu, ysd = Ytr.mean(0), Ytr.std(0) + 1e-8
+ for recs in (tr, va, te):
+ for r in recs:
+ r['y'] = (r['y'] - ymu) / ysd
+
+ rni = args.rni_dim if args.feat == 'rni' else 0
+ use_rwse = args.feat == 'rwse'
+ samples = 8 if rni else 1
+ model = Net(n_atom, args.hidden, args.layers, args.conv, rni, use_rwse).to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs)
+ lossf = nn.L1Loss()
+ trl = to_loader(tr, args.bs, True, drop_last=True)
+ trl_e, val, tel = to_loader(tr, 256, False), to_loader(va, 256, False), to_loader(te, 256, False)
+
+ t0 = time.time(); best_val = 9e9; best = {}
+ for ep in range(args.epochs):
+ model.train()
+ for b in trl:
+ b = b.to(dev); opt.zero_grad()
+ loss = lossf(model(b.x, b.edge_index, b.batch, b.rwse), b.y)
+ loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
+ sched.step()
+ if (ep + 1) % 20 == 0 or ep == args.epochs - 1:
+ vm = eval_mae(model, val, dev, ymu, ysd, samples)
+ if sum(vm) < best_val:
+ best_val = sum(vm)
+ best = {'ep': ep + 1, 'train_mae': eval_mae(model, trl_e, dev, ymu, ysd, samples),
+ 'val_mae': vm, 'test_mae': eval_mae(model, tel, dev, ymu, ysd, samples)}
+ print(f"ep{ep+1} val_mae(c5,c6)={[round(x,3) for x in vm]}", flush=True)
+
+ tag = f"{args.conv}_{args.feat}_L{args.layers}_s{args.seed}"
+ rep = {'dataset': 'ZINC-cycle56', 'tag': tag, **vars(args),
+ 'y_std_raw': ysd.tolist(), 'sec': round(time.time() - t0, 1), 'dev': dev, **best}
+ tm = best.get('test_mae'); trm = best.get('train_mae')
+ print(f"[{tag}] train_mae(c5,c6)={[round(x,3) for x in trm]} test_mae={[round(x,3) for x in tm]} "
+ f"(raw rings; std={ [round(x,2) for x in ysd.tolist()] }) @ep{best.get('ep')} ({rep['sec']}s)")
+ with open(os.path.join(OUT, f"cyc_{tag}.json"), 'w') as f:
+ json.dump(rep, f, indent=2)
+ print(" wrote", os.path.join(OUT, f"cyc_{tag}.json"))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diag/train_rec.py b/diag/train_rec.py
new file mode 100644
index 0000000..9db7eb1
--- /dev/null
+++ b/diag/train_rec.py
@@ -0,0 +1,491 @@
+"""Step-2: RRoG/TRM-on-GNN for ZINC ring-counting.
+
+The graph is encoded once with a GIN encoder. A shared edge-free node-wise compute block then
+refines hidden state over n_sup*T recurrent steps (TRM-style: carry latent detached between
+deep-supervision steps). --grad_mode controls the LAST supervision step's recursion:
+ full : backprop through all T inner recursions (TRM)
+ 1step : backprop only the last inner recursion, first T-1 detached (HRM 1-step-gradient)
+
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_rec.py --grad_mode full --sigma 0 --K 1
+"""
+import argparse, json, os, time
+import numpy as np
+import torch
+import torch.nn as nn
+from torch_geometric.loader import DataLoader
+from torch_geometric.data import Batch, Data
+from torch_geometric.nn import (
+ APPNP,
+ ARMAConv,
+ ChebConv,
+ FiLMConv,
+ GATv2Conv,
+ GCNConv,
+ GENConv,
+ GINEConv,
+ GINConv,
+ GraphConv,
+ MFConv,
+ PNAConv,
+ ResGatedGraphConv,
+ SAGEConv,
+ SGConv,
+ TAGConv,
+ TransformerConv,
+ global_add_pool,
+)
+from torch_geometric.utils import degree
+from diag.train_cycle import prepare
+
+PROJECT_ROOT = os.environ.get(
+ 'RROG_ROOT',
+ os.path.abspath(os.path.join(os.path.dirname(__file__), '..')),
+)
+OUT = os.environ.get('RROG_RUNS_DIR', os.path.join(PROJECT_ROOT, 'runs'))
+SUPPORTED_VIEWS = [
+ 'gin', 'gine', 'gcn', 'graphsage', 'gatv2', 'graphconv', 'transformer', 'pna',
+ 'gen', 'film', 'resgated', 'tag', 'sgc', 'cheb', 'arma', 'mf', 'appnp',
+]
+
+
+def data_list(recs):
+ return [Data(x=r['x'], edge_index=r['edge_index'], y=r['y'].view(1, 2),
+ num_nodes=r['x'].numel()) for r in recs]
+
+
+def loader(recs, bs, shuffle, drop_last=False):
+ data = recs if recs and isinstance(recs[0], Data) else data_list(recs)
+ return DataLoader(data, batch_size=bs, shuffle=shuffle, drop_last=drop_last)
+
+
+def degree_histogram(data):
+ max_degree = 0
+ degs = []
+ for graph in data:
+ deg = degree(graph.edge_index[1], num_nodes=graph.num_nodes, dtype=torch.long)
+ degs.append(deg)
+ if deg.numel():
+ max_degree = max(max_degree, int(deg.max().item()))
+ hist = torch.zeros(max_degree + 1, dtype=torch.long)
+ for deg in degs:
+ hist += torch.bincount(deg, minlength=hist.numel())
+ return hist
+
+
+def make_view_layer(view, hidden, deg):
+ if view == 'gin':
+ return GINConv(nn.Sequential(
+ nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True)
+ if view == 'gine':
+ return GINEConv(nn.Sequential(
+ nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)),
+ train_eps=True, edge_dim=hidden)
+ if view == 'gcn':
+ return GCNConv(hidden, hidden)
+ if view == 'graphsage':
+ return SAGEConv(hidden, hidden)
+ if view == 'gatv2':
+ return GATv2Conv(hidden, hidden, heads=4, concat=False)
+ if view == 'graphconv':
+ return GraphConv(hidden, hidden)
+ if view == 'transformer':
+ return TransformerConv(hidden, hidden, heads=4, concat=False)
+ if view == 'pna':
+ if deg is None:
+ raise ValueError('PNA view requires a training-set degree histogram')
+ return PNAConv(
+ hidden, hidden,
+ aggregators=['mean', 'min', 'max', 'std'],
+ scalers=['identity', 'amplification', 'attenuation'],
+ deg=deg,
+ )
+ if view == 'gen':
+ return GENConv(hidden, hidden)
+ if view == 'film':
+ return FiLMConv(hidden, hidden)
+ if view == 'resgated':
+ return ResGatedGraphConv(hidden, hidden)
+ if view == 'tag':
+ return TAGConv(hidden, hidden, K=3)
+ if view == 'sgc':
+ return SGConv(hidden, hidden, K=2, cached=False)
+ if view == 'cheb':
+ return ChebConv(hidden, hidden, K=3)
+ if view == 'arma':
+ return ARMAConv(hidden, hidden, num_stacks=1, num_layers=2)
+ if view == 'mf':
+ return MFConv(hidden, hidden)
+ if view == 'appnp':
+ return APPNP(K=5, alpha=0.1)
+ raise ValueError(f'unsupported view: {view}')
+
+
+class RecGIN(nn.Module):
+ def __init__(self, n_atom, hidden=128, T=3, n_sup=3, sigma=0.0, inner=2,
+ grad_mode='full', agg_layers=5, compute_layers=None, view='gin', deg=None):
+ super().__init__()
+ self.view = view
+ self.agg_layers = agg_layers
+ self.compute_layers = compute_layers or inner
+ self.emb = nn.Embedding(n_atom, hidden)
+ self.edge_emb = nn.Embedding(1, hidden) if view == 'gine' else None
+ self.agg_convs = nn.ModuleList()
+ for _ in range(agg_layers):
+ self.agg_convs.append(make_view_layer(view, hidden, deg))
+ self.agg_bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(agg_layers)])
+ core = []
+ d = hidden
+ for _ in range(self.compute_layers - 1):
+ core += [nn.Linear(d, hidden), nn.GELU()]
+ d = hidden
+ core.append(nn.Linear(d, hidden))
+ self.core_norm = nn.LayerNorm(hidden)
+ self.core = nn.Sequential(*core)
+ nn.init.zeros_(self.core[-1].weight)
+ nn.init.zeros_(self.core[-1].bias)
+ self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 2))
+ self.qhead = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 1))
+ with torch.no_grad():
+ self.qhead[-1].weight.zero_()
+ self.qhead[-1].bias.fill_(-5.0)
+ self.T, self.n_sup, self.sigma, self.grad_mode = T, n_sup, sigma, grad_mode
+
+ def aggregate(self, x, ei):
+ h = self.emb(x)
+ for conv, bn in zip(self.agg_convs, self.agg_bns):
+ if self.view == 'gine':
+ edge_attr = self.edge_emb(torch.zeros(ei.size(1), dtype=torch.long, device=ei.device))
+ h = bn(conv(h, ei, edge_attr)).relu()
+ else:
+ h = bn(conv(h, ei)).relu()
+ return h
+
+ def core_step(self, combined, state):
+ """Shared TRM compute core. Deliberately edge-free."""
+ return state + self.core(self.core_norm(combined))
+
+ def _z_step(self, y, z, ctx, noise):
+ z = self.core_step(ctx + y + z, z)
+ if noise and self.sigma > 0:
+ z = z + self.sigma * torch.randn_like(z)
+ return z
+
+ def _y_step(self, y, z, noise):
+ y = self.core_step(y + z, y)
+ if noise and self.sigma > 0:
+ y = y + self.sigma * torch.randn_like(y)
+ return y
+
+ def recurse(self, y, z, ctx, noise, one_step=False):
+ if self.T == 0:
+ return y, z
+ if one_step: # HRM 1-step gradient
+ with torch.no_grad():
+ for _ in range(self.T - 1):
+ z = self._z_step(y, z, ctx, noise)
+ z = z.detach()
+ z = self._z_step(y, z, ctx, noise) # only last inner carries grad
+ y = self._y_step(y, z, noise)
+ return y, z
+ for _ in range(self.T): # TRM full recursion
+ z = self._z_step(y, z, ctx, noise)
+ y = self._y_step(y, z, noise)
+ return y, z
+
+ def predict(self, y, batch):
+ pooled = global_add_pool(y, batch)
+ return self.head(pooled), self.qhead(pooled).view(-1)
+
+ def forward_trace(self, x, ei, batch, steps, noise=False):
+ ctx = self.aggregate(x, ei)
+ y = ctx
+ z = torch.zeros_like(ctx)
+ preds, q_logits = [], []
+ for s in range(steps):
+ y, z = self.recurse(y, z, ctx, noise, one_step=(self.grad_mode == '1step'))
+ pred, q = self.predict(y, batch)
+ preds.append(pred)
+ q_logits.append(q)
+ if s < steps - 1:
+ y, z = y.detach(), z.detach()
+ return preds, q_logits
+
+ def forward(self, x, ei, batch, noise=False):
+ ctx = self.aggregate(x, ei)
+ y = ctx
+ z = torch.zeros_like(ctx)
+ preds = []
+ for s in range(self.n_sup):
+ if s < self.n_sup - 1:
+ with torch.no_grad():
+ y, z = self.recurse(y, z, ctx, noise)
+ y, z = y.detach(), z.detach()
+ else:
+ y, z = self.recurse(y, z, ctx, noise, one_step=(self.grad_mode == '1step'))
+ pred, _ = self.predict(y, batch)
+ preds.append(pred)
+ _, q = self.predict(y, batch)
+ return preds, q
+
+
+@torch.no_grad()
+def evaluate(model, ld, dev, ymu, ysd, K=1, select='none'):
+ model.eval()
+ ysd_d, ymu_d = ysd.to(dev), ymu.to(dev)
+ ae = torch.zeros(2); ae_or = torch.zeros(2); n = 0
+ for b in ld:
+ b = b.to(dev)
+ if K == 1:
+ preds, _ = model(b.x, b.edge_index, b.batch, noise=model.sigma > 0)
+ chosen = oracle = preds[-1]
+ else:
+ P, Q = [], []
+ for _ in range(K):
+ preds, q = model(b.x, b.edge_index, b.batch, noise=True)
+ P.append(preds[-1]); Q.append(q)
+ P = torch.stack(P); Q = torch.stack(Q)
+ ar = torch.arange(P.size(1), device=dev)
+ chosen = P[Q.argmax(0), ar] if select == 'bestq' else P.mean(0)
+ oracle = P[(P - b.y.unsqueeze(0)).abs().sum(-1).argmin(0), ar]
+ ae += ((chosen * ysd_d + ymu_d) - (b.y * ysd_d + ymu_d)).abs().sum(0).cpu()
+ ae_or += ((oracle * ysd_d + ymu_d) - (b.y * ysd_d + ymu_d)).abs().sum(0).cpu()
+ n += b.num_graphs
+ return (ae / n).tolist(), (ae_or / n).tolist()
+
+
+@torch.no_grad()
+def evaluate_trace(model, ld, dev, ymu, ysd, steps, adaptive=False):
+ model.eval()
+ ysd_d, ymu_d = ysd.to(dev), ymu.to(dev)
+ ae = torch.zeros(2)
+ n = 0
+ step_sum = 0.0
+ for b in ld:
+ b = b.to(dev)
+ preds, q_logits = model.forward_trace(b.x, b.edge_index, b.batch, steps, noise=False)
+ P = torch.stack(preds, dim=0)
+ if adaptive:
+ Q = torch.stack(q_logits, dim=0)
+ halted = Q > 0
+ any_halt = halted.any(dim=0)
+ first_halt = halted.to(torch.int64).argmax(dim=0)
+ fallback = torch.full_like(first_halt, steps - 1)
+ idx = torch.where(any_halt, first_halt, fallback)
+ chosen = P[idx, torch.arange(P.size(1), device=dev)]
+ step_sum += (idx.to(torch.float32) + 1).sum().item()
+ else:
+ chosen = P[-1]
+ step_sum += steps * b.num_graphs
+ ae += ((chosen * ysd_d + ymu_d) - (b.y * ysd_d + ymu_d)).abs().sum(0).cpu()
+ n += b.num_graphs
+ return (ae / n).tolist(), step_sum / max(n, 1)
+
+
+def _split_nodes(t, ptr):
+ return [t[ptr[i].item():ptr[i + 1].item()].detach() for i in range(ptr.numel() - 1)]
+
+
+def act_train_step(model, state, replacement_batch, opt, dev, args):
+ replacement = replacement_batch.to_data_list()
+ batch_size = len(replacement)
+ if state is None:
+ state = {
+ 'graphs': [None for _ in range(batch_size)],
+ 'y': [None for _ in range(batch_size)],
+ 'z': [None for _ in range(batch_size)],
+ 'steps': torch.zeros(batch_size, dtype=torch.long, device=dev),
+ 'halted': torch.ones(batch_size, dtype=torch.bool, device=dev),
+ }
+
+ halted_cpu = state['halted'].detach().cpu().tolist()
+ for i, halted in enumerate(halted_cpu):
+ if halted:
+ state['graphs'][i] = replacement[i]
+
+ b = Batch.from_data_list(state['graphs']).to(dev)
+ ctx = model.aggregate(b.x, b.edge_index)
+ ptr = b.ptr
+ y_parts, z_parts = [], []
+ for i in range(batch_size):
+ start, end = ptr[i].item(), ptr[i + 1].item()
+ if halted_cpu[i] or state['y'][i] is None:
+ y_parts.append(ctx[start:end])
+ z_parts.append(torch.zeros_like(ctx[start:end]))
+ else:
+ y_parts.append(state['y'][i].to(dev))
+ z_parts.append(state['z'][i].to(dev))
+ y = torch.cat(y_parts, dim=0)
+ z = torch.cat(z_parts, dim=0)
+
+ opt.zero_grad()
+ y, z = model.recurse(y, z, ctx, noise=False, one_step=(model.grad_mode == '1step'))
+ pred, q = model.predict(y, b.batch)
+ per_graph_err = (pred - b.y).abs().mean(1)
+ pred_loss = per_graph_err.mean()
+ with torch.no_grad():
+ if args.halt_target == 'binary':
+ halt_target = (per_graph_err <= args.halt_norm_threshold).to(q.dtype)
+ else:
+ halt_target = torch.sigmoid((args.halt_norm_threshold - per_graph_err) / args.halt_temp)
+ q_loss = nn.functional.binary_cross_entropy_with_logits(q, halt_target)
+ loss = pred_loss + 0.5 * args.lam_q * q_loss
+ y_det, z_det = y.detach(), z.detach()
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
+ opt.step()
+
+ state['y'] = _split_nodes(y_det, ptr)
+ state['z'] = _split_nodes(z_det, ptr)
+ with torch.no_grad():
+ was_halted = state['halted']
+ steps = torch.where(was_halted, torch.zeros_like(state['steps']), state['steps']) + 1
+ halted = (steps >= args.halt_max_steps) | (q.detach() > 0)
+ if args.halt_exploration_prob > 0 and args.halt_max_steps > 1:
+ explore = torch.rand_like(q) < args.halt_exploration_prob
+ min_steps = torch.where(
+ explore,
+ torch.randint(2, args.halt_max_steps + 1, steps.shape, device=dev),
+ torch.zeros_like(steps),
+ )
+ halted = halted & (steps >= min_steps)
+ state['steps'] = steps
+ state['halted'] = halted
+
+ return state, {
+ 'loss': float(loss.detach().cpu()),
+ 'pred_loss': float(pred_loss.detach().cpu()),
+ 'q_loss': float(q_loss.detach().cpu()),
+ 'halted_frac': float(state['halted'].to(torch.float32).mean().detach().cpu()),
+ 'steps': float(state['steps'].to(torch.float32).mean().detach().cpu()),
+ }
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--grad_mode', choices=['full', '1step'], default='full')
+ ap.add_argument('--sigma', type=float, default=0.0)
+ ap.add_argument('--K', type=int, default=1)
+ ap.add_argument('--select', choices=['none', 'bestq'], default='bestq')
+ ap.add_argument('--T', type=int, default=3)
+ ap.add_argument('--n_sup', type=int, default=3)
+ ap.add_argument('--hidden', type=int, default=128)
+ ap.add_argument('--agg_layers', type=int, default=5)
+ ap.add_argument('--compute_layers', type=int, default=2)
+ ap.add_argument('--view', choices=SUPPORTED_VIEWS, default='gin')
+ ap.add_argument('--epochs', type=int, default=200)
+ ap.add_argument('--lr', type=float, default=1e-3)
+ ap.add_argument('--bs', type=int, default=128)
+ ap.add_argument('--lam_q', type=float, default=1.0)
+ ap.add_argument('--act', action='store_true',
+ help='train all recurrent depths up to halt_max_steps and train qhead as a halt head')
+ ap.add_argument('--halt_max_steps', type=int, default=8)
+ ap.add_argument('--halt_norm_threshold', type=float, default=0.30)
+ ap.add_argument('--halt_temp', type=float, default=0.10)
+ ap.add_argument('--halt_target', choices=['soft', 'binary'], default='soft')
+ ap.add_argument('--halt_exploration_prob', type=float, default=0.1)
+ ap.add_argument('--loss_mode', choices=['last', 'trace'], default='trace')
+ ap.add_argument('--seed', type=int, default=0)
+ ap.add_argument('--device', default='auto')
+ args = ap.parse_args()
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ dev = 'cuda' if args.device == 'auto' and torch.cuda.is_available() else (
+ 'cpu' if args.device == 'auto' else args.device)
+ os.makedirs(OUT, exist_ok=True)
+
+ tr, va, te = prepare('train'), prepare('val'), prepare('test')
+ n_atom = int(max(r['x'].max() for r in tr + va + te)) + 1
+ Ytr = torch.stack([r['y'] for r in tr]); ymu, ysd = Ytr.mean(0), Ytr.std(0) + 1e-8
+ for recs in (tr, va, te):
+ for r in recs:
+ r['y'] = (r['y'] - ymu) / ysd
+ train_data = data_list(tr)
+ trl = loader(train_data, args.bs, True, drop_last=True)
+ val, tel = loader(va, 256, False), loader(te, 256, False)
+
+ deg = degree_histogram(train_data) if args.view == 'pna' else None
+ model = RecGIN(n_atom, args.hidden, args.T, args.n_sup, args.sigma, grad_mode=args.grad_mode,
+ agg_layers=args.agg_layers, compute_layers=args.compute_layers,
+ view=args.view, deg=deg).to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs)
+ l1 = nn.L1Loss()
+ act_steps = max(1, args.halt_max_steps)
+
+ t0 = time.time(); best_val = 9e9; best = {}; best_state = None; act_state = None
+ for ep in range(args.epochs):
+ model.train()
+ act_metrics = []
+ for b in trl:
+ if args.act:
+ act_state, metrics = act_train_step(model, act_state, b, opt, dev, args)
+ act_metrics.append(metrics)
+ else:
+ b = b.to(dev); opt.zero_grad()
+ if args.loss_mode == 'trace':
+ preds, q_logits = model.forward_trace(
+ b.x, b.edge_index, b.batch, args.n_sup, noise=model.sigma > 0)
+ q = q_logits[-1]
+ else:
+ preds, q = model(b.x, b.edge_index, b.batch, noise=model.sigma > 0)
+ loss = sum(l1(p, b.y) for p in preds) / len(preds)
+ with torch.no_grad():
+ tq = -(preds[-1] - b.y).abs().mean(1)
+ loss = loss + args.lam_q * nn.functional.mse_loss(q, tq)
+ loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
+ sched.step()
+ if (ep + 1) % 20 == 0 or ep == args.epochs - 1:
+ if args.act:
+ vm, _ = evaluate_trace(model, val, dev, ymu, ysd, act_steps, adaptive=False)
+ else:
+ vm, _ = evaluate(model, val, dev, ymu, ysd, args.K, args.select)
+ if sum(vm) < best_val:
+ best_val = sum(vm)
+ if args.act:
+ tem, fixed_steps = evaluate_trace(model, tel, dev, ymu, ysd, act_steps, adaptive=False)
+ tea, adaptive_steps = evaluate_trace(model, tel, dev, ymu, ysd, act_steps, adaptive=True)
+ best = {'ep': ep + 1, 'val_mae': vm, 'test_mae': tem,
+ 'test_mae_adaptive': tea, 'fixed_steps': fixed_steps,
+ 'adaptive_steps': adaptive_steps}
+ else:
+ tem, teo = evaluate(model, tel, dev, ymu, ysd, args.K, args.select)
+ best = {'ep': ep + 1, 'val_mae': vm, 'test_mae': tem, 'test_mae_oracle': teo}
+ best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
+ if args.act and act_metrics:
+ hm = sum(m['halted_frac'] for m in act_metrics) / len(act_metrics)
+ sm = sum(m['steps'] for m in act_metrics) / len(act_metrics)
+ print(f"ep{ep+1} val_mae={[round(x,3) for x in vm]} halt={hm:.2f} train_steps={sm:.2f}", flush=True)
+ else:
+ print(f"ep{ep+1} val_mae={[round(x,3) for x in vm]}", flush=True)
+
+ act_tag = f"_actfull{act_steps}_{args.halt_target}{args.halt_norm_threshold:g}_e{args.epochs}" if args.act else ""
+ loss_tag = f"_{args.loss_mode}" if (not args.act and args.loss_mode != 'last') else ""
+ view_tag = f"_{args.view}" if args.view != 'gin' else ""
+ tag = f"rec_rrog{view_tag}_{args.grad_mode}_sig{args.sigma}_K{args.K}_{args.select}_T{args.T}_ns{args.n_sup}{loss_tag}{act_tag}_s{args.seed}"
+ rep = {'dataset': 'ZINC-cycle56', 'tag': tag, **vars(args), 'sec': round(time.time() - t0, 1),
+ 'dev': dev, 'arch': 'rrog_once_agg_node_compute', 'y_std_raw': ysd.tolist(), **best}
+ if args.act:
+ print(f"[{tag}] test_mae={[round(x,3) for x in best.get('test_mae')]} "
+ f"adaptive={[round(x,3) for x in best.get('test_mae_adaptive')]} "
+ f"steps={best.get('adaptive_steps'):.2f}/{best.get('fixed_steps'):.2f} "
+ f"@ep{best.get('ep')} ({rep['sec']}s)")
+ else:
+ print(f"[{tag}] test_mae={[round(x,3) for x in best.get('test_mae')]} "
+ f"oracle@K={[round(x,3) for x in best.get('test_mae_oracle')]} @ep{best.get('ep')} ({rep['sec']}s)")
+ with open(os.path.join(OUT, f"{tag}.json"), 'w') as f:
+ json.dump(rep, f, indent=2)
+ torch.save({'state': best_state or model.state_dict(),
+ 'cfg': {'n_atom': n_atom, 'hidden': args.hidden, 'T': args.T, 'n_sup': args.n_sup,
+ 'sigma': args.sigma, 'grad_mode': args.grad_mode,
+ 'agg_layers': args.agg_layers, 'compute_layers': args.compute_layers,
+ 'view': args.view,
+ 'loss_mode': args.loss_mode,
+ 'act': args.act, 'act_impl': 'persistent_recycle' if args.act else 'none',
+ 'halt_max_steps': act_steps,
+ 'halt_exploration_prob': args.halt_exploration_prob,
+ 'arch': 'rrog_once_agg_node_compute'},
+ 'ymu': ymu, 'ysd': ysd}, os.path.join(OUT, f"ckpt_{tag}.pt"))
+ print(" wrote", os.path.join(OUT, f"ckpt_{tag}.pt"))
+
+
+if __name__ == "__main__":
+ main()