summaryrefslogtreecommitdiff
path: root/rrog/collect_zinc.py
diff options
context:
space:
mode:
Diffstat (limited to 'rrog/collect_zinc.py')
-rw-r--r--rrog/collect_zinc.py137
1 files changed, 137 insertions, 0 deletions
diff --git a/rrog/collect_zinc.py b/rrog/collect_zinc.py
new file mode 100644
index 0000000..49ae1e9
--- /dev/null
+++ b/rrog/collect_zinc.py
@@ -0,0 +1,137 @@
+import argparse
+import json
+import math
+from collections import defaultdict
+from pathlib import Path
+
+
+def _read(path: Path) -> dict | None:
+ try:
+ with path.open() as f:
+ rep = json.load(f)
+ except (OSError, json.JSONDecodeError):
+ return None
+ if rep.get("dataset") != "ZINC-cycle56":
+ return None
+ if rep.get("K") != 1 or rep.get("select") != "none" or float(rep.get("sigma", 0.0)) != 0.0:
+ return None
+ if "test_mae" not in rep or "val_mae" not in rep:
+ return None
+ return rep
+
+
+def _compute(rep: dict) -> str:
+ if rep.get("act"):
+ return f"rrog-act-T{rep.get('T')}-ns{rep.get('n_sup')}"
+ if int(rep.get("T", -1)) == 0 and int(rep.get("n_sup", -1)) == 1:
+ return "classic"
+ label = f"fixed-rrog-T{rep.get('T')}-ns{rep.get('n_sup')}"
+ if rep.get("loss_mode") == "trace":
+ label += "+trace"
+ return label
+
+
+def _view(rep: dict) -> str:
+ return str(rep.get("view", "gin"))
+
+
+def _score(rep: dict, split: str) -> float:
+ return float(sum(rep[f"{split}_mae"]))
+
+
+def _mean(xs: list[float]) -> float:
+ return sum(xs) / len(xs)
+
+
+def _std(xs: list[float]) -> float:
+ if len(xs) < 2:
+ return 0.0
+ mu = _mean(xs)
+ return math.sqrt(sum((x - mu) ** 2 for x in xs) / (len(xs) - 1))
+
+
+def _fmt(x: float, digits: int) -> str:
+ return f"{x:.{digits}f}"
+
+
+def _markdown(headers: list[str], rows: list[list[str]]) -> str:
+ lines = ["| " + " | ".join(headers) + " |", "| " + " | ".join(["---"] * len(headers)) + " |"]
+ lines.extend("| " + " | ".join(row) + " |" for row in rows)
+ return "\n".join(lines)
+
+
+def print_zinc(args) -> None:
+ by_cell: dict[tuple[str, str, int], dict] = {}
+ for path in sorted(Path(args.runs_dir).glob("rec_rrog*_sig0.0_K1_none_T*_s*.json")):
+ rep = _read(path)
+ if rep is None:
+ continue
+ if args.epochs is not None and int(rep.get("epochs", -1)) != args.epochs:
+ continue
+ if int(rep.get("epochs", 0)) < args.min_epochs:
+ continue
+ key = (_view(rep), _compute(rep), int(rep["seed"]))
+ old = by_cell.get(key)
+ if old is None or int(rep.get("epochs", 0)) > int(old.get("epochs", 0)):
+ by_cell[key] = rep
+
+ grouped: dict[tuple[str, str], list[dict]] = defaultdict(list)
+ for rep in by_cell.values():
+ grouped[(_view(rep), _compute(rep))].append(rep)
+
+ classic_by_view = {
+ view: reps
+ for (view, compute), reps in grouped.items()
+ if compute == "classic"
+ }
+
+ base_rows = []
+ for view, classic in sorted(classic_by_view.items()):
+ vals = [_score(rep, "val") for rep in classic]
+ tests = [_score(rep, "test") for rep in classic]
+ base_rows.append([
+ "zinc-cycle56",
+ view,
+ str(len(classic)),
+ f"{_fmt(_mean(vals), args.digits)} +/- {_fmt(_std(vals), args.digits)}",
+ f"{_fmt(_mean(tests), args.digits)} +/- {_fmt(_std(tests), args.digits)}",
+ ])
+
+ delta_rows = []
+ for (view, compute), reps in sorted(grouped.items()):
+ if compute == "classic":
+ continue
+ base_by_seed = {int(rep["seed"]): rep for rep in classic_by_view.get(view, [])}
+ paired = [(rep, base_by_seed[int(rep["seed"])]) for rep in reps if int(rep["seed"]) in base_by_seed]
+ if not paired:
+ continue
+ vals = [_score(rep, "val") for rep, _ in paired]
+ tests = [_score(rep, "test") for rep, _ in paired]
+ val_deltas = [_score(base, "val") - _score(rep, "val") for rep, base in paired]
+ test_deltas = [_score(base, "test") - _score(rep, "test") for rep, base in paired]
+ delta_rows.append([
+ "zinc-cycle56",
+ view,
+ compute,
+ str(len(paired)),
+ f"{_fmt(_mean(vals), args.digits)} ({_fmt(_mean(val_deltas), args.digits)})",
+ f"{_fmt(_mean(tests), args.digits)} ({_fmt(_mean(test_deltas), args.digits)})",
+ ])
+
+ print("\nZINC-cycle56 classic baseline")
+ print(_markdown(["task", "backbone", "n", "val MAE-sum", "test MAE-sum"], base_rows))
+ print("\nZINC-cycle56 delta vs matching classic")
+ print(_markdown(["task", "backbone", "compute", "n", "val score (improvement)", "test score (improvement)"], delta_rows))
+
+
+def main() -> None:
+ ap = argparse.ArgumentParser()
+ ap.add_argument("--runs-dir", default="runs")
+ ap.add_argument("--epochs", type=int)
+ ap.add_argument("--min-epochs", type=int, default=10)
+ ap.add_argument("--digits", type=int, default=4)
+ print_zinc(ap.parse_args())
+
+
+if __name__ == "__main__":
+ main()