summaryrefslogtreecommitdiff
path: root/diag/train_cycle.py
blob: 598e349b4a9da74f229ef182db85b5ab03a1a401 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
"""GNN-native ring-counting on real molecules (ZINC): regress [#5-cycles, #6-cycles].

k-cycle counts (k>=3) are provably NOT computable by 1-WL/MPNN (Chen et al. 2020) -> a REAL
H2 ceiling on REAL graphs. Training-based diagnosis (partition instrument is vacuous on
feature-rich graphs):
  GIN(L)        1-WL baseline           -> should FAIL to count
  GCN(L)        sub-1-WL reference
  GIN+RNI       random feats = NOISE    -> PTRM-style crude symmetry break (eval-averaged)
  GIN+RWSE      random-walk return probs-> structured >1-WL positive control
Reads: GIN high error + RWSE fixes it = real ceiling exists; RNI also fixes = crude noise
breaks it (bridge cashed); only RWSE = bridge needs STRUCTURED stochasticity (GRAM>PTRM).
Targets z-scored for training; per-target MAE reported in RAW ring units.

Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_cycle.py --conv gin --feat none
"""
import argparse, json, os, time
import numpy as np
import torch
import torch.nn as nn
import networkx as nx
from torch_geometric.datasets import ZINC
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_networkx
from torch_geometric.nn import GINConv, GCNConv, global_add_pool

PROJECT_ROOT = os.environ.get(
    'RROG_ROOT',
    os.path.abspath(os.path.join(os.path.dirname(__file__), '..')),
)
DATA_ROOT = os.environ.get('RROG_DATA_DIR', os.path.join(PROJECT_ROOT, 'data'))
OUT = os.environ.get('RROG_RUNS_DIR', os.path.join(PROJECT_ROOT, 'runs'))
ROOT = os.path.join(DATA_ROOT, 'zinc')
CACHE = os.path.join(DATA_ROOT, 'cycle_cache')
RWSE_K = 16


def rwse(edge_index, n, K=RWSE_K):
    A = np.zeros((n, n), dtype=np.float64)
    ei = edge_index.numpy()
    A[ei[0], ei[1]] = 1.0
    A = np.maximum(A, A.T)
    deg = A.sum(1)
    P = A / np.where(deg > 0, deg, 1.0)[:, None]
    out = np.zeros((n, K), dtype=np.float32)
    M = np.eye(n)
    for k in range(K):
        M = M @ P
        out[:, k] = np.diag(M)
    return torch.from_numpy(out)


def c56(data):
    G = to_networkx(data, to_undirected=True)
    c = {5: 0, 6: 0}
    for cyc in nx.simple_cycles(G, length_bound=6):
        L = len(cyc)
        if L in c:
            c[L] += 1
    return [float(c[5]), float(c[6])]


def prepare(split):
    os.makedirs(CACHE, exist_ok=True)
    fp = os.path.join(CACHE, f"{split}.pt")
    if os.path.exists(fp):
        return torch.load(fp, weights_only=False)
    ds = ZINC(ROOT, subset=True, split=split)
    out = []
    for g in ds:
        out.append({'x': g.x.view(-1).long(), 'edge_index': g.edge_index,
                    'rwse': rwse(g.edge_index, g.num_nodes),
                    'y': torch.tensor(c56(g), dtype=torch.float)})
    torch.save(out, fp)
    return out


class Net(nn.Module):
    def __init__(self, n_atom, hidden, layers, conv='gin', rni=0, use_rwse=False):
        super().__init__()
        self.emb = nn.Embedding(n_atom, hidden)
        self.rni, self.use_rwse = rni, use_rwse
        din = hidden + rni + (RWSE_K if use_rwse else 0)
        self.lin_in = nn.Linear(din, hidden)
        self.convs, self.bns = nn.ModuleList(), nn.ModuleList()
        for _ in range(layers):
            if conv == 'gin':
                self.convs.append(GINConv(nn.Sequential(
                    nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True))
            else:
                self.convs.append(GCNConv(hidden, hidden))
            self.bns.append(nn.BatchNorm1d(hidden))
        self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 2))

    def forward(self, x, edge_index, batch, rwse=None):
        h = self.emb(x)
        parts = [h]
        if self.use_rwse:
            parts.append(rwse)
        if self.rni:
            parts.append(torch.randn(h.size(0), self.rni, device=h.device))
        h = self.lin_in(torch.cat(parts, dim=1))
        for conv, bn in zip(self.convs, self.bns):
            h = bn(conv(h, edge_index)).relu()
        return self.head(global_add_pool(h, batch))


def to_loader(recs, bs, shuffle, drop_last=False):
    data = [Data(x=r['x'], edge_index=r['edge_index'], rwse=r['rwse'],
                 y=r['y'].view(1, 2), num_nodes=r['x'].numel()) for r in recs]
    return DataLoader(data, batch_size=bs, shuffle=shuffle, drop_last=drop_last)


@torch.no_grad()
def eval_mae(model, loader, dev, ymu, ysd, samples=1):
    model.eval(); abs_err = torch.zeros(2); n = 0
    for b in loader:
        b = b.to(dev)
        ps = torch.stack([model(b.x, b.edge_index, b.batch, b.rwse) for _ in range(samples)]).mean(0)
        pr = ps * ysd.to(dev) + ymu.to(dev)          # un-standardize -> raw ring units
        yr = b.y * ysd.to(dev) + ymu.to(dev)
        abs_err += (pr - yr).abs().sum(0).cpu(); n += b.num_graphs
    return (abs_err / n).tolist()


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--conv', choices=['gin', 'gcn'], default='gin')
    ap.add_argument('--feat', choices=['none', 'rni', 'rwse'], default='none')
    ap.add_argument('--layers', type=int, default=5)
    ap.add_argument('--hidden', type=int, default=128)
    ap.add_argument('--rni_dim', type=int, default=16)
    ap.add_argument('--epochs', type=int, default=200)
    ap.add_argument('--lr', type=float, default=1e-3)
    ap.add_argument('--bs', type=int, default=128)
    ap.add_argument('--seed', type=int, default=0)
    args = ap.parse_args()
    torch.manual_seed(args.seed); np.random.seed(args.seed)
    dev = 'cuda' if torch.cuda.is_available() else 'cpu'
    os.makedirs(OUT, exist_ok=True)

    tr, va, te = prepare('train'), prepare('val'), prepare('test')
    n_atom = int(max(r['x'].max() for r in tr + va + te)) + 1
    Ytr = torch.stack([r['y'] for r in tr])
    ymu, ysd = Ytr.mean(0), Ytr.std(0) + 1e-8
    for recs in (tr, va, te):
        for r in recs:
            r['y'] = (r['y'] - ymu) / ysd

    rni = args.rni_dim if args.feat == 'rni' else 0
    use_rwse = args.feat == 'rwse'
    samples = 8 if rni else 1
    model = Net(n_atom, args.hidden, args.layers, args.conv, rni, use_rwse).to(dev)
    opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs)
    lossf = nn.L1Loss()
    trl = to_loader(tr, args.bs, True, drop_last=True)
    trl_e, val, tel = to_loader(tr, 256, False), to_loader(va, 256, False), to_loader(te, 256, False)

    t0 = time.time(); best_val = 9e9; best = {}
    for ep in range(args.epochs):
        model.train()
        for b in trl:
            b = b.to(dev); opt.zero_grad()
            loss = lossf(model(b.x, b.edge_index, b.batch, b.rwse), b.y)
            loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
        sched.step()
        if (ep + 1) % 20 == 0 or ep == args.epochs - 1:
            vm = eval_mae(model, val, dev, ymu, ysd, samples)
            if sum(vm) < best_val:
                best_val = sum(vm)
                best = {'ep': ep + 1, 'train_mae': eval_mae(model, trl_e, dev, ymu, ysd, samples),
                        'val_mae': vm, 'test_mae': eval_mae(model, tel, dev, ymu, ysd, samples)}
            print(f"ep{ep+1} val_mae(c5,c6)={[round(x,3) for x in vm]}", flush=True)

    tag = f"{args.conv}_{args.feat}_L{args.layers}_s{args.seed}"
    rep = {'dataset': 'ZINC-cycle56', 'tag': tag, **vars(args),
           'y_std_raw': ysd.tolist(), 'sec': round(time.time() - t0, 1), 'dev': dev, **best}
    tm = best.get('test_mae'); trm = best.get('train_mae')
    print(f"[{tag}] train_mae(c5,c6)={[round(x,3) for x in trm]} test_mae={[round(x,3) for x in tm]} "
          f"(raw rings; std={ [round(x,2) for x in ysd.tolist()] }) @ep{best.get('ep')} ({rep['sec']}s)")
    with open(os.path.join(OUT, f"cyc_{tag}.json"), 'w') as f:
        json.dump(rep, f, indent=2)
    print("  wrote", os.path.join(OUT, f"cyc_{tag}.json"))


if __name__ == "__main__":
    main()