1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
|
import argparse
import json
import math
from collections import defaultdict
from pathlib import Path
def _read(path: Path) -> dict | None:
try:
with path.open() as f:
rep = json.load(f)
except (OSError, json.JSONDecodeError):
return None
if rep.get("dataset") != "ZINC-cycle56":
return None
if rep.get("K") != 1 or rep.get("select") != "none" or float(rep.get("sigma", 0.0)) != 0.0:
return None
if "test_mae" not in rep or "val_mae" not in rep:
return None
return rep
def _compute(rep: dict) -> str:
if rep.get("act"):
return f"rrog-act-T{rep.get('T')}-ns{rep.get('n_sup')}"
if int(rep.get("T", -1)) == 0 and int(rep.get("n_sup", -1)) == 1:
return "classic"
label = f"fixed-rrog-T{rep.get('T')}-ns{rep.get('n_sup')}"
if rep.get("loss_mode") == "trace":
label += "+trace"
return label
def _view(rep: dict) -> str:
return str(rep.get("view", "gin"))
def _score(rep: dict, split: str) -> float:
return float(sum(rep[f"{split}_mae"]))
def _mean(xs: list[float]) -> float:
return sum(xs) / len(xs)
def _std(xs: list[float]) -> float:
if len(xs) < 2:
return 0.0
mu = _mean(xs)
return math.sqrt(sum((x - mu) ** 2 for x in xs) / (len(xs) - 1))
def _fmt(x: float, digits: int) -> str:
return f"{x:.{digits}f}"
def _markdown(headers: list[str], rows: list[list[str]]) -> str:
lines = ["| " + " | ".join(headers) + " |", "| " + " | ".join(["---"] * len(headers)) + " |"]
lines.extend("| " + " | ".join(row) + " |" for row in rows)
return "\n".join(lines)
def print_zinc(args) -> None:
by_cell: dict[tuple[str, str, int], dict] = {}
for path in sorted(Path(args.runs_dir).glob("rec_rrog*_sig0.0_K1_none_T*_s*.json")):
rep = _read(path)
if rep is None:
continue
if args.epochs is not None and int(rep.get("epochs", -1)) != args.epochs:
continue
if int(rep.get("epochs", 0)) < args.min_epochs:
continue
key = (_view(rep), _compute(rep), int(rep["seed"]))
old = by_cell.get(key)
if old is None or int(rep.get("epochs", 0)) > int(old.get("epochs", 0)):
by_cell[key] = rep
grouped: dict[tuple[str, str], list[dict]] = defaultdict(list)
for rep in by_cell.values():
grouped[(_view(rep), _compute(rep))].append(rep)
classic_by_view = {
view: reps
for (view, compute), reps in grouped.items()
if compute == "classic"
}
base_rows = []
for view, classic in sorted(classic_by_view.items()):
vals = [_score(rep, "val") for rep in classic]
tests = [_score(rep, "test") for rep in classic]
base_rows.append([
"zinc-cycle56",
view,
str(len(classic)),
f"{_fmt(_mean(vals), args.digits)} +/- {_fmt(_std(vals), args.digits)}",
f"{_fmt(_mean(tests), args.digits)} +/- {_fmt(_std(tests), args.digits)}",
])
delta_rows = []
for (view, compute), reps in sorted(grouped.items()):
if compute == "classic":
continue
base_by_seed = {int(rep["seed"]): rep for rep in classic_by_view.get(view, [])}
paired = [(rep, base_by_seed[int(rep["seed"])]) for rep in reps if int(rep["seed"]) in base_by_seed]
if not paired:
continue
vals = [_score(rep, "val") for rep, _ in paired]
tests = [_score(rep, "test") for rep, _ in paired]
val_deltas = [_score(base, "val") - _score(rep, "val") for rep, base in paired]
test_deltas = [_score(base, "test") - _score(rep, "test") for rep, base in paired]
delta_rows.append([
"zinc-cycle56",
view,
compute,
str(len(paired)),
f"{_fmt(_mean(vals), args.digits)} ({_fmt(_mean(val_deltas), args.digits)})",
f"{_fmt(_mean(tests), args.digits)} ({_fmt(_mean(test_deltas), args.digits)})",
])
print("\nZINC-cycle56 classic baseline")
print(_markdown(["task", "backbone", "n", "val MAE-sum", "test MAE-sum"], base_rows))
print("\nZINC-cycle56 delta vs matching classic")
print(_markdown(["task", "backbone", "compute", "n", "val score (improvement)", "test score (improvement)"], delta_rows))
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--runs-dir", default="runs")
ap.add_argument("--epochs", type=int)
ap.add_argument("--min-epochs", type=int, default=10)
ap.add_argument("--digits", type=int, default=4)
print_zinc(ap.parse_args())
if __name__ == "__main__":
main()
|