from __future__ import annotations import csv from pathlib import Path import matplotlib.pyplot as plt import numpy as np ROOT = Path("research/flossing") IN_DIR = ROOT / "q_lambda_scatter" OUT = ROOT / "meeting_artifacts_v2" RUNS = [ ( "TRM baseline + PTRM rollouts", IN_DIR / "base58590_k25_d64_sigma03_Lonly_fdlyap_n512_seed20260602.npz", ), ( "TRM multi4 + PTRM rollouts", IN_DIR / "multi4_35805_k25_d64_sigma03_Lonly_fdlyap_n512_seed20260602.npz", ), ] def rank_average(values: np.ndarray) -> np.ndarray: order = np.argsort(values, kind="mergesort") sorted_values = values[order] ranks = np.empty(len(values), dtype=float) i = 0 while i < len(values): j = i + 1 while j < len(values) and sorted_values[j] == sorted_values[i]: j += 1 ranks[order[i:j]] = 0.5 * (i + j - 1) + 1.0 i = j return ranks def corr(x: np.ndarray, y: np.ndarray) -> float: mask = np.isfinite(x) & np.isfinite(y) x = x[mask] y = y[mask] if len(x) < 3: return float("nan") x = x - x.mean() y = y - y.mean() denom = np.sqrt(np.square(x).sum() * np.square(y).sum()) if denom <= 0: return float("nan") return float(np.dot(x, y) / denom) def spearman(x: np.ndarray, y: np.ndarray) -> float: mask = np.isfinite(x) & np.isfinite(y) x = x[mask] y = y[mask] if len(x) < 3: return float("nan") return corr(rank_average(x), rank_average(y)) def mean_within_problem_corr(stability: np.ndarray, q_halt: np.ndarray, rank: bool) -> float: vals: list[float] = [] for x, y in zip(stability, q_halt): val = spearman(x, y) if rank else corr(x, y) if np.isfinite(val): vals.append(val) return float(np.mean(vals)) if vals else float("nan") def summarize(name: str, path: Path) -> dict[str, float | str]: data = np.load(path) exact = data["exact"].astype(bool) q_halt = data["q_halt"].astype(float) lyap = data["lyap"].astype(float) if lyap.size == 0: raise ValueError(f"{path} does not contain lyap") stability = -lyap arange = np.arange(exact.shape[0]) q_idx = q_halt.argmax(axis=1) lyap_idx = lyap.argmin(axis=1) correct_count = exact.sum(axis=1) mixed = (correct_count > 0) & (correct_count < exact.shape[1]) mixed_summary: dict[str, float] = { "mixed_problem_count": float(mixed.sum()), "zero_success_problem_count": float((correct_count == 0).sum()), "full_success_problem_count": float((correct_count == exact.shape[1]).sum()), "mixed_global_pearson_q_vs_stability": float("nan"), "mixed_q_max_exact": float("nan"), "mixed_lambda_min_exact": float("nan"), "mixed_oracle_exact": float("nan"), } if mixed.any(): m_arange = np.arange(int(mixed.sum())) m_exact = exact[mixed] m_q = q_halt[mixed] m_lyap = lyap[mixed] mixed_summary.update( { "mixed_global_pearson_q_vs_stability": corr((-m_lyap).reshape(-1), m_q.reshape(-1)), "mixed_q_max_exact": float(m_exact[m_arange, m_q.argmax(axis=1)].mean()), "mixed_lambda_min_exact": float(m_exact[m_arange, m_lyap.argmin(axis=1)].mean()), "mixed_oracle_exact": float(m_exact.any(axis=1).mean()), } ) out: dict[str, float | str] = { "name": name, "path": str(path), "n_samples": float(exact.shape[0]), "rollouts": float(exact.shape[1]), "mean_rollout_exact": float(exact.mean()), "q_max_exact": float(exact[arange, q_idx].mean()), "lambda_min_exact": float(exact[arange, lyap_idx].mean()), "oracle_pass_exact": float(exact.any(axis=1).mean()), "q_lambda_same_argmax_frac": float((q_idx == lyap_idx).mean()), "global_pearson_q_vs_stability": corr(stability.reshape(-1), q_halt.reshape(-1)), "global_spearman_q_vs_stability": spearman(stability.reshape(-1), q_halt.reshape(-1)), "within_problem_pearson_mean": mean_within_problem_corr(stability, q_halt, rank=False), "within_problem_spearman_mean": mean_within_problem_corr(stability, q_halt, rank=True), "q_success_mean": float(q_halt[exact].mean()) if exact.any() else float("nan"), "q_fail_mean": float(q_halt[~exact].mean()) if (~exact).any() else float("nan"), "lambda_success_mean": float(lyap[exact].mean()) if exact.any() else float("nan"), "lambda_fail_mean": float(lyap[~exact].mean()) if (~exact).any() else float("nan"), } out.update(mixed_summary) return out def scatter_panel( ax: plt.Axes, stability_2d: np.ndarray, q_2d: np.ndarray, exact_2d: np.ndarray, title: str, ) -> None: stability = stability_2d.reshape(-1) q_halt = q_2d.reshape(-1) exact = exact_2d.reshape(-1) finite = np.isfinite(stability) & np.isfinite(q_halt) exact = exact[finite] q_halt = q_halt[finite] stability = stability[finite] if len(stability) == 0: ax.set_title(title + "\n(no points)") return xlo, xhi = np.quantile(stability, [0.005, 0.995]) ylo, yhi = np.quantile(q_halt, [0.005, 0.995]) visible = (stability >= xlo) & (stability <= xhi) & (q_halt >= ylo) & (q_halt <= yhi) ax.scatter( stability[visible & ~exact], q_halt[visible & ~exact], s=8, alpha=0.22, color="#dc2626", linewidths=0, label="incorrect rollout", ) ax.scatter( stability[visible & exact], q_halt[visible & exact], s=8, alpha=0.17, color="#2563eb", linewidths=0, label="correct rollout", ) fit = visible if int(fit.sum()) >= 3: slope, intercept = np.polyfit(stability[fit], q_halt[fit], 1) xs = np.linspace(xlo, xhi, 100) ax.plot(xs, slope * xs + intercept, color="black", linewidth=1.8, alpha=0.75) ax.set_title(title) ax.set_xlim(xlo, xhi) ax.set_ylim(ylo, yhi) ax.grid(alpha=0.22) def plot() -> list[dict[str, float | str]]: missing = [path for _name, path in RUNS if not path.exists()] if missing: raise SystemExit("missing input files:\n" + "\n".join(str(p) for p in missing)) OUT.mkdir(parents=True, exist_ok=True) summaries = [summarize(name, path) for name, path in RUNS] fig, axes = plt.subplots(2, len(RUNS), figsize=(12.0, 8.2), sharey="row") if len(RUNS) == 1: axes = np.asarray(axes).reshape(2, 1) for col, ((name, path), summary) in enumerate(zip(RUNS, summaries)): data = np.load(path) exact = data["exact"].astype(bool) q_halt = data["q_halt"].astype(float) stability = -data["lyap"].astype(float) correct_count = exact.sum(axis=1) mixed = (correct_count > 0) & (correct_count < exact.shape[1]) scatter_panel( axes[0, col], stability, q_halt, exact, f"{name}: all rollouts\n" f"r={summary['global_pearson_q_vs_stability']:.2f}, " f"rho={summary['global_spearman_q_vs_stability']:.2f}, " f"Q exact={summary['q_max_exact']:.3f}", ) scatter_panel( axes[1, col], stability[mixed], q_halt[mixed], exact[mixed], f"mixed problems only (n={int(summary['mixed_problem_count'])})\n" f"r={summary['mixed_global_pearson_q_vs_stability']:.2f}, " f"Q={summary['mixed_q_max_exact']:.3f}, " f"lambda-min={summary['mixed_lambda_min_exact']:.3f}", ) for ax in axes[1, :]: ax.set_xlabel("stability proxy = -lambda_1") axes[0, 0].set_ylabel("Q-head halt logit") axes[1, 0].set_ylabel("Q-head halt logit") axes[0, 0].legend(frameon=False, loc="lower right", fontsize=8) fig.suptitle("PTRM Q-head score contains a stability signal; mixed problems reveal selector behavior") fig.tight_layout() fig.savefig(OUT / "fig5_qhead_vs_lambda1_ptrm.png", dpi=240) plt.close(fig) out_csv = OUT / "fig5_qhead_vs_lambda1_ptrm_summary.csv" with out_csv.open("w", newline="") as f: writer = csv.DictWriter(f, fieldnames=list(summaries[0].keys())) writer.writeheader() writer.writerows(summaries) return summaries def main() -> None: summaries = plot() for row in summaries: print( f"{row['name']}: " f"q_exact={row['q_max_exact']:.4f} " f"lambda_min_exact={row['lambda_min_exact']:.4f} " f"oracle={row['oracle_pass_exact']:.4f} " f"pearson={row['global_pearson_q_vs_stability']:.4f} " f"spearman={row['global_spearman_q_vs_stability']:.4f} " f"within_spearman={row['within_problem_spearman_mean']:.4f} " f"mixed_pearson={row['mixed_global_pearson_q_vs_stability']:.4f}" ) print(f"wrote {OUT / 'fig5_qhead_vs_lambda1_ptrm.png'}") print(f"wrote {OUT / 'fig5_qhead_vs_lambda1_ptrm_summary.csv'}") if __name__ == "__main__": main()