#!/usr/bin/env python3 from __future__ import annotations import argparse import csv import math from collections import defaultdict from pathlib import Path import matplotlib.pyplot as plt DEFAULT_BACKBONES = [ "gin", "gine", "gcn", "graphsage", "gatv2", "graphconv", "transformer", "pna", "gen", "film", "resgated", "tag", "sgc", "cheb", "arma", "mf", "appnp", ] def mean(xs: list[float]) -> float: return sum(xs) / len(xs) if xs else math.nan def parse_list(value: str) -> list[str]: return [x.strip() for x in value.replace(",", " ").split() if x.strip()] def load_cells(path: Path, compute_label: str, datasets: list[str]) -> dict[tuple[str, str], float]: requested = {x.lower() for x in datasets} values: dict[tuple[str, str], list[float]] = defaultdict(list) with path.open() as f: for row in csv.DictReader(f): dataset = row["dataset"].lower() if requested and dataset not in requested: continue if row["compute_label"] != compute_label: continue values[(dataset, row["view"])].append(float(row["test_delta"])) return {key: mean(xs) for key, xs in values.items()} def render_table( cells: dict[tuple[str, str], float], datasets: list[str], backbones: list[str], title: str, out_path: Path, digits: int, ) -> None: n_rows = max(1, len(datasets)) n_cols = max(1, len(backbones)) fig_w = max(12.0, 1.05 * n_cols + 2.8) fig_h = max(2.4, 0.72 * n_rows + 1.8) fig, ax = plt.subplots(figsize=(fig_w, fig_h), dpi=220) ax.axis("off") data = [] for dataset in datasets: row = [] for backbone in backbones: value = cells.get((dataset.lower(), backbone)) row.append("" if value is None or math.isnan(value) else f"{value:+.{digits}f}") data.append(row) table = ax.table( cellText=data, rowLabels=datasets, colLabels=backbones, cellLoc="center", rowLoc="center", loc="center", ) table.auto_set_font_size(False) table.set_fontsize(8.5) table.scale(1.0, 1.65) header_bg = "#f1f5f9" edge = "#cbd5e1" positive = "#15803d" negative = "#b91c1c" missing = "#64748b" for (row_idx, col_idx), cell in table.get_celld().items(): cell.set_edgecolor(edge) cell.set_linewidth(0.7) text = cell.get_text() if row_idx == 0 or col_idx == -1: cell.set_facecolor(header_bg) text.set_fontweight("bold") text.set_color("#0f172a") continue raw = text.get_text() if not raw: text.set_text("-") text.set_color(missing) continue value = float(raw) if value > 0: text.set_color(positive) text.set_fontweight("bold") elif value < 0: text.set_color(negative) else: text.set_color("#334155") ax.set_title(title, fontsize=13, fontweight="bold", pad=14) fig.tight_layout(pad=0.8) out_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out_path, bbox_inches="tight", facecolor="white") plt.close(fig) def main() -> None: ap = argparse.ArgumentParser() ap.add_argument("--deltas", default="analysis/paired_deltas.csv") ap.add_argument("--compute-label", default="fixed-rrog-T1-ns3+trace") ap.add_argument("--datasets", default="zinc-cycle56") ap.add_argument("--backbones", default=" ".join(DEFAULT_BACKBONES)) ap.add_argument("--out", default="analysis/zinc_delta_table.png") ap.add_argument("--title", default="") ap.add_argument("--digits", type=int, default=3) args = ap.parse_args() datasets = parse_list(args.datasets) backbones = parse_list(args.backbones) cells = load_cells(Path(args.deltas), args.compute_label, datasets) title = args.title or f"Test Delta vs Classic: {args.compute_label}" render_table(cells, datasets, backbones, title, Path(args.out), args.digits) print(f"wrote {args.out}") if __name__ == "__main__": main()