summaryrefslogtreecommitdiff
path: root/rrog/runspecs.py
diff options
context:
space:
mode:
Diffstat (limited to 'rrog/runspecs.py')
-rw-r--r--rrog/runspecs.py188
1 files changed, 188 insertions, 0 deletions
diff --git a/rrog/runspecs.py b/rrog/runspecs.py
new file mode 100644
index 0000000..285348f
--- /dev/null
+++ b/rrog/runspecs.py
@@ -0,0 +1,188 @@
+from rrog.registry import RunSpec
+
+
+ZINC_VIEWS = [
+ "gin", "gine", "gcn", "graphsage", "gatv2", "graphconv", "transformer", "pna",
+ "gen", "film", "resgated", "tag", "sgc", "cheb", "arma", "mf", "appnp",
+]
+OGB_MOL_VIEWS = [
+ "gin", "gine", "gcn", "graphsage", "gatv2", "graphconv", "transformer", "pna",
+ "gen", "film", "resgated", "tag", "sgc", "cheb", "arma", "mf", "appnp",
+]
+OGB_MOL_TASKS = [
+ "ogbg-molhiv",
+ "ogbg-molpcba",
+ "ogbg-molbbbp",
+ "ogbg-molbace",
+ "ogbg-moltox21",
+ "ogbg-molclintox",
+ "ogbg-molsider",
+ "ogbg-molesol",
+ "ogbg-molfreesolv",
+ "ogbg-mollipo",
+]
+
+
+def _zinc_cycle56_gin_command(args: dict[str, object]) -> list[str]:
+ compute = str(args.get("compute", "rrog-act"))
+ view = str(args.get("view", "gin"))
+ epochs = int(args.get("epochs", 200))
+ hidden = int(args.get("hidden", 64))
+ batch_size = int(args.get("bs", 256))
+ seed = int(args.get("seed", 0))
+ t = int(args.get("T", 1))
+ n_sup = int(args.get("n_sup", 3))
+ halt_max_steps = int(args.get("halt_max_steps", 8))
+ halt_target = str(args.get("halt_target", "binary"))
+
+ cmd = [
+ "python3", "diag/train_rec.py",
+ "--grad_mode", "full",
+ "--T", str(t),
+ "--n_sup", str(n_sup),
+ "--hidden", str(hidden),
+ "--bs", str(batch_size),
+ "--epochs", str(epochs),
+ "--agg_layers", str(args.get("agg_layers", 5)),
+ "--compute_layers", str(args.get("compute_layers", 2)),
+ "--view", view,
+ "--sigma", "0",
+ "--K", "1",
+ "--select", "none",
+ "--seed", str(seed),
+ ]
+
+ if compute in ["classic", "view-only"]:
+ cmd[cmd.index("--T") + 1] = "0"
+ elif compute == "fixed-rrog":
+ pass
+ elif compute == "rrog-act":
+ cmd.extend([
+ "--act",
+ "--halt_max_steps", str(halt_max_steps),
+ "--halt_target", halt_target,
+ "--halt_norm_threshold", str(args.get("halt_norm_threshold", 0.3)),
+ ])
+ else:
+ raise ValueError(f"unsupported compute for zinc-cycle56/{view}: {compute}")
+ if args.get("device") is not None:
+ cmd.extend(["--device", str(args["device"])])
+ return cmd
+
+
+def _ogb_graphprop_gin_command(args: dict[str, object]) -> list[str]:
+ task = str(args["task"])
+ view = str(args.get("view", "gin"))
+ compute = str(args.get("compute", "rrog-act"))
+ cmd = [
+ "python3", "rrog/train_ogb_graphprop.py",
+ "--dataset", task,
+ "--view", view,
+ "--compute", compute,
+ "--T", str(args.get("T", 1)),
+ "--n_sup", str(args.get("n_sup", 3)),
+ "--hidden", str(args.get("hidden", 128)),
+ "--bs", str(args.get("bs", 128)),
+ "--epochs", str(args.get("epochs", 100)),
+ "--eval_every", str(args.get("eval_every", 10)),
+ "--agg_layers", str(args.get("agg_layers", 5)),
+ "--compute_layers", str(args.get("compute_layers", 2)),
+ "--seed", str(args.get("seed", 0)),
+ ]
+ for key in ["lr", "lam_q"]:
+ if args.get(key) is not None:
+ cmd.extend([f"--{key}", str(args[key])])
+ if compute == "rrog-act":
+ cmd.extend([
+ "--halt_max_steps", str(args.get("halt_max_steps", 8)),
+ "--halt_min_steps", str(args.get("halt_min_steps", 2)),
+ "--halt_target", str(args.get("halt_target", "loss")),
+ "--halt_loss_threshold", str(args.get("halt_loss_threshold", 0.2)),
+ "--halt_exploration_prob", str(args.get("halt_exploration_prob", 0.1)),
+ ])
+ if args.get("q_warmup_epochs") is not None:
+ cmd.extend(["--q_warmup_epochs", str(args["q_warmup_epochs"])])
+ if float(args.get("ema", 0.0) or 0.0) > 0:
+ cmd.extend(["--ema", str(args["ema"])])
+ if args.get("device") is not None:
+ cmd.extend(["--device", str(args["device"])])
+ for key in ["max_train_batches", "max_eval_batches", "num_workers"]:
+ if args.get(key) is not None:
+ cmd.extend([f"--{key}", str(args[key])])
+ return cmd
+
+
+RUN_SPECS = [
+]
+
+for _view in ZINC_VIEWS:
+ RUN_SPECS.extend([
+ RunSpec(
+ task="zinc-cycle56",
+ view=_view,
+ compute="classic",
+ default_args={"compute": "classic", "view": _view, "T": 0, "n_sup": 1, "epochs": 200},
+ command_builder=_zinc_cycle56_gin_command,
+ ),
+ RunSpec(
+ task="zinc-cycle56",
+ view=_view,
+ compute="view-only",
+ default_args={"compute": "view-only", "view": _view, "T": 0, "n_sup": 3, "epochs": 200},
+ command_builder=_zinc_cycle56_gin_command,
+ ),
+ RunSpec(
+ task="zinc-cycle56",
+ view=_view,
+ compute="fixed-rrog",
+ default_args={"compute": "fixed-rrog", "view": _view, "T": 3, "n_sup": 3, "epochs": 200},
+ command_builder=_zinc_cycle56_gin_command,
+ ),
+ RunSpec(
+ task="zinc-cycle56",
+ view=_view,
+ compute="rrog-act",
+ default_args={"compute": "rrog-act", "view": _view, "T": 1, "n_sup": 3, "epochs": 200},
+ command_builder=_zinc_cycle56_gin_command,
+ ),
+ ])
+
+for _task in OGB_MOL_TASKS:
+ for _view in OGB_MOL_VIEWS:
+ RUN_SPECS.extend([
+ RunSpec(
+ task=_task,
+ view=_view,
+ compute="classic",
+ default_args={"task": _task, "view": _view, "compute": "classic", "T": 0, "n_sup": 1, "epochs": 100},
+ command_builder=_ogb_graphprop_gin_command,
+ ),
+ RunSpec(
+ task=_task,
+ view=_view,
+ compute="view-only",
+ default_args={"task": _task, "view": _view, "compute": "view-only", "T": 0, "n_sup": 3, "epochs": 100},
+ command_builder=_ogb_graphprop_gin_command,
+ ),
+ RunSpec(
+ task=_task,
+ view=_view,
+ compute="fixed-rrog",
+ default_args={"task": _task, "view": _view, "compute": "fixed-rrog", "T": 3, "n_sup": 3, "epochs": 100},
+ command_builder=_ogb_graphprop_gin_command,
+ ),
+ RunSpec(
+ task=_task,
+ view=_view,
+ compute="rrog-act",
+ default_args={"task": _task, "view": _view, "compute": "rrog-act", "T": 1, "n_sup": 3, "epochs": 100},
+ command_builder=_ogb_graphprop_gin_command,
+ ),
+ ])
+
+
+def find_run_spec(task: str, view: str, compute: str, modifier: str = "none") -> RunSpec:
+ for spec in RUN_SPECS:
+ if (spec.task, spec.view, spec.compute, spec.modifier) == (task, view, compute, modifier):
+ return spec
+ raise KeyError(f"no implemented run spec for task={task} view={view} compute={compute} modifier={modifier}")