"""Paired comparison of PTRM rollout caches. This focuses on per-problem correct rollout counts, not only best-Q accuracy. Use it when two caches were generated with the same sample seed so `idx` aligns. """ from __future__ import annotations import argparse import csv from pathlib import Path import numpy as np def _load(path: Path) -> dict[str, np.ndarray]: data = np.load(path) return {name: data[name] for name in data.files} def _take(values: np.ndarray, idx: np.ndarray) -> np.ndarray: return values[np.arange(values.shape[0]), idx] def _summary(prefix: str, exact: np.ndarray, q_halt: np.ndarray, det_exact: np.ndarray) -> dict[str, float]: correct_count = exact.sum(axis=1).astype(np.float32) q_sel = _take(exact, q_halt.argmax(axis=1)).astype(bool) oracle = exact.any(axis=1) out = { f"{prefix}/det_exact": float(det_exact.mean()), f"{prefix}/q_max_exact": float(q_sel.mean()), f"{prefix}/oracle_pass": float(oracle.mean()), f"{prefix}/mean_rollout_exact": float(exact.mean()), f"{prefix}/correct_count_mean": float(correct_count.mean()), f"{prefix}/correct_count_median": float(np.median(correct_count)), f"{prefix}/correct_count_q10": float(np.quantile(correct_count, 0.10)), f"{prefix}/correct_count_q25": float(np.quantile(correct_count, 0.25)), f"{prefix}/correct_count_q75": float(np.quantile(correct_count, 0.75)), f"{prefix}/correct_count_q90": float(np.quantile(correct_count, 0.90)), f"{prefix}/zero_correct_frac": float((correct_count == 0).mean()), f"{prefix}/all_correct_frac": float((correct_count == exact.shape[1]).mean()), } for threshold in (1, 5, 10, 25, 50, 75, 90): if threshold <= exact.shape[1]: out[f"{prefix}/correct_count_ge_{threshold}_frac"] = float((correct_count >= threshold).mean()) for name, mask in { "det_success": det_exact.astype(bool), "det_fail": ~det_exact.astype(bool), }.items(): if mask.any(): out[f"{prefix}/correct_count_{name}_mean"] = float(correct_count[mask].mean()) out[f"{prefix}/oracle_{name}_frac"] = float(oracle[mask].mean()) out[f"{prefix}/q_max_{name}_frac"] = float(q_sel[mask].mean()) return out def _prefix_curves(prefix: str, exact: np.ndarray, q_halt: np.ndarray) -> tuple[dict[str, float], list[dict[str, float]]]: rows: list[dict[str, float]] = [] out: dict[str, float] = {} max_k = exact.shape[1] for k in range(1, max_k + 1): exact_k = exact[:, :k] q_k = q_halt[:, :k] oracle = exact_k.any(axis=1) q_sel = _take(exact_k, q_k.argmax(axis=1)).astype(bool) count = exact_k.sum(axis=1).astype(np.float32) row = { "K": float(k), f"{prefix}/oracle_pass": float(oracle.mean()), f"{prefix}/q_max_exact": float(q_sel.mean()), f"{prefix}/correct_count_mean": float(count.mean()), f"{prefix}/zero_correct_frac": float((count == 0).mean()), f"{prefix}/correct_count_ge_5_frac": float((count >= min(5, k)).mean()), f"{prefix}/correct_count_ge_half_frac": float((count >= (k / 2.0)).mean()), } rows.append(row) for target in (0.90, 0.95, 0.975, 0.99): for metric in ("oracle_pass", "q_max_exact"): needed = next( (int(row["K"]) for row in rows if row[f"{prefix}/{metric}"] >= target), None, ) out[f"{prefix}/K_for_{metric}_{target:g}"] = float("nan") if needed is None else float(needed) return out, rows def compare(base_path: Path, test_path: Path, base_label: str, test_label: str) -> tuple[dict[str, float | str], list[dict[str, float]]]: base = _load(base_path) test = _load(test_path) if "idx" in base and "idx" in test and not np.array_equal(base["idx"], test["idx"]): raise ValueError("Caches are not paired: idx arrays differ.") base_exact = base["exact"].astype(bool) test_exact = test["exact"].astype(bool) base_q = base["q_halt"].astype(np.float32) test_q = test["q_halt"].astype(np.float32) base_det = base["det_exact"].astype(bool) test_det = test["det_exact"].astype(bool) base_count = base_exact.sum(axis=1).astype(np.float32) test_count = test_exact.sum(axis=1).astype(np.float32) delta = test_count - base_count base_oracle = base_exact.any(axis=1) test_oracle = test_exact.any(axis=1) base_q_sel = _take(base_exact, base_q.argmax(axis=1)).astype(bool) test_q_sel = _take(test_exact, test_q.argmax(axis=1)).astype(bool) out: dict[str, float | str] = { "base_file": str(base_path), "test_file": str(test_path), "n_samples": float(base_exact.shape[0]), "rollouts": float(base_exact.shape[1]), } out.update(_summary(base_label, base_exact, base_q, base_det)) out.update(_summary(test_label, test_exact, test_q, test_det)) base_curve_summary, base_curve_rows = _prefix_curves(base_label, base_exact, base_q) test_curve_summary, test_curve_rows = _prefix_curves(test_label, test_exact, test_q) out.update(base_curve_summary) out.update(test_curve_summary) curve_rows: list[dict[str, float]] = [] for base_row, test_row in zip(base_curve_rows, test_curve_rows, strict=True): row = {**base_row, **{k: v for k, v in test_row.items() if k != "K"}} k = int(row["K"]) base_prefix = f"{base_label}/" test_prefix = f"{test_label}/" row["delta/oracle_pass"] = row[f"{test_prefix}oracle_pass"] - row[f"{base_prefix}oracle_pass"] row["delta/q_max_exact"] = row[f"{test_prefix}q_max_exact"] - row[f"{base_prefix}q_max_exact"] row["delta/correct_count_mean"] = row[f"{test_prefix}correct_count_mean"] - row[f"{base_prefix}correct_count_mean"] curve_rows.append(row) out.update( { "delta_correct_count_mean": float(delta.mean()), "delta_correct_count_median": float(np.median(delta)), "delta_correct_count_q10": float(np.quantile(delta, 0.10)), "delta_correct_count_q25": float(np.quantile(delta, 0.25)), "delta_correct_count_q75": float(np.quantile(delta, 0.75)), "delta_correct_count_q90": float(np.quantile(delta, 0.90)), "test_more_correct_frac": float((delta > 0).mean()), "test_equal_correct_frac": float((delta == 0).mean()), "test_fewer_correct_frac": float((delta < 0).mean()), "base_zero_test_nonzero_frac": float(((base_count == 0) & (test_count > 0)).mean()), "base_nonzero_test_zero_frac": float(((base_count > 0) & (test_count == 0)).mean()), "both_oracle_success_test_more_frac": float((base_oracle & test_oracle & (delta > 0)).mean()), "oracle_test_only_frac": float((~base_oracle & test_oracle).mean()), "oracle_base_only_frac": float((base_oracle & ~test_oracle).mean()), "oracle_both_success_frac": float((base_oracle & test_oracle).mean()), "oracle_both_fail_frac": float((~base_oracle & ~test_oracle).mean()), "q_max_test_only_frac": float((~base_q_sel & test_q_sel).mean()), "q_max_base_only_frac": float((base_q_sel & ~test_q_sel).mean()), "q_max_both_success_frac": float((base_q_sel & test_q_sel).mean()), "q_max_both_fail_frac": float((~base_q_sel & ~test_q_sel).mean()), } ) hist_rows: list[dict[str, float]] = [] max_count = int(max(base_count.max(), test_count.max())) for count in range(max_count + 1): hist_rows.append( { "correct_count": float(count), f"{base_label}_frac": float((base_count == count).mean()), f"{test_label}_frac": float((test_count == count).mean()), } ) return out, hist_rows, curve_rows def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--base", required=True) parser.add_argument("--test", required=True) parser.add_argument("--base-label", default="base") parser.add_argument("--test-label", default="test") parser.add_argument("--out-prefix", default=None) args = parser.parse_args() summary, hist_rows, curve_rows = compare(Path(args.base), Path(args.test), args.base_label, args.test_label) for key in sorted(summary): print(f"{key}: {summary[key]}") if args.out_prefix: prefix = Path(args.out_prefix) prefix.parent.mkdir(parents=True, exist_ok=True) with (prefix.with_suffix(".summary.csv")).open("w", newline="") as f: writer = csv.DictWriter(f, fieldnames=sorted(summary)) writer.writeheader() writer.writerow(summary) with (prefix.with_suffix(".hist.csv")).open("w", newline="") as f: writer = csv.DictWriter(f, fieldnames=list(hist_rows[0])) writer.writeheader() writer.writerows(hist_rows) with (prefix.with_suffix(".kcurve.csv")).open("w", newline="") as f: writer = csv.DictWriter(f, fieldnames=list(curve_rows[0])) writer.writeheader() writer.writerows(curve_rows) print( f"saved {prefix.with_suffix('.summary.csv')}, " f"{prefix.with_suffix('.hist.csv')}, and {prefix.with_suffix('.kcurve.csv')}" ) if __name__ == "__main__": main()