diff options
Diffstat (limited to 'diag/train_cycle.py')
| -rw-r--r-- | diag/train_cycle.py | 188 |
1 files changed, 188 insertions, 0 deletions
diff --git a/diag/train_cycle.py b/diag/train_cycle.py new file mode 100644 index 0000000..598e349 --- /dev/null +++ b/diag/train_cycle.py @@ -0,0 +1,188 @@ +"""GNN-native ring-counting on real molecules (ZINC): regress [#5-cycles, #6-cycles]. + +k-cycle counts (k>=3) are provably NOT computable by 1-WL/MPNN (Chen et al. 2020) -> a REAL +H2 ceiling on REAL graphs. Training-based diagnosis (partition instrument is vacuous on +feature-rich graphs): + GIN(L) 1-WL baseline -> should FAIL to count + GCN(L) sub-1-WL reference + GIN+RNI random feats = NOISE -> PTRM-style crude symmetry break (eval-averaged) + GIN+RWSE random-walk return probs-> structured >1-WL positive control +Reads: GIN high error + RWSE fixes it = real ceiling exists; RNI also fixes = crude noise +breaks it (bridge cashed); only RWSE = bridge needs STRUCTURED stochasticity (GRAM>PTRM). +Targets z-scored for training; per-target MAE reported in RAW ring units. + +Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_cycle.py --conv gin --feat none +""" +import argparse, json, os, time +import numpy as np +import torch +import torch.nn as nn +import networkx as nx +from torch_geometric.datasets import ZINC +from torch_geometric.data import Data +from torch_geometric.loader import DataLoader +from torch_geometric.utils import to_networkx +from torch_geometric.nn import GINConv, GCNConv, global_add_pool + +PROJECT_ROOT = os.environ.get( + 'RROG_ROOT', + os.path.abspath(os.path.join(os.path.dirname(__file__), '..')), +) +DATA_ROOT = os.environ.get('RROG_DATA_DIR', os.path.join(PROJECT_ROOT, 'data')) +OUT = os.environ.get('RROG_RUNS_DIR', os.path.join(PROJECT_ROOT, 'runs')) +ROOT = os.path.join(DATA_ROOT, 'zinc') +CACHE = os.path.join(DATA_ROOT, 'cycle_cache') +RWSE_K = 16 + + +def rwse(edge_index, n, K=RWSE_K): + A = np.zeros((n, n), dtype=np.float64) + ei = edge_index.numpy() + A[ei[0], ei[1]] = 1.0 + A = np.maximum(A, A.T) + deg = A.sum(1) + P = A / np.where(deg > 0, deg, 1.0)[:, None] + out = np.zeros((n, K), dtype=np.float32) + M = np.eye(n) + for k in range(K): + M = M @ P + out[:, k] = np.diag(M) + return torch.from_numpy(out) + + +def c56(data): + G = to_networkx(data, to_undirected=True) + c = {5: 0, 6: 0} + for cyc in nx.simple_cycles(G, length_bound=6): + L = len(cyc) + if L in c: + c[L] += 1 + return [float(c[5]), float(c[6])] + + +def prepare(split): + os.makedirs(CACHE, exist_ok=True) + fp = os.path.join(CACHE, f"{split}.pt") + if os.path.exists(fp): + return torch.load(fp, weights_only=False) + ds = ZINC(ROOT, subset=True, split=split) + out = [] + for g in ds: + out.append({'x': g.x.view(-1).long(), 'edge_index': g.edge_index, + 'rwse': rwse(g.edge_index, g.num_nodes), + 'y': torch.tensor(c56(g), dtype=torch.float)}) + torch.save(out, fp) + return out + + +class Net(nn.Module): + def __init__(self, n_atom, hidden, layers, conv='gin', rni=0, use_rwse=False): + super().__init__() + self.emb = nn.Embedding(n_atom, hidden) + self.rni, self.use_rwse = rni, use_rwse + din = hidden + rni + (RWSE_K if use_rwse else 0) + self.lin_in = nn.Linear(din, hidden) + self.convs, self.bns = nn.ModuleList(), nn.ModuleList() + for _ in range(layers): + if conv == 'gin': + self.convs.append(GINConv(nn.Sequential( + nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True)) + else: + self.convs.append(GCNConv(hidden, hidden)) + self.bns.append(nn.BatchNorm1d(hidden)) + self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 2)) + + def forward(self, x, edge_index, batch, rwse=None): + h = self.emb(x) + parts = [h] + if self.use_rwse: + parts.append(rwse) + if self.rni: + parts.append(torch.randn(h.size(0), self.rni, device=h.device)) + h = self.lin_in(torch.cat(parts, dim=1)) + for conv, bn in zip(self.convs, self.bns): + h = bn(conv(h, edge_index)).relu() + return self.head(global_add_pool(h, batch)) + + +def to_loader(recs, bs, shuffle, drop_last=False): + data = [Data(x=r['x'], edge_index=r['edge_index'], rwse=r['rwse'], + y=r['y'].view(1, 2), num_nodes=r['x'].numel()) for r in recs] + return DataLoader(data, batch_size=bs, shuffle=shuffle, drop_last=drop_last) + + +@torch.no_grad() +def eval_mae(model, loader, dev, ymu, ysd, samples=1): + model.eval(); abs_err = torch.zeros(2); n = 0 + for b in loader: + b = b.to(dev) + ps = torch.stack([model(b.x, b.edge_index, b.batch, b.rwse) for _ in range(samples)]).mean(0) + pr = ps * ysd.to(dev) + ymu.to(dev) # un-standardize -> raw ring units + yr = b.y * ysd.to(dev) + ymu.to(dev) + abs_err += (pr - yr).abs().sum(0).cpu(); n += b.num_graphs + return (abs_err / n).tolist() + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--conv', choices=['gin', 'gcn'], default='gin') + ap.add_argument('--feat', choices=['none', 'rni', 'rwse'], default='none') + ap.add_argument('--layers', type=int, default=5) + ap.add_argument('--hidden', type=int, default=128) + ap.add_argument('--rni_dim', type=int, default=16) + ap.add_argument('--epochs', type=int, default=200) + ap.add_argument('--lr', type=float, default=1e-3) + ap.add_argument('--bs', type=int, default=128) + ap.add_argument('--seed', type=int, default=0) + args = ap.parse_args() + torch.manual_seed(args.seed); np.random.seed(args.seed) + dev = 'cuda' if torch.cuda.is_available() else 'cpu' + os.makedirs(OUT, exist_ok=True) + + tr, va, te = prepare('train'), prepare('val'), prepare('test') + n_atom = int(max(r['x'].max() for r in tr + va + te)) + 1 + Ytr = torch.stack([r['y'] for r in tr]) + ymu, ysd = Ytr.mean(0), Ytr.std(0) + 1e-8 + for recs in (tr, va, te): + for r in recs: + r['y'] = (r['y'] - ymu) / ysd + + rni = args.rni_dim if args.feat == 'rni' else 0 + use_rwse = args.feat == 'rwse' + samples = 8 if rni else 1 + model = Net(n_atom, args.hidden, args.layers, args.conv, rni, use_rwse).to(dev) + opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5) + sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs) + lossf = nn.L1Loss() + trl = to_loader(tr, args.bs, True, drop_last=True) + trl_e, val, tel = to_loader(tr, 256, False), to_loader(va, 256, False), to_loader(te, 256, False) + + t0 = time.time(); best_val = 9e9; best = {} + for ep in range(args.epochs): + model.train() + for b in trl: + b = b.to(dev); opt.zero_grad() + loss = lossf(model(b.x, b.edge_index, b.batch, b.rwse), b.y) + loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step() + sched.step() + if (ep + 1) % 20 == 0 or ep == args.epochs - 1: + vm = eval_mae(model, val, dev, ymu, ysd, samples) + if sum(vm) < best_val: + best_val = sum(vm) + best = {'ep': ep + 1, 'train_mae': eval_mae(model, trl_e, dev, ymu, ysd, samples), + 'val_mae': vm, 'test_mae': eval_mae(model, tel, dev, ymu, ysd, samples)} + print(f"ep{ep+1} val_mae(c5,c6)={[round(x,3) for x in vm]}", flush=True) + + tag = f"{args.conv}_{args.feat}_L{args.layers}_s{args.seed}" + rep = {'dataset': 'ZINC-cycle56', 'tag': tag, **vars(args), + 'y_std_raw': ysd.tolist(), 'sec': round(time.time() - t0, 1), 'dev': dev, **best} + tm = best.get('test_mae'); trm = best.get('train_mae') + print(f"[{tag}] train_mae(c5,c6)={[round(x,3) for x in trm]} test_mae={[round(x,3) for x in tm]} " + f"(raw rings; std={ [round(x,2) for x in ysd.tolist()] }) @ep{best.get('ep')} ({rep['sec']}s)") + with open(os.path.join(OUT, f"cyc_{tag}.json"), 'w') as f: + json.dump(rep, f, indent=2) + print(" wrote", os.path.join(OUT, f"cyc_{tag}.json")) + + +if __name__ == "__main__": + main() |
