summaryrefslogtreecommitdiff
path: root/scripts/analyze_ogb_hiv_log.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/analyze_ogb_hiv_log.py')
-rw-r--r--scripts/analyze_ogb_hiv_log.py312
1 files changed, 312 insertions, 0 deletions
diff --git a/scripts/analyze_ogb_hiv_log.py b/scripts/analyze_ogb_hiv_log.py
new file mode 100644
index 0000000..21253fe
--- /dev/null
+++ b/scripts/analyze_ogb_hiv_log.py
@@ -0,0 +1,312 @@
+#!/usr/bin/env python3
+import argparse
+import csv
+import json
+import math
+import re
+from collections import defaultdict
+from pathlib import Path
+
+
+RUN_RE = re.compile(
+ r"^\[run\]\s+"
+ r"(?P<dataset>\S+)\s+view=(?P<view>\S+)\s+compute=(?P<compute>\S+)\s+"
+ r"T=(?P<T>\d+)\s+ns=(?P<ns>\d+)"
+)
+EP_RE = re.compile(
+ r"^ep(?P<ep>\d+)\s+val_(?P<metric>\w+)=(?P<val>[-+0-9.eE]+).*"
+ r"train_steps=(?P<steps>[-+0-9.eE]+)"
+)
+
+
+def _load_json(path: Path) -> dict | None:
+ try:
+ with path.open() as f:
+ return json.load(f)
+ except (OSError, json.JSONDecodeError):
+ return None
+
+
+def _score(rep: dict, split: str) -> float | None:
+ metric = rep.get("metric")
+ if not metric:
+ return None
+ value = rep.get(split, {}).get(metric)
+ return None if value is None else float(value)
+
+
+def parse_curves(path: Path):
+ curves = defaultdict(list)
+ current = None
+ with path.open(errors="replace") as f:
+ for line in f:
+ line = line.strip()
+ m = RUN_RE.match(line)
+ if m:
+ current = (
+ m.group("dataset"),
+ m.group("view"),
+ m.group("compute"),
+ int(m.group("T")),
+ int(m.group("ns")),
+ )
+ continue
+ m = EP_RE.match(line)
+ if current is not None and m:
+ curves[current].append({
+ "ep": int(m.group("ep")),
+ "metric": m.group("metric"),
+ "val": float(m.group("val")),
+ "train_steps": float(m.group("steps")),
+ })
+ return curves
+
+
+def load_runs(runs_dir: Path, dataset: str):
+ runs = {}
+ for path in sorted(runs_dir.glob(f"{dataset}_*.json")):
+ rep = _load_json(path)
+ if rep is None:
+ continue
+ if rep.get("dataset") != dataset:
+ continue
+ key = (
+ rep.get("view"),
+ rep.get("compute"),
+ int(rep.get("T", -1)),
+ int(rep.get("n_sup", -1)),
+ int(rep.get("seed", -1)),
+ )
+ runs[key] = rep
+ return runs
+
+
+def curve_stats(curve: list[dict]) -> dict:
+ if not curve:
+ return {
+ "curve_best_ep": None,
+ "curve_best_val": None,
+ "curve_final_ep": None,
+ "curve_final_val": None,
+ "final_gap": None,
+ "late_slope": None,
+ "range": None,
+ }
+ best = max(curve, key=lambda x: x["val"])
+ final = curve[-1]
+ late_slope = None
+ if len(curve) >= 2:
+ late_slope = final["val"] - curve[-2]["val"]
+ vals = [x["val"] for x in curve]
+ return {
+ "curve_best_ep": best["ep"],
+ "curve_best_val": best["val"],
+ "curve_final_ep": final["ep"],
+ "curve_final_val": final["val"],
+ "final_gap": best["val"] - final["val"],
+ "late_slope": late_slope,
+ "range": max(vals) - min(vals),
+ }
+
+
+def labels_for(row: dict, *, val_tol: float, test_tol: float, collapse_gap: float) -> str:
+ labels = []
+ if row["val_delta"] >= -val_tol and row["test_delta"] >= -test_tol:
+ labels.append("positive_or_tie")
+ elif row["val_delta"] >= -val_tol and row["test_delta"] < -test_tol:
+ labels.append("val_up_test_down")
+ elif row["val_delta"] < -val_tol and row["test_delta"] >= -test_tol:
+ labels.append("test_up_val_down")
+ else:
+ labels.append("val_and_test_down")
+
+ if row["val_delta"] < -0.02:
+ labels.append("fixed_val_lags_classic")
+ if row["cand_final_gap"] is not None and row["cand_final_gap"] > collapse_gap:
+ labels.append("late_collapse")
+ if (
+ row["cand_curve_best_ep"] is not None
+ and row["cand_curve_final_ep"] is not None
+ and row["cand_curve_best_ep"] >= row["cand_curve_final_ep"] - 10
+ and (row["cand_final_gap"] is None or row["cand_final_gap"] <= 0.01)
+ ):
+ labels.append("maybe_undertrained")
+ return ",".join(labels)
+
+
+def fmt(x, digits=4):
+ if x is None:
+ return ""
+ if isinstance(x, int):
+ return str(x)
+ return f"{x:.{digits}f}"
+
+
+def markdown_table(headers, rows):
+ out = [
+ "| " + " | ".join(headers) + " |",
+ "| " + " | ".join(["---"] * len(headers)) + " |",
+ ]
+ out.extend("| " + " | ".join(row) + " |" for row in rows)
+ return "\n".join(out)
+
+
+def mean(xs):
+ return sum(xs) / len(xs) if xs else 0.0
+
+
+def corr(xs, ys):
+ if len(xs) != len(ys) or len(xs) < 2:
+ return 0.0
+ mx = mean(xs)
+ my = mean(ys)
+ den = math.sqrt(sum((x - mx) ** 2 for x in xs) * sum((y - my) ** 2 for y in ys))
+ if den == 0:
+ return 0.0
+ return sum((x - mx) * (y - my) for x, y in zip(xs, ys)) / den
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument("--dataset", default="ogbg-molhiv")
+ ap.add_argument("--runs-dir", default="runs")
+ ap.add_argument("--log", default="logs/ogbg-molhiv_0.log")
+ ap.add_argument("--out-dir", default="analysis")
+ ap.add_argument("--candidate-compute", default="fixed-rrog")
+ ap.add_argument("--candidate-T", type=int, default=3)
+ ap.add_argument("--candidate-n-sup", type=int, default=3)
+ ap.add_argument("--baseline-compute", default="classic")
+ ap.add_argument("--baseline-T", type=int, default=0)
+ ap.add_argument("--baseline-n-sup", type=int, default=1)
+ ap.add_argument("--seed", type=int, default=0)
+ ap.add_argument("--val-tol", type=float, default=0.0)
+ ap.add_argument("--test-tol", type=float, default=0.0)
+ ap.add_argument("--collapse-gap", type=float, default=0.03)
+ args = ap.parse_args()
+
+ runs_dir = Path(args.runs_dir)
+ curves = parse_curves(Path(args.log))
+ runs = load_runs(runs_dir, args.dataset)
+ out_dir = Path(args.out_dir)
+ out_dir.mkdir(parents=True, exist_ok=True)
+
+ rows = []
+ views = sorted({
+ view for (view, compute, _t, _ns, seed) in runs
+ if seed == args.seed and compute == args.baseline_compute
+ })
+ for view in views:
+ base_key = (view, args.baseline_compute, args.baseline_T, args.baseline_n_sup, args.seed)
+ cand_key = (view, args.candidate_compute, args.candidate_T, args.candidate_n_sup, args.seed)
+ base = runs.get(base_key)
+ cand = runs.get(cand_key)
+ if base is None or cand is None:
+ continue
+ metric = cand["metric"]
+ base_val = _score(base, "val")
+ base_test = _score(base, "test")
+ cand_val = _score(cand, "val")
+ cand_test = _score(cand, "test")
+ if None in {base_val, base_test, cand_val, cand_test}:
+ continue
+ base_curve = curve_stats(curves.get((args.dataset, view, args.baseline_compute, args.baseline_T, args.baseline_n_sup), []))
+ cand_curve = curve_stats(curves.get((args.dataset, view, args.candidate_compute, args.candidate_T, args.candidate_n_sup), []))
+ row = {
+ "view": view,
+ "metric": metric,
+ "base_ep": int(base.get("ep") or 0),
+ "base_val": base_val,
+ "base_test": base_test,
+ "cand_ep": int(cand.get("ep") or 0),
+ "cand_val": cand_val,
+ "cand_test": cand_test,
+ "val_delta": cand_val - base_val,
+ "test_delta": cand_test - base_test,
+ "base_curve_best_ep": base_curve["curve_best_ep"],
+ "base_final_gap": base_curve["final_gap"],
+ "cand_curve_best_ep": cand_curve["curve_best_ep"],
+ "cand_curve_final_ep": cand_curve["curve_final_ep"],
+ "cand_final_gap": cand_curve["final_gap"],
+ "cand_late_slope": cand_curve["late_slope"],
+ "cand_range": cand_curve["range"],
+ }
+ row["labels"] = labels_for(
+ row,
+ val_tol=args.val_tol,
+ test_tol=args.test_tol,
+ collapse_gap=args.collapse_gap,
+ )
+ rows.append(row)
+
+ csv_path = out_dir / f"{args.dataset}_{args.candidate_compute}_T{args.candidate_T}_ns{args.candidate_n_sup}_diagnostics.csv"
+ fieldnames = [
+ "view", "metric", "base_ep", "base_val", "base_test", "cand_ep", "cand_val", "cand_test",
+ "val_delta", "test_delta", "base_curve_best_ep", "base_final_gap",
+ "cand_curve_best_ep", "cand_curve_final_ep", "cand_final_gap", "cand_late_slope",
+ "cand_range", "labels",
+ ]
+ with csv_path.open("w", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
+ writer.writeheader()
+ writer.writerows(rows)
+
+ counts = defaultdict(int)
+ for row in rows:
+ for label in row["labels"].split(","):
+ counts[label] += 1
+
+ pos = [r for r in rows if r["val_delta"] >= 0 and r["test_delta"] >= 0]
+ val_up_test_down = [r for r in rows if r["val_delta"] >= 0 and r["test_delta"] < 0]
+ val_down_test_up = [r for r in rows if r["val_delta"] < 0 and r["test_delta"] >= 0]
+ both_down = [r for r in rows if r["val_delta"] < 0 and r["test_delta"] < 0]
+ late = [r for r in rows if "maybe_undertrained" in r["labels"]]
+ collapse = [r for r in rows if "late_collapse" in r["labels"]]
+ val_deltas = [r["val_delta"] for r in rows]
+ test_deltas = [r["test_delta"] for r in rows]
+
+ def row_line(r):
+ return [
+ r["view"],
+ fmt(r["base_val"]),
+ fmt(r["cand_val"]),
+ fmt(r["val_delta"]),
+ fmt(r["base_test"]),
+ fmt(r["cand_test"]),
+ fmt(r["test_delta"]),
+ str(r["cand_ep"]),
+ fmt(r["cand_final_gap"]),
+ r["labels"],
+ ]
+
+ md_path = csv_path.with_suffix(".md")
+ with md_path.open("w") as f:
+ f.write(f"# {args.dataset} {args.candidate_compute} T={args.candidate_T} ns={args.candidate_n_sup} diagnostics\n\n")
+ f.write(f"- rows: {len(rows)}\n")
+ f.write(f"- val+test positive: {len(pos)}\n")
+ f.write(f"- val up, test down: {len(val_up_test_down)}\n")
+ f.write(f"- val down, test up: {len(val_down_test_up)}\n")
+ f.write(f"- val+test down: {len(both_down)}\n")
+ f.write(f"- maybe undertrained: {len(late)}\n")
+ f.write(f"- late collapse: {len(collapse)}\n\n")
+ f.write(f"- mean val delta: {fmt(mean(val_deltas))}\n")
+ f.write(f"- mean test delta: {fmt(mean(test_deltas))}\n")
+ f.write(f"- val/test delta corr: {fmt(corr(val_deltas, test_deltas))}\n\n")
+ f.write("## Label Counts\n\n")
+ f.write(markdown_table(["label", "n"], [[k, str(v)] for k, v in sorted(counts.items())]))
+ f.write("\n\n## Per-Backbone Rows\n\n")
+ headers = [
+ "view", "base_val", "cand_val", "d_val", "base_test", "cand_test",
+ "d_test", "cand_ep", "cand_final_gap", "labels",
+ ]
+ f.write(markdown_table(headers, [row_line(r) for r in rows]))
+ f.write("\n")
+
+ print(f"wrote {csv_path}")
+ print(f"wrote {md_path}")
+ print(f"rows={len(rows)} val+test={len(pos)} val_up_test_down={len(val_up_test_down)} "
+ f"val_down_test_up={len(val_down_test_up)} both_down={len(both_down)} "
+ f"maybe_undertrained={len(late)} late_collapse={len(collapse)}")
+
+
+if __name__ == "__main__":
+ main()