diff options
Diffstat (limited to 'rrog/cli.py')
| -rw-r--r-- | rrog/cli.py | 176 |
1 files changed, 176 insertions, 0 deletions
diff --git a/rrog/cli.py b/rrog/cli.py new file mode 100644 index 0000000..9c826e3 --- /dev/null +++ b/rrog/cli.py @@ -0,0 +1,176 @@ +import argparse +import json +import os +import subprocess + +from rrog.backbones import COMPUTES, MODIFIERS, VIEWS +from rrog.benchmarks import BENCHMARKS +from rrog.collect_results import print_tables +from rrog.collect_zinc import print_zinc +from rrog.runspecs import find_run_spec + + +def _rows(items): + return [ + { + "name": item.name, + "tier": getattr(item, "tier", None), + "family": getattr(item, "family", None), + "domain": getattr(item, "domain", None), + "task_type": getattr(item, "task_type", None), + "metric": getattr(item, "metric", None), + "priority": item.priority, + "status": item.status, + "notes": item.notes, + } + for item in items + ] + + +def _print_table(items): + for row in _rows(items): + cols = [row["name"], row.get("tier") or row.get("family") or "", row["status"], row["notes"]] + print("\t".join(str(c) for c in cols)) + + +def list_axis(axis: str, as_json: bool): + mapping = { + "benchmarks": BENCHMARKS, + "views": VIEWS, + "computes": COMPUTES, + "modifiers": MODIFIERS, + } + items = mapping[axis] + if as_json: + print(json.dumps(_rows(items), indent=2)) + else: + _print_table(items) + + +def build_command(args) -> list[str]: + spec = find_run_spec(args.task, args.view, args.compute, args.modifier) + run_args = dict(spec.default_args) + for key in [ + "epochs", "hidden", "bs", "seed", "T", "n_sup", "halt_max_steps", "halt_target", + "halt_min_steps", "halt_loss_threshold", "q_warmup_epochs", + "eval_every", "max_train_batches", "max_eval_batches", "num_workers", "ema", "lr", "lam_q", + "device", + ]: + value = getattr(args, key) + if value is not None: + run_args[key] = value + run_args["compute"] = args.compute + return spec.command_builder(run_args) + + +def print_matrix(args): + tasks = [b for b in BENCHMARKS if args.tier == "all" or b.tier == args.tier] + tasks = sorted(tasks, key=lambda x: x.priority)[:args.limit_tasks] + + if args.kind == "main": + views = sorted(VIEWS, key=lambda x: x.priority)[:args.limit_views] + computes = [c for c in COMPUTES if c.name in ["classic", "view-only", "fixed-rrog", "rrog-act"]] + for task in tasks: + for view in views: + for compute in computes: + try: + find_run_spec(task.name, view.name, compute.name) + status = "implemented" + except KeyError: + status = "planned" + print(f"{task.name}\t{view.name}\t{compute.name}\tnone\t{status}") + return + + if args.kind == "modifier": + task_names = {"zinc-cycle56", "ogbg-molhiv", "peptides-struct", "peptides-func", "ogbn-arxiv", "qm9"} + mods = [m for m in sorted(MODIFIERS, key=lambda x: x.priority) if m.name != "none"] + for task in tasks: + if task.name not in task_names: + continue + for mod in mods: + print(f"{task.name}\tgin\tview-only\t{mod.name}\tplanned") + return + + raise ValueError(args.kind) + + +def main(): + ap = argparse.ArgumentParser() + sub = ap.add_subparsers(dest="cmd", required=True) + + lp = sub.add_parser("list") + lp.add_argument("axis", choices=["benchmarks", "views", "computes", "modifiers"]) + lp.add_argument("--json", action="store_true") + + rp = sub.add_parser("run") + rp.add_argument("--task", default="zinc-cycle56") + rp.add_argument("--view", default="gin") + rp.add_argument("--compute", default="rrog-act") + rp.add_argument("--modifier", default="none") + rp.add_argument("--epochs", type=int) + rp.add_argument("--hidden", type=int) + rp.add_argument("--bs", type=int) + rp.add_argument("--seed", type=int) + rp.add_argument("--T", type=int) + rp.add_argument("--n_sup", type=int) + rp.add_argument("--halt_max_steps", type=int) + rp.add_argument("--halt_target", choices=["soft", "binary", "exact", "loss"]) + rp.add_argument("--halt_min_steps", type=int) + rp.add_argument("--halt_loss_threshold", type=float) + rp.add_argument("--q_warmup_epochs", type=int) + rp.add_argument("--eval_every", type=int) + rp.add_argument("--max_train_batches", type=int) + rp.add_argument("--max_eval_batches", type=int) + rp.add_argument("--num_workers", type=int) + rp.add_argument("--ema", type=float) + rp.add_argument("--lr", type=float) + rp.add_argument("--lam_q", type=float) + rp.add_argument("--device") + rp.add_argument("--dry-run", action="store_true") + + mp = sub.add_parser("matrix") + mp.add_argument("--kind", choices=["main", "modifier"], default="main") + mp.add_argument("--tier", choices=["A", "B", "C", "all"], default="A") + mp.add_argument("--limit-tasks", type=int, default=20) + mp.add_argument("--limit-views", type=int, default=6) + + op = sub.add_parser("results") + op.add_argument("--runs-dir", default="runs") + op.add_argument("--glob", default="*.json") + op.add_argument("--min-epochs", type=int, default=10) + op.add_argument("--epochs", type=int) + op.add_argument("--digits", type=int, default=4) + + zp = sub.add_parser("zinc-results") + zp.add_argument("--runs-dir", default="runs") + zp.add_argument("--epochs", type=int) + zp.add_argument("--min-epochs", type=int, default=10) + zp.add_argument("--digits", type=int, default=4) + + args = ap.parse_args() + if args.cmd == "list": + list_axis(args.axis, args.json) + return + if args.cmd == "matrix": + print_matrix(args) + return + if args.cmd == "results": + print_tables(args) + return + if args.cmd == "zinc-results": + print_zinc(args) + return + + cmd = build_command(args) + env = dict(os.environ) + env["PYTHONPATH"] = os.getcwd() + os.pathsep + env.get("PYTHONPATH", "") + print(" ".join(cmd), flush=True) + if not args.dry_run: + raise SystemExit(subprocess.call(cmd, env=env)) + + +if __name__ == "__main__": + try: + main() + except BrokenPipeError: + raise SystemExit(0) |
