import argparse import json import os import subprocess from rrog.backbones import COMPUTES, MODIFIERS, VIEWS from rrog.benchmarks import BENCHMARKS from rrog.collect_results import print_tables from rrog.collect_zinc import print_zinc from rrog.runspecs import find_run_spec def _rows(items): return [ { "name": item.name, "tier": getattr(item, "tier", None), "family": getattr(item, "family", None), "domain": getattr(item, "domain", None), "task_type": getattr(item, "task_type", None), "metric": getattr(item, "metric", None), "priority": item.priority, "status": item.status, "notes": item.notes, } for item in items ] def _print_table(items): for row in _rows(items): cols = [row["name"], row.get("tier") or row.get("family") or "", row["status"], row["notes"]] print("\t".join(str(c) for c in cols)) def list_axis(axis: str, as_json: bool): mapping = { "benchmarks": BENCHMARKS, "views": VIEWS, "computes": COMPUTES, "modifiers": MODIFIERS, } items = mapping[axis] if as_json: print(json.dumps(_rows(items), indent=2)) else: _print_table(items) def build_command(args) -> list[str]: spec = find_run_spec(args.task, args.view, args.compute, args.modifier) run_args = dict(spec.default_args) for key in [ "epochs", "hidden", "bs", "seed", "T", "n_sup", "halt_max_steps", "halt_target", "halt_min_steps", "halt_loss_threshold", "q_warmup_epochs", "eval_every", "max_train_batches", "max_eval_batches", "num_workers", "ema", "lr", "lam_q", "device", ]: value = getattr(args, key) if value is not None: run_args[key] = value run_args["compute"] = args.compute return spec.command_builder(run_args) def print_matrix(args): tasks = [b for b in BENCHMARKS if args.tier == "all" or b.tier == args.tier] tasks = sorted(tasks, key=lambda x: x.priority)[:args.limit_tasks] if args.kind == "main": views = sorted(VIEWS, key=lambda x: x.priority)[:args.limit_views] computes = [c for c in COMPUTES if c.name in ["classic", "view-only", "fixed-rrog", "rrog-act"]] for task in tasks: for view in views: for compute in computes: try: find_run_spec(task.name, view.name, compute.name) status = "implemented" except KeyError: status = "planned" print(f"{task.name}\t{view.name}\t{compute.name}\tnone\t{status}") return if args.kind == "modifier": task_names = {"zinc-cycle56", "ogbg-molhiv", "peptides-struct", "peptides-func", "ogbn-arxiv", "qm9"} mods = [m for m in sorted(MODIFIERS, key=lambda x: x.priority) if m.name != "none"] for task in tasks: if task.name not in task_names: continue for mod in mods: print(f"{task.name}\tgin\tview-only\t{mod.name}\tplanned") return raise ValueError(args.kind) def main(): ap = argparse.ArgumentParser() sub = ap.add_subparsers(dest="cmd", required=True) lp = sub.add_parser("list") lp.add_argument("axis", choices=["benchmarks", "views", "computes", "modifiers"]) lp.add_argument("--json", action="store_true") rp = sub.add_parser("run") rp.add_argument("--task", default="zinc-cycle56") rp.add_argument("--view", default="gin") rp.add_argument("--compute", default="rrog-act") rp.add_argument("--modifier", default="none") rp.add_argument("--epochs", type=int) rp.add_argument("--hidden", type=int) rp.add_argument("--bs", type=int) rp.add_argument("--seed", type=int) rp.add_argument("--T", type=int) rp.add_argument("--n_sup", type=int) rp.add_argument("--halt_max_steps", type=int) rp.add_argument("--halt_target", choices=["soft", "binary", "exact", "loss"]) rp.add_argument("--halt_min_steps", type=int) rp.add_argument("--halt_loss_threshold", type=float) rp.add_argument("--q_warmup_epochs", type=int) rp.add_argument("--eval_every", type=int) rp.add_argument("--max_train_batches", type=int) rp.add_argument("--max_eval_batches", type=int) rp.add_argument("--num_workers", type=int) rp.add_argument("--ema", type=float) rp.add_argument("--lr", type=float) rp.add_argument("--lam_q", type=float) rp.add_argument("--device") rp.add_argument("--dry-run", action="store_true") mp = sub.add_parser("matrix") mp.add_argument("--kind", choices=["main", "modifier"], default="main") mp.add_argument("--tier", choices=["A", "B", "C", "all"], default="A") mp.add_argument("--limit-tasks", type=int, default=20) mp.add_argument("--limit-views", type=int, default=6) op = sub.add_parser("results") op.add_argument("--runs-dir", default="runs") op.add_argument("--glob", default="*.json") op.add_argument("--min-epochs", type=int, default=10) op.add_argument("--epochs", type=int) op.add_argument("--digits", type=int, default=4) zp = sub.add_parser("zinc-results") zp.add_argument("--runs-dir", default="runs") zp.add_argument("--epochs", type=int) zp.add_argument("--min-epochs", type=int, default=10) zp.add_argument("--digits", type=int, default=4) args = ap.parse_args() if args.cmd == "list": list_axis(args.axis, args.json) return if args.cmd == "matrix": print_matrix(args) return if args.cmd == "results": print_tables(args) return if args.cmd == "zinc-results": print_zinc(args) return cmd = build_command(args) env = dict(os.environ) env["PYTHONPATH"] = os.getcwd() + os.pathsep + env.get("PYTHONPATH", "") print(" ".join(cmd), flush=True) if not args.dry_run: raise SystemExit(subprocess.call(cmd, env=env)) if __name__ == "__main__": try: main() except BrokenPipeError: raise SystemExit(0)