summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-06-22 13:10:30 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-06-22 13:10:30 -0500
commit6301f7ef680527bc34b34c06bfc0eba567ddbc45 (patch)
tree6202c39f8ed0a15014582577fe67f6d8eb28f0f4
parent322a70e9ceaf2cd9d09e9847e291952c3dab9518 (diff)
Add OGB protocol diagnosticsHEADmain
-rw-r--r--README.md6
-rw-r--r--analysis/ogbg-molhiv_fixed-rrog_T3_ns3_diagnostics.csv18
-rw-r--r--analysis/ogbg-molhiv_fixed-rrog_T3_ns3_diagnostics.md47
-rw-r--r--rrog/collect_results.py6
-rw-r--r--scripts/analyze_ogb_hiv_log.py312
-rwxr-xr-xscripts/run_ogb_mol_task_full.sh4
6 files changed, 391 insertions, 2 deletions
diff --git a/README.md b/README.md
index f07c90f..2b55983 100644
--- a/README.md
+++ b/README.md
@@ -68,6 +68,12 @@ Run one OGB molecular task:
TASK=ogbg-molhiv DEVICE=cuda:1 EPOCHS=100 ./scripts/run_ogb_mol_task_full.sh
```
+Run the same OGB task with the lighter fixed recursion used by the ZINC sweep:
+
+```bash
+TASK=ogbg-molhiv DEVICE=cuda:1 EPOCHS=100 FIXED_T=1 FIXED_NS=3 ./scripts/run_ogb_mol_task_full.sh
+```
+
Run all selected OGB molecular tasks serially on one GPU:
```bash
diff --git a/analysis/ogbg-molhiv_fixed-rrog_T3_ns3_diagnostics.csv b/analysis/ogbg-molhiv_fixed-rrog_T3_ns3_diagnostics.csv
new file mode 100644
index 0000000..8d8ff16
--- /dev/null
+++ b/analysis/ogbg-molhiv_fixed-rrog_T3_ns3_diagnostics.csv
@@ -0,0 +1,18 @@
+view,metric,base_ep,base_val,base_test,cand_ep,cand_val,cand_test,val_delta,test_delta,base_curve_best_ep,base_final_gap,cand_curve_best_ep,cand_curve_final_ep,cand_final_gap,cand_late_slope,cand_range,labels
+appnp,rocauc,30,0.7674851190476192,0.7000000000000001,30,0.7203422006662747,0.682502558952471,-0.047142918381344434,-0.01749744104752904,30,0.04686000000000001,30,100,0.047719999999999985,0.0032600000000000406,0.07535000000000003,"val_and_test_down,fixed_val_lags_classic,late_collapse"
+arma,rocauc,70,0.7872115667254557,0.7296336352575368,90,0.7926495443856555,0.7317541860599857,0.005437977660199822,0.002120550802448884,70,0.0020499999999999963,90,100,0.0011499999999999844,-0.0011499999999999844,0.13578999999999997,"positive_or_tie,maybe_undertrained"
+cheb,rocauc,60,0.783120835782873,0.7281735838853589,40,0.7734757740544776,0.7425713126943356,-0.009645061728395299,0.014397728808976717,60,0.021220000000000017,40,100,0.03462999999999994,-0.0027800000000000047,0.050449999999999995,"test_up_val_down,late_collapse"
+film,rocauc,20,0.7921535126396236,0.7842136773595473,30,0.7633101851851851,0.7727785395623709,-0.028843327454438517,-0.01143513779717642,20,0.043240000000000056,30,100,0.05358000000000007,-0.003670000000000062,0.07297000000000009,"val_and_test_down,fixed_val_lags_classic,late_collapse"
+gatv2,rocauc,20,0.7834576474622771,0.75625253481141,20,0.7835464432686655,0.746617354525966,8.879580638843088e-05,-0.009635180285443967,20,0.040959999999999996,20,100,0.09077000000000002,-0.0033800000000000496,0.09077000000000002,"val_up_test_down,late_collapse"
+gcn,rocauc,30,0.7754323437193807,0.719778288495336,40,0.7454683519498334,0.7482879159504818,-0.029963991769547338,0.028509627455145847,30,0.04236999999999991,40,100,0.010009999999999963,-0.006759999999999988,0.10924999999999996,"test_up_val_down,fixed_val_lags_classic"
+gen,rocauc,30,0.7530159954928474,0.7313524788041483,60,0.8127970066627475,0.7631375654222755,0.05978101116990009,0.03178508661812718,30,0.06394,60,100,0.07395999999999991,-0.015559999999999907,0.11785999999999996,"positive_or_tie,late_collapse"
+gin,rocauc,40,0.816746889084852,0.7723980764402558,30,0.7523760533019792,0.7324378609088626,-0.06437083578287284,-0.03996021553139317,40,0.032930000000000015,30,100,0.013860000000000094,0.0015699999999999603,0.09743000000000002,"val_and_test_down,fixed_val_lags_classic"
+gine,rocauc,70,0.7942448069762884,0.7445141852874717,40,0.7574588477366255,0.7400818864790746,-0.03678595923966288,-0.004432298808397128,70,0.008459999999999912,40,100,0.02404000000000006,-0.005790000000000073,0.14568000000000003,"val_and_test_down,fixed_val_lags_classic"
+graphconv,rocauc,80,0.7706618655692731,0.7108113327796983,80,0.7712528169704094,0.6979180748952278,0.0005909514011362971,-0.012893257884470488,80,0.009280000000000066,80,100,0.004809999999999981,-9.999999999998899e-05,0.07740000000000002,val_up_test_down
+graphsage,rocauc,40,0.7968596903782088,0.7625118291199134,20,0.7586759626690183,0.7641070704339598,-0.03818372770919054,0.001595241314046314,40,0.054950000000000054,20,100,0.022279999999999966,-0.008159999999999945,0.08989000000000003,"test_up_val_down,fixed_val_lags_classic"
+mf,rocauc,60,0.784538506760729,0.7073504702678692,50,0.7930782137958065,0.7216728789663763,0.008539707035077448,0.014322408698507094,60,0.015589999999999993,50,100,0.04754000000000003,0.00019000000000002348,0.08803000000000005,"positive_or_tie,late_collapse"
+pna,rocauc,60,0.7779523074661963,0.7647926765677203,60,0.7890395355673133,0.7629792000618011,0.011087228101117064,-0.0018134765059192315,60,0.05793999999999999,60,100,0.022619999999999973,-0.0008000000000000229,0.19358999999999993,val_up_test_down
+resgated,rocauc,80,0.8143738977072311,0.7241738928909403,80,0.8173500881834216,0.7054809865003185,0.0029761904761904656,-0.01869290639062171,80,0.006010000000000071,80,100,0.014419999999999988,0.0013400000000000079,0.13631000000000004,val_up_test_down
+sgc,rocauc,80,0.7478597148736037,0.7019100407501111,50,0.7481720311581422,0.7177311265184728,0.00031231628453853855,0.01582108576836172,80,0.01100000000000001,50,100,0.023460000000000036,-0.006759999999999988,0.06252000000000002,positive_or_tie
+tag,rocauc,70,0.7528292181069958,0.7255354487340427,70,0.7803834754066236,0.7351764228741382,0.027554257299627793,0.00964097414009546,70,0.009939999999999949,70,100,0.010239999999999916,-0.0008799999999999919,0.17215999999999998,positive_or_tie
+transformer,rocauc,30,0.7617332941407015,0.7547248884682979,40,0.8088103811483441,0.7413024585256572,0.04707708700764268,-0.013422429942640646,30,0.04059999999999997,40,100,0.09591000000000005,-0.0031600000000000517,0.10176000000000007,"val_up_test_down,late_collapse"
diff --git a/analysis/ogbg-molhiv_fixed-rrog_T3_ns3_diagnostics.md b/analysis/ogbg-molhiv_fixed-rrog_T3_ns3_diagnostics.md
new file mode 100644
index 0000000..0379e0e
--- /dev/null
+++ b/analysis/ogbg-molhiv_fixed-rrog_T3_ns3_diagnostics.md
@@ -0,0 +1,47 @@
+# ogbg-molhiv fixed-rrog T=3 ns=3 diagnostics
+
+- rows: 17
+- val+test positive: 5
+- val up, test down: 5
+- val down, test up: 3
+- val+test down: 4
+- maybe undertrained: 1
+- late collapse: 7
+
+- mean val delta: -0.0054
+- mean test delta: -0.0007
+- val/test delta corr: 0.4504
+
+## Label Counts
+
+| label | n |
+| --- | --- |
+| fixed_val_lags_classic | 6 |
+| late_collapse | 7 |
+| maybe_undertrained | 1 |
+| positive_or_tie | 5 |
+| test_up_val_down | 3 |
+| val_and_test_down | 4 |
+| val_up_test_down | 5 |
+
+## Per-Backbone Rows
+
+| view | base_val | cand_val | d_val | base_test | cand_test | d_test | cand_ep | cand_final_gap | labels |
+| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
+| appnp | 0.7675 | 0.7203 | -0.0471 | 0.7000 | 0.6825 | -0.0175 | 30 | 0.0477 | val_and_test_down,fixed_val_lags_classic,late_collapse |
+| arma | 0.7872 | 0.7926 | 0.0054 | 0.7296 | 0.7318 | 0.0021 | 90 | 0.0011 | positive_or_tie,maybe_undertrained |
+| cheb | 0.7831 | 0.7735 | -0.0096 | 0.7282 | 0.7426 | 0.0144 | 40 | 0.0346 | test_up_val_down,late_collapse |
+| film | 0.7922 | 0.7633 | -0.0288 | 0.7842 | 0.7728 | -0.0114 | 30 | 0.0536 | val_and_test_down,fixed_val_lags_classic,late_collapse |
+| gatv2 | 0.7835 | 0.7835 | 0.0001 | 0.7563 | 0.7466 | -0.0096 | 20 | 0.0908 | val_up_test_down,late_collapse |
+| gcn | 0.7754 | 0.7455 | -0.0300 | 0.7198 | 0.7483 | 0.0285 | 40 | 0.0100 | test_up_val_down,fixed_val_lags_classic |
+| gen | 0.7530 | 0.8128 | 0.0598 | 0.7314 | 0.7631 | 0.0318 | 60 | 0.0740 | positive_or_tie,late_collapse |
+| gin | 0.8167 | 0.7524 | -0.0644 | 0.7724 | 0.7324 | -0.0400 | 30 | 0.0139 | val_and_test_down,fixed_val_lags_classic |
+| gine | 0.7942 | 0.7575 | -0.0368 | 0.7445 | 0.7401 | -0.0044 | 40 | 0.0240 | val_and_test_down,fixed_val_lags_classic |
+| graphconv | 0.7707 | 0.7713 | 0.0006 | 0.7108 | 0.6979 | -0.0129 | 80 | 0.0048 | val_up_test_down |
+| graphsage | 0.7969 | 0.7587 | -0.0382 | 0.7625 | 0.7641 | 0.0016 | 20 | 0.0223 | test_up_val_down,fixed_val_lags_classic |
+| mf | 0.7845 | 0.7931 | 0.0085 | 0.7074 | 0.7217 | 0.0143 | 50 | 0.0475 | positive_or_tie,late_collapse |
+| pna | 0.7780 | 0.7890 | 0.0111 | 0.7648 | 0.7630 | -0.0018 | 60 | 0.0226 | val_up_test_down |
+| resgated | 0.8144 | 0.8174 | 0.0030 | 0.7242 | 0.7055 | -0.0187 | 80 | 0.0144 | val_up_test_down |
+| sgc | 0.7479 | 0.7482 | 0.0003 | 0.7019 | 0.7177 | 0.0158 | 50 | 0.0235 | positive_or_tie |
+| tag | 0.7528 | 0.7804 | 0.0276 | 0.7255 | 0.7352 | 0.0096 | 70 | 0.0102 | positive_or_tie |
+| transformer | 0.7617 | 0.8088 | 0.0471 | 0.7547 | 0.7413 | -0.0134 | 40 | 0.0959 | val_up_test_down,late_collapse |
diff --git a/rrog/collect_results.py b/rrog/collect_results.py
index 125c0ab..4c1390c 100644
--- a/rrog/collect_results.py
+++ b/rrog/collect_results.py
@@ -84,7 +84,11 @@ def _is_classic_baseline(rep: dict) -> bool:
def _compute_label(rep: dict) -> str:
- label = str(rep["compute"])
+ compute = str(rep["compute"])
+ if _is_classic_baseline(rep):
+ label = "classic"
+ else:
+ label = f"{compute}-T{rep.get('T')}-ns{rep.get('n_sup')}"
ema = float(rep.get("ema", 0.0) or 0.0)
if ema > 0:
label += f"+ema{ema:g}"
diff --git a/scripts/analyze_ogb_hiv_log.py b/scripts/analyze_ogb_hiv_log.py
new file mode 100644
index 0000000..21253fe
--- /dev/null
+++ b/scripts/analyze_ogb_hiv_log.py
@@ -0,0 +1,312 @@
+#!/usr/bin/env python3
+import argparse
+import csv
+import json
+import math
+import re
+from collections import defaultdict
+from pathlib import Path
+
+
+RUN_RE = re.compile(
+ r"^\[run\]\s+"
+ r"(?P<dataset>\S+)\s+view=(?P<view>\S+)\s+compute=(?P<compute>\S+)\s+"
+ r"T=(?P<T>\d+)\s+ns=(?P<ns>\d+)"
+)
+EP_RE = re.compile(
+ r"^ep(?P<ep>\d+)\s+val_(?P<metric>\w+)=(?P<val>[-+0-9.eE]+).*"
+ r"train_steps=(?P<steps>[-+0-9.eE]+)"
+)
+
+
+def _load_json(path: Path) -> dict | None:
+ try:
+ with path.open() as f:
+ return json.load(f)
+ except (OSError, json.JSONDecodeError):
+ return None
+
+
+def _score(rep: dict, split: str) -> float | None:
+ metric = rep.get("metric")
+ if not metric:
+ return None
+ value = rep.get(split, {}).get(metric)
+ return None if value is None else float(value)
+
+
+def parse_curves(path: Path):
+ curves = defaultdict(list)
+ current = None
+ with path.open(errors="replace") as f:
+ for line in f:
+ line = line.strip()
+ m = RUN_RE.match(line)
+ if m:
+ current = (
+ m.group("dataset"),
+ m.group("view"),
+ m.group("compute"),
+ int(m.group("T")),
+ int(m.group("ns")),
+ )
+ continue
+ m = EP_RE.match(line)
+ if current is not None and m:
+ curves[current].append({
+ "ep": int(m.group("ep")),
+ "metric": m.group("metric"),
+ "val": float(m.group("val")),
+ "train_steps": float(m.group("steps")),
+ })
+ return curves
+
+
+def load_runs(runs_dir: Path, dataset: str):
+ runs = {}
+ for path in sorted(runs_dir.glob(f"{dataset}_*.json")):
+ rep = _load_json(path)
+ if rep is None:
+ continue
+ if rep.get("dataset") != dataset:
+ continue
+ key = (
+ rep.get("view"),
+ rep.get("compute"),
+ int(rep.get("T", -1)),
+ int(rep.get("n_sup", -1)),
+ int(rep.get("seed", -1)),
+ )
+ runs[key] = rep
+ return runs
+
+
+def curve_stats(curve: list[dict]) -> dict:
+ if not curve:
+ return {
+ "curve_best_ep": None,
+ "curve_best_val": None,
+ "curve_final_ep": None,
+ "curve_final_val": None,
+ "final_gap": None,
+ "late_slope": None,
+ "range": None,
+ }
+ best = max(curve, key=lambda x: x["val"])
+ final = curve[-1]
+ late_slope = None
+ if len(curve) >= 2:
+ late_slope = final["val"] - curve[-2]["val"]
+ vals = [x["val"] for x in curve]
+ return {
+ "curve_best_ep": best["ep"],
+ "curve_best_val": best["val"],
+ "curve_final_ep": final["ep"],
+ "curve_final_val": final["val"],
+ "final_gap": best["val"] - final["val"],
+ "late_slope": late_slope,
+ "range": max(vals) - min(vals),
+ }
+
+
+def labels_for(row: dict, *, val_tol: float, test_tol: float, collapse_gap: float) -> str:
+ labels = []
+ if row["val_delta"] >= -val_tol and row["test_delta"] >= -test_tol:
+ labels.append("positive_or_tie")
+ elif row["val_delta"] >= -val_tol and row["test_delta"] < -test_tol:
+ labels.append("val_up_test_down")
+ elif row["val_delta"] < -val_tol and row["test_delta"] >= -test_tol:
+ labels.append("test_up_val_down")
+ else:
+ labels.append("val_and_test_down")
+
+ if row["val_delta"] < -0.02:
+ labels.append("fixed_val_lags_classic")
+ if row["cand_final_gap"] is not None and row["cand_final_gap"] > collapse_gap:
+ labels.append("late_collapse")
+ if (
+ row["cand_curve_best_ep"] is not None
+ and row["cand_curve_final_ep"] is not None
+ and row["cand_curve_best_ep"] >= row["cand_curve_final_ep"] - 10
+ and (row["cand_final_gap"] is None or row["cand_final_gap"] <= 0.01)
+ ):
+ labels.append("maybe_undertrained")
+ return ",".join(labels)
+
+
+def fmt(x, digits=4):
+ if x is None:
+ return ""
+ if isinstance(x, int):
+ return str(x)
+ return f"{x:.{digits}f}"
+
+
+def markdown_table(headers, rows):
+ out = [
+ "| " + " | ".join(headers) + " |",
+ "| " + " | ".join(["---"] * len(headers)) + " |",
+ ]
+ out.extend("| " + " | ".join(row) + " |" for row in rows)
+ return "\n".join(out)
+
+
+def mean(xs):
+ return sum(xs) / len(xs) if xs else 0.0
+
+
+def corr(xs, ys):
+ if len(xs) != len(ys) or len(xs) < 2:
+ return 0.0
+ mx = mean(xs)
+ my = mean(ys)
+ den = math.sqrt(sum((x - mx) ** 2 for x in xs) * sum((y - my) ** 2 for y in ys))
+ if den == 0:
+ return 0.0
+ return sum((x - mx) * (y - my) for x, y in zip(xs, ys)) / den
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument("--dataset", default="ogbg-molhiv")
+ ap.add_argument("--runs-dir", default="runs")
+ ap.add_argument("--log", default="logs/ogbg-molhiv_0.log")
+ ap.add_argument("--out-dir", default="analysis")
+ ap.add_argument("--candidate-compute", default="fixed-rrog")
+ ap.add_argument("--candidate-T", type=int, default=3)
+ ap.add_argument("--candidate-n-sup", type=int, default=3)
+ ap.add_argument("--baseline-compute", default="classic")
+ ap.add_argument("--baseline-T", type=int, default=0)
+ ap.add_argument("--baseline-n-sup", type=int, default=1)
+ ap.add_argument("--seed", type=int, default=0)
+ ap.add_argument("--val-tol", type=float, default=0.0)
+ ap.add_argument("--test-tol", type=float, default=0.0)
+ ap.add_argument("--collapse-gap", type=float, default=0.03)
+ args = ap.parse_args()
+
+ runs_dir = Path(args.runs_dir)
+ curves = parse_curves(Path(args.log))
+ runs = load_runs(runs_dir, args.dataset)
+ out_dir = Path(args.out_dir)
+ out_dir.mkdir(parents=True, exist_ok=True)
+
+ rows = []
+ views = sorted({
+ view for (view, compute, _t, _ns, seed) in runs
+ if seed == args.seed and compute == args.baseline_compute
+ })
+ for view in views:
+ base_key = (view, args.baseline_compute, args.baseline_T, args.baseline_n_sup, args.seed)
+ cand_key = (view, args.candidate_compute, args.candidate_T, args.candidate_n_sup, args.seed)
+ base = runs.get(base_key)
+ cand = runs.get(cand_key)
+ if base is None or cand is None:
+ continue
+ metric = cand["metric"]
+ base_val = _score(base, "val")
+ base_test = _score(base, "test")
+ cand_val = _score(cand, "val")
+ cand_test = _score(cand, "test")
+ if None in {base_val, base_test, cand_val, cand_test}:
+ continue
+ base_curve = curve_stats(curves.get((args.dataset, view, args.baseline_compute, args.baseline_T, args.baseline_n_sup), []))
+ cand_curve = curve_stats(curves.get((args.dataset, view, args.candidate_compute, args.candidate_T, args.candidate_n_sup), []))
+ row = {
+ "view": view,
+ "metric": metric,
+ "base_ep": int(base.get("ep") or 0),
+ "base_val": base_val,
+ "base_test": base_test,
+ "cand_ep": int(cand.get("ep") or 0),
+ "cand_val": cand_val,
+ "cand_test": cand_test,
+ "val_delta": cand_val - base_val,
+ "test_delta": cand_test - base_test,
+ "base_curve_best_ep": base_curve["curve_best_ep"],
+ "base_final_gap": base_curve["final_gap"],
+ "cand_curve_best_ep": cand_curve["curve_best_ep"],
+ "cand_curve_final_ep": cand_curve["curve_final_ep"],
+ "cand_final_gap": cand_curve["final_gap"],
+ "cand_late_slope": cand_curve["late_slope"],
+ "cand_range": cand_curve["range"],
+ }
+ row["labels"] = labels_for(
+ row,
+ val_tol=args.val_tol,
+ test_tol=args.test_tol,
+ collapse_gap=args.collapse_gap,
+ )
+ rows.append(row)
+
+ csv_path = out_dir / f"{args.dataset}_{args.candidate_compute}_T{args.candidate_T}_ns{args.candidate_n_sup}_diagnostics.csv"
+ fieldnames = [
+ "view", "metric", "base_ep", "base_val", "base_test", "cand_ep", "cand_val", "cand_test",
+ "val_delta", "test_delta", "base_curve_best_ep", "base_final_gap",
+ "cand_curve_best_ep", "cand_curve_final_ep", "cand_final_gap", "cand_late_slope",
+ "cand_range", "labels",
+ ]
+ with csv_path.open("w", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
+ writer.writeheader()
+ writer.writerows(rows)
+
+ counts = defaultdict(int)
+ for row in rows:
+ for label in row["labels"].split(","):
+ counts[label] += 1
+
+ pos = [r for r in rows if r["val_delta"] >= 0 and r["test_delta"] >= 0]
+ val_up_test_down = [r for r in rows if r["val_delta"] >= 0 and r["test_delta"] < 0]
+ val_down_test_up = [r for r in rows if r["val_delta"] < 0 and r["test_delta"] >= 0]
+ both_down = [r for r in rows if r["val_delta"] < 0 and r["test_delta"] < 0]
+ late = [r for r in rows if "maybe_undertrained" in r["labels"]]
+ collapse = [r for r in rows if "late_collapse" in r["labels"]]
+ val_deltas = [r["val_delta"] for r in rows]
+ test_deltas = [r["test_delta"] for r in rows]
+
+ def row_line(r):
+ return [
+ r["view"],
+ fmt(r["base_val"]),
+ fmt(r["cand_val"]),
+ fmt(r["val_delta"]),
+ fmt(r["base_test"]),
+ fmt(r["cand_test"]),
+ fmt(r["test_delta"]),
+ str(r["cand_ep"]),
+ fmt(r["cand_final_gap"]),
+ r["labels"],
+ ]
+
+ md_path = csv_path.with_suffix(".md")
+ with md_path.open("w") as f:
+ f.write(f"# {args.dataset} {args.candidate_compute} T={args.candidate_T} ns={args.candidate_n_sup} diagnostics\n\n")
+ f.write(f"- rows: {len(rows)}\n")
+ f.write(f"- val+test positive: {len(pos)}\n")
+ f.write(f"- val up, test down: {len(val_up_test_down)}\n")
+ f.write(f"- val down, test up: {len(val_down_test_up)}\n")
+ f.write(f"- val+test down: {len(both_down)}\n")
+ f.write(f"- maybe undertrained: {len(late)}\n")
+ f.write(f"- late collapse: {len(collapse)}\n\n")
+ f.write(f"- mean val delta: {fmt(mean(val_deltas))}\n")
+ f.write(f"- mean test delta: {fmt(mean(test_deltas))}\n")
+ f.write(f"- val/test delta corr: {fmt(corr(val_deltas, test_deltas))}\n\n")
+ f.write("## Label Counts\n\n")
+ f.write(markdown_table(["label", "n"], [[k, str(v)] for k, v in sorted(counts.items())]))
+ f.write("\n\n## Per-Backbone Rows\n\n")
+ headers = [
+ "view", "base_val", "cand_val", "d_val", "base_test", "cand_test",
+ "d_test", "cand_ep", "cand_final_gap", "labels",
+ ]
+ f.write(markdown_table(headers, [row_line(r) for r in rows]))
+ f.write("\n")
+
+ print(f"wrote {csv_path}")
+ print(f"wrote {md_path}")
+ print(f"rows={len(rows)} val+test={len(pos)} val_up_test_down={len(val_up_test_down)} "
+ f"val_down_test_up={len(val_down_test_up)} both_down={len(both_down)} "
+ f"maybe_undertrained={len(late)} late_collapse={len(collapse)}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/run_ogb_mol_task_full.sh b/scripts/run_ogb_mol_task_full.sh
index b25bff3..71ddba1 100755
--- a/scripts/run_ogb_mol_task_full.sh
+++ b/scripts/run_ogb_mol_task_full.sh
@@ -10,6 +10,8 @@ DEVICE="${DEVICE:-cuda:1}"
EPOCHS="${EPOCHS:-100}"
SEED="${SEED:-0}"
HIDDEN="${HIDDEN:-128}"
+FIXED_T="${FIXED_T:-3}"
+FIXED_NS="${FIXED_NS:-3}"
VIEWS="${VIEWS:-gin gine gcn graphsage gatv2 graphconv transformer pna gen film resgated tag sgc cheb arma mf appnp}"
mkdir -p runs logs
@@ -48,7 +50,7 @@ run_cell() {
for view in ${VIEWS}; do
run_cell "${view}" classic 0 1
- run_cell "${view}" fixed-rrog 3 3
+ run_cell "${view}" fixed-rrog "${FIXED_T}" "${FIXED_NS}"
done
python3 -m rrog.cli results --epochs "${EPOCHS}"