summaryrefslogtreecommitdiff
path: root/diag/train_rec.py
diff options
context:
space:
mode:
Diffstat (limited to 'diag/train_rec.py')
-rw-r--r--diag/train_rec.py174
1 files changed, 174 insertions, 0 deletions
diff --git a/diag/train_rec.py b/diag/train_rec.py
new file mode 100644
index 0000000..9866f28
--- /dev/null
+++ b/diag/train_rec.py
@@ -0,0 +1,174 @@
+"""Step-2: Recursive (TRM-ish) GNN on ZINC ring-counting + optional PTRM noise/selection.
+
+Recurrent shared-weight GIN block, deep-supervised over n_sup steps (TRM-style: carry latent
+detached between steps). --grad_mode controls the LAST supervision step's recursion:
+ full : backprop through all T inner recursions (TRM)
+ 1step : backprop only the last inner recursion, first T-1 detached (HRM 1-step-gradient)
+Optional per-step Gaussian noise (sigma) + K stochastic rollouts selected by a value head
+(best-Q@K) for the PTRM experiments. Saves a checkpoint for the LE diagnostic (diag/lyap.py).
+
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_rec.py --grad_mode full --sigma 0 --K 1
+"""
+import argparse, json, os, time
+import numpy as np
+import torch
+import torch.nn as nn
+from torch_geometric.loader import DataLoader
+from torch_geometric.data import Data
+from torch_geometric.nn import GINConv, global_add_pool
+from diag.train_cycle import prepare
+
+OUT = '/home/yurenh2/rrog/runs'
+
+
+def loader(recs, bs, shuffle, drop_last=False):
+ data = [Data(x=r['x'], edge_index=r['edge_index'], y=r['y'].view(1, 2),
+ num_nodes=r['x'].numel()) for r in recs]
+ return DataLoader(data, batch_size=bs, shuffle=shuffle, drop_last=drop_last)
+
+
+class RecGIN(nn.Module):
+ def __init__(self, n_atom, hidden=128, T=3, n_sup=3, sigma=0.0, inner=2, grad_mode='full'):
+ super().__init__()
+ self.emb = nn.Embedding(n_atom, hidden)
+ self.convs = nn.ModuleList([GINConv(nn.Sequential(
+ nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True)
+ for _ in range(inner)])
+ self.bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(inner)])
+ self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 2))
+ self.qhead = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 1))
+ self.T, self.n_sup, self.sigma, self.grad_mode = T, n_sup, sigma, grad_mode
+
+ def block(self, z, ei):
+ for conv, bn in zip(self.convs, self.bns):
+ z = bn(conv(z, ei)).relu()
+ return z
+
+ def _inner(self, z, h0, ei, noise):
+ z = self.block(z + h0, ei)
+ if noise and self.sigma > 0:
+ z = z + self.sigma * torch.randn_like(z)
+ return z
+
+ def recurse(self, z, h0, ei, noise, one_step=False):
+ if one_step: # HRM 1-step gradient
+ with torch.no_grad():
+ for _ in range(self.T - 1):
+ z = self._inner(z, h0, ei, noise)
+ z = z.detach()
+ return self._inner(z, h0, ei, noise) # only last inner carries grad
+ for _ in range(self.T): # TRM full recursion
+ z = self._inner(z, h0, ei, noise)
+ return z
+
+ def forward(self, x, ei, batch, noise=False):
+ h0 = self.emb(x)
+ z = torch.zeros_like(h0)
+ preds = []
+ for s in range(self.n_sup):
+ if s < self.n_sup - 1:
+ with torch.no_grad():
+ z = self.recurse(z, h0, ei, noise)
+ z = z.detach()
+ else:
+ z = self.recurse(z, h0, ei, noise, one_step=(self.grad_mode == '1step'))
+ preds.append(self.head(global_add_pool(z, batch)))
+ q = self.qhead(global_add_pool(z, batch)).view(-1)
+ return preds, q
+
+
+@torch.no_grad()
+def evaluate(model, ld, dev, ymu, ysd, K=1, select='none'):
+ model.eval()
+ ysd_d, ymu_d = ysd.to(dev), ymu.to(dev)
+ ae = torch.zeros(2); ae_or = torch.zeros(2); n = 0
+ for b in ld:
+ b = b.to(dev)
+ if K == 1:
+ preds, _ = model(b.x, b.edge_index, b.batch, noise=model.sigma > 0)
+ chosen = oracle = preds[-1]
+ else:
+ P, Q = [], []
+ for _ in range(K):
+ preds, q = model(b.x, b.edge_index, b.batch, noise=True)
+ P.append(preds[-1]); Q.append(q)
+ P = torch.stack(P); Q = torch.stack(Q)
+ ar = torch.arange(P.size(1), device=dev)
+ chosen = P[Q.argmax(0), ar] if select == 'bestq' else P.mean(0)
+ oracle = P[(P - b.y.unsqueeze(0)).abs().sum(-1).argmin(0), ar]
+ ae += ((chosen * ysd_d + ymu_d) - (b.y * ysd_d + ymu_d)).abs().sum(0).cpu()
+ ae_or += ((oracle * ysd_d + ymu_d) - (b.y * ysd_d + ymu_d)).abs().sum(0).cpu()
+ n += b.num_graphs
+ return (ae / n).tolist(), (ae_or / n).tolist()
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--grad_mode', choices=['full', '1step'], default='full')
+ ap.add_argument('--sigma', type=float, default=0.0)
+ ap.add_argument('--K', type=int, default=1)
+ ap.add_argument('--select', choices=['none', 'bestq'], default='bestq')
+ ap.add_argument('--T', type=int, default=3)
+ ap.add_argument('--n_sup', type=int, default=3)
+ ap.add_argument('--hidden', type=int, default=128)
+ ap.add_argument('--epochs', type=int, default=200)
+ ap.add_argument('--lr', type=float, default=1e-3)
+ ap.add_argument('--bs', type=int, default=128)
+ ap.add_argument('--lam_q', type=float, default=1.0)
+ ap.add_argument('--seed', type=int, default=0)
+ args = ap.parse_args()
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ os.makedirs(OUT, exist_ok=True)
+
+ tr, va, te = prepare('train'), prepare('val'), prepare('test')
+ n_atom = int(max(r['x'].max() for r in tr + va + te)) + 1
+ Ytr = torch.stack([r['y'] for r in tr]); ymu, ysd = Ytr.mean(0), Ytr.std(0) + 1e-8
+ for recs in (tr, va, te):
+ for r in recs:
+ r['y'] = (r['y'] - ymu) / ysd
+ trl = loader(tr, args.bs, True, drop_last=True)
+ val, tel = loader(va, 256, False), loader(te, 256, False)
+
+ model = RecGIN(n_atom, args.hidden, args.T, args.n_sup, args.sigma, grad_mode=args.grad_mode).to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs)
+ l1 = nn.L1Loss()
+
+ t0 = time.time(); best_val = 9e9; best = {}; best_state = None
+ for ep in range(args.epochs):
+ model.train()
+ for b in trl:
+ b = b.to(dev); opt.zero_grad()
+ preds, q = model(b.x, b.edge_index, b.batch, noise=model.sigma > 0)
+ loss = sum(l1(p, b.y) for p in preds) / len(preds)
+ with torch.no_grad():
+ tq = -(preds[-1] - b.y).abs().mean(1)
+ loss = loss + args.lam_q * nn.functional.mse_loss(q, tq)
+ loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
+ sched.step()
+ if (ep + 1) % 20 == 0 or ep == args.epochs - 1:
+ vm, _ = evaluate(model, val, dev, ymu, ysd, args.K, args.select)
+ if sum(vm) < best_val:
+ best_val = sum(vm)
+ tem, teo = evaluate(model, tel, dev, ymu, ysd, args.K, args.select)
+ best = {'ep': ep + 1, 'val_mae': vm, 'test_mae': tem, 'test_mae_oracle': teo}
+ best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
+ print(f"ep{ep+1} val_mae={[round(x,3) for x in vm]}", flush=True)
+
+ tag = f"rec_{args.grad_mode}_sig{args.sigma}_K{args.K}_{args.select}_T{args.T}_ns{args.n_sup}_s{args.seed}"
+ rep = {'dataset': 'ZINC-cycle56', 'tag': tag, **vars(args), 'sec': round(time.time() - t0, 1),
+ 'dev': dev, 'y_std_raw': ysd.tolist(), **best}
+ print(f"[{tag}] test_mae={[round(x,3) for x in best.get('test_mae')]} "
+ f"oracle@K={[round(x,3) for x in best.get('test_mae_oracle')]} @ep{best.get('ep')} ({rep['sec']}s)")
+ with open(os.path.join(OUT, f"{tag}.json"), 'w') as f:
+ json.dump(rep, f, indent=2)
+ torch.save({'state': best_state or model.state_dict(),
+ 'cfg': {'n_atom': n_atom, 'hidden': args.hidden, 'T': args.T, 'n_sup': args.n_sup,
+ 'sigma': args.sigma, 'grad_mode': args.grad_mode},
+ 'ymu': ymu, 'ysd': ysd}, os.path.join(OUT, f"ckpt_{tag}.pt"))
+ print(" wrote", os.path.join(OUT, f"ckpt_{tag}.pt"))
+
+
+if __name__ == "__main__":
+ main()