import argparse import json import math from collections import defaultdict from pathlib import Path def _read(path: Path) -> dict | None: try: with path.open() as f: rep = json.load(f) except (OSError, json.JSONDecodeError): return None if rep.get("dataset") != "ZINC-cycle56": return None if rep.get("K") != 1 or rep.get("select") != "none" or float(rep.get("sigma", 0.0)) != 0.0: return None if "test_mae" not in rep or "val_mae" not in rep: return None return rep def _compute(rep: dict) -> str: if rep.get("act"): return f"rrog-act-T{rep.get('T')}-ns{rep.get('n_sup')}" if int(rep.get("T", -1)) == 0 and int(rep.get("n_sup", -1)) == 1: return "classic" label = f"fixed-rrog-T{rep.get('T')}-ns{rep.get('n_sup')}" if rep.get("loss_mode") == "trace": label += "+trace" return label def _view(rep: dict) -> str: return str(rep.get("view", "gin")) def _score(rep: dict, split: str) -> float: return float(sum(rep[f"{split}_mae"])) def _mean(xs: list[float]) -> float: return sum(xs) / len(xs) def _std(xs: list[float]) -> float: if len(xs) < 2: return 0.0 mu = _mean(xs) return math.sqrt(sum((x - mu) ** 2 for x in xs) / (len(xs) - 1)) def _fmt(x: float, digits: int) -> str: return f"{x:.{digits}f}" def _markdown(headers: list[str], rows: list[list[str]]) -> str: lines = ["| " + " | ".join(headers) + " |", "| " + " | ".join(["---"] * len(headers)) + " |"] lines.extend("| " + " | ".join(row) + " |" for row in rows) return "\n".join(lines) def print_zinc(args) -> None: by_cell: dict[tuple[str, str, int], dict] = {} for path in sorted(Path(args.runs_dir).glob("rec_rrog*_sig0.0_K1_none_T*_s*.json")): rep = _read(path) if rep is None: continue if args.epochs is not None and int(rep.get("epochs", -1)) != args.epochs: continue if int(rep.get("epochs", 0)) < args.min_epochs: continue key = (_view(rep), _compute(rep), int(rep["seed"])) old = by_cell.get(key) if old is None or int(rep.get("epochs", 0)) > int(old.get("epochs", 0)): by_cell[key] = rep grouped: dict[tuple[str, str], list[dict]] = defaultdict(list) for rep in by_cell.values(): grouped[(_view(rep), _compute(rep))].append(rep) classic_by_view = { view: reps for (view, compute), reps in grouped.items() if compute == "classic" } base_rows = [] for view, classic in sorted(classic_by_view.items()): vals = [_score(rep, "val") for rep in classic] tests = [_score(rep, "test") for rep in classic] base_rows.append([ "zinc-cycle56", view, str(len(classic)), f"{_fmt(_mean(vals), args.digits)} +/- {_fmt(_std(vals), args.digits)}", f"{_fmt(_mean(tests), args.digits)} +/- {_fmt(_std(tests), args.digits)}", ]) delta_rows = [] for (view, compute), reps in sorted(grouped.items()): if compute == "classic": continue base_by_seed = {int(rep["seed"]): rep for rep in classic_by_view.get(view, [])} paired = [(rep, base_by_seed[int(rep["seed"])]) for rep in reps if int(rep["seed"]) in base_by_seed] if not paired: continue vals = [_score(rep, "val") for rep, _ in paired] tests = [_score(rep, "test") for rep, _ in paired] val_deltas = [_score(base, "val") - _score(rep, "val") for rep, base in paired] test_deltas = [_score(base, "test") - _score(rep, "test") for rep, base in paired] delta_rows.append([ "zinc-cycle56", view, compute, str(len(paired)), f"{_fmt(_mean(vals), args.digits)} ({_fmt(_mean(val_deltas), args.digits)})", f"{_fmt(_mean(tests), args.digits)} ({_fmt(_mean(test_deltas), args.digits)})", ]) print("\nZINC-cycle56 classic baseline") print(_markdown(["task", "backbone", "n", "val MAE-sum", "test MAE-sum"], base_rows)) print("\nZINC-cycle56 delta vs matching classic") print(_markdown(["task", "backbone", "compute", "n", "val score (improvement)", "test score (improvement)"], delta_rows)) def main() -> None: ap = argparse.ArgumentParser() ap.add_argument("--runs-dir", default="runs") ap.add_argument("--epochs", type=int) ap.add_argument("--min-epochs", type=int, default=10) ap.add_argument("--digits", type=int, default=4) print_zinc(ap.parse_args()) if __name__ == "__main__": main()