summaryrefslogtreecommitdiff
path: root/rrog/cli.py
diff options
context:
space:
mode:
Diffstat (limited to 'rrog/cli.py')
-rw-r--r--rrog/cli.py176
1 files changed, 176 insertions, 0 deletions
diff --git a/rrog/cli.py b/rrog/cli.py
new file mode 100644
index 0000000..9c826e3
--- /dev/null
+++ b/rrog/cli.py
@@ -0,0 +1,176 @@
+import argparse
+import json
+import os
+import subprocess
+
+from rrog.backbones import COMPUTES, MODIFIERS, VIEWS
+from rrog.benchmarks import BENCHMARKS
+from rrog.collect_results import print_tables
+from rrog.collect_zinc import print_zinc
+from rrog.runspecs import find_run_spec
+
+
+def _rows(items):
+ return [
+ {
+ "name": item.name,
+ "tier": getattr(item, "tier", None),
+ "family": getattr(item, "family", None),
+ "domain": getattr(item, "domain", None),
+ "task_type": getattr(item, "task_type", None),
+ "metric": getattr(item, "metric", None),
+ "priority": item.priority,
+ "status": item.status,
+ "notes": item.notes,
+ }
+ for item in items
+ ]
+
+
+def _print_table(items):
+ for row in _rows(items):
+ cols = [row["name"], row.get("tier") or row.get("family") or "", row["status"], row["notes"]]
+ print("\t".join(str(c) for c in cols))
+
+
+def list_axis(axis: str, as_json: bool):
+ mapping = {
+ "benchmarks": BENCHMARKS,
+ "views": VIEWS,
+ "computes": COMPUTES,
+ "modifiers": MODIFIERS,
+ }
+ items = mapping[axis]
+ if as_json:
+ print(json.dumps(_rows(items), indent=2))
+ else:
+ _print_table(items)
+
+
+def build_command(args) -> list[str]:
+ spec = find_run_spec(args.task, args.view, args.compute, args.modifier)
+ run_args = dict(spec.default_args)
+ for key in [
+ "epochs", "hidden", "bs", "seed", "T", "n_sup", "halt_max_steps", "halt_target",
+ "halt_min_steps", "halt_loss_threshold", "q_warmup_epochs",
+ "eval_every", "max_train_batches", "max_eval_batches", "num_workers", "ema", "lr", "lam_q",
+ "device",
+ ]:
+ value = getattr(args, key)
+ if value is not None:
+ run_args[key] = value
+ run_args["compute"] = args.compute
+ return spec.command_builder(run_args)
+
+
+def print_matrix(args):
+ tasks = [b for b in BENCHMARKS if args.tier == "all" or b.tier == args.tier]
+ tasks = sorted(tasks, key=lambda x: x.priority)[:args.limit_tasks]
+
+ if args.kind == "main":
+ views = sorted(VIEWS, key=lambda x: x.priority)[:args.limit_views]
+ computes = [c for c in COMPUTES if c.name in ["classic", "view-only", "fixed-rrog", "rrog-act"]]
+ for task in tasks:
+ for view in views:
+ for compute in computes:
+ try:
+ find_run_spec(task.name, view.name, compute.name)
+ status = "implemented"
+ except KeyError:
+ status = "planned"
+ print(f"{task.name}\t{view.name}\t{compute.name}\tnone\t{status}")
+ return
+
+ if args.kind == "modifier":
+ task_names = {"zinc-cycle56", "ogbg-molhiv", "peptides-struct", "peptides-func", "ogbn-arxiv", "qm9"}
+ mods = [m for m in sorted(MODIFIERS, key=lambda x: x.priority) if m.name != "none"]
+ for task in tasks:
+ if task.name not in task_names:
+ continue
+ for mod in mods:
+ print(f"{task.name}\tgin\tview-only\t{mod.name}\tplanned")
+ return
+
+ raise ValueError(args.kind)
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ sub = ap.add_subparsers(dest="cmd", required=True)
+
+ lp = sub.add_parser("list")
+ lp.add_argument("axis", choices=["benchmarks", "views", "computes", "modifiers"])
+ lp.add_argument("--json", action="store_true")
+
+ rp = sub.add_parser("run")
+ rp.add_argument("--task", default="zinc-cycle56")
+ rp.add_argument("--view", default="gin")
+ rp.add_argument("--compute", default="rrog-act")
+ rp.add_argument("--modifier", default="none")
+ rp.add_argument("--epochs", type=int)
+ rp.add_argument("--hidden", type=int)
+ rp.add_argument("--bs", type=int)
+ rp.add_argument("--seed", type=int)
+ rp.add_argument("--T", type=int)
+ rp.add_argument("--n_sup", type=int)
+ rp.add_argument("--halt_max_steps", type=int)
+ rp.add_argument("--halt_target", choices=["soft", "binary", "exact", "loss"])
+ rp.add_argument("--halt_min_steps", type=int)
+ rp.add_argument("--halt_loss_threshold", type=float)
+ rp.add_argument("--q_warmup_epochs", type=int)
+ rp.add_argument("--eval_every", type=int)
+ rp.add_argument("--max_train_batches", type=int)
+ rp.add_argument("--max_eval_batches", type=int)
+ rp.add_argument("--num_workers", type=int)
+ rp.add_argument("--ema", type=float)
+ rp.add_argument("--lr", type=float)
+ rp.add_argument("--lam_q", type=float)
+ rp.add_argument("--device")
+ rp.add_argument("--dry-run", action="store_true")
+
+ mp = sub.add_parser("matrix")
+ mp.add_argument("--kind", choices=["main", "modifier"], default="main")
+ mp.add_argument("--tier", choices=["A", "B", "C", "all"], default="A")
+ mp.add_argument("--limit-tasks", type=int, default=20)
+ mp.add_argument("--limit-views", type=int, default=6)
+
+ op = sub.add_parser("results")
+ op.add_argument("--runs-dir", default="runs")
+ op.add_argument("--glob", default="*.json")
+ op.add_argument("--min-epochs", type=int, default=10)
+ op.add_argument("--epochs", type=int)
+ op.add_argument("--digits", type=int, default=4)
+
+ zp = sub.add_parser("zinc-results")
+ zp.add_argument("--runs-dir", default="runs")
+ zp.add_argument("--epochs", type=int)
+ zp.add_argument("--min-epochs", type=int, default=10)
+ zp.add_argument("--digits", type=int, default=4)
+
+ args = ap.parse_args()
+ if args.cmd == "list":
+ list_axis(args.axis, args.json)
+ return
+ if args.cmd == "matrix":
+ print_matrix(args)
+ return
+ if args.cmd == "results":
+ print_tables(args)
+ return
+ if args.cmd == "zinc-results":
+ print_zinc(args)
+ return
+
+ cmd = build_command(args)
+ env = dict(os.environ)
+ env["PYTHONPATH"] = os.getcwd() + os.pathsep + env.get("PYTHONPATH", "")
+ print(" ".join(cmd), flush=True)
+ if not args.dry_run:
+ raise SystemExit(subprocess.call(cmd, env=env))
+
+
+if __name__ == "__main__":
+ try:
+ main()
+ except BrokenPipeError:
+ raise SystemExit(0)