diff options
Diffstat (limited to 'rrog/collect_zinc.py')
| -rw-r--r-- | rrog/collect_zinc.py | 137 |
1 files changed, 137 insertions, 0 deletions
diff --git a/rrog/collect_zinc.py b/rrog/collect_zinc.py new file mode 100644 index 0000000..49ae1e9 --- /dev/null +++ b/rrog/collect_zinc.py @@ -0,0 +1,137 @@ +import argparse +import json +import math +from collections import defaultdict +from pathlib import Path + + +def _read(path: Path) -> dict | None: + try: + with path.open() as f: + rep = json.load(f) + except (OSError, json.JSONDecodeError): + return None + if rep.get("dataset") != "ZINC-cycle56": + return None + if rep.get("K") != 1 or rep.get("select") != "none" or float(rep.get("sigma", 0.0)) != 0.0: + return None + if "test_mae" not in rep or "val_mae" not in rep: + return None + return rep + + +def _compute(rep: dict) -> str: + if rep.get("act"): + return f"rrog-act-T{rep.get('T')}-ns{rep.get('n_sup')}" + if int(rep.get("T", -1)) == 0 and int(rep.get("n_sup", -1)) == 1: + return "classic" + label = f"fixed-rrog-T{rep.get('T')}-ns{rep.get('n_sup')}" + if rep.get("loss_mode") == "trace": + label += "+trace" + return label + + +def _view(rep: dict) -> str: + return str(rep.get("view", "gin")) + + +def _score(rep: dict, split: str) -> float: + return float(sum(rep[f"{split}_mae"])) + + +def _mean(xs: list[float]) -> float: + return sum(xs) / len(xs) + + +def _std(xs: list[float]) -> float: + if len(xs) < 2: + return 0.0 + mu = _mean(xs) + return math.sqrt(sum((x - mu) ** 2 for x in xs) / (len(xs) - 1)) + + +def _fmt(x: float, digits: int) -> str: + return f"{x:.{digits}f}" + + +def _markdown(headers: list[str], rows: list[list[str]]) -> str: + lines = ["| " + " | ".join(headers) + " |", "| " + " | ".join(["---"] * len(headers)) + " |"] + lines.extend("| " + " | ".join(row) + " |" for row in rows) + return "\n".join(lines) + + +def print_zinc(args) -> None: + by_cell: dict[tuple[str, str, int], dict] = {} + for path in sorted(Path(args.runs_dir).glob("rec_rrog*_sig0.0_K1_none_T*_s*.json")): + rep = _read(path) + if rep is None: + continue + if args.epochs is not None and int(rep.get("epochs", -1)) != args.epochs: + continue + if int(rep.get("epochs", 0)) < args.min_epochs: + continue + key = (_view(rep), _compute(rep), int(rep["seed"])) + old = by_cell.get(key) + if old is None or int(rep.get("epochs", 0)) > int(old.get("epochs", 0)): + by_cell[key] = rep + + grouped: dict[tuple[str, str], list[dict]] = defaultdict(list) + for rep in by_cell.values(): + grouped[(_view(rep), _compute(rep))].append(rep) + + classic_by_view = { + view: reps + for (view, compute), reps in grouped.items() + if compute == "classic" + } + + base_rows = [] + for view, classic in sorted(classic_by_view.items()): + vals = [_score(rep, "val") for rep in classic] + tests = [_score(rep, "test") for rep in classic] + base_rows.append([ + "zinc-cycle56", + view, + str(len(classic)), + f"{_fmt(_mean(vals), args.digits)} +/- {_fmt(_std(vals), args.digits)}", + f"{_fmt(_mean(tests), args.digits)} +/- {_fmt(_std(tests), args.digits)}", + ]) + + delta_rows = [] + for (view, compute), reps in sorted(grouped.items()): + if compute == "classic": + continue + base_by_seed = {int(rep["seed"]): rep for rep in classic_by_view.get(view, [])} + paired = [(rep, base_by_seed[int(rep["seed"])]) for rep in reps if int(rep["seed"]) in base_by_seed] + if not paired: + continue + vals = [_score(rep, "val") for rep, _ in paired] + tests = [_score(rep, "test") for rep, _ in paired] + val_deltas = [_score(base, "val") - _score(rep, "val") for rep, base in paired] + test_deltas = [_score(base, "test") - _score(rep, "test") for rep, base in paired] + delta_rows.append([ + "zinc-cycle56", + view, + compute, + str(len(paired)), + f"{_fmt(_mean(vals), args.digits)} ({_fmt(_mean(val_deltas), args.digits)})", + f"{_fmt(_mean(tests), args.digits)} ({_fmt(_mean(test_deltas), args.digits)})", + ]) + + print("\nZINC-cycle56 classic baseline") + print(_markdown(["task", "backbone", "n", "val MAE-sum", "test MAE-sum"], base_rows)) + print("\nZINC-cycle56 delta vs matching classic") + print(_markdown(["task", "backbone", "compute", "n", "val score (improvement)", "test score (improvement)"], delta_rows)) + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--runs-dir", default="runs") + ap.add_argument("--epochs", type=int) + ap.add_argument("--min-epochs", type=int, default=10) + ap.add_argument("--digits", type=int, default=4) + print_zinc(ap.parse_args()) + + +if __name__ == "__main__": + main() |
