summaryrefslogtreecommitdiff
path: root/rrog/collect_results.py
diff options
context:
space:
mode:
Diffstat (limited to 'rrog/collect_results.py')
-rw-r--r--rrog/collect_results.py239
1 files changed, 239 insertions, 0 deletions
diff --git a/rrog/collect_results.py b/rrog/collect_results.py
new file mode 100644
index 0000000..125c0ab
--- /dev/null
+++ b/rrog/collect_results.py
@@ -0,0 +1,239 @@
+import argparse
+import json
+import math
+from collections import defaultdict
+from pathlib import Path
+
+
+HIGHER_BETTER = {
+ "accuracy",
+ "ap",
+ "auc",
+ "f1",
+ "mrr",
+ "rocauc",
+}
+
+LOWER_BETTER = {
+ "mae",
+ "raw-mae",
+ "rmse",
+}
+
+
+def _metric_direction(metric: str) -> int:
+ metric = metric.lower()
+ if metric in HIGHER_BETTER:
+ return 1
+ if metric in LOWER_BETTER:
+ return -1
+ if "mae" in metric or "rmse" in metric:
+ return -1
+ return 1
+
+
+def _read(path: Path) -> dict | None:
+ try:
+ with path.open() as f:
+ rep = json.load(f)
+ except (OSError, json.JSONDecodeError):
+ return None
+ required = {"dataset", "view", "compute", "seed", "metric", "val"}
+ if not required.issubset(rep):
+ return None
+ return rep
+
+
+def _score(rep: dict, split: str) -> float | None:
+ metric = rep.get("metric")
+ if not metric:
+ return None
+ value = rep.get(split, {}).get(metric)
+ if value is None:
+ return None
+ return float(value)
+
+
+def _rank_key(rep: dict) -> tuple[int, int, int]:
+ return (
+ int(rep.get("epochs", 0)),
+ int(rep.get("hidden", 0)),
+ int(rep.get("ep", 0) or 0),
+ )
+
+
+def _mean(xs: list[float]) -> float:
+ return sum(xs) / len(xs)
+
+
+def _std(xs: list[float]) -> float:
+ if len(xs) < 2:
+ return 0.0
+ mu = _mean(xs)
+ return math.sqrt(sum((x - mu) ** 2 for x in xs) / (len(xs) - 1))
+
+
+def _fmt(value: float | None, digits: int) -> str:
+ if value is None:
+ return ""
+ return f"{value:.{digits}f}"
+
+
+def _is_classic_baseline(rep: dict) -> bool:
+ return rep.get("compute") == "classic" and int(rep.get("T", -1)) == 0 and int(rep.get("n_sup", -1)) == 1
+
+
+def _compute_label(rep: dict) -> str:
+ label = str(rep["compute"])
+ ema = float(rep.get("ema", 0.0) or 0.0)
+ if ema > 0:
+ label += f"+ema{ema:g}"
+ return label
+
+
+def _choose_runs(paths: list[Path], min_epochs: int, epochs: int | None) -> list[dict]:
+ candidates: dict[tuple[str, str, str, int], dict] = {}
+ for path in paths:
+ rep = _read(path)
+ if rep is None:
+ continue
+ if epochs is not None and int(rep.get("epochs", -1)) != epochs:
+ continue
+ if int(rep.get("epochs", 0)) < min_epochs:
+ continue
+ val = _score(rep, "val")
+ test = _score(rep, "test")
+ if val is None or test is None:
+ continue
+ key = (str(rep["dataset"]), str(rep["view"]), _compute_label(rep), int(rep["seed"]))
+ old = candidates.get(key)
+ if old is None or _rank_key(rep) > _rank_key(old):
+ candidates[key] = rep
+ return list(candidates.values())
+
+
+def _group_by_cell(runs: list[dict]) -> dict[tuple[str, str, str], list[dict]]:
+ grouped: dict[tuple[str, str, str], list[dict]] = defaultdict(list)
+ for rep in runs:
+ grouped[(str(rep["dataset"]), str(rep["view"]), _compute_label(rep))].append(rep)
+ return dict(grouped)
+
+
+def _summarize_cell(reps: list[dict], split: str) -> tuple[float, float, int]:
+ scores = [_score(rep, split) for rep in reps]
+ xs = [x for x in scores if x is not None]
+ return _mean(xs), _std(xs), len(xs)
+
+
+def _markdown_table(headers: list[str], rows: list[list[str]]) -> str:
+ out = ["| " + " | ".join(headers) + " |", "| " + " | ".join(["---"] * len(headers)) + " |"]
+ out.extend("| " + " | ".join(row) + " |" for row in rows)
+ return "\n".join(out)
+
+
+def print_tables(args) -> None:
+ paths = sorted(Path(args.runs_dir).glob(args.glob))
+ runs = _choose_runs(paths, args.min_epochs, args.epochs)
+ grouped = _group_by_cell(runs)
+
+ classic_cells = {
+ (task, view): reps
+ for (task, view, compute), reps in grouped.items()
+ if compute == "classic"
+ for reps in [[rep for rep in reps if _is_classic_baseline(rep)]]
+ if reps
+ }
+
+ baseline_rows = []
+ for (task, view), reps in sorted(classic_cells.items()):
+ metric = str(reps[0]["metric"])
+ val_mu, val_sd, n = _summarize_cell(reps, "val")
+ test_mu, test_sd, _ = _summarize_cell(reps, "test")
+ baseline_rows.append([
+ task,
+ view,
+ metric,
+ str(n),
+ f"{_fmt(val_mu, args.digits)} +/- {_fmt(val_sd, args.digits)}",
+ f"{_fmt(test_mu, args.digits)} +/- {_fmt(test_sd, args.digits)}",
+ ])
+
+ delta_rows = []
+ for (task, view, compute), reps in sorted(grouped.items()):
+ if compute == "classic":
+ continue
+ base = classic_cells.get((task, view))
+ if not base:
+ continue
+ metric = str(reps[0]["metric"])
+ direction = _metric_direction(metric)
+ base_by_seed = {int(rep["seed"]): rep for rep in base}
+ paired = []
+ for rep in reps:
+ seed = int(rep["seed"])
+ if seed in base_by_seed:
+ paired.append((rep, base_by_seed[seed]))
+ if not paired:
+ base_test_mu, _, _ = _summarize_cell(base, "test")
+ base_val_mu, _, _ = _summarize_cell(base, "val")
+ paired = [(rep, None) for rep in reps]
+ else:
+ base_test_mu = None
+ base_val_mu = None
+
+ val_scores, test_scores, val_deltas, test_deltas = [], [], [], []
+ adaptive_steps = []
+ for rep, base_rep in paired:
+ val = _score(rep, "val")
+ test = _score(rep, "test")
+ if val is None or test is None:
+ continue
+ if base_rep is None:
+ base_val = base_val_mu
+ base_test = base_test_mu
+ else:
+ base_val = _score(base_rep, "val")
+ base_test = _score(base_rep, "test")
+ if base_val is None or base_test is None:
+ continue
+ val_scores.append(val)
+ test_scores.append(test)
+ val_deltas.append(direction * (val - base_val))
+ test_deltas.append(direction * (test - base_test))
+ if rep.get("adaptive_steps") is not None:
+ adaptive_steps.append(float(rep["adaptive_steps"]))
+
+ if not test_scores:
+ continue
+ delta_rows.append([
+ task,
+ view,
+ compute,
+ metric,
+ str(len(test_scores)),
+ f"{_fmt(_mean(val_scores), args.digits)} ({_fmt(_mean(val_deltas), args.digits)})",
+ f"{_fmt(_mean(test_scores), args.digits)} ({_fmt(_mean(test_deltas), args.digits)})",
+ _fmt(_mean(adaptive_steps), 2) if adaptive_steps else "",
+ ])
+
+ print("\nClassic baseline: task x backbone")
+ print(_markdown_table(["task", "backbone", "metric", "n", "val", "test"], baseline_rows))
+ print("\nDelta vs matching classic")
+ print(_markdown_table([
+ "task", "backbone", "compute", "metric", "n", "val score (delta)", "test score (delta)", "steps"
+ ], delta_rows))
+
+
+def main() -> None:
+ ap = argparse.ArgumentParser()
+ ap.add_argument("--runs-dir", default="runs")
+ ap.add_argument("--glob", default="*.json")
+ ap.add_argument("--min-epochs", type=int, default=10)
+ ap.add_argument("--epochs", type=int)
+ ap.add_argument("--digits", type=int, default=4)
+ args = ap.parse_args()
+ print_tables(args)
+
+
+if __name__ == "__main__":
+ main()