summaryrefslogtreecommitdiff
path: root/research/flossing/compare_ptrm_rollout_counts.py
diff options
context:
space:
mode:
Diffstat (limited to 'research/flossing/compare_ptrm_rollout_counts.py')
-rw-r--r--research/flossing/compare_ptrm_rollout_counts.py203
1 files changed, 203 insertions, 0 deletions
diff --git a/research/flossing/compare_ptrm_rollout_counts.py b/research/flossing/compare_ptrm_rollout_counts.py
new file mode 100644
index 0000000..0528820
--- /dev/null
+++ b/research/flossing/compare_ptrm_rollout_counts.py
@@ -0,0 +1,203 @@
+"""Paired comparison of PTRM rollout caches.
+
+This focuses on per-problem correct rollout counts, not only best-Q accuracy.
+Use it when two caches were generated with the same sample seed so `idx` aligns.
+"""
+from __future__ import annotations
+
+import argparse
+import csv
+from pathlib import Path
+
+import numpy as np
+
+
+def _load(path: Path) -> dict[str, np.ndarray]:
+ data = np.load(path)
+ return {name: data[name] for name in data.files}
+
+
+def _take(values: np.ndarray, idx: np.ndarray) -> np.ndarray:
+ return values[np.arange(values.shape[0]), idx]
+
+
+def _summary(prefix: str, exact: np.ndarray, q_halt: np.ndarray, det_exact: np.ndarray) -> dict[str, float]:
+ correct_count = exact.sum(axis=1).astype(np.float32)
+ q_sel = _take(exact, q_halt.argmax(axis=1)).astype(bool)
+ oracle = exact.any(axis=1)
+ out = {
+ f"{prefix}/det_exact": float(det_exact.mean()),
+ f"{prefix}/q_max_exact": float(q_sel.mean()),
+ f"{prefix}/oracle_pass": float(oracle.mean()),
+ f"{prefix}/mean_rollout_exact": float(exact.mean()),
+ f"{prefix}/correct_count_mean": float(correct_count.mean()),
+ f"{prefix}/correct_count_median": float(np.median(correct_count)),
+ f"{prefix}/correct_count_q10": float(np.quantile(correct_count, 0.10)),
+ f"{prefix}/correct_count_q25": float(np.quantile(correct_count, 0.25)),
+ f"{prefix}/correct_count_q75": float(np.quantile(correct_count, 0.75)),
+ f"{prefix}/correct_count_q90": float(np.quantile(correct_count, 0.90)),
+ f"{prefix}/zero_correct_frac": float((correct_count == 0).mean()),
+ f"{prefix}/all_correct_frac": float((correct_count == exact.shape[1]).mean()),
+ }
+ for threshold in (1, 5, 10, 25, 50, 75, 90):
+ if threshold <= exact.shape[1]:
+ out[f"{prefix}/correct_count_ge_{threshold}_frac"] = float((correct_count >= threshold).mean())
+ for name, mask in {
+ "det_success": det_exact.astype(bool),
+ "det_fail": ~det_exact.astype(bool),
+ }.items():
+ if mask.any():
+ out[f"{prefix}/correct_count_{name}_mean"] = float(correct_count[mask].mean())
+ out[f"{prefix}/oracle_{name}_frac"] = float(oracle[mask].mean())
+ out[f"{prefix}/q_max_{name}_frac"] = float(q_sel[mask].mean())
+ return out
+
+
+def _prefix_curves(prefix: str, exact: np.ndarray, q_halt: np.ndarray) -> tuple[dict[str, float], list[dict[str, float]]]:
+ rows: list[dict[str, float]] = []
+ out: dict[str, float] = {}
+ max_k = exact.shape[1]
+ for k in range(1, max_k + 1):
+ exact_k = exact[:, :k]
+ q_k = q_halt[:, :k]
+ oracle = exact_k.any(axis=1)
+ q_sel = _take(exact_k, q_k.argmax(axis=1)).astype(bool)
+ count = exact_k.sum(axis=1).astype(np.float32)
+ row = {
+ "K": float(k),
+ f"{prefix}/oracle_pass": float(oracle.mean()),
+ f"{prefix}/q_max_exact": float(q_sel.mean()),
+ f"{prefix}/correct_count_mean": float(count.mean()),
+ f"{prefix}/zero_correct_frac": float((count == 0).mean()),
+ f"{prefix}/correct_count_ge_5_frac": float((count >= min(5, k)).mean()),
+ f"{prefix}/correct_count_ge_half_frac": float((count >= (k / 2.0)).mean()),
+ }
+ rows.append(row)
+
+ for target in (0.90, 0.95, 0.975, 0.99):
+ for metric in ("oracle_pass", "q_max_exact"):
+ needed = next(
+ (int(row["K"]) for row in rows if row[f"{prefix}/{metric}"] >= target),
+ None,
+ )
+ out[f"{prefix}/K_for_{metric}_{target:g}"] = float("nan") if needed is None else float(needed)
+ return out, rows
+
+
+def compare(base_path: Path, test_path: Path, base_label: str, test_label: str) -> tuple[dict[str, float | str], list[dict[str, float]]]:
+ base = _load(base_path)
+ test = _load(test_path)
+ if "idx" in base and "idx" in test and not np.array_equal(base["idx"], test["idx"]):
+ raise ValueError("Caches are not paired: idx arrays differ.")
+
+ base_exact = base["exact"].astype(bool)
+ test_exact = test["exact"].astype(bool)
+ base_q = base["q_halt"].astype(np.float32)
+ test_q = test["q_halt"].astype(np.float32)
+ base_det = base["det_exact"].astype(bool)
+ test_det = test["det_exact"].astype(bool)
+
+ base_count = base_exact.sum(axis=1).astype(np.float32)
+ test_count = test_exact.sum(axis=1).astype(np.float32)
+ delta = test_count - base_count
+ base_oracle = base_exact.any(axis=1)
+ test_oracle = test_exact.any(axis=1)
+ base_q_sel = _take(base_exact, base_q.argmax(axis=1)).astype(bool)
+ test_q_sel = _take(test_exact, test_q.argmax(axis=1)).astype(bool)
+
+ out: dict[str, float | str] = {
+ "base_file": str(base_path),
+ "test_file": str(test_path),
+ "n_samples": float(base_exact.shape[0]),
+ "rollouts": float(base_exact.shape[1]),
+ }
+ out.update(_summary(base_label, base_exact, base_q, base_det))
+ out.update(_summary(test_label, test_exact, test_q, test_det))
+ base_curve_summary, base_curve_rows = _prefix_curves(base_label, base_exact, base_q)
+ test_curve_summary, test_curve_rows = _prefix_curves(test_label, test_exact, test_q)
+ out.update(base_curve_summary)
+ out.update(test_curve_summary)
+ curve_rows: list[dict[str, float]] = []
+ for base_row, test_row in zip(base_curve_rows, test_curve_rows, strict=True):
+ row = {**base_row, **{k: v for k, v in test_row.items() if k != "K"}}
+ k = int(row["K"])
+ base_prefix = f"{base_label}/"
+ test_prefix = f"{test_label}/"
+ row["delta/oracle_pass"] = row[f"{test_prefix}oracle_pass"] - row[f"{base_prefix}oracle_pass"]
+ row["delta/q_max_exact"] = row[f"{test_prefix}q_max_exact"] - row[f"{base_prefix}q_max_exact"]
+ row["delta/correct_count_mean"] = row[f"{test_prefix}correct_count_mean"] - row[f"{base_prefix}correct_count_mean"]
+ curve_rows.append(row)
+ out.update(
+ {
+ "delta_correct_count_mean": float(delta.mean()),
+ "delta_correct_count_median": float(np.median(delta)),
+ "delta_correct_count_q10": float(np.quantile(delta, 0.10)),
+ "delta_correct_count_q25": float(np.quantile(delta, 0.25)),
+ "delta_correct_count_q75": float(np.quantile(delta, 0.75)),
+ "delta_correct_count_q90": float(np.quantile(delta, 0.90)),
+ "test_more_correct_frac": float((delta > 0).mean()),
+ "test_equal_correct_frac": float((delta == 0).mean()),
+ "test_fewer_correct_frac": float((delta < 0).mean()),
+ "base_zero_test_nonzero_frac": float(((base_count == 0) & (test_count > 0)).mean()),
+ "base_nonzero_test_zero_frac": float(((base_count > 0) & (test_count == 0)).mean()),
+ "both_oracle_success_test_more_frac": float((base_oracle & test_oracle & (delta > 0)).mean()),
+ "oracle_test_only_frac": float((~base_oracle & test_oracle).mean()),
+ "oracle_base_only_frac": float((base_oracle & ~test_oracle).mean()),
+ "oracle_both_success_frac": float((base_oracle & test_oracle).mean()),
+ "oracle_both_fail_frac": float((~base_oracle & ~test_oracle).mean()),
+ "q_max_test_only_frac": float((~base_q_sel & test_q_sel).mean()),
+ "q_max_base_only_frac": float((base_q_sel & ~test_q_sel).mean()),
+ "q_max_both_success_frac": float((base_q_sel & test_q_sel).mean()),
+ "q_max_both_fail_frac": float((~base_q_sel & ~test_q_sel).mean()),
+ }
+ )
+
+ hist_rows: list[dict[str, float]] = []
+ max_count = int(max(base_count.max(), test_count.max()))
+ for count in range(max_count + 1):
+ hist_rows.append(
+ {
+ "correct_count": float(count),
+ f"{base_label}_frac": float((base_count == count).mean()),
+ f"{test_label}_frac": float((test_count == count).mean()),
+ }
+ )
+ return out, hist_rows, curve_rows
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--base", required=True)
+ parser.add_argument("--test", required=True)
+ parser.add_argument("--base-label", default="base")
+ parser.add_argument("--test-label", default="test")
+ parser.add_argument("--out-prefix", default=None)
+ args = parser.parse_args()
+
+ summary, hist_rows, curve_rows = compare(Path(args.base), Path(args.test), args.base_label, args.test_label)
+ for key in sorted(summary):
+ print(f"{key}: {summary[key]}")
+
+ if args.out_prefix:
+ prefix = Path(args.out_prefix)
+ prefix.parent.mkdir(parents=True, exist_ok=True)
+ with (prefix.with_suffix(".summary.csv")).open("w", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=sorted(summary))
+ writer.writeheader()
+ writer.writerow(summary)
+ with (prefix.with_suffix(".hist.csv")).open("w", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=list(hist_rows[0]))
+ writer.writeheader()
+ writer.writerows(hist_rows)
+ with (prefix.with_suffix(".kcurve.csv")).open("w", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=list(curve_rows[0]))
+ writer.writeheader()
+ writer.writerows(curve_rows)
+ print(
+ f"saved {prefix.with_suffix('.summary.csv')}, "
+ f"{prefix.with_suffix('.hist.csv')}, and {prefix.with_suffix('.kcurve.csv')}"
+ )
+
+
+if __name__ == "__main__":
+ main()