summaryrefslogtreecommitdiff
path: root/diag/train_cycle.py
diff options
context:
space:
mode:
Diffstat (limited to 'diag/train_cycle.py')
-rw-r--r--diag/train_cycle.py188
1 files changed, 188 insertions, 0 deletions
diff --git a/diag/train_cycle.py b/diag/train_cycle.py
new file mode 100644
index 0000000..598e349
--- /dev/null
+++ b/diag/train_cycle.py
@@ -0,0 +1,188 @@
+"""GNN-native ring-counting on real molecules (ZINC): regress [#5-cycles, #6-cycles].
+
+k-cycle counts (k>=3) are provably NOT computable by 1-WL/MPNN (Chen et al. 2020) -> a REAL
+H2 ceiling on REAL graphs. Training-based diagnosis (partition instrument is vacuous on
+feature-rich graphs):
+ GIN(L) 1-WL baseline -> should FAIL to count
+ GCN(L) sub-1-WL reference
+ GIN+RNI random feats = NOISE -> PTRM-style crude symmetry break (eval-averaged)
+ GIN+RWSE random-walk return probs-> structured >1-WL positive control
+Reads: GIN high error + RWSE fixes it = real ceiling exists; RNI also fixes = crude noise
+breaks it (bridge cashed); only RWSE = bridge needs STRUCTURED stochasticity (GRAM>PTRM).
+Targets z-scored for training; per-target MAE reported in RAW ring units.
+
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_cycle.py --conv gin --feat none
+"""
+import argparse, json, os, time
+import numpy as np
+import torch
+import torch.nn as nn
+import networkx as nx
+from torch_geometric.datasets import ZINC
+from torch_geometric.data import Data
+from torch_geometric.loader import DataLoader
+from torch_geometric.utils import to_networkx
+from torch_geometric.nn import GINConv, GCNConv, global_add_pool
+
+PROJECT_ROOT = os.environ.get(
+ 'RROG_ROOT',
+ os.path.abspath(os.path.join(os.path.dirname(__file__), '..')),
+)
+DATA_ROOT = os.environ.get('RROG_DATA_DIR', os.path.join(PROJECT_ROOT, 'data'))
+OUT = os.environ.get('RROG_RUNS_DIR', os.path.join(PROJECT_ROOT, 'runs'))
+ROOT = os.path.join(DATA_ROOT, 'zinc')
+CACHE = os.path.join(DATA_ROOT, 'cycle_cache')
+RWSE_K = 16
+
+
+def rwse(edge_index, n, K=RWSE_K):
+ A = np.zeros((n, n), dtype=np.float64)
+ ei = edge_index.numpy()
+ A[ei[0], ei[1]] = 1.0
+ A = np.maximum(A, A.T)
+ deg = A.sum(1)
+ P = A / np.where(deg > 0, deg, 1.0)[:, None]
+ out = np.zeros((n, K), dtype=np.float32)
+ M = np.eye(n)
+ for k in range(K):
+ M = M @ P
+ out[:, k] = np.diag(M)
+ return torch.from_numpy(out)
+
+
+def c56(data):
+ G = to_networkx(data, to_undirected=True)
+ c = {5: 0, 6: 0}
+ for cyc in nx.simple_cycles(G, length_bound=6):
+ L = len(cyc)
+ if L in c:
+ c[L] += 1
+ return [float(c[5]), float(c[6])]
+
+
+def prepare(split):
+ os.makedirs(CACHE, exist_ok=True)
+ fp = os.path.join(CACHE, f"{split}.pt")
+ if os.path.exists(fp):
+ return torch.load(fp, weights_only=False)
+ ds = ZINC(ROOT, subset=True, split=split)
+ out = []
+ for g in ds:
+ out.append({'x': g.x.view(-1).long(), 'edge_index': g.edge_index,
+ 'rwse': rwse(g.edge_index, g.num_nodes),
+ 'y': torch.tensor(c56(g), dtype=torch.float)})
+ torch.save(out, fp)
+ return out
+
+
+class Net(nn.Module):
+ def __init__(self, n_atom, hidden, layers, conv='gin', rni=0, use_rwse=False):
+ super().__init__()
+ self.emb = nn.Embedding(n_atom, hidden)
+ self.rni, self.use_rwse = rni, use_rwse
+ din = hidden + rni + (RWSE_K if use_rwse else 0)
+ self.lin_in = nn.Linear(din, hidden)
+ self.convs, self.bns = nn.ModuleList(), nn.ModuleList()
+ for _ in range(layers):
+ if conv == 'gin':
+ self.convs.append(GINConv(nn.Sequential(
+ nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True))
+ else:
+ self.convs.append(GCNConv(hidden, hidden))
+ self.bns.append(nn.BatchNorm1d(hidden))
+ self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 2))
+
+ def forward(self, x, edge_index, batch, rwse=None):
+ h = self.emb(x)
+ parts = [h]
+ if self.use_rwse:
+ parts.append(rwse)
+ if self.rni:
+ parts.append(torch.randn(h.size(0), self.rni, device=h.device))
+ h = self.lin_in(torch.cat(parts, dim=1))
+ for conv, bn in zip(self.convs, self.bns):
+ h = bn(conv(h, edge_index)).relu()
+ return self.head(global_add_pool(h, batch))
+
+
+def to_loader(recs, bs, shuffle, drop_last=False):
+ data = [Data(x=r['x'], edge_index=r['edge_index'], rwse=r['rwse'],
+ y=r['y'].view(1, 2), num_nodes=r['x'].numel()) for r in recs]
+ return DataLoader(data, batch_size=bs, shuffle=shuffle, drop_last=drop_last)
+
+
+@torch.no_grad()
+def eval_mae(model, loader, dev, ymu, ysd, samples=1):
+ model.eval(); abs_err = torch.zeros(2); n = 0
+ for b in loader:
+ b = b.to(dev)
+ ps = torch.stack([model(b.x, b.edge_index, b.batch, b.rwse) for _ in range(samples)]).mean(0)
+ pr = ps * ysd.to(dev) + ymu.to(dev) # un-standardize -> raw ring units
+ yr = b.y * ysd.to(dev) + ymu.to(dev)
+ abs_err += (pr - yr).abs().sum(0).cpu(); n += b.num_graphs
+ return (abs_err / n).tolist()
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--conv', choices=['gin', 'gcn'], default='gin')
+ ap.add_argument('--feat', choices=['none', 'rni', 'rwse'], default='none')
+ ap.add_argument('--layers', type=int, default=5)
+ ap.add_argument('--hidden', type=int, default=128)
+ ap.add_argument('--rni_dim', type=int, default=16)
+ ap.add_argument('--epochs', type=int, default=200)
+ ap.add_argument('--lr', type=float, default=1e-3)
+ ap.add_argument('--bs', type=int, default=128)
+ ap.add_argument('--seed', type=int, default=0)
+ args = ap.parse_args()
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ os.makedirs(OUT, exist_ok=True)
+
+ tr, va, te = prepare('train'), prepare('val'), prepare('test')
+ n_atom = int(max(r['x'].max() for r in tr + va + te)) + 1
+ Ytr = torch.stack([r['y'] for r in tr])
+ ymu, ysd = Ytr.mean(0), Ytr.std(0) + 1e-8
+ for recs in (tr, va, te):
+ for r in recs:
+ r['y'] = (r['y'] - ymu) / ysd
+
+ rni = args.rni_dim if args.feat == 'rni' else 0
+ use_rwse = args.feat == 'rwse'
+ samples = 8 if rni else 1
+ model = Net(n_atom, args.hidden, args.layers, args.conv, rni, use_rwse).to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs)
+ lossf = nn.L1Loss()
+ trl = to_loader(tr, args.bs, True, drop_last=True)
+ trl_e, val, tel = to_loader(tr, 256, False), to_loader(va, 256, False), to_loader(te, 256, False)
+
+ t0 = time.time(); best_val = 9e9; best = {}
+ for ep in range(args.epochs):
+ model.train()
+ for b in trl:
+ b = b.to(dev); opt.zero_grad()
+ loss = lossf(model(b.x, b.edge_index, b.batch, b.rwse), b.y)
+ loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
+ sched.step()
+ if (ep + 1) % 20 == 0 or ep == args.epochs - 1:
+ vm = eval_mae(model, val, dev, ymu, ysd, samples)
+ if sum(vm) < best_val:
+ best_val = sum(vm)
+ best = {'ep': ep + 1, 'train_mae': eval_mae(model, trl_e, dev, ymu, ysd, samples),
+ 'val_mae': vm, 'test_mae': eval_mae(model, tel, dev, ymu, ysd, samples)}
+ print(f"ep{ep+1} val_mae(c5,c6)={[round(x,3) for x in vm]}", flush=True)
+
+ tag = f"{args.conv}_{args.feat}_L{args.layers}_s{args.seed}"
+ rep = {'dataset': 'ZINC-cycle56', 'tag': tag, **vars(args),
+ 'y_std_raw': ysd.tolist(), 'sec': round(time.time() - t0, 1), 'dev': dev, **best}
+ tm = best.get('test_mae'); trm = best.get('train_mae')
+ print(f"[{tag}] train_mae(c5,c6)={[round(x,3) for x in trm]} test_mae={[round(x,3) for x in tm]} "
+ f"(raw rings; std={ [round(x,2) for x in ysd.tolist()] }) @ep{best.get('ep')} ({rep['sec']}s)")
+ with open(os.path.join(OUT, f"cyc_{tag}.json"), 'w') as f:
+ json.dump(rep, f, indent=2)
+ print(" wrote", os.path.join(OUT, f"cyc_{tag}.json"))
+
+
+if __name__ == "__main__":
+ main()