summaryrefslogtreecommitdiff
path: root/scripts/render_delta_table.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/render_delta_table.py')
-rwxr-xr-xscripts/render_delta_table.py52
1 files changed, 45 insertions, 7 deletions
diff --git a/scripts/render_delta_table.py b/scripts/render_delta_table.py
index f1177a6..7b0dda5 100755
--- a/scripts/render_delta_table.py
+++ b/scripts/render_delta_table.py
@@ -30,6 +30,19 @@ DEFAULT_BACKBONES = [
"appnp",
]
+DEFAULT_DATASETS = [
+ "zinc-cycle56",
+ "ogbg-molhiv",
+ "ogbg-molbbbp",
+ "ogbg-molbace",
+ "ogbg-moltox21",
+ "ogbg-molclintox",
+ "ogbg-molsider",
+ "ogbg-molesol",
+ "ogbg-molfreesolv",
+ "ogbg-mollipo",
+]
+
def mean(xs: list[float]) -> float:
return sum(xs) / len(xs) if xs else math.nan
@@ -39,7 +52,25 @@ def parse_list(value: str) -> list[str]:
return [x.strip() for x in value.replace(",", " ").split() if x.strip()]
-def load_cells(path: Path, compute_label: str, datasets: list[str]) -> dict[tuple[str, str], float]:
+def parse_label_overrides(value: str) -> dict[str, str]:
+ overrides = {}
+ for item in value.split(","):
+ item = item.strip()
+ if not item:
+ continue
+ if "=" not in item:
+ raise ValueError(f"label override must be DATASET=LABEL, got {item!r}")
+ dataset, label = item.split("=", 1)
+ overrides[dataset.strip().lower()] = label.strip()
+ return overrides
+
+
+def load_cells(
+ path: Path,
+ compute_label: str,
+ datasets: list[str],
+ label_overrides: dict[str, str],
+) -> dict[tuple[str, str], float]:
requested = {x.lower() for x in datasets}
values: dict[tuple[str, str], list[float]] = defaultdict(list)
with path.open() as f:
@@ -47,7 +78,8 @@ def load_cells(path: Path, compute_label: str, datasets: list[str]) -> dict[tupl
dataset = row["dataset"].lower()
if requested and dataset not in requested:
continue
- if row["compute_label"] != compute_label:
+ expected_label = label_overrides.get(dataset, compute_label)
+ if row["compute_label"] != expected_label:
continue
values[(dataset, row["view"])].append(float(row["test_delta"]))
return {key: mean(xs) for key, xs in values.items()}
@@ -126,18 +158,24 @@ def render_table(
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--deltas", default="analysis/paired_deltas.csv")
- ap.add_argument("--compute-label", default="fixed-rrog-T1-ns3+trace")
- ap.add_argument("--datasets", default="zinc-cycle56")
+ ap.add_argument("--compute-label", default="fixed-rrog-T1-ns3")
+ ap.add_argument(
+ "--label-overrides",
+ default="zinc-cycle56=fixed-rrog-T1-ns3+trace",
+ help="comma-separated DATASET=COMPUTE_LABEL overrides",
+ )
+ ap.add_argument("--datasets", default=" ".join(DEFAULT_DATASETS))
ap.add_argument("--backbones", default=" ".join(DEFAULT_BACKBONES))
- ap.add_argument("--out", default="analysis/zinc_delta_table.png")
+ ap.add_argument("--out", default="analysis/fixed_rrog_delta_matrix.png")
ap.add_argument("--title", default="")
ap.add_argument("--digits", type=int, default=3)
args = ap.parse_args()
datasets = parse_list(args.datasets)
backbones = parse_list(args.backbones)
- cells = load_cells(Path(args.deltas), args.compute_label, datasets)
- title = args.title or f"Test Delta vs Classic: {args.compute_label}"
+ label_overrides = parse_label_overrides(args.label_overrides)
+ cells = load_cells(Path(args.deltas), args.compute_label, datasets, label_overrides)
+ title = args.title or "Test Delta vs Classic: fixed RRoG"
render_table(cells, datasets, backbones, title, Path(args.out), args.digits)
print(f"wrote {args.out}")