diff options
| -rw-r--r-- | README.md | 9 | ||||
| -rw-r--r-- | analysis/fixed_rrog_delta_matrix.png | bin | 0 -> 253924 bytes | |||
| -rwxr-xr-x | scripts/render_delta_table.py | 52 |
3 files changed, 54 insertions, 7 deletions
@@ -130,7 +130,16 @@ Render a delta table PNG: ```bash python3 scripts/render_delta_table.py \ + --title "Fixed RRoG Test Delta vs Classic" \ + --out analysis/fixed_rrog_delta_matrix.png +``` + +Render the ZINC-only row: + +```bash +python3 scripts/render_delta_table.py \ --compute-label fixed-rrog-T1-ns3+trace \ + --label-overrides "" \ --datasets zinc-cycle56 \ --out analysis/zinc_delta_table.png ``` diff --git a/analysis/fixed_rrog_delta_matrix.png b/analysis/fixed_rrog_delta_matrix.png Binary files differnew file mode 100644 index 0000000..19d8075 --- /dev/null +++ b/analysis/fixed_rrog_delta_matrix.png diff --git a/scripts/render_delta_table.py b/scripts/render_delta_table.py index f1177a6..7b0dda5 100755 --- a/scripts/render_delta_table.py +++ b/scripts/render_delta_table.py @@ -30,6 +30,19 @@ DEFAULT_BACKBONES = [ "appnp", ] +DEFAULT_DATASETS = [ + "zinc-cycle56", + "ogbg-molhiv", + "ogbg-molbbbp", + "ogbg-molbace", + "ogbg-moltox21", + "ogbg-molclintox", + "ogbg-molsider", + "ogbg-molesol", + "ogbg-molfreesolv", + "ogbg-mollipo", +] + def mean(xs: list[float]) -> float: return sum(xs) / len(xs) if xs else math.nan @@ -39,7 +52,25 @@ def parse_list(value: str) -> list[str]: return [x.strip() for x in value.replace(",", " ").split() if x.strip()] -def load_cells(path: Path, compute_label: str, datasets: list[str]) -> dict[tuple[str, str], float]: +def parse_label_overrides(value: str) -> dict[str, str]: + overrides = {} + for item in value.split(","): + item = item.strip() + if not item: + continue + if "=" not in item: + raise ValueError(f"label override must be DATASET=LABEL, got {item!r}") + dataset, label = item.split("=", 1) + overrides[dataset.strip().lower()] = label.strip() + return overrides + + +def load_cells( + path: Path, + compute_label: str, + datasets: list[str], + label_overrides: dict[str, str], +) -> dict[tuple[str, str], float]: requested = {x.lower() for x in datasets} values: dict[tuple[str, str], list[float]] = defaultdict(list) with path.open() as f: @@ -47,7 +78,8 @@ def load_cells(path: Path, compute_label: str, datasets: list[str]) -> dict[tupl dataset = row["dataset"].lower() if requested and dataset not in requested: continue - if row["compute_label"] != compute_label: + expected_label = label_overrides.get(dataset, compute_label) + if row["compute_label"] != expected_label: continue values[(dataset, row["view"])].append(float(row["test_delta"])) return {key: mean(xs) for key, xs in values.items()} @@ -126,18 +158,24 @@ def render_table( def main() -> None: ap = argparse.ArgumentParser() ap.add_argument("--deltas", default="analysis/paired_deltas.csv") - ap.add_argument("--compute-label", default="fixed-rrog-T1-ns3+trace") - ap.add_argument("--datasets", default="zinc-cycle56") + ap.add_argument("--compute-label", default="fixed-rrog-T1-ns3") + ap.add_argument( + "--label-overrides", + default="zinc-cycle56=fixed-rrog-T1-ns3+trace", + help="comma-separated DATASET=COMPUTE_LABEL overrides", + ) + ap.add_argument("--datasets", default=" ".join(DEFAULT_DATASETS)) ap.add_argument("--backbones", default=" ".join(DEFAULT_BACKBONES)) - ap.add_argument("--out", default="analysis/zinc_delta_table.png") + ap.add_argument("--out", default="analysis/fixed_rrog_delta_matrix.png") ap.add_argument("--title", default="") ap.add_argument("--digits", type=int, default=3) args = ap.parse_args() datasets = parse_list(args.datasets) backbones = parse_list(args.backbones) - cells = load_cells(Path(args.deltas), args.compute_label, datasets) - title = args.title or f"Test Delta vs Classic: {args.compute_label}" + label_overrides = parse_label_overrides(args.label_overrides) + cells = load_cells(Path(args.deltas), args.compute_label, datasets, label_overrides) + title = args.title or "Test Delta vs Classic: fixed RRoG" render_table(cells, datasets, backbones, title, Path(args.out), args.digits) print(f"wrote {args.out}") |
