summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-06-29 18:55:22 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-06-29 18:55:22 -0500
commitf798b2069f7fdf0e288a6195fa193ecf50986c3f (patch)
treea7fee1a4a805e5c43183bfd95c32f878f14ae40f
parent03c6d41729077319c3493ed4055e9824c8eda3ba (diff)
Render full fixed RRoG delta matrix
-rw-r--r--README.md9
-rw-r--r--analysis/fixed_rrog_delta_matrix.pngbin0 -> 253924 bytes
-rwxr-xr-xscripts/render_delta_table.py52
3 files changed, 54 insertions, 7 deletions
diff --git a/README.md b/README.md
index 8aa9357..511e4a7 100644
--- a/README.md
+++ b/README.md
@@ -130,7 +130,16 @@ Render a delta table PNG:
```bash
python3 scripts/render_delta_table.py \
+ --title "Fixed RRoG Test Delta vs Classic" \
+ --out analysis/fixed_rrog_delta_matrix.png
+```
+
+Render the ZINC-only row:
+
+```bash
+python3 scripts/render_delta_table.py \
--compute-label fixed-rrog-T1-ns3+trace \
+ --label-overrides "" \
--datasets zinc-cycle56 \
--out analysis/zinc_delta_table.png
```
diff --git a/analysis/fixed_rrog_delta_matrix.png b/analysis/fixed_rrog_delta_matrix.png
new file mode 100644
index 0000000..19d8075
--- /dev/null
+++ b/analysis/fixed_rrog_delta_matrix.png
Binary files differ
diff --git a/scripts/render_delta_table.py b/scripts/render_delta_table.py
index f1177a6..7b0dda5 100755
--- a/scripts/render_delta_table.py
+++ b/scripts/render_delta_table.py
@@ -30,6 +30,19 @@ DEFAULT_BACKBONES = [
"appnp",
]
+DEFAULT_DATASETS = [
+ "zinc-cycle56",
+ "ogbg-molhiv",
+ "ogbg-molbbbp",
+ "ogbg-molbace",
+ "ogbg-moltox21",
+ "ogbg-molclintox",
+ "ogbg-molsider",
+ "ogbg-molesol",
+ "ogbg-molfreesolv",
+ "ogbg-mollipo",
+]
+
def mean(xs: list[float]) -> float:
return sum(xs) / len(xs) if xs else math.nan
@@ -39,7 +52,25 @@ def parse_list(value: str) -> list[str]:
return [x.strip() for x in value.replace(",", " ").split() if x.strip()]
-def load_cells(path: Path, compute_label: str, datasets: list[str]) -> dict[tuple[str, str], float]:
+def parse_label_overrides(value: str) -> dict[str, str]:
+ overrides = {}
+ for item in value.split(","):
+ item = item.strip()
+ if not item:
+ continue
+ if "=" not in item:
+ raise ValueError(f"label override must be DATASET=LABEL, got {item!r}")
+ dataset, label = item.split("=", 1)
+ overrides[dataset.strip().lower()] = label.strip()
+ return overrides
+
+
+def load_cells(
+ path: Path,
+ compute_label: str,
+ datasets: list[str],
+ label_overrides: dict[str, str],
+) -> dict[tuple[str, str], float]:
requested = {x.lower() for x in datasets}
values: dict[tuple[str, str], list[float]] = defaultdict(list)
with path.open() as f:
@@ -47,7 +78,8 @@ def load_cells(path: Path, compute_label: str, datasets: list[str]) -> dict[tupl
dataset = row["dataset"].lower()
if requested and dataset not in requested:
continue
- if row["compute_label"] != compute_label:
+ expected_label = label_overrides.get(dataset, compute_label)
+ if row["compute_label"] != expected_label:
continue
values[(dataset, row["view"])].append(float(row["test_delta"]))
return {key: mean(xs) for key, xs in values.items()}
@@ -126,18 +158,24 @@ def render_table(
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--deltas", default="analysis/paired_deltas.csv")
- ap.add_argument("--compute-label", default="fixed-rrog-T1-ns3+trace")
- ap.add_argument("--datasets", default="zinc-cycle56")
+ ap.add_argument("--compute-label", default="fixed-rrog-T1-ns3")
+ ap.add_argument(
+ "--label-overrides",
+ default="zinc-cycle56=fixed-rrog-T1-ns3+trace",
+ help="comma-separated DATASET=COMPUTE_LABEL overrides",
+ )
+ ap.add_argument("--datasets", default=" ".join(DEFAULT_DATASETS))
ap.add_argument("--backbones", default=" ".join(DEFAULT_BACKBONES))
- ap.add_argument("--out", default="analysis/zinc_delta_table.png")
+ ap.add_argument("--out", default="analysis/fixed_rrog_delta_matrix.png")
ap.add_argument("--title", default="")
ap.add_argument("--digits", type=int, default=3)
args = ap.parse_args()
datasets = parse_list(args.datasets)
backbones = parse_list(args.backbones)
- cells = load_cells(Path(args.deltas), args.compute_label, datasets)
- title = args.title or f"Test Delta vs Classic: {args.compute_label}"
+ label_overrides = parse_label_overrides(args.label_overrides)
+ cells = load_cells(Path(args.deltas), args.compute_label, datasets, label_overrides)
+ title = args.title or "Test Delta vs Classic: fixed RRoG"
render_table(cells, datasets, backbones, title, Path(args.out), args.digits)
print(f"wrote {args.out}")