From 66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Sat, 13 Jun 2026 12:35:36 -0500 Subject: rrm workspace: TRM/HRM/SRM code, Maze dataset, dynamical-analysis pipeline Curated export for clone-and-run Maze training (2x A6000) + diagnostics. trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible). Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md. Co-Authored-By: Claude Fable 5 --- research/flossing/compare_ptrm_rollout_counts.py | 203 +++++++++++++++++++++++ 1 file changed, 203 insertions(+) create mode 100644 research/flossing/compare_ptrm_rollout_counts.py (limited to 'research/flossing/compare_ptrm_rollout_counts.py') diff --git a/research/flossing/compare_ptrm_rollout_counts.py b/research/flossing/compare_ptrm_rollout_counts.py new file mode 100644 index 0000000..0528820 --- /dev/null +++ b/research/flossing/compare_ptrm_rollout_counts.py @@ -0,0 +1,203 @@ +"""Paired comparison of PTRM rollout caches. + +This focuses on per-problem correct rollout counts, not only best-Q accuracy. +Use it when two caches were generated with the same sample seed so `idx` aligns. +""" +from __future__ import annotations + +import argparse +import csv +from pathlib import Path + +import numpy as np + + +def _load(path: Path) -> dict[str, np.ndarray]: + data = np.load(path) + return {name: data[name] for name in data.files} + + +def _take(values: np.ndarray, idx: np.ndarray) -> np.ndarray: + return values[np.arange(values.shape[0]), idx] + + +def _summary(prefix: str, exact: np.ndarray, q_halt: np.ndarray, det_exact: np.ndarray) -> dict[str, float]: + correct_count = exact.sum(axis=1).astype(np.float32) + q_sel = _take(exact, q_halt.argmax(axis=1)).astype(bool) + oracle = exact.any(axis=1) + out = { + f"{prefix}/det_exact": float(det_exact.mean()), + f"{prefix}/q_max_exact": float(q_sel.mean()), + f"{prefix}/oracle_pass": float(oracle.mean()), + f"{prefix}/mean_rollout_exact": float(exact.mean()), + f"{prefix}/correct_count_mean": float(correct_count.mean()), + f"{prefix}/correct_count_median": float(np.median(correct_count)), + f"{prefix}/correct_count_q10": float(np.quantile(correct_count, 0.10)), + f"{prefix}/correct_count_q25": float(np.quantile(correct_count, 0.25)), + f"{prefix}/correct_count_q75": float(np.quantile(correct_count, 0.75)), + f"{prefix}/correct_count_q90": float(np.quantile(correct_count, 0.90)), + f"{prefix}/zero_correct_frac": float((correct_count == 0).mean()), + f"{prefix}/all_correct_frac": float((correct_count == exact.shape[1]).mean()), + } + for threshold in (1, 5, 10, 25, 50, 75, 90): + if threshold <= exact.shape[1]: + out[f"{prefix}/correct_count_ge_{threshold}_frac"] = float((correct_count >= threshold).mean()) + for name, mask in { + "det_success": det_exact.astype(bool), + "det_fail": ~det_exact.astype(bool), + }.items(): + if mask.any(): + out[f"{prefix}/correct_count_{name}_mean"] = float(correct_count[mask].mean()) + out[f"{prefix}/oracle_{name}_frac"] = float(oracle[mask].mean()) + out[f"{prefix}/q_max_{name}_frac"] = float(q_sel[mask].mean()) + return out + + +def _prefix_curves(prefix: str, exact: np.ndarray, q_halt: np.ndarray) -> tuple[dict[str, float], list[dict[str, float]]]: + rows: list[dict[str, float]] = [] + out: dict[str, float] = {} + max_k = exact.shape[1] + for k in range(1, max_k + 1): + exact_k = exact[:, :k] + q_k = q_halt[:, :k] + oracle = exact_k.any(axis=1) + q_sel = _take(exact_k, q_k.argmax(axis=1)).astype(bool) + count = exact_k.sum(axis=1).astype(np.float32) + row = { + "K": float(k), + f"{prefix}/oracle_pass": float(oracle.mean()), + f"{prefix}/q_max_exact": float(q_sel.mean()), + f"{prefix}/correct_count_mean": float(count.mean()), + f"{prefix}/zero_correct_frac": float((count == 0).mean()), + f"{prefix}/correct_count_ge_5_frac": float((count >= min(5, k)).mean()), + f"{prefix}/correct_count_ge_half_frac": float((count >= (k / 2.0)).mean()), + } + rows.append(row) + + for target in (0.90, 0.95, 0.975, 0.99): + for metric in ("oracle_pass", "q_max_exact"): + needed = next( + (int(row["K"]) for row in rows if row[f"{prefix}/{metric}"] >= target), + None, + ) + out[f"{prefix}/K_for_{metric}_{target:g}"] = float("nan") if needed is None else float(needed) + return out, rows + + +def compare(base_path: Path, test_path: Path, base_label: str, test_label: str) -> tuple[dict[str, float | str], list[dict[str, float]]]: + base = _load(base_path) + test = _load(test_path) + if "idx" in base and "idx" in test and not np.array_equal(base["idx"], test["idx"]): + raise ValueError("Caches are not paired: idx arrays differ.") + + base_exact = base["exact"].astype(bool) + test_exact = test["exact"].astype(bool) + base_q = base["q_halt"].astype(np.float32) + test_q = test["q_halt"].astype(np.float32) + base_det = base["det_exact"].astype(bool) + test_det = test["det_exact"].astype(bool) + + base_count = base_exact.sum(axis=1).astype(np.float32) + test_count = test_exact.sum(axis=1).astype(np.float32) + delta = test_count - base_count + base_oracle = base_exact.any(axis=1) + test_oracle = test_exact.any(axis=1) + base_q_sel = _take(base_exact, base_q.argmax(axis=1)).astype(bool) + test_q_sel = _take(test_exact, test_q.argmax(axis=1)).astype(bool) + + out: dict[str, float | str] = { + "base_file": str(base_path), + "test_file": str(test_path), + "n_samples": float(base_exact.shape[0]), + "rollouts": float(base_exact.shape[1]), + } + out.update(_summary(base_label, base_exact, base_q, base_det)) + out.update(_summary(test_label, test_exact, test_q, test_det)) + base_curve_summary, base_curve_rows = _prefix_curves(base_label, base_exact, base_q) + test_curve_summary, test_curve_rows = _prefix_curves(test_label, test_exact, test_q) + out.update(base_curve_summary) + out.update(test_curve_summary) + curve_rows: list[dict[str, float]] = [] + for base_row, test_row in zip(base_curve_rows, test_curve_rows, strict=True): + row = {**base_row, **{k: v for k, v in test_row.items() if k != "K"}} + k = int(row["K"]) + base_prefix = f"{base_label}/" + test_prefix = f"{test_label}/" + row["delta/oracle_pass"] = row[f"{test_prefix}oracle_pass"] - row[f"{base_prefix}oracle_pass"] + row["delta/q_max_exact"] = row[f"{test_prefix}q_max_exact"] - row[f"{base_prefix}q_max_exact"] + row["delta/correct_count_mean"] = row[f"{test_prefix}correct_count_mean"] - row[f"{base_prefix}correct_count_mean"] + curve_rows.append(row) + out.update( + { + "delta_correct_count_mean": float(delta.mean()), + "delta_correct_count_median": float(np.median(delta)), + "delta_correct_count_q10": float(np.quantile(delta, 0.10)), + "delta_correct_count_q25": float(np.quantile(delta, 0.25)), + "delta_correct_count_q75": float(np.quantile(delta, 0.75)), + "delta_correct_count_q90": float(np.quantile(delta, 0.90)), + "test_more_correct_frac": float((delta > 0).mean()), + "test_equal_correct_frac": float((delta == 0).mean()), + "test_fewer_correct_frac": float((delta < 0).mean()), + "base_zero_test_nonzero_frac": float(((base_count == 0) & (test_count > 0)).mean()), + "base_nonzero_test_zero_frac": float(((base_count > 0) & (test_count == 0)).mean()), + "both_oracle_success_test_more_frac": float((base_oracle & test_oracle & (delta > 0)).mean()), + "oracle_test_only_frac": float((~base_oracle & test_oracle).mean()), + "oracle_base_only_frac": float((base_oracle & ~test_oracle).mean()), + "oracle_both_success_frac": float((base_oracle & test_oracle).mean()), + "oracle_both_fail_frac": float((~base_oracle & ~test_oracle).mean()), + "q_max_test_only_frac": float((~base_q_sel & test_q_sel).mean()), + "q_max_base_only_frac": float((base_q_sel & ~test_q_sel).mean()), + "q_max_both_success_frac": float((base_q_sel & test_q_sel).mean()), + "q_max_both_fail_frac": float((~base_q_sel & ~test_q_sel).mean()), + } + ) + + hist_rows: list[dict[str, float]] = [] + max_count = int(max(base_count.max(), test_count.max())) + for count in range(max_count + 1): + hist_rows.append( + { + "correct_count": float(count), + f"{base_label}_frac": float((base_count == count).mean()), + f"{test_label}_frac": float((test_count == count).mean()), + } + ) + return out, hist_rows, curve_rows + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--base", required=True) + parser.add_argument("--test", required=True) + parser.add_argument("--base-label", default="base") + parser.add_argument("--test-label", default="test") + parser.add_argument("--out-prefix", default=None) + args = parser.parse_args() + + summary, hist_rows, curve_rows = compare(Path(args.base), Path(args.test), args.base_label, args.test_label) + for key in sorted(summary): + print(f"{key}: {summary[key]}") + + if args.out_prefix: + prefix = Path(args.out_prefix) + prefix.parent.mkdir(parents=True, exist_ok=True) + with (prefix.with_suffix(".summary.csv")).open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=sorted(summary)) + writer.writeheader() + writer.writerow(summary) + with (prefix.with_suffix(".hist.csv")).open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=list(hist_rows[0])) + writer.writeheader() + writer.writerows(hist_rows) + with (prefix.with_suffix(".kcurve.csv")).open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=list(curve_rows[0])) + writer.writeheader() + writer.writerows(curve_rows) + print( + f"saved {prefix.with_suffix('.summary.csv')}, " + f"{prefix.with_suffix('.hist.csv')}, and {prefix.with_suffix('.kcurve.csv')}" + ) + + +if __name__ == "__main__": + main() -- cgit v1.2.3