summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md9
-rw-r--r--analysis/zinc_delta_table.pngbin0 -> 61924 bytes
-rwxr-xr-xscripts/render_delta_table.py146
3 files changed, 155 insertions, 0 deletions
diff --git a/README.md b/README.md
index 1e840d3..8aa9357 100644
--- a/README.md
+++ b/README.md
@@ -126,6 +126,15 @@ This writes `analysis/result_audit.md`, `analysis/result_cells.csv`,
`analysis/paired_deltas.csv`, `analysis/delta_summary.csv`, and
`analysis/coverage.csv`.
+Render a delta table PNG:
+
+```bash
+python3 scripts/render_delta_table.py \
+ --compute-label fixed-rrog-T1-ns3+trace \
+ --datasets zinc-cycle56 \
+ --out analysis/zinc_delta_table.png
+```
+
## Backbones
The implemented 2D view/backbone list is shared across ZINC and OGB:
diff --git a/analysis/zinc_delta_table.png b/analysis/zinc_delta_table.png
new file mode 100644
index 0000000..9c3b9c4
--- /dev/null
+++ b/analysis/zinc_delta_table.png
Binary files differ
diff --git a/scripts/render_delta_table.py b/scripts/render_delta_table.py
new file mode 100755
index 0000000..f1177a6
--- /dev/null
+++ b/scripts/render_delta_table.py
@@ -0,0 +1,146 @@
+#!/usr/bin/env python3
+from __future__ import annotations
+
+import argparse
+import csv
+import math
+from collections import defaultdict
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+
+
+DEFAULT_BACKBONES = [
+ "gin",
+ "gine",
+ "gcn",
+ "graphsage",
+ "gatv2",
+ "graphconv",
+ "transformer",
+ "pna",
+ "gen",
+ "film",
+ "resgated",
+ "tag",
+ "sgc",
+ "cheb",
+ "arma",
+ "mf",
+ "appnp",
+]
+
+
+def mean(xs: list[float]) -> float:
+ return sum(xs) / len(xs) if xs else math.nan
+
+
+def parse_list(value: str) -> list[str]:
+ return [x.strip() for x in value.replace(",", " ").split() if x.strip()]
+
+
+def load_cells(path: Path, compute_label: str, datasets: list[str]) -> dict[tuple[str, str], float]:
+ requested = {x.lower() for x in datasets}
+ values: dict[tuple[str, str], list[float]] = defaultdict(list)
+ with path.open() as f:
+ for row in csv.DictReader(f):
+ dataset = row["dataset"].lower()
+ if requested and dataset not in requested:
+ continue
+ if row["compute_label"] != compute_label:
+ continue
+ values[(dataset, row["view"])].append(float(row["test_delta"]))
+ return {key: mean(xs) for key, xs in values.items()}
+
+
+def render_table(
+ cells: dict[tuple[str, str], float],
+ datasets: list[str],
+ backbones: list[str],
+ title: str,
+ out_path: Path,
+ digits: int,
+) -> None:
+ n_rows = max(1, len(datasets))
+ n_cols = max(1, len(backbones))
+ fig_w = max(12.0, 1.05 * n_cols + 2.8)
+ fig_h = max(2.4, 0.72 * n_rows + 1.8)
+ fig, ax = plt.subplots(figsize=(fig_w, fig_h), dpi=220)
+ ax.axis("off")
+
+ data = []
+ for dataset in datasets:
+ row = []
+ for backbone in backbones:
+ value = cells.get((dataset.lower(), backbone))
+ row.append("" if value is None or math.isnan(value) else f"{value:+.{digits}f}")
+ data.append(row)
+
+ table = ax.table(
+ cellText=data,
+ rowLabels=datasets,
+ colLabels=backbones,
+ cellLoc="center",
+ rowLoc="center",
+ loc="center",
+ )
+ table.auto_set_font_size(False)
+ table.set_fontsize(8.5)
+ table.scale(1.0, 1.65)
+
+ header_bg = "#f1f5f9"
+ edge = "#cbd5e1"
+ positive = "#15803d"
+ negative = "#b91c1c"
+ missing = "#64748b"
+ for (row_idx, col_idx), cell in table.get_celld().items():
+ cell.set_edgecolor(edge)
+ cell.set_linewidth(0.7)
+ text = cell.get_text()
+ if row_idx == 0 or col_idx == -1:
+ cell.set_facecolor(header_bg)
+ text.set_fontweight("bold")
+ text.set_color("#0f172a")
+ continue
+ raw = text.get_text()
+ if not raw:
+ text.set_text("-")
+ text.set_color(missing)
+ continue
+ value = float(raw)
+ if value > 0:
+ text.set_color(positive)
+ text.set_fontweight("bold")
+ elif value < 0:
+ text.set_color(negative)
+ else:
+ text.set_color("#334155")
+
+ ax.set_title(title, fontsize=13, fontweight="bold", pad=14)
+ fig.tight_layout(pad=0.8)
+ out_path.parent.mkdir(parents=True, exist_ok=True)
+ fig.savefig(out_path, bbox_inches="tight", facecolor="white")
+ plt.close(fig)
+
+
+def main() -> None:
+ ap = argparse.ArgumentParser()
+ ap.add_argument("--deltas", default="analysis/paired_deltas.csv")
+ ap.add_argument("--compute-label", default="fixed-rrog-T1-ns3+trace")
+ ap.add_argument("--datasets", default="zinc-cycle56")
+ ap.add_argument("--backbones", default=" ".join(DEFAULT_BACKBONES))
+ ap.add_argument("--out", default="analysis/zinc_delta_table.png")
+ ap.add_argument("--title", default="")
+ ap.add_argument("--digits", type=int, default=3)
+ args = ap.parse_args()
+
+ datasets = parse_list(args.datasets)
+ backbones = parse_list(args.backbones)
+ cells = load_cells(Path(args.deltas), args.compute_label, datasets)
+ title = args.title or f"Test Delta vs Classic: {args.compute_label}"
+ render_table(cells, datasets, backbones, title, Path(args.out), args.digits)
+ print(f"wrote {args.out}")
+
+
+if __name__ == "__main__":
+ main()