from rrog.registry import RunSpec ZINC_VIEWS = [ "gin", "gine", "gcn", "graphsage", "gatv2", "graphconv", "transformer", "pna", "gen", "film", "resgated", "tag", "sgc", "cheb", "arma", "mf", "appnp", ] OGB_MOL_VIEWS = [ "gin", "gine", "gcn", "graphsage", "gatv2", "graphconv", "transformer", "pna", "gen", "film", "resgated", "tag", "sgc", "cheb", "arma", "mf", "appnp", ] OGB_MOL_TASKS = [ "ogbg-molhiv", "ogbg-molpcba", "ogbg-molbbbp", "ogbg-molbace", "ogbg-moltox21", "ogbg-molclintox", "ogbg-molsider", "ogbg-molesol", "ogbg-molfreesolv", "ogbg-mollipo", ] def _zinc_cycle56_gin_command(args: dict[str, object]) -> list[str]: compute = str(args.get("compute", "rrog-act")) view = str(args.get("view", "gin")) epochs = int(args.get("epochs", 200)) hidden = int(args.get("hidden", 64)) batch_size = int(args.get("bs", 256)) seed = int(args.get("seed", 0)) t = int(args.get("T", 1)) n_sup = int(args.get("n_sup", 3)) halt_max_steps = int(args.get("halt_max_steps", 8)) halt_target = str(args.get("halt_target", "binary")) cmd = [ "python3", "diag/train_rec.py", "--grad_mode", "full", "--T", str(t), "--n_sup", str(n_sup), "--hidden", str(hidden), "--bs", str(batch_size), "--epochs", str(epochs), "--agg_layers", str(args.get("agg_layers", 5)), "--compute_layers", str(args.get("compute_layers", 2)), "--view", view, "--sigma", "0", "--K", "1", "--select", "none", "--seed", str(seed), ] if compute in ["classic", "view-only"]: cmd[cmd.index("--T") + 1] = "0" elif compute == "fixed-rrog": pass elif compute == "rrog-act": cmd.extend([ "--act", "--halt_max_steps", str(halt_max_steps), "--halt_target", halt_target, "--halt_norm_threshold", str(args.get("halt_norm_threshold", 0.3)), ]) else: raise ValueError(f"unsupported compute for zinc-cycle56/{view}: {compute}") if args.get("device") is not None: cmd.extend(["--device", str(args["device"])]) return cmd def _ogb_graphprop_gin_command(args: dict[str, object]) -> list[str]: task = str(args["task"]) view = str(args.get("view", "gin")) compute = str(args.get("compute", "rrog-act")) cmd = [ "python3", "rrog/train_ogb_graphprop.py", "--dataset", task, "--view", view, "--compute", compute, "--T", str(args.get("T", 1)), "--n_sup", str(args.get("n_sup", 3)), "--hidden", str(args.get("hidden", 128)), "--bs", str(args.get("bs", 128)), "--epochs", str(args.get("epochs", 100)), "--eval_every", str(args.get("eval_every", 10)), "--agg_layers", str(args.get("agg_layers", 5)), "--compute_layers", str(args.get("compute_layers", 2)), "--seed", str(args.get("seed", 0)), ] for key in ["lr", "lam_q"]: if args.get(key) is not None: cmd.extend([f"--{key}", str(args[key])]) if compute == "rrog-act": cmd.extend([ "--halt_max_steps", str(args.get("halt_max_steps", 8)), "--halt_min_steps", str(args.get("halt_min_steps", 2)), "--halt_target", str(args.get("halt_target", "loss")), "--halt_loss_threshold", str(args.get("halt_loss_threshold", 0.2)), "--halt_exploration_prob", str(args.get("halt_exploration_prob", 0.1)), ]) if args.get("q_warmup_epochs") is not None: cmd.extend(["--q_warmup_epochs", str(args["q_warmup_epochs"])]) if float(args.get("ema", 0.0) or 0.0) > 0: cmd.extend(["--ema", str(args["ema"])]) if args.get("device") is not None: cmd.extend(["--device", str(args["device"])]) for key in ["max_train_batches", "max_eval_batches", "num_workers"]: if args.get(key) is not None: cmd.extend([f"--{key}", str(args[key])]) return cmd RUN_SPECS = [ ] for _view in ZINC_VIEWS: RUN_SPECS.extend([ RunSpec( task="zinc-cycle56", view=_view, compute="classic", default_args={"compute": "classic", "view": _view, "T": 0, "n_sup": 1, "epochs": 200}, command_builder=_zinc_cycle56_gin_command, ), RunSpec( task="zinc-cycle56", view=_view, compute="view-only", default_args={"compute": "view-only", "view": _view, "T": 0, "n_sup": 3, "epochs": 200}, command_builder=_zinc_cycle56_gin_command, ), RunSpec( task="zinc-cycle56", view=_view, compute="fixed-rrog", default_args={"compute": "fixed-rrog", "view": _view, "T": 3, "n_sup": 3, "epochs": 200}, command_builder=_zinc_cycle56_gin_command, ), RunSpec( task="zinc-cycle56", view=_view, compute="rrog-act", default_args={"compute": "rrog-act", "view": _view, "T": 1, "n_sup": 3, "epochs": 200}, command_builder=_zinc_cycle56_gin_command, ), ]) for _task in OGB_MOL_TASKS: for _view in OGB_MOL_VIEWS: RUN_SPECS.extend([ RunSpec( task=_task, view=_view, compute="classic", default_args={"task": _task, "view": _view, "compute": "classic", "T": 0, "n_sup": 1, "epochs": 100}, command_builder=_ogb_graphprop_gin_command, ), RunSpec( task=_task, view=_view, compute="view-only", default_args={"task": _task, "view": _view, "compute": "view-only", "T": 0, "n_sup": 3, "epochs": 100}, command_builder=_ogb_graphprop_gin_command, ), RunSpec( task=_task, view=_view, compute="fixed-rrog", default_args={"task": _task, "view": _view, "compute": "fixed-rrog", "T": 3, "n_sup": 3, "epochs": 100}, command_builder=_ogb_graphprop_gin_command, ), RunSpec( task=_task, view=_view, compute="rrog-act", default_args={"task": _task, "view": _view, "compute": "rrog-act", "T": 1, "n_sup": 3, "epochs": 100}, command_builder=_ogb_graphprop_gin_command, ), ]) def find_run_spec(task: str, view: str, compute: str, modifier: str = "none") -> RunSpec: for spec in RUN_SPECS: if (spec.task, spec.view, spec.compute, spec.modifier) == (task, view, compute, modifier): return spec raise KeyError(f"no implemented run spec for task={task} view={view} compute={compute} modifier={modifier}")