diff options
Diffstat (limited to 'rrog/collect_results.py')
| -rw-r--r-- | rrog/collect_results.py | 239 |
1 files changed, 239 insertions, 0 deletions
diff --git a/rrog/collect_results.py b/rrog/collect_results.py new file mode 100644 index 0000000..125c0ab --- /dev/null +++ b/rrog/collect_results.py @@ -0,0 +1,239 @@ +import argparse +import json +import math +from collections import defaultdict +from pathlib import Path + + +HIGHER_BETTER = { + "accuracy", + "ap", + "auc", + "f1", + "mrr", + "rocauc", +} + +LOWER_BETTER = { + "mae", + "raw-mae", + "rmse", +} + + +def _metric_direction(metric: str) -> int: + metric = metric.lower() + if metric in HIGHER_BETTER: + return 1 + if metric in LOWER_BETTER: + return -1 + if "mae" in metric or "rmse" in metric: + return -1 + return 1 + + +def _read(path: Path) -> dict | None: + try: + with path.open() as f: + rep = json.load(f) + except (OSError, json.JSONDecodeError): + return None + required = {"dataset", "view", "compute", "seed", "metric", "val"} + if not required.issubset(rep): + return None + return rep + + +def _score(rep: dict, split: str) -> float | None: + metric = rep.get("metric") + if not metric: + return None + value = rep.get(split, {}).get(metric) + if value is None: + return None + return float(value) + + +def _rank_key(rep: dict) -> tuple[int, int, int]: + return ( + int(rep.get("epochs", 0)), + int(rep.get("hidden", 0)), + int(rep.get("ep", 0) or 0), + ) + + +def _mean(xs: list[float]) -> float: + return sum(xs) / len(xs) + + +def _std(xs: list[float]) -> float: + if len(xs) < 2: + return 0.0 + mu = _mean(xs) + return math.sqrt(sum((x - mu) ** 2 for x in xs) / (len(xs) - 1)) + + +def _fmt(value: float | None, digits: int) -> str: + if value is None: + return "" + return f"{value:.{digits}f}" + + +def _is_classic_baseline(rep: dict) -> bool: + return rep.get("compute") == "classic" and int(rep.get("T", -1)) == 0 and int(rep.get("n_sup", -1)) == 1 + + +def _compute_label(rep: dict) -> str: + label = str(rep["compute"]) + ema = float(rep.get("ema", 0.0) or 0.0) + if ema > 0: + label += f"+ema{ema:g}" + return label + + +def _choose_runs(paths: list[Path], min_epochs: int, epochs: int | None) -> list[dict]: + candidates: dict[tuple[str, str, str, int], dict] = {} + for path in paths: + rep = _read(path) + if rep is None: + continue + if epochs is not None and int(rep.get("epochs", -1)) != epochs: + continue + if int(rep.get("epochs", 0)) < min_epochs: + continue + val = _score(rep, "val") + test = _score(rep, "test") + if val is None or test is None: + continue + key = (str(rep["dataset"]), str(rep["view"]), _compute_label(rep), int(rep["seed"])) + old = candidates.get(key) + if old is None or _rank_key(rep) > _rank_key(old): + candidates[key] = rep + return list(candidates.values()) + + +def _group_by_cell(runs: list[dict]) -> dict[tuple[str, str, str], list[dict]]: + grouped: dict[tuple[str, str, str], list[dict]] = defaultdict(list) + for rep in runs: + grouped[(str(rep["dataset"]), str(rep["view"]), _compute_label(rep))].append(rep) + return dict(grouped) + + +def _summarize_cell(reps: list[dict], split: str) -> tuple[float, float, int]: + scores = [_score(rep, split) for rep in reps] + xs = [x for x in scores if x is not None] + return _mean(xs), _std(xs), len(xs) + + +def _markdown_table(headers: list[str], rows: list[list[str]]) -> str: + out = ["| " + " | ".join(headers) + " |", "| " + " | ".join(["---"] * len(headers)) + " |"] + out.extend("| " + " | ".join(row) + " |" for row in rows) + return "\n".join(out) + + +def print_tables(args) -> None: + paths = sorted(Path(args.runs_dir).glob(args.glob)) + runs = _choose_runs(paths, args.min_epochs, args.epochs) + grouped = _group_by_cell(runs) + + classic_cells = { + (task, view): reps + for (task, view, compute), reps in grouped.items() + if compute == "classic" + for reps in [[rep for rep in reps if _is_classic_baseline(rep)]] + if reps + } + + baseline_rows = [] + for (task, view), reps in sorted(classic_cells.items()): + metric = str(reps[0]["metric"]) + val_mu, val_sd, n = _summarize_cell(reps, "val") + test_mu, test_sd, _ = _summarize_cell(reps, "test") + baseline_rows.append([ + task, + view, + metric, + str(n), + f"{_fmt(val_mu, args.digits)} +/- {_fmt(val_sd, args.digits)}", + f"{_fmt(test_mu, args.digits)} +/- {_fmt(test_sd, args.digits)}", + ]) + + delta_rows = [] + for (task, view, compute), reps in sorted(grouped.items()): + if compute == "classic": + continue + base = classic_cells.get((task, view)) + if not base: + continue + metric = str(reps[0]["metric"]) + direction = _metric_direction(metric) + base_by_seed = {int(rep["seed"]): rep for rep in base} + paired = [] + for rep in reps: + seed = int(rep["seed"]) + if seed in base_by_seed: + paired.append((rep, base_by_seed[seed])) + if not paired: + base_test_mu, _, _ = _summarize_cell(base, "test") + base_val_mu, _, _ = _summarize_cell(base, "val") + paired = [(rep, None) for rep in reps] + else: + base_test_mu = None + base_val_mu = None + + val_scores, test_scores, val_deltas, test_deltas = [], [], [], [] + adaptive_steps = [] + for rep, base_rep in paired: + val = _score(rep, "val") + test = _score(rep, "test") + if val is None or test is None: + continue + if base_rep is None: + base_val = base_val_mu + base_test = base_test_mu + else: + base_val = _score(base_rep, "val") + base_test = _score(base_rep, "test") + if base_val is None or base_test is None: + continue + val_scores.append(val) + test_scores.append(test) + val_deltas.append(direction * (val - base_val)) + test_deltas.append(direction * (test - base_test)) + if rep.get("adaptive_steps") is not None: + adaptive_steps.append(float(rep["adaptive_steps"])) + + if not test_scores: + continue + delta_rows.append([ + task, + view, + compute, + metric, + str(len(test_scores)), + f"{_fmt(_mean(val_scores), args.digits)} ({_fmt(_mean(val_deltas), args.digits)})", + f"{_fmt(_mean(test_scores), args.digits)} ({_fmt(_mean(test_deltas), args.digits)})", + _fmt(_mean(adaptive_steps), 2) if adaptive_steps else "", + ]) + + print("\nClassic baseline: task x backbone") + print(_markdown_table(["task", "backbone", "metric", "n", "val", "test"], baseline_rows)) + print("\nDelta vs matching classic") + print(_markdown_table([ + "task", "backbone", "compute", "metric", "n", "val score (delta)", "test score (delta)", "steps" + ], delta_rows)) + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--runs-dir", default="runs") + ap.add_argument("--glob", default="*.json") + ap.add_argument("--min-epochs", type=int, default=10) + ap.add_argument("--epochs", type=int) + ap.add_argument("--digits", type=int, default=4) + args = ap.parse_args() + print_tables(args) + + +if __name__ == "__main__": + main() |
