summaryrefslogtreecommitdiff
path: root/rrog/collect_zinc.py
blob: 49ae1e97bdee8b0bf8c8bfbd8cc91a0de388f10f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import argparse
import json
import math
from collections import defaultdict
from pathlib import Path


def _read(path: Path) -> dict | None:
    try:
        with path.open() as f:
            rep = json.load(f)
    except (OSError, json.JSONDecodeError):
        return None
    if rep.get("dataset") != "ZINC-cycle56":
        return None
    if rep.get("K") != 1 or rep.get("select") != "none" or float(rep.get("sigma", 0.0)) != 0.0:
        return None
    if "test_mae" not in rep or "val_mae" not in rep:
        return None
    return rep


def _compute(rep: dict) -> str:
    if rep.get("act"):
        return f"rrog-act-T{rep.get('T')}-ns{rep.get('n_sup')}"
    if int(rep.get("T", -1)) == 0 and int(rep.get("n_sup", -1)) == 1:
        return "classic"
    label = f"fixed-rrog-T{rep.get('T')}-ns{rep.get('n_sup')}"
    if rep.get("loss_mode") == "trace":
        label += "+trace"
    return label


def _view(rep: dict) -> str:
    return str(rep.get("view", "gin"))


def _score(rep: dict, split: str) -> float:
    return float(sum(rep[f"{split}_mae"]))


def _mean(xs: list[float]) -> float:
    return sum(xs) / len(xs)


def _std(xs: list[float]) -> float:
    if len(xs) < 2:
        return 0.0
    mu = _mean(xs)
    return math.sqrt(sum((x - mu) ** 2 for x in xs) / (len(xs) - 1))


def _fmt(x: float, digits: int) -> str:
    return f"{x:.{digits}f}"


def _markdown(headers: list[str], rows: list[list[str]]) -> str:
    lines = ["| " + " | ".join(headers) + " |", "| " + " | ".join(["---"] * len(headers)) + " |"]
    lines.extend("| " + " | ".join(row) + " |" for row in rows)
    return "\n".join(lines)


def print_zinc(args) -> None:
    by_cell: dict[tuple[str, str, int], dict] = {}
    for path in sorted(Path(args.runs_dir).glob("rec_rrog*_sig0.0_K1_none_T*_s*.json")):
        rep = _read(path)
        if rep is None:
            continue
        if args.epochs is not None and int(rep.get("epochs", -1)) != args.epochs:
            continue
        if int(rep.get("epochs", 0)) < args.min_epochs:
            continue
        key = (_view(rep), _compute(rep), int(rep["seed"]))
        old = by_cell.get(key)
        if old is None or int(rep.get("epochs", 0)) > int(old.get("epochs", 0)):
            by_cell[key] = rep

    grouped: dict[tuple[str, str], list[dict]] = defaultdict(list)
    for rep in by_cell.values():
        grouped[(_view(rep), _compute(rep))].append(rep)

    classic_by_view = {
        view: reps
        for (view, compute), reps in grouped.items()
        if compute == "classic"
    }

    base_rows = []
    for view, classic in sorted(classic_by_view.items()):
        vals = [_score(rep, "val") for rep in classic]
        tests = [_score(rep, "test") for rep in classic]
        base_rows.append([
            "zinc-cycle56",
            view,
            str(len(classic)),
            f"{_fmt(_mean(vals), args.digits)} +/- {_fmt(_std(vals), args.digits)}",
            f"{_fmt(_mean(tests), args.digits)} +/- {_fmt(_std(tests), args.digits)}",
        ])

    delta_rows = []
    for (view, compute), reps in sorted(grouped.items()):
        if compute == "classic":
            continue
        base_by_seed = {int(rep["seed"]): rep for rep in classic_by_view.get(view, [])}
        paired = [(rep, base_by_seed[int(rep["seed"])]) for rep in reps if int(rep["seed"]) in base_by_seed]
        if not paired:
            continue
        vals = [_score(rep, "val") for rep, _ in paired]
        tests = [_score(rep, "test") for rep, _ in paired]
        val_deltas = [_score(base, "val") - _score(rep, "val") for rep, base in paired]
        test_deltas = [_score(base, "test") - _score(rep, "test") for rep, base in paired]
        delta_rows.append([
            "zinc-cycle56",
            view,
            compute,
            str(len(paired)),
            f"{_fmt(_mean(vals), args.digits)} ({_fmt(_mean(val_deltas), args.digits)})",
            f"{_fmt(_mean(tests), args.digits)} ({_fmt(_mean(test_deltas), args.digits)})",
        ])

    print("\nZINC-cycle56 classic baseline")
    print(_markdown(["task", "backbone", "n", "val MAE-sum", "test MAE-sum"], base_rows))
    print("\nZINC-cycle56 delta vs matching classic")
    print(_markdown(["task", "backbone", "compute", "n", "val score (improvement)", "test score (improvement)"], delta_rows))


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--runs-dir", default="runs")
    ap.add_argument("--epochs", type=int)
    ap.add_argument("--min-epochs", type=int, default=10)
    ap.add_argument("--digits", type=int, default=4)
    print_zinc(ap.parse_args())


if __name__ == "__main__":
    main()