#!/usr/bin/env python3 from __future__ import annotations import argparse import csv import math from collections import defaultdict from pathlib import Path import matplotlib.pyplot as plt DEFAULT_BACKBONES = [ "gin", "gine", "gcn", "graphsage", "gatv2", "graphconv", "transformer", "pna", "gen", "film", "resgated", "tag", "sgc", "cheb", "arma", "mf", "appnp", ] DEFAULT_DATASETS = [ "zinc-cycle56", "ogbg-molhiv", "ogbg-molbbbp", "ogbg-molbace", "ogbg-moltox21", "ogbg-molclintox", "ogbg-molsider", "ogbg-molesol", "ogbg-molfreesolv", "ogbg-mollipo", ] def mean(xs: list[float]) -> float: return sum(xs) / len(xs) if xs else math.nan def parse_list(value: str) -> list[str]: return [x.strip() for x in value.replace(",", " ").split() if x.strip()] def parse_label_overrides(value: str) -> dict[str, str]: overrides = {} for item in value.split(","): item = item.strip() if not item: continue if "=" not in item: raise ValueError(f"label override must be DATASET=LABEL, got {item!r}") dataset, label = item.split("=", 1) overrides[dataset.strip().lower()] = label.strip() return overrides def load_cells( path: Path, compute_label: str, datasets: list[str], label_overrides: dict[str, str], ) -> dict[tuple[str, str], float]: requested = {x.lower() for x in datasets} values: dict[tuple[str, str], list[float]] = defaultdict(list) with path.open() as f: for row in csv.DictReader(f): dataset = row["dataset"].lower() if requested and dataset not in requested: continue expected_label = label_overrides.get(dataset, compute_label) if row["compute_label"] != expected_label: continue values[(dataset, row["view"])].append(float(row["test_delta"])) return {key: mean(xs) for key, xs in values.items()} def render_table( cells: dict[tuple[str, str], float], datasets: list[str], backbones: list[str], title: str, out_path: Path, digits: int, ) -> None: n_rows = max(1, len(datasets)) n_cols = max(1, len(backbones)) fig_w = max(12.0, 1.05 * n_cols + 2.8) fig_h = max(2.4, 0.72 * n_rows + 1.8) fig, ax = plt.subplots(figsize=(fig_w, fig_h), dpi=220) ax.axis("off") data = [] for dataset in datasets: row = [] for backbone in backbones: value = cells.get((dataset.lower(), backbone)) if value is None or math.isnan(value) or value < 0: row.append("-") else: row.append(f"{value:+.{digits}f}") data.append(row) table = ax.table( cellText=data, rowLabels=datasets, colLabels=backbones, cellLoc="center", rowLoc="center", loc="center", ) table.auto_set_font_size(False) table.set_fontsize(8.5) table.scale(1.0, 1.65) header_bg = "#f1f5f9" edge = "#cbd5e1" positive = "#15803d" missing = "#64748b" for (row_idx, col_idx), cell in table.get_celld().items(): cell.set_edgecolor(edge) cell.set_linewidth(0.7) text = cell.get_text() if row_idx == 0 or col_idx == -1: cell.set_facecolor(header_bg) text.set_fontweight("bold") text.set_color("#0f172a") continue raw = text.get_text() if raw == "-": text.set_text("-") text.set_color(missing) continue value = float(raw) if value > 0: text.set_color(positive) text.set_fontweight("bold") else: text.set_color("#334155") ax.set_title(title, fontsize=13, fontweight="bold", pad=14) fig.tight_layout(pad=0.8) out_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out_path, bbox_inches="tight", facecolor="white") plt.close(fig) def main() -> None: ap = argparse.ArgumentParser() ap.add_argument("--deltas", default="analysis/paired_deltas.csv") ap.add_argument("--compute-label", default="fixed-rrog-T1-ns3") ap.add_argument( "--label-overrides", default="zinc-cycle56=fixed-rrog-T1-ns3+trace", help="comma-separated DATASET=COMPUTE_LABEL overrides", ) ap.add_argument("--datasets", default=" ".join(DEFAULT_DATASETS)) ap.add_argument("--backbones", default=" ".join(DEFAULT_BACKBONES)) ap.add_argument("--out", default="analysis/fixed_rrog_delta_matrix.png") ap.add_argument("--title", default="") ap.add_argument("--digits", type=int, default=3) args = ap.parse_args() datasets = parse_list(args.datasets) backbones = parse_list(args.backbones) label_overrides = parse_label_overrides(args.label_overrides) cells = load_cells(Path(args.deltas), args.compute_label, datasets, label_overrides) title = args.title or "Test Delta vs Classic: fixed RRoG" render_table(cells, datasets, backbones, title, Path(args.out), args.digits) print(f"wrote {args.out}") if __name__ == "__main__": main()