diff options
Diffstat (limited to 'rrog/train_ogb_graphprop.py')
| -rw-r--r-- | rrog/train_ogb_graphprop.py | 685 |
1 files changed, 685 insertions, 0 deletions
diff --git a/rrog/train_ogb_graphprop.py b/rrog/train_ogb_graphprop.py new file mode 100644 index 0000000..387ef3c --- /dev/null +++ b/rrog/train_ogb_graphprop.py @@ -0,0 +1,685 @@ +import argparse +from contextlib import contextmanager +import json +import os +import time + +import numpy as np +import torch +import torch.nn as nn +from ogb.graphproppred import Evaluator, PygGraphPropPredDataset +from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder +from torch_geometric.data import Batch +from torch_geometric.loader import DataLoader +from torch_geometric.nn import ( + APPNP, + ARMAConv, + ChebConv, + FiLMConv, + GATv2Conv, + GCNConv, + GENConv, + GINEConv, + GraphConv, + MFConv, + PNAConv, + ResGatedGraphConv, + SAGEConv, + SGConv, + TAGConv, + TransformerConv, + global_add_pool, +) +from torch_geometric.utils import degree + + +PROJECT_ROOT = os.environ.get( + "RROG_ROOT", + os.path.abspath(os.path.join(os.path.dirname(__file__), "..")), +) +DATA_ROOT = os.environ.get("RROG_DATA_DIR", os.path.join(PROJECT_ROOT, "data")) +ROOT = os.path.join(DATA_ROOT, "ogb") +OUT = os.environ.get("RROG_RUNS_DIR", os.path.join(PROJECT_ROOT, "runs")) +SUPPORTED_VIEWS = [ + "gin", "gine", "gcn", "graphsage", "gatv2", "graphconv", "transformer", "pna", + "gen", "film", "resgated", "tag", "sgc", "cheb", "arma", "mf", "appnp", +] +SUPPORTED_MOL_DATASETS = [ + "ogbg-molhiv", + "ogbg-molpcba", + "ogbg-molbbbp", + "ogbg-molbace", + "ogbg-moltox21", + "ogbg-molclintox", + "ogbg-molsider", + "ogbg-molesol", + "ogbg-molfreesolv", + "ogbg-mollipo", +] +HIGHER_BETTER = {"rocauc", "ap", "auc", "accuracy", "acc", "f1"} +LOWER_BETTER = {"rmse", "mae"} + +_TORCH_LOAD = torch.load + + +def _torch_load_ogb_compat(*args, **kwargs): + kwargs.setdefault("weights_only", False) + return _TORCH_LOAD(*args, **kwargs) + + +torch.load = _torch_load_ogb_compat + + +def clone_state_dict(model): + return {k: v.detach().clone() for k, v in model.state_dict().items()} + + +@torch.no_grad() +def update_ema_state(ema_state, model, decay): + if ema_state is None: + return + for key, value in model.state_dict().items(): + if torch.is_floating_point(value): + ema_state[key].mul_(decay).add_(value.detach(), alpha=1.0 - decay) + else: + ema_state[key].copy_(value.detach()) + + +@contextmanager +def using_ema_state(model, ema_state): + if ema_state is None: + yield + return + raw_state = clone_state_dict(model) + model.load_state_dict(ema_state, strict=True) + try: + yield + finally: + model.load_state_dict(raw_state, strict=True) + + +def metric_direction(metric: str) -> int: + metric = metric.lower() + if metric in LOWER_BETTER or "rmse" in metric or "mae" in metric: + return -1 + return 1 + + +def is_regression_metric(metric: str) -> bool: + return metric_direction(metric) < 0 + + +def is_better(score: float, best: float | None, metric: str) -> bool: + if best is None: + return True + direction = metric_direction(metric) + return score > best if direction > 0 else score < best + + +def jsonable(obj): + if isinstance(obj, dict): + return {str(k): jsonable(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [jsonable(v) for v in obj] + if isinstance(obj, (np.integer, np.floating)): + return obj.item() + if isinstance(obj, torch.Tensor): + if obj.ndim == 0: + return obj.detach().cpu().item() + return obj.detach().cpu().tolist() + return obj + + +def degree_histogram(dataset) -> torch.Tensor: + max_degree = 0 + degs = [] + for graph in dataset: + deg = degree(graph.edge_index[1], num_nodes=graph.num_nodes, dtype=torch.long) + degs.append(deg) + if deg.numel(): + max_degree = max(max_degree, int(deg.max().item())) + hist = torch.zeros(max_degree + 1, dtype=torch.long) + for deg in degs: + hist += torch.bincount(deg, minlength=hist.numel()) + return hist + + +def make_view_layer(view: str, hidden: int, deg: torch.Tensor | None): + if view in {"gin", "gine"}: + mlp = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)) + return GINEConv(mlp, train_eps=True) + if view == "gcn": + return GCNConv(hidden, hidden) + if view == "graphsage": + return SAGEConv(hidden, hidden) + if view == "gatv2": + return GATv2Conv(hidden, hidden, heads=4, concat=False, edge_dim=hidden) + if view == "graphconv": + return GraphConv(hidden, hidden) + if view == "transformer": + return TransformerConv(hidden, hidden, heads=4, concat=False, edge_dim=hidden) + if view == "pna": + if deg is None: + raise ValueError("PNA view requires a training-set degree histogram") + return PNAConv( + hidden, hidden, + aggregators=["mean", "min", "max", "std"], + scalers=["identity", "amplification", "attenuation"], + deg=deg, + edge_dim=hidden, + ) + if view == "gen": + return GENConv(hidden, hidden, edge_dim=hidden) + if view == "film": + return FiLMConv(hidden, hidden) + if view == "resgated": + return ResGatedGraphConv(hidden, hidden, edge_dim=hidden) + if view == "tag": + return TAGConv(hidden, hidden, K=3) + if view == "sgc": + return SGConv(hidden, hidden, K=2, cached=False) + if view == "cheb": + return ChebConv(hidden, hidden, K=3) + if view == "arma": + return ARMAConv(hidden, hidden, num_stacks=1, num_layers=2) + if view == "mf": + return MFConv(hidden, hidden) + if view == "appnp": + return APPNP(K=5, alpha=0.1) + raise ValueError(f"unsupported OGB view: {view}") + + +EDGE_ATTR_VIEWS = {"gin", "gine", "gatv2", "transformer", "pna", "gen", "resgated"} + + +class OGBRRoG(nn.Module): + def __init__( + self, hidden, num_tasks, view="gin", T=1, n_sup=3, agg_layers=5, + compute_layers=2, grad_mode="full", deg=None, + ): + super().__init__() + self.view = view + self.atom_encoder = AtomEncoder(hidden) + self.bond_encoder = BondEncoder(hidden) + self.agg_convs = nn.ModuleList() + self.agg_bns = nn.ModuleList() + for _ in range(agg_layers): + self.agg_convs.append(make_view_layer(view, hidden, deg)) + self.agg_bns.append(nn.BatchNorm1d(hidden)) + + core = [] + d = hidden + for _ in range(compute_layers - 1): + core += [nn.Linear(d, hidden), nn.GELU()] + d = hidden + core.append(nn.Linear(d, hidden)) + self.core_norm = nn.LayerNorm(hidden) + self.core = nn.Sequential(*core) + nn.init.zeros_(self.core[-1].weight) + nn.init.zeros_(self.core[-1].bias) + + self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, num_tasks)) + self.qhead = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 1)) + with torch.no_grad(): + self.qhead[-1].weight.zero_() + self.qhead[-1].bias.fill_(-5.0) + + self.T = T + self.n_sup = n_sup + self.grad_mode = grad_mode + self.agg_layers = agg_layers + self.compute_layers = compute_layers + self.hidden = hidden + self.num_tasks = num_tasks + + def aggregate(self, x, edge_index, edge_attr): + h = self.atom_encoder(x) + e = self.bond_encoder(edge_attr) + for conv, bn in zip(self.agg_convs, self.agg_bns): + if self.view in {"gin", "gine", "pna", "gen", "resgated"}: + h = bn(conv(h, edge_index, e)).relu() + elif self.view in {"gatv2", "transformer"}: + h = bn(conv(h, edge_index, edge_attr=e)).relu() + else: + h = bn(conv(h, edge_index)).relu() + return h + + def core_step(self, combined, state): + return state + self.core(self.core_norm(combined)) + + def _z_step(self, y, z, ctx): + return self.core_step(ctx + y + z, z) + + def _y_step(self, y, z): + return self.core_step(y + z, y) + + def recurse(self, y, z, ctx, one_step=False): + if self.T == 0: + return y, z + if one_step: + with torch.no_grad(): + for _ in range(self.T - 1): + z = self._z_step(y, z, ctx) + z = z.detach() + z = self._z_step(y, z, ctx) + y = self._y_step(y, z) + return y, z + for _ in range(self.T): + z = self._z_step(y, z, ctx) + y = self._y_step(y, z) + return y, z + + def predict(self, y, batch): + pooled = global_add_pool(y, batch) + return self.head(pooled), self.qhead(pooled).view(-1) + + def forward_trace(self, data, steps): + ctx = self.aggregate(data.x, data.edge_index, data.edge_attr) + y = ctx + z = torch.zeros_like(ctx) + preds, q_logits = [], [] + for s in range(steps): + y, z = self.recurse(y, z, ctx, one_step=(self.grad_mode == "1step")) + pred, q = self.predict(y, data.batch) + preds.append(pred) + q_logits.append(q) + if s < steps - 1: + y, z = y.detach(), z.detach() + return preds, q_logits + + def forward(self, data): + steps = self.n_sup + preds, q_logits = self.forward_trace(data, steps) + return preds, q_logits[-1] + + +def supervised_loss(logits, y, metric): + per_graph, has_label = per_graph_supervised_loss(logits, y, metric) + if not has_label.any(): + return logits.sum() * 0.0 + return per_graph[has_label].mean() + + +def per_graph_supervised_loss(logits, y, metric): + y = y.to(torch.float32) + mask = ~torch.isnan(y) + target = torch.where(mask, y, torch.zeros_like(y)) + if is_regression_metric(metric): + losses = (logits - target).pow(2) + else: + losses = nn.functional.binary_cross_entropy_with_logits(logits, target, reduction="none") + losses = torch.where(mask, losses, torch.zeros_like(losses)) + denom = mask.sum(dim=1).clamp_min(1) + return losses.sum(dim=1) / denom, mask.any(dim=1) + + +@torch.no_grad() +def halt_targets(logits, y): + y = y.to(torch.float32) + mask = ~torch.isnan(y) + pred = (logits > 0).to(y.dtype) + correct_or_missing = (~mask) | (pred == y) + has_label = mask.any(dim=1) + return (correct_or_missing.all(dim=1) & has_label).to(logits.dtype) + + +@torch.no_grad() +def evaluate(model, loader, evaluator, dev, steps=None, adaptive=False, halt_min_steps=1, max_batches=0): + model.eval() + ys, ps = [], [] + step_sum = 0.0 + n = 0 + for i, batch in enumerate(loader): + if max_batches and i >= max_batches: + break + batch = batch.to(dev) + if steps is None: + preds, _ = model(batch) + pred = preds[-1] + used_steps = float(model.n_sup) + else: + preds, qs = model.forward_trace(batch, steps) + stack = torch.stack(preds, dim=0) + if adaptive: + q = torch.stack(qs, dim=0) + step_ids = torch.arange(1, steps + 1, device=dev).view(-1, 1) + halted = (q > 0) & (step_ids >= halt_min_steps) + any_halt = halted.any(dim=0) + first_halt = halted.to(torch.int64).argmax(dim=0) + fallback = torch.full_like(first_halt, steps - 1) + idx = torch.where(any_halt, first_halt, fallback) + pred = stack[idx, torch.arange(stack.size(1), device=dev)] + used_steps = float((idx.to(torch.float32) + 1).sum().item()) + else: + pred = stack[-1] + used_steps = float(steps * batch.num_graphs) + ys.append(batch.y.detach().cpu()) + ps.append(pred.detach().cpu()) + step_sum += used_steps + n += batch.num_graphs + y_true = torch.cat(ys, dim=0) + y_pred = torch.cat(ps, dim=0) + result = evaluator.eval({"y_true": y_true, "y_pred": y_pred}) + return result, step_sum / max(n, 1) + + +def _split_nodes(t, ptr): + return [t[ptr[i].item():ptr[i + 1].item()].detach() for i in range(ptr.numel() - 1)] + + +def act_train_step(model, state, replacement_batch, opt, dev, args, metric): + replacement = replacement_batch.to_data_list() + batch_size = len(replacement) + if state is None: + state = { + "graphs": [None for _ in range(batch_size)], + "y": [None for _ in range(batch_size)], + "z": [None for _ in range(batch_size)], + "steps": torch.zeros(batch_size, dtype=torch.long, device=dev), + "halted": torch.ones(batch_size, dtype=torch.bool, device=dev), + } + + halted_cpu = state["halted"].detach().cpu().tolist() + for i, halted in enumerate(halted_cpu): + if halted: + state["graphs"][i] = replacement[i] + + batch = Batch.from_data_list(state["graphs"]).to(dev) + ctx = model.aggregate(batch.x, batch.edge_index, batch.edge_attr) + ptr = batch.ptr + y_parts, z_parts = [], [] + for i in range(batch_size): + start, end = ptr[i].item(), ptr[i + 1].item() + if halted_cpu[i] or state["y"][i] is None: + y_parts.append(ctx[start:end]) + z_parts.append(torch.zeros_like(ctx[start:end])) + else: + y_parts.append(state["y"][i].to(dev)) + z_parts.append(state["z"][i].to(dev)) + y = torch.cat(y_parts, dim=0) + z = torch.cat(z_parts, dim=0) + + opt.zero_grad() + y, z = model.recurse(y, z, ctx, one_step=(model.grad_mode == "1step")) + logits, q = model.predict(y, batch.batch) + pred_loss = supervised_loss(logits, batch.y, metric) + if args.halt_target == "exact" and not is_regression_metric(metric): + target = halt_targets(logits.detach(), batch.y) + elif args.halt_target == "loss": + per_graph_loss, has_label = per_graph_supervised_loss(logits.detach(), batch.y, metric) + target = ((per_graph_loss <= args.halt_loss_threshold) & has_label).to(logits.dtype) + else: + raise ValueError(args.halt_target) + q_loss = nn.functional.binary_cross_entropy_with_logits(q, target) + loss = pred_loss + 0.5 * args.lam_q * q_loss + y_det, z_det = y.detach(), z.detach() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + opt.step() + + state["y"] = _split_nodes(y_det, ptr) + state["z"] = _split_nodes(z_det, ptr) + with torch.no_grad(): + was_halted = state["halted"] + steps = torch.where(was_halted, torch.zeros_like(state["steps"]), state["steps"]) + 1 + halted = (steps >= args.halt_max_steps) | ((q.detach() > 0) & (steps >= args.halt_min_steps)) + if args.halt_exploration_prob > 0 and args.halt_max_steps > 1: + explore = torch.rand_like(q) < args.halt_exploration_prob + min_steps = torch.where( + explore, + torch.randint(2, args.halt_max_steps + 1, steps.shape, device=dev), + torch.zeros_like(steps), + ) + halted = halted & (steps >= min_steps) + state["steps"] = steps + state["halted"] = halted + + return state, { + "loss": float(loss.detach().cpu()), + "pred_loss": float(pred_loss.detach().cpu()), + "q_loss": float(q_loss.detach().cpu()), + "halted_frac": float(state["halted"].to(torch.float32).mean().detach().cpu()), + "steps": float(state["steps"].to(torch.float32).mean().detach().cpu()), + } + + +def _halt_target(args, logits, y, metric): + if args.halt_target == "exact" and not is_regression_metric(metric): + return halt_targets(logits.detach(), y) + if args.halt_target == "loss": + per_graph_loss, has_label = per_graph_supervised_loss(logits.detach(), y, metric) + return ((per_graph_loss <= args.halt_loss_threshold) & has_label).to(logits.dtype) + raise ValueError(args.halt_target) + + +def act_trace_train_step(model, batch, opt, dev, args, epoch, metric): + batch = batch.to(dev) + steps = max(1, args.halt_max_steps) + opt.zero_grad() + preds, qs = model.forward_trace(batch, steps) + pred_loss = sum(supervised_loss(pred, batch.y, metric) for pred in preds) / len(preds) + if epoch <= args.q_warmup_epochs: + q_loss = pred_loss.detach() * 0.0 + loss = pred_loss + else: + q_losses = [] + for step_idx, (pred, q) in enumerate(zip(preds, qs), start=1): + target = _halt_target(args, pred, batch.y, metric) + if step_idx < args.halt_min_steps: + target = torch.zeros_like(target) + q_losses.append(nn.functional.binary_cross_entropy_with_logits(q, target)) + q_loss = sum(q_losses) / len(q_losses) + loss = pred_loss + 0.5 * args.lam_q * q_loss + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + opt.step() + + with torch.no_grad(): + q_stack = torch.stack(qs, dim=0) + step_ids = torch.arange(1, steps + 1, device=dev).view(-1, 1) + halted = (q_stack > 0) & (step_ids >= args.halt_min_steps) + any_halt = halted.any(dim=0) + first_halt = halted.to(torch.int64).argmax(dim=0) + fallback = torch.full_like(first_halt, steps - 1) + idx = torch.where(any_halt, first_halt, fallback) + + return { + "loss": float(loss.detach().cpu()), + "pred_loss": float(pred_loss.detach().cpu()), + "q_loss": float(q_loss.detach().cpu()), + "halted_frac": float(any_halt.to(torch.float32).mean().detach().cpu()), + "steps": float((idx.to(torch.float32) + 1).mean().detach().cpu()), + } + + +def train_epoch(model, loader, opt, dev, args, act_state, ema_state, epoch, metric): + model.train() + metrics = [] + for i, batch in enumerate(loader): + if args.max_train_batches and i >= args.max_train_batches: + break + if args.compute == "rrog-act": + m = act_trace_train_step(model, batch, opt, dev, args, epoch, metric) + update_ema_state(ema_state, model, args.ema) + metrics.append(m) + continue + batch = batch.to(dev) + opt.zero_grad() + preds, _ = model(batch) + loss = sum(supervised_loss(pred, batch.y, metric) for pred in preds) / len(preds) + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + opt.step() + update_ema_state(ema_state, model, args.ema) + metrics.append({"loss": float(loss.detach().cpu()), "halted_frac": 0.0, "steps": float(model.n_sup)}) + return act_state, metrics + + +def metric_value(result, evaluator): + return float(result[evaluator.eval_metric]) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--dataset", choices=SUPPORTED_MOL_DATASETS, default="ogbg-molhiv") + ap.add_argument("--view", choices=SUPPORTED_VIEWS, default="gin") + ap.add_argument("--compute", choices=["classic", "view-only", "fixed-rrog", "rrog-act"], default="rrog-act") + ap.add_argument("--T", type=int, default=1) + ap.add_argument("--n_sup", type=int, default=3) + ap.add_argument("--hidden", type=int, default=128) + ap.add_argument("--agg_layers", type=int, default=5) + ap.add_argument("--compute_layers", type=int, default=2) + ap.add_argument("--epochs", type=int, default=100) + ap.add_argument("--eval_every", type=int, default=10) + ap.add_argument("--lr", type=float, default=1e-3) + ap.add_argument("--bs", type=int, default=128) + ap.add_argument("--lam_q", type=float, default=1.0) + ap.add_argument("--halt_max_steps", type=int, default=8) + ap.add_argument("--halt_min_steps", type=int, default=1) + ap.add_argument("--halt_target", choices=["exact", "loss"], default="loss") + ap.add_argument("--halt_loss_threshold", type=float, default=0.2) + ap.add_argument("--halt_exploration_prob", type=float, default=0.1) + ap.add_argument("--q_warmup_epochs", type=int, default=0) + ap.add_argument("--ema", type=float, default=0.0) + ap.add_argument("--seed", type=int, default=0) + ap.add_argument("--num_workers", type=int, default=0) + ap.add_argument("--max_train_batches", type=int, default=0) + ap.add_argument("--max_eval_batches", type=int, default=0) + ap.add_argument("--device", default="auto") + args = ap.parse_args() + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + if args.compute == "classic": + args.n_sup = 1 + if args.device == "auto": + dev = "cuda" if torch.cuda.is_available() else "cpu" + else: + dev = args.device + os.makedirs(OUT, exist_ok=True) + + dataset = PygGraphPropPredDataset(name=args.dataset, root=ROOT) + split_idx = dataset.get_idx_split() + evaluator = Evaluator(args.dataset) + metric = evaluator.eval_metric + num_tasks = dataset.num_tasks + + train_dataset = dataset[split_idx["train"]] + valid_dataset = dataset[split_idx["valid"]] + test_dataset = dataset[split_idx["test"]] + train_loader = DataLoader(train_dataset, batch_size=args.bs, shuffle=True, + drop_last=True, num_workers=args.num_workers) + valid_loader = DataLoader(valid_dataset, batch_size=256, shuffle=False, + num_workers=args.num_workers) + test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, + num_workers=args.num_workers) + + T = 0 if args.compute in ["classic", "view-only"] else args.T + deg = degree_histogram(train_dataset) if args.view == "pna" else None + model = OGBRRoG(args.hidden, num_tasks, view=args.view, T=T, n_sup=args.n_sup, + agg_layers=args.agg_layers, compute_layers=args.compute_layers, deg=deg).to(dev) + opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5) + sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, max(args.epochs, 1)) + ema_state = clone_state_dict(model) if args.ema > 0 else None + + t0 = time.time() + best_val = None + best = {} + best_state = None + act_state = None + steps = max(1, args.halt_max_steps) + + for ep in range(args.epochs): + act_state, train_metrics = train_epoch( + model, train_loader, opt, dev, args, act_state, ema_state, ep + 1, metric) + sched.step() + if (ep + 1) % args.eval_every == 0 or ep == args.epochs - 1: + with using_ema_state(model, ema_state): + if args.compute == "rrog-act": + val_result, fixed_val_steps = evaluate(model, valid_loader, evaluator, dev, steps=steps, + adaptive=False, halt_min_steps=args.halt_min_steps, + max_batches=args.max_eval_batches) + val_adapt, adaptive_val_steps = evaluate(model, valid_loader, evaluator, dev, steps=steps, + adaptive=True, halt_min_steps=args.halt_min_steps, + max_batches=args.max_eval_batches) + else: + val_result, _ = evaluate(model, valid_loader, evaluator, dev, + max_batches=args.max_eval_batches) + val_score = metric_value(val_result, evaluator) + if is_better(val_score, best_val, metric): + best_val = val_score + if args.compute == "rrog-act": + test_fixed, fixed_steps = evaluate(model, test_loader, evaluator, dev, steps=steps, + adaptive=False, halt_min_steps=args.halt_min_steps, + max_batches=args.max_eval_batches) + test_adapt, adaptive_steps = evaluate(model, test_loader, evaluator, dev, steps=steps, + adaptive=True, halt_min_steps=args.halt_min_steps, + max_batches=args.max_eval_batches) + best = { + "ep": ep + 1, + "val": val_result, + "val_adaptive": val_adapt, + "test": test_fixed, + "test_adaptive": test_adapt, + "fixed_val_steps": fixed_val_steps, + "adaptive_val_steps": adaptive_val_steps, + "fixed_steps": fixed_steps, + "adaptive_steps": adaptive_steps, + } + else: + test_result, _ = evaluate(model, test_loader, evaluator, dev, + max_batches=args.max_eval_batches) + best = {"ep": ep + 1, "val": val_result, "test": test_result} + best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} + halted = sum(m.get("halted_frac", 0.0) for m in train_metrics) / max(len(train_metrics), 1) + train_steps = sum(m.get("steps", 0.0) for m in train_metrics) / max(len(train_metrics), 1) + msg = f"ep{ep+1} val_{evaluator.eval_metric}={val_score:.5f}" + if args.compute == "rrog-act": + msg += ( + f" val_adapt_{evaluator.eval_metric}={metric_value(val_adapt, evaluator):.5f}" + f" adapt_steps={adaptive_val_steps:.2f}" + ) + msg += f" halt={halted:.2f} train_steps={train_steps:.2f}" + print(msg, flush=True) + + ema_tag = f"_ema{args.ema:g}" if args.ema > 0 else "" + tag = ( + f"{args.dataset}_{args.view}_{args.compute}_T{T}_ns{args.n_sup}_" + f"h{args.hidden}_e{args.epochs}{ema_tag}_s{args.seed}" + ) + rep = { + "dataset": args.dataset, + "tag": tag, + **vars(args), + "metric": evaluator.eval_metric, + "sec": round(time.time() - t0, 1), + "dev": dev, + **best, + } + with open(os.path.join(OUT, f"{tag}.json"), "w") as f: + json.dump(jsonable(rep), f, indent=2) + torch.save({ + "state": best_state or model.state_dict(), + "cfg": { + "dataset": args.dataset, + "hidden": args.hidden, + "num_tasks": num_tasks, + "T": T, + "n_sup": args.n_sup, + "agg_layers": args.agg_layers, + "compute_layers": args.compute_layers, + "compute": args.compute, + "halt_max_steps": steps, + "halt_min_steps": args.halt_min_steps, + "halt_target": args.halt_target, + "halt_loss_threshold": args.halt_loss_threshold, + "view": args.view, + }, + }, os.path.join(OUT, f"ckpt_{tag}.pt")) + print(f"[{tag}] best_ep={best.get('ep')} val={best.get('val')} test={best.get('test')} " + f"adaptive={best.get('test_adaptive')} steps={best.get('adaptive_steps')}", flush=True) + print(" wrote", os.path.join(OUT, f"{tag}.json")) + + +if __name__ == "__main__": + main() |
