diff options
Diffstat (limited to 'diag/train_rec.py')
| -rw-r--r-- | diag/train_rec.py | 174 |
1 files changed, 174 insertions, 0 deletions
diff --git a/diag/train_rec.py b/diag/train_rec.py new file mode 100644 index 0000000..9866f28 --- /dev/null +++ b/diag/train_rec.py @@ -0,0 +1,174 @@ +"""Step-2: Recursive (TRM-ish) GNN on ZINC ring-counting + optional PTRM noise/selection. + +Recurrent shared-weight GIN block, deep-supervised over n_sup steps (TRM-style: carry latent +detached between steps). --grad_mode controls the LAST supervision step's recursion: + full : backprop through all T inner recursions (TRM) + 1step : backprop only the last inner recursion, first T-1 detached (HRM 1-step-gradient) +Optional per-step Gaussian noise (sigma) + K stochastic rollouts selected by a value head +(best-Q@K) for the PTRM experiments. Saves a checkpoint for the LE diagnostic (diag/lyap.py). + +Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/train_rec.py --grad_mode full --sigma 0 --K 1 +""" +import argparse, json, os, time +import numpy as np +import torch +import torch.nn as nn +from torch_geometric.loader import DataLoader +from torch_geometric.data import Data +from torch_geometric.nn import GINConv, global_add_pool +from diag.train_cycle import prepare + +OUT = '/home/yurenh2/rrog/runs' + + +def loader(recs, bs, shuffle, drop_last=False): + data = [Data(x=r['x'], edge_index=r['edge_index'], y=r['y'].view(1, 2), + num_nodes=r['x'].numel()) for r in recs] + return DataLoader(data, batch_size=bs, shuffle=shuffle, drop_last=drop_last) + + +class RecGIN(nn.Module): + def __init__(self, n_atom, hidden=128, T=3, n_sup=3, sigma=0.0, inner=2, grad_mode='full'): + super().__init__() + self.emb = nn.Embedding(n_atom, hidden) + self.convs = nn.ModuleList([GINConv(nn.Sequential( + nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)), train_eps=True) + for _ in range(inner)]) + self.bns = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(inner)]) + self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 2)) + self.qhead = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 1)) + self.T, self.n_sup, self.sigma, self.grad_mode = T, n_sup, sigma, grad_mode + + def block(self, z, ei): + for conv, bn in zip(self.convs, self.bns): + z = bn(conv(z, ei)).relu() + return z + + def _inner(self, z, h0, ei, noise): + z = self.block(z + h0, ei) + if noise and self.sigma > 0: + z = z + self.sigma * torch.randn_like(z) + return z + + def recurse(self, z, h0, ei, noise, one_step=False): + if one_step: # HRM 1-step gradient + with torch.no_grad(): + for _ in range(self.T - 1): + z = self._inner(z, h0, ei, noise) + z = z.detach() + return self._inner(z, h0, ei, noise) # only last inner carries grad + for _ in range(self.T): # TRM full recursion + z = self._inner(z, h0, ei, noise) + return z + + def forward(self, x, ei, batch, noise=False): + h0 = self.emb(x) + z = torch.zeros_like(h0) + preds = [] + for s in range(self.n_sup): + if s < self.n_sup - 1: + with torch.no_grad(): + z = self.recurse(z, h0, ei, noise) + z = z.detach() + else: + z = self.recurse(z, h0, ei, noise, one_step=(self.grad_mode == '1step')) + preds.append(self.head(global_add_pool(z, batch))) + q = self.qhead(global_add_pool(z, batch)).view(-1) + return preds, q + + +@torch.no_grad() +def evaluate(model, ld, dev, ymu, ysd, K=1, select='none'): + model.eval() + ysd_d, ymu_d = ysd.to(dev), ymu.to(dev) + ae = torch.zeros(2); ae_or = torch.zeros(2); n = 0 + for b in ld: + b = b.to(dev) + if K == 1: + preds, _ = model(b.x, b.edge_index, b.batch, noise=model.sigma > 0) + chosen = oracle = preds[-1] + else: + P, Q = [], [] + for _ in range(K): + preds, q = model(b.x, b.edge_index, b.batch, noise=True) + P.append(preds[-1]); Q.append(q) + P = torch.stack(P); Q = torch.stack(Q) + ar = torch.arange(P.size(1), device=dev) + chosen = P[Q.argmax(0), ar] if select == 'bestq' else P.mean(0) + oracle = P[(P - b.y.unsqueeze(0)).abs().sum(-1).argmin(0), ar] + ae += ((chosen * ysd_d + ymu_d) - (b.y * ysd_d + ymu_d)).abs().sum(0).cpu() + ae_or += ((oracle * ysd_d + ymu_d) - (b.y * ysd_d + ymu_d)).abs().sum(0).cpu() + n += b.num_graphs + return (ae / n).tolist(), (ae_or / n).tolist() + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--grad_mode', choices=['full', '1step'], default='full') + ap.add_argument('--sigma', type=float, default=0.0) + ap.add_argument('--K', type=int, default=1) + ap.add_argument('--select', choices=['none', 'bestq'], default='bestq') + ap.add_argument('--T', type=int, default=3) + ap.add_argument('--n_sup', type=int, default=3) + ap.add_argument('--hidden', type=int, default=128) + ap.add_argument('--epochs', type=int, default=200) + ap.add_argument('--lr', type=float, default=1e-3) + ap.add_argument('--bs', type=int, default=128) + ap.add_argument('--lam_q', type=float, default=1.0) + ap.add_argument('--seed', type=int, default=0) + args = ap.parse_args() + torch.manual_seed(args.seed); np.random.seed(args.seed) + dev = 'cuda' if torch.cuda.is_available() else 'cpu' + os.makedirs(OUT, exist_ok=True) + + tr, va, te = prepare('train'), prepare('val'), prepare('test') + n_atom = int(max(r['x'].max() for r in tr + va + te)) + 1 + Ytr = torch.stack([r['y'] for r in tr]); ymu, ysd = Ytr.mean(0), Ytr.std(0) + 1e-8 + for recs in (tr, va, te): + for r in recs: + r['y'] = (r['y'] - ymu) / ysd + trl = loader(tr, args.bs, True, drop_last=True) + val, tel = loader(va, 256, False), loader(te, 256, False) + + model = RecGIN(n_atom, args.hidden, args.T, args.n_sup, args.sigma, grad_mode=args.grad_mode).to(dev) + opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5) + sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.epochs) + l1 = nn.L1Loss() + + t0 = time.time(); best_val = 9e9; best = {}; best_state = None + for ep in range(args.epochs): + model.train() + for b in trl: + b = b.to(dev); opt.zero_grad() + preds, q = model(b.x, b.edge_index, b.batch, noise=model.sigma > 0) + loss = sum(l1(p, b.y) for p in preds) / len(preds) + with torch.no_grad(): + tq = -(preds[-1] - b.y).abs().mean(1) + loss = loss + args.lam_q * nn.functional.mse_loss(q, tq) + loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step() + sched.step() + if (ep + 1) % 20 == 0 or ep == args.epochs - 1: + vm, _ = evaluate(model, val, dev, ymu, ysd, args.K, args.select) + if sum(vm) < best_val: + best_val = sum(vm) + tem, teo = evaluate(model, tel, dev, ymu, ysd, args.K, args.select) + best = {'ep': ep + 1, 'val_mae': vm, 'test_mae': tem, 'test_mae_oracle': teo} + best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} + print(f"ep{ep+1} val_mae={[round(x,3) for x in vm]}", flush=True) + + tag = f"rec_{args.grad_mode}_sig{args.sigma}_K{args.K}_{args.select}_T{args.T}_ns{args.n_sup}_s{args.seed}" + rep = {'dataset': 'ZINC-cycle56', 'tag': tag, **vars(args), 'sec': round(time.time() - t0, 1), + 'dev': dev, 'y_std_raw': ysd.tolist(), **best} + print(f"[{tag}] test_mae={[round(x,3) for x in best.get('test_mae')]} " + f"oracle@K={[round(x,3) for x in best.get('test_mae_oracle')]} @ep{best.get('ep')} ({rep['sec']}s)") + with open(os.path.join(OUT, f"{tag}.json"), 'w') as f: + json.dump(rep, f, indent=2) + torch.save({'state': best_state or model.state_dict(), + 'cfg': {'n_atom': n_atom, 'hidden': args.hidden, 'T': args.T, 'n_sup': args.n_sup, + 'sigma': args.sigma, 'grad_mode': args.grad_mode}, + 'ymu': ymu, 'ysd': ysd}, os.path.join(OUT, f"ckpt_{tag}.pt")) + print(" wrote", os.path.join(OUT, f"ckpt_{tag}.pt")) + + +if __name__ == "__main__": + main() |
