summaryrefslogtreecommitdiff
path: root/rrog/train_ogb_graphprop.py
diff options
context:
space:
mode:
Diffstat (limited to 'rrog/train_ogb_graphprop.py')
-rw-r--r--rrog/train_ogb_graphprop.py685
1 files changed, 685 insertions, 0 deletions
diff --git a/rrog/train_ogb_graphprop.py b/rrog/train_ogb_graphprop.py
new file mode 100644
index 0000000..387ef3c
--- /dev/null
+++ b/rrog/train_ogb_graphprop.py
@@ -0,0 +1,685 @@
+import argparse
+from contextlib import contextmanager
+import json
+import os
+import time
+
+import numpy as np
+import torch
+import torch.nn as nn
+from ogb.graphproppred import Evaluator, PygGraphPropPredDataset
+from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
+from torch_geometric.data import Batch
+from torch_geometric.loader import DataLoader
+from torch_geometric.nn import (
+ APPNP,
+ ARMAConv,
+ ChebConv,
+ FiLMConv,
+ GATv2Conv,
+ GCNConv,
+ GENConv,
+ GINEConv,
+ GraphConv,
+ MFConv,
+ PNAConv,
+ ResGatedGraphConv,
+ SAGEConv,
+ SGConv,
+ TAGConv,
+ TransformerConv,
+ global_add_pool,
+)
+from torch_geometric.utils import degree
+
+
+PROJECT_ROOT = os.environ.get(
+ "RROG_ROOT",
+ os.path.abspath(os.path.join(os.path.dirname(__file__), "..")),
+)
+DATA_ROOT = os.environ.get("RROG_DATA_DIR", os.path.join(PROJECT_ROOT, "data"))
+ROOT = os.path.join(DATA_ROOT, "ogb")
+OUT = os.environ.get("RROG_RUNS_DIR", os.path.join(PROJECT_ROOT, "runs"))
+SUPPORTED_VIEWS = [
+ "gin", "gine", "gcn", "graphsage", "gatv2", "graphconv", "transformer", "pna",
+ "gen", "film", "resgated", "tag", "sgc", "cheb", "arma", "mf", "appnp",
+]
+SUPPORTED_MOL_DATASETS = [
+ "ogbg-molhiv",
+ "ogbg-molpcba",
+ "ogbg-molbbbp",
+ "ogbg-molbace",
+ "ogbg-moltox21",
+ "ogbg-molclintox",
+ "ogbg-molsider",
+ "ogbg-molesol",
+ "ogbg-molfreesolv",
+ "ogbg-mollipo",
+]
+HIGHER_BETTER = {"rocauc", "ap", "auc", "accuracy", "acc", "f1"}
+LOWER_BETTER = {"rmse", "mae"}
+
+_TORCH_LOAD = torch.load
+
+
+def _torch_load_ogb_compat(*args, **kwargs):
+ kwargs.setdefault("weights_only", False)
+ return _TORCH_LOAD(*args, **kwargs)
+
+
+torch.load = _torch_load_ogb_compat
+
+
+def clone_state_dict(model):
+ return {k: v.detach().clone() for k, v in model.state_dict().items()}
+
+
+@torch.no_grad()
+def update_ema_state(ema_state, model, decay):
+ if ema_state is None:
+ return
+ for key, value in model.state_dict().items():
+ if torch.is_floating_point(value):
+ ema_state[key].mul_(decay).add_(value.detach(), alpha=1.0 - decay)
+ else:
+ ema_state[key].copy_(value.detach())
+
+
+@contextmanager
+def using_ema_state(model, ema_state):
+ if ema_state is None:
+ yield
+ return
+ raw_state = clone_state_dict(model)
+ model.load_state_dict(ema_state, strict=True)
+ try:
+ yield
+ finally:
+ model.load_state_dict(raw_state, strict=True)
+
+
+def metric_direction(metric: str) -> int:
+ metric = metric.lower()
+ if metric in LOWER_BETTER or "rmse" in metric or "mae" in metric:
+ return -1
+ return 1
+
+
+def is_regression_metric(metric: str) -> bool:
+ return metric_direction(metric) < 0
+
+
+def is_better(score: float, best: float | None, metric: str) -> bool:
+ if best is None:
+ return True
+ direction = metric_direction(metric)
+ return score > best if direction > 0 else score < best
+
+
+def jsonable(obj):
+ if isinstance(obj, dict):
+ return {str(k): jsonable(v) for k, v in obj.items()}
+ if isinstance(obj, (list, tuple)):
+ return [jsonable(v) for v in obj]
+ if isinstance(obj, (np.integer, np.floating)):
+ return obj.item()
+ if isinstance(obj, torch.Tensor):
+ if obj.ndim == 0:
+ return obj.detach().cpu().item()
+ return obj.detach().cpu().tolist()
+ return obj
+
+
+def degree_histogram(dataset) -> torch.Tensor:
+ max_degree = 0
+ degs = []
+ for graph in dataset:
+ deg = degree(graph.edge_index[1], num_nodes=graph.num_nodes, dtype=torch.long)
+ degs.append(deg)
+ if deg.numel():
+ max_degree = max(max_degree, int(deg.max().item()))
+ hist = torch.zeros(max_degree + 1, dtype=torch.long)
+ for deg in degs:
+ hist += torch.bincount(deg, minlength=hist.numel())
+ return hist
+
+
+def make_view_layer(view: str, hidden: int, deg: torch.Tensor | None):
+ if view in {"gin", "gine"}:
+ mlp = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden))
+ return GINEConv(mlp, train_eps=True)
+ if view == "gcn":
+ return GCNConv(hidden, hidden)
+ if view == "graphsage":
+ return SAGEConv(hidden, hidden)
+ if view == "gatv2":
+ return GATv2Conv(hidden, hidden, heads=4, concat=False, edge_dim=hidden)
+ if view == "graphconv":
+ return GraphConv(hidden, hidden)
+ if view == "transformer":
+ return TransformerConv(hidden, hidden, heads=4, concat=False, edge_dim=hidden)
+ if view == "pna":
+ if deg is None:
+ raise ValueError("PNA view requires a training-set degree histogram")
+ return PNAConv(
+ hidden, hidden,
+ aggregators=["mean", "min", "max", "std"],
+ scalers=["identity", "amplification", "attenuation"],
+ deg=deg,
+ edge_dim=hidden,
+ )
+ if view == "gen":
+ return GENConv(hidden, hidden, edge_dim=hidden)
+ if view == "film":
+ return FiLMConv(hidden, hidden)
+ if view == "resgated":
+ return ResGatedGraphConv(hidden, hidden, edge_dim=hidden)
+ if view == "tag":
+ return TAGConv(hidden, hidden, K=3)
+ if view == "sgc":
+ return SGConv(hidden, hidden, K=2, cached=False)
+ if view == "cheb":
+ return ChebConv(hidden, hidden, K=3)
+ if view == "arma":
+ return ARMAConv(hidden, hidden, num_stacks=1, num_layers=2)
+ if view == "mf":
+ return MFConv(hidden, hidden)
+ if view == "appnp":
+ return APPNP(K=5, alpha=0.1)
+ raise ValueError(f"unsupported OGB view: {view}")
+
+
+EDGE_ATTR_VIEWS = {"gin", "gine", "gatv2", "transformer", "pna", "gen", "resgated"}
+
+
+class OGBRRoG(nn.Module):
+ def __init__(
+ self, hidden, num_tasks, view="gin", T=1, n_sup=3, agg_layers=5,
+ compute_layers=2, grad_mode="full", deg=None,
+ ):
+ super().__init__()
+ self.view = view
+ self.atom_encoder = AtomEncoder(hidden)
+ self.bond_encoder = BondEncoder(hidden)
+ self.agg_convs = nn.ModuleList()
+ self.agg_bns = nn.ModuleList()
+ for _ in range(agg_layers):
+ self.agg_convs.append(make_view_layer(view, hidden, deg))
+ self.agg_bns.append(nn.BatchNorm1d(hidden))
+
+ core = []
+ d = hidden
+ for _ in range(compute_layers - 1):
+ core += [nn.Linear(d, hidden), nn.GELU()]
+ d = hidden
+ core.append(nn.Linear(d, hidden))
+ self.core_norm = nn.LayerNorm(hidden)
+ self.core = nn.Sequential(*core)
+ nn.init.zeros_(self.core[-1].weight)
+ nn.init.zeros_(self.core[-1].bias)
+
+ self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, num_tasks))
+ self.qhead = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 1))
+ with torch.no_grad():
+ self.qhead[-1].weight.zero_()
+ self.qhead[-1].bias.fill_(-5.0)
+
+ self.T = T
+ self.n_sup = n_sup
+ self.grad_mode = grad_mode
+ self.agg_layers = agg_layers
+ self.compute_layers = compute_layers
+ self.hidden = hidden
+ self.num_tasks = num_tasks
+
+ def aggregate(self, x, edge_index, edge_attr):
+ h = self.atom_encoder(x)
+ e = self.bond_encoder(edge_attr)
+ for conv, bn in zip(self.agg_convs, self.agg_bns):
+ if self.view in {"gin", "gine", "pna", "gen", "resgated"}:
+ h = bn(conv(h, edge_index, e)).relu()
+ elif self.view in {"gatv2", "transformer"}:
+ h = bn(conv(h, edge_index, edge_attr=e)).relu()
+ else:
+ h = bn(conv(h, edge_index)).relu()
+ return h
+
+ def core_step(self, combined, state):
+ return state + self.core(self.core_norm(combined))
+
+ def _z_step(self, y, z, ctx):
+ return self.core_step(ctx + y + z, z)
+
+ def _y_step(self, y, z):
+ return self.core_step(y + z, y)
+
+ def recurse(self, y, z, ctx, one_step=False):
+ if self.T == 0:
+ return y, z
+ if one_step:
+ with torch.no_grad():
+ for _ in range(self.T - 1):
+ z = self._z_step(y, z, ctx)
+ z = z.detach()
+ z = self._z_step(y, z, ctx)
+ y = self._y_step(y, z)
+ return y, z
+ for _ in range(self.T):
+ z = self._z_step(y, z, ctx)
+ y = self._y_step(y, z)
+ return y, z
+
+ def predict(self, y, batch):
+ pooled = global_add_pool(y, batch)
+ return self.head(pooled), self.qhead(pooled).view(-1)
+
+ def forward_trace(self, data, steps):
+ ctx = self.aggregate(data.x, data.edge_index, data.edge_attr)
+ y = ctx
+ z = torch.zeros_like(ctx)
+ preds, q_logits = [], []
+ for s in range(steps):
+ y, z = self.recurse(y, z, ctx, one_step=(self.grad_mode == "1step"))
+ pred, q = self.predict(y, data.batch)
+ preds.append(pred)
+ q_logits.append(q)
+ if s < steps - 1:
+ y, z = y.detach(), z.detach()
+ return preds, q_logits
+
+ def forward(self, data):
+ steps = self.n_sup
+ preds, q_logits = self.forward_trace(data, steps)
+ return preds, q_logits[-1]
+
+
+def supervised_loss(logits, y, metric):
+ per_graph, has_label = per_graph_supervised_loss(logits, y, metric)
+ if not has_label.any():
+ return logits.sum() * 0.0
+ return per_graph[has_label].mean()
+
+
+def per_graph_supervised_loss(logits, y, metric):
+ y = y.to(torch.float32)
+ mask = ~torch.isnan(y)
+ target = torch.where(mask, y, torch.zeros_like(y))
+ if is_regression_metric(metric):
+ losses = (logits - target).pow(2)
+ else:
+ losses = nn.functional.binary_cross_entropy_with_logits(logits, target, reduction="none")
+ losses = torch.where(mask, losses, torch.zeros_like(losses))
+ denom = mask.sum(dim=1).clamp_min(1)
+ return losses.sum(dim=1) / denom, mask.any(dim=1)
+
+
+@torch.no_grad()
+def halt_targets(logits, y):
+ y = y.to(torch.float32)
+ mask = ~torch.isnan(y)
+ pred = (logits > 0).to(y.dtype)
+ correct_or_missing = (~mask) | (pred == y)
+ has_label = mask.any(dim=1)
+ return (correct_or_missing.all(dim=1) & has_label).to(logits.dtype)
+
+
+@torch.no_grad()
+def evaluate(model, loader, evaluator, dev, steps=None, adaptive=False, halt_min_steps=1, max_batches=0):
+ model.eval()
+ ys, ps = [], []
+ step_sum = 0.0
+ n = 0
+ for i, batch in enumerate(loader):
+ if max_batches and i >= max_batches:
+ break
+ batch = batch.to(dev)
+ if steps is None:
+ preds, _ = model(batch)
+ pred = preds[-1]
+ used_steps = float(model.n_sup)
+ else:
+ preds, qs = model.forward_trace(batch, steps)
+ stack = torch.stack(preds, dim=0)
+ if adaptive:
+ q = torch.stack(qs, dim=0)
+ step_ids = torch.arange(1, steps + 1, device=dev).view(-1, 1)
+ halted = (q > 0) & (step_ids >= halt_min_steps)
+ any_halt = halted.any(dim=0)
+ first_halt = halted.to(torch.int64).argmax(dim=0)
+ fallback = torch.full_like(first_halt, steps - 1)
+ idx = torch.where(any_halt, first_halt, fallback)
+ pred = stack[idx, torch.arange(stack.size(1), device=dev)]
+ used_steps = float((idx.to(torch.float32) + 1).sum().item())
+ else:
+ pred = stack[-1]
+ used_steps = float(steps * batch.num_graphs)
+ ys.append(batch.y.detach().cpu())
+ ps.append(pred.detach().cpu())
+ step_sum += used_steps
+ n += batch.num_graphs
+ y_true = torch.cat(ys, dim=0)
+ y_pred = torch.cat(ps, dim=0)
+ result = evaluator.eval({"y_true": y_true, "y_pred": y_pred})
+ return result, step_sum / max(n, 1)
+
+
+def _split_nodes(t, ptr):
+ return [t[ptr[i].item():ptr[i + 1].item()].detach() for i in range(ptr.numel() - 1)]
+
+
+def act_train_step(model, state, replacement_batch, opt, dev, args, metric):
+ replacement = replacement_batch.to_data_list()
+ batch_size = len(replacement)
+ if state is None:
+ state = {
+ "graphs": [None for _ in range(batch_size)],
+ "y": [None for _ in range(batch_size)],
+ "z": [None for _ in range(batch_size)],
+ "steps": torch.zeros(batch_size, dtype=torch.long, device=dev),
+ "halted": torch.ones(batch_size, dtype=torch.bool, device=dev),
+ }
+
+ halted_cpu = state["halted"].detach().cpu().tolist()
+ for i, halted in enumerate(halted_cpu):
+ if halted:
+ state["graphs"][i] = replacement[i]
+
+ batch = Batch.from_data_list(state["graphs"]).to(dev)
+ ctx = model.aggregate(batch.x, batch.edge_index, batch.edge_attr)
+ ptr = batch.ptr
+ y_parts, z_parts = [], []
+ for i in range(batch_size):
+ start, end = ptr[i].item(), ptr[i + 1].item()
+ if halted_cpu[i] or state["y"][i] is None:
+ y_parts.append(ctx[start:end])
+ z_parts.append(torch.zeros_like(ctx[start:end]))
+ else:
+ y_parts.append(state["y"][i].to(dev))
+ z_parts.append(state["z"][i].to(dev))
+ y = torch.cat(y_parts, dim=0)
+ z = torch.cat(z_parts, dim=0)
+
+ opt.zero_grad()
+ y, z = model.recurse(y, z, ctx, one_step=(model.grad_mode == "1step"))
+ logits, q = model.predict(y, batch.batch)
+ pred_loss = supervised_loss(logits, batch.y, metric)
+ if args.halt_target == "exact" and not is_regression_metric(metric):
+ target = halt_targets(logits.detach(), batch.y)
+ elif args.halt_target == "loss":
+ per_graph_loss, has_label = per_graph_supervised_loss(logits.detach(), batch.y, metric)
+ target = ((per_graph_loss <= args.halt_loss_threshold) & has_label).to(logits.dtype)
+ else:
+ raise ValueError(args.halt_target)
+ q_loss = nn.functional.binary_cross_entropy_with_logits(q, target)
+ loss = pred_loss + 0.5 * args.lam_q * q_loss
+ y_det, z_det = y.detach(), z.detach()
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
+ opt.step()
+
+ state["y"] = _split_nodes(y_det, ptr)
+ state["z"] = _split_nodes(z_det, ptr)
+ with torch.no_grad():
+ was_halted = state["halted"]
+ steps = torch.where(was_halted, torch.zeros_like(state["steps"]), state["steps"]) + 1
+ halted = (steps >= args.halt_max_steps) | ((q.detach() > 0) & (steps >= args.halt_min_steps))
+ if args.halt_exploration_prob > 0 and args.halt_max_steps > 1:
+ explore = torch.rand_like(q) < args.halt_exploration_prob
+ min_steps = torch.where(
+ explore,
+ torch.randint(2, args.halt_max_steps + 1, steps.shape, device=dev),
+ torch.zeros_like(steps),
+ )
+ halted = halted & (steps >= min_steps)
+ state["steps"] = steps
+ state["halted"] = halted
+
+ return state, {
+ "loss": float(loss.detach().cpu()),
+ "pred_loss": float(pred_loss.detach().cpu()),
+ "q_loss": float(q_loss.detach().cpu()),
+ "halted_frac": float(state["halted"].to(torch.float32).mean().detach().cpu()),
+ "steps": float(state["steps"].to(torch.float32).mean().detach().cpu()),
+ }
+
+
+def _halt_target(args, logits, y, metric):
+ if args.halt_target == "exact" and not is_regression_metric(metric):
+ return halt_targets(logits.detach(), y)
+ if args.halt_target == "loss":
+ per_graph_loss, has_label = per_graph_supervised_loss(logits.detach(), y, metric)
+ return ((per_graph_loss <= args.halt_loss_threshold) & has_label).to(logits.dtype)
+ raise ValueError(args.halt_target)
+
+
+def act_trace_train_step(model, batch, opt, dev, args, epoch, metric):
+ batch = batch.to(dev)
+ steps = max(1, args.halt_max_steps)
+ opt.zero_grad()
+ preds, qs = model.forward_trace(batch, steps)
+ pred_loss = sum(supervised_loss(pred, batch.y, metric) for pred in preds) / len(preds)
+ if epoch <= args.q_warmup_epochs:
+ q_loss = pred_loss.detach() * 0.0
+ loss = pred_loss
+ else:
+ q_losses = []
+ for step_idx, (pred, q) in enumerate(zip(preds, qs), start=1):
+ target = _halt_target(args, pred, batch.y, metric)
+ if step_idx < args.halt_min_steps:
+ target = torch.zeros_like(target)
+ q_losses.append(nn.functional.binary_cross_entropy_with_logits(q, target))
+ q_loss = sum(q_losses) / len(q_losses)
+ loss = pred_loss + 0.5 * args.lam_q * q_loss
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
+ opt.step()
+
+ with torch.no_grad():
+ q_stack = torch.stack(qs, dim=0)
+ step_ids = torch.arange(1, steps + 1, device=dev).view(-1, 1)
+ halted = (q_stack > 0) & (step_ids >= args.halt_min_steps)
+ any_halt = halted.any(dim=0)
+ first_halt = halted.to(torch.int64).argmax(dim=0)
+ fallback = torch.full_like(first_halt, steps - 1)
+ idx = torch.where(any_halt, first_halt, fallback)
+
+ return {
+ "loss": float(loss.detach().cpu()),
+ "pred_loss": float(pred_loss.detach().cpu()),
+ "q_loss": float(q_loss.detach().cpu()),
+ "halted_frac": float(any_halt.to(torch.float32).mean().detach().cpu()),
+ "steps": float((idx.to(torch.float32) + 1).mean().detach().cpu()),
+ }
+
+
+def train_epoch(model, loader, opt, dev, args, act_state, ema_state, epoch, metric):
+ model.train()
+ metrics = []
+ for i, batch in enumerate(loader):
+ if args.max_train_batches and i >= args.max_train_batches:
+ break
+ if args.compute == "rrog-act":
+ m = act_trace_train_step(model, batch, opt, dev, args, epoch, metric)
+ update_ema_state(ema_state, model, args.ema)
+ metrics.append(m)
+ continue
+ batch = batch.to(dev)
+ opt.zero_grad()
+ preds, _ = model(batch)
+ loss = sum(supervised_loss(pred, batch.y, metric) for pred in preds) / len(preds)
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
+ opt.step()
+ update_ema_state(ema_state, model, args.ema)
+ metrics.append({"loss": float(loss.detach().cpu()), "halted_frac": 0.0, "steps": float(model.n_sup)})
+ return act_state, metrics
+
+
+def metric_value(result, evaluator):
+ return float(result[evaluator.eval_metric])
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument("--dataset", choices=SUPPORTED_MOL_DATASETS, default="ogbg-molhiv")
+ ap.add_argument("--view", choices=SUPPORTED_VIEWS, default="gin")
+ ap.add_argument("--compute", choices=["classic", "view-only", "fixed-rrog", "rrog-act"], default="rrog-act")
+ ap.add_argument("--T", type=int, default=1)
+ ap.add_argument("--n_sup", type=int, default=3)
+ ap.add_argument("--hidden", type=int, default=128)
+ ap.add_argument("--agg_layers", type=int, default=5)
+ ap.add_argument("--compute_layers", type=int, default=2)
+ ap.add_argument("--epochs", type=int, default=100)
+ ap.add_argument("--eval_every", type=int, default=10)
+ ap.add_argument("--lr", type=float, default=1e-3)
+ ap.add_argument("--bs", type=int, default=128)
+ ap.add_argument("--lam_q", type=float, default=1.0)
+ ap.add_argument("--halt_max_steps", type=int, default=8)
+ ap.add_argument("--halt_min_steps", type=int, default=1)
+ ap.add_argument("--halt_target", choices=["exact", "loss"], default="loss")
+ ap.add_argument("--halt_loss_threshold", type=float, default=0.2)
+ ap.add_argument("--halt_exploration_prob", type=float, default=0.1)
+ ap.add_argument("--q_warmup_epochs", type=int, default=0)
+ ap.add_argument("--ema", type=float, default=0.0)
+ ap.add_argument("--seed", type=int, default=0)
+ ap.add_argument("--num_workers", type=int, default=0)
+ ap.add_argument("--max_train_batches", type=int, default=0)
+ ap.add_argument("--max_eval_batches", type=int, default=0)
+ ap.add_argument("--device", default="auto")
+ args = ap.parse_args()
+
+ torch.manual_seed(args.seed)
+ np.random.seed(args.seed)
+ if args.compute == "classic":
+ args.n_sup = 1
+ if args.device == "auto":
+ dev = "cuda" if torch.cuda.is_available() else "cpu"
+ else:
+ dev = args.device
+ os.makedirs(OUT, exist_ok=True)
+
+ dataset = PygGraphPropPredDataset(name=args.dataset, root=ROOT)
+ split_idx = dataset.get_idx_split()
+ evaluator = Evaluator(args.dataset)
+ metric = evaluator.eval_metric
+ num_tasks = dataset.num_tasks
+
+ train_dataset = dataset[split_idx["train"]]
+ valid_dataset = dataset[split_idx["valid"]]
+ test_dataset = dataset[split_idx["test"]]
+ train_loader = DataLoader(train_dataset, batch_size=args.bs, shuffle=True,
+ drop_last=True, num_workers=args.num_workers)
+ valid_loader = DataLoader(valid_dataset, batch_size=256, shuffle=False,
+ num_workers=args.num_workers)
+ test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False,
+ num_workers=args.num_workers)
+
+ T = 0 if args.compute in ["classic", "view-only"] else args.T
+ deg = degree_histogram(train_dataset) if args.view == "pna" else None
+ model = OGBRRoG(args.hidden, num_tasks, view=args.view, T=T, n_sup=args.n_sup,
+ agg_layers=args.agg_layers, compute_layers=args.compute_layers, deg=deg).to(dev)
+ opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, max(args.epochs, 1))
+ ema_state = clone_state_dict(model) if args.ema > 0 else None
+
+ t0 = time.time()
+ best_val = None
+ best = {}
+ best_state = None
+ act_state = None
+ steps = max(1, args.halt_max_steps)
+
+ for ep in range(args.epochs):
+ act_state, train_metrics = train_epoch(
+ model, train_loader, opt, dev, args, act_state, ema_state, ep + 1, metric)
+ sched.step()
+ if (ep + 1) % args.eval_every == 0 or ep == args.epochs - 1:
+ with using_ema_state(model, ema_state):
+ if args.compute == "rrog-act":
+ val_result, fixed_val_steps = evaluate(model, valid_loader, evaluator, dev, steps=steps,
+ adaptive=False, halt_min_steps=args.halt_min_steps,
+ max_batches=args.max_eval_batches)
+ val_adapt, adaptive_val_steps = evaluate(model, valid_loader, evaluator, dev, steps=steps,
+ adaptive=True, halt_min_steps=args.halt_min_steps,
+ max_batches=args.max_eval_batches)
+ else:
+ val_result, _ = evaluate(model, valid_loader, evaluator, dev,
+ max_batches=args.max_eval_batches)
+ val_score = metric_value(val_result, evaluator)
+ if is_better(val_score, best_val, metric):
+ best_val = val_score
+ if args.compute == "rrog-act":
+ test_fixed, fixed_steps = evaluate(model, test_loader, evaluator, dev, steps=steps,
+ adaptive=False, halt_min_steps=args.halt_min_steps,
+ max_batches=args.max_eval_batches)
+ test_adapt, adaptive_steps = evaluate(model, test_loader, evaluator, dev, steps=steps,
+ adaptive=True, halt_min_steps=args.halt_min_steps,
+ max_batches=args.max_eval_batches)
+ best = {
+ "ep": ep + 1,
+ "val": val_result,
+ "val_adaptive": val_adapt,
+ "test": test_fixed,
+ "test_adaptive": test_adapt,
+ "fixed_val_steps": fixed_val_steps,
+ "adaptive_val_steps": adaptive_val_steps,
+ "fixed_steps": fixed_steps,
+ "adaptive_steps": adaptive_steps,
+ }
+ else:
+ test_result, _ = evaluate(model, test_loader, evaluator, dev,
+ max_batches=args.max_eval_batches)
+ best = {"ep": ep + 1, "val": val_result, "test": test_result}
+ best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
+ halted = sum(m.get("halted_frac", 0.0) for m in train_metrics) / max(len(train_metrics), 1)
+ train_steps = sum(m.get("steps", 0.0) for m in train_metrics) / max(len(train_metrics), 1)
+ msg = f"ep{ep+1} val_{evaluator.eval_metric}={val_score:.5f}"
+ if args.compute == "rrog-act":
+ msg += (
+ f" val_adapt_{evaluator.eval_metric}={metric_value(val_adapt, evaluator):.5f}"
+ f" adapt_steps={adaptive_val_steps:.2f}"
+ )
+ msg += f" halt={halted:.2f} train_steps={train_steps:.2f}"
+ print(msg, flush=True)
+
+ ema_tag = f"_ema{args.ema:g}" if args.ema > 0 else ""
+ tag = (
+ f"{args.dataset}_{args.view}_{args.compute}_T{T}_ns{args.n_sup}_"
+ f"h{args.hidden}_e{args.epochs}{ema_tag}_s{args.seed}"
+ )
+ rep = {
+ "dataset": args.dataset,
+ "tag": tag,
+ **vars(args),
+ "metric": evaluator.eval_metric,
+ "sec": round(time.time() - t0, 1),
+ "dev": dev,
+ **best,
+ }
+ with open(os.path.join(OUT, f"{tag}.json"), "w") as f:
+ json.dump(jsonable(rep), f, indent=2)
+ torch.save({
+ "state": best_state or model.state_dict(),
+ "cfg": {
+ "dataset": args.dataset,
+ "hidden": args.hidden,
+ "num_tasks": num_tasks,
+ "T": T,
+ "n_sup": args.n_sup,
+ "agg_layers": args.agg_layers,
+ "compute_layers": args.compute_layers,
+ "compute": args.compute,
+ "halt_max_steps": steps,
+ "halt_min_steps": args.halt_min_steps,
+ "halt_target": args.halt_target,
+ "halt_loss_threshold": args.halt_loss_threshold,
+ "view": args.view,
+ },
+ }, os.path.join(OUT, f"ckpt_{tag}.pt"))
+ print(f"[{tag}] best_ep={best.get('ep')} val={best.get('val')} test={best.get('test')} "
+ f"adaptive={best.get('test_adaptive')} steps={best.get('adaptive_steps')}", flush=True)
+ print(" wrote", os.path.join(OUT, f"{tag}.json"))
+
+
+if __name__ == "__main__":
+ main()