diff options
Diffstat (limited to 'rrog/runspecs.py')
| -rw-r--r-- | rrog/runspecs.py | 188 |
1 files changed, 188 insertions, 0 deletions
diff --git a/rrog/runspecs.py b/rrog/runspecs.py new file mode 100644 index 0000000..285348f --- /dev/null +++ b/rrog/runspecs.py @@ -0,0 +1,188 @@ +from rrog.registry import RunSpec + + +ZINC_VIEWS = [ + "gin", "gine", "gcn", "graphsage", "gatv2", "graphconv", "transformer", "pna", + "gen", "film", "resgated", "tag", "sgc", "cheb", "arma", "mf", "appnp", +] +OGB_MOL_VIEWS = [ + "gin", "gine", "gcn", "graphsage", "gatv2", "graphconv", "transformer", "pna", + "gen", "film", "resgated", "tag", "sgc", "cheb", "arma", "mf", "appnp", +] +OGB_MOL_TASKS = [ + "ogbg-molhiv", + "ogbg-molpcba", + "ogbg-molbbbp", + "ogbg-molbace", + "ogbg-moltox21", + "ogbg-molclintox", + "ogbg-molsider", + "ogbg-molesol", + "ogbg-molfreesolv", + "ogbg-mollipo", +] + + +def _zinc_cycle56_gin_command(args: dict[str, object]) -> list[str]: + compute = str(args.get("compute", "rrog-act")) + view = str(args.get("view", "gin")) + epochs = int(args.get("epochs", 200)) + hidden = int(args.get("hidden", 64)) + batch_size = int(args.get("bs", 256)) + seed = int(args.get("seed", 0)) + t = int(args.get("T", 1)) + n_sup = int(args.get("n_sup", 3)) + halt_max_steps = int(args.get("halt_max_steps", 8)) + halt_target = str(args.get("halt_target", "binary")) + + cmd = [ + "python3", "diag/train_rec.py", + "--grad_mode", "full", + "--T", str(t), + "--n_sup", str(n_sup), + "--hidden", str(hidden), + "--bs", str(batch_size), + "--epochs", str(epochs), + "--agg_layers", str(args.get("agg_layers", 5)), + "--compute_layers", str(args.get("compute_layers", 2)), + "--view", view, + "--sigma", "0", + "--K", "1", + "--select", "none", + "--seed", str(seed), + ] + + if compute in ["classic", "view-only"]: + cmd[cmd.index("--T") + 1] = "0" + elif compute == "fixed-rrog": + pass + elif compute == "rrog-act": + cmd.extend([ + "--act", + "--halt_max_steps", str(halt_max_steps), + "--halt_target", halt_target, + "--halt_norm_threshold", str(args.get("halt_norm_threshold", 0.3)), + ]) + else: + raise ValueError(f"unsupported compute for zinc-cycle56/{view}: {compute}") + if args.get("device") is not None: + cmd.extend(["--device", str(args["device"])]) + return cmd + + +def _ogb_graphprop_gin_command(args: dict[str, object]) -> list[str]: + task = str(args["task"]) + view = str(args.get("view", "gin")) + compute = str(args.get("compute", "rrog-act")) + cmd = [ + "python3", "rrog/train_ogb_graphprop.py", + "--dataset", task, + "--view", view, + "--compute", compute, + "--T", str(args.get("T", 1)), + "--n_sup", str(args.get("n_sup", 3)), + "--hidden", str(args.get("hidden", 128)), + "--bs", str(args.get("bs", 128)), + "--epochs", str(args.get("epochs", 100)), + "--eval_every", str(args.get("eval_every", 10)), + "--agg_layers", str(args.get("agg_layers", 5)), + "--compute_layers", str(args.get("compute_layers", 2)), + "--seed", str(args.get("seed", 0)), + ] + for key in ["lr", "lam_q"]: + if args.get(key) is not None: + cmd.extend([f"--{key}", str(args[key])]) + if compute == "rrog-act": + cmd.extend([ + "--halt_max_steps", str(args.get("halt_max_steps", 8)), + "--halt_min_steps", str(args.get("halt_min_steps", 2)), + "--halt_target", str(args.get("halt_target", "loss")), + "--halt_loss_threshold", str(args.get("halt_loss_threshold", 0.2)), + "--halt_exploration_prob", str(args.get("halt_exploration_prob", 0.1)), + ]) + if args.get("q_warmup_epochs") is not None: + cmd.extend(["--q_warmup_epochs", str(args["q_warmup_epochs"])]) + if float(args.get("ema", 0.0) or 0.0) > 0: + cmd.extend(["--ema", str(args["ema"])]) + if args.get("device") is not None: + cmd.extend(["--device", str(args["device"])]) + for key in ["max_train_batches", "max_eval_batches", "num_workers"]: + if args.get(key) is not None: + cmd.extend([f"--{key}", str(args[key])]) + return cmd + + +RUN_SPECS = [ +] + +for _view in ZINC_VIEWS: + RUN_SPECS.extend([ + RunSpec( + task="zinc-cycle56", + view=_view, + compute="classic", + default_args={"compute": "classic", "view": _view, "T": 0, "n_sup": 1, "epochs": 200}, + command_builder=_zinc_cycle56_gin_command, + ), + RunSpec( + task="zinc-cycle56", + view=_view, + compute="view-only", + default_args={"compute": "view-only", "view": _view, "T": 0, "n_sup": 3, "epochs": 200}, + command_builder=_zinc_cycle56_gin_command, + ), + RunSpec( + task="zinc-cycle56", + view=_view, + compute="fixed-rrog", + default_args={"compute": "fixed-rrog", "view": _view, "T": 3, "n_sup": 3, "epochs": 200}, + command_builder=_zinc_cycle56_gin_command, + ), + RunSpec( + task="zinc-cycle56", + view=_view, + compute="rrog-act", + default_args={"compute": "rrog-act", "view": _view, "T": 1, "n_sup": 3, "epochs": 200}, + command_builder=_zinc_cycle56_gin_command, + ), + ]) + +for _task in OGB_MOL_TASKS: + for _view in OGB_MOL_VIEWS: + RUN_SPECS.extend([ + RunSpec( + task=_task, + view=_view, + compute="classic", + default_args={"task": _task, "view": _view, "compute": "classic", "T": 0, "n_sup": 1, "epochs": 100}, + command_builder=_ogb_graphprop_gin_command, + ), + RunSpec( + task=_task, + view=_view, + compute="view-only", + default_args={"task": _task, "view": _view, "compute": "view-only", "T": 0, "n_sup": 3, "epochs": 100}, + command_builder=_ogb_graphprop_gin_command, + ), + RunSpec( + task=_task, + view=_view, + compute="fixed-rrog", + default_args={"task": _task, "view": _view, "compute": "fixed-rrog", "T": 3, "n_sup": 3, "epochs": 100}, + command_builder=_ogb_graphprop_gin_command, + ), + RunSpec( + task=_task, + view=_view, + compute="rrog-act", + default_args={"task": _task, "view": _view, "compute": "rrog-act", "T": 1, "n_sup": 3, "epochs": 100}, + command_builder=_ogb_graphprop_gin_command, + ), + ]) + + +def find_run_spec(task: str, view: str, compute: str, modifier: str = "none") -> RunSpec: + for spec in RUN_SPECS: + if (spec.task, spec.view, spec.compute, spec.modifier) == (task, view, compute, modifier): + return spec + raise KeyError(f"no implemented run spec for task={task} view={view} compute={compute} modifier={modifier}") |
