diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
| commit | 66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch) | |
| tree | c29cba61124018755a19b02c9d33e3ad5f2e05cc /research/flossing/report_bundle_20260603/scripts/make_q_lambda_scatter.py | |
Curated export for clone-and-run Maze training (2x A6000) + diagnostics.
trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible).
Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'research/flossing/report_bundle_20260603/scripts/make_q_lambda_scatter.py')
| -rw-r--r-- | research/flossing/report_bundle_20260603/scripts/make_q_lambda_scatter.py | 263 |
1 files changed, 263 insertions, 0 deletions
diff --git a/research/flossing/report_bundle_20260603/scripts/make_q_lambda_scatter.py b/research/flossing/report_bundle_20260603/scripts/make_q_lambda_scatter.py new file mode 100644 index 0000000..54ff8d2 --- /dev/null +++ b/research/flossing/report_bundle_20260603/scripts/make_q_lambda_scatter.py @@ -0,0 +1,263 @@ +from __future__ import annotations + +import csv +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + + +ROOT = Path("research/flossing") +IN_DIR = ROOT / "q_lambda_scatter" +OUT = ROOT / "meeting_artifacts_v2" + +RUNS = [ + ( + "TRM baseline + PTRM rollouts", + IN_DIR / "base58590_k25_d64_sigma03_Lonly_fdlyap_n512_seed20260602.npz", + ), + ( + "TRM multi4 + PTRM rollouts", + IN_DIR / "multi4_35805_k25_d64_sigma03_Lonly_fdlyap_n512_seed20260602.npz", + ), +] + + +def rank_average(values: np.ndarray) -> np.ndarray: + order = np.argsort(values, kind="mergesort") + sorted_values = values[order] + ranks = np.empty(len(values), dtype=float) + i = 0 + while i < len(values): + j = i + 1 + while j < len(values) and sorted_values[j] == sorted_values[i]: + j += 1 + ranks[order[i:j]] = 0.5 * (i + j - 1) + 1.0 + i = j + return ranks + + +def corr(x: np.ndarray, y: np.ndarray) -> float: + mask = np.isfinite(x) & np.isfinite(y) + x = x[mask] + y = y[mask] + if len(x) < 3: + return float("nan") + x = x - x.mean() + y = y - y.mean() + denom = np.sqrt(np.square(x).sum() * np.square(y).sum()) + if denom <= 0: + return float("nan") + return float(np.dot(x, y) / denom) + + +def spearman(x: np.ndarray, y: np.ndarray) -> float: + mask = np.isfinite(x) & np.isfinite(y) + x = x[mask] + y = y[mask] + if len(x) < 3: + return float("nan") + return corr(rank_average(x), rank_average(y)) + + +def mean_within_problem_corr(stability: np.ndarray, q_halt: np.ndarray, rank: bool) -> float: + vals: list[float] = [] + for x, y in zip(stability, q_halt): + val = spearman(x, y) if rank else corr(x, y) + if np.isfinite(val): + vals.append(val) + return float(np.mean(vals)) if vals else float("nan") + + +def summarize(name: str, path: Path) -> dict[str, float | str]: + data = np.load(path) + exact = data["exact"].astype(bool) + q_halt = data["q_halt"].astype(float) + lyap = data["lyap"].astype(float) + if lyap.size == 0: + raise ValueError(f"{path} does not contain lyap") + stability = -lyap + + arange = np.arange(exact.shape[0]) + q_idx = q_halt.argmax(axis=1) + lyap_idx = lyap.argmin(axis=1) + correct_count = exact.sum(axis=1) + mixed = (correct_count > 0) & (correct_count < exact.shape[1]) + + mixed_summary: dict[str, float] = { + "mixed_problem_count": float(mixed.sum()), + "zero_success_problem_count": float((correct_count == 0).sum()), + "full_success_problem_count": float((correct_count == exact.shape[1]).sum()), + "mixed_global_pearson_q_vs_stability": float("nan"), + "mixed_q_max_exact": float("nan"), + "mixed_lambda_min_exact": float("nan"), + "mixed_oracle_exact": float("nan"), + } + if mixed.any(): + m_arange = np.arange(int(mixed.sum())) + m_exact = exact[mixed] + m_q = q_halt[mixed] + m_lyap = lyap[mixed] + mixed_summary.update( + { + "mixed_global_pearson_q_vs_stability": corr((-m_lyap).reshape(-1), m_q.reshape(-1)), + "mixed_q_max_exact": float(m_exact[m_arange, m_q.argmax(axis=1)].mean()), + "mixed_lambda_min_exact": float(m_exact[m_arange, m_lyap.argmin(axis=1)].mean()), + "mixed_oracle_exact": float(m_exact.any(axis=1).mean()), + } + ) + + out: dict[str, float | str] = { + "name": name, + "path": str(path), + "n_samples": float(exact.shape[0]), + "rollouts": float(exact.shape[1]), + "mean_rollout_exact": float(exact.mean()), + "q_max_exact": float(exact[arange, q_idx].mean()), + "lambda_min_exact": float(exact[arange, lyap_idx].mean()), + "oracle_pass_exact": float(exact.any(axis=1).mean()), + "q_lambda_same_argmax_frac": float((q_idx == lyap_idx).mean()), + "global_pearson_q_vs_stability": corr(stability.reshape(-1), q_halt.reshape(-1)), + "global_spearman_q_vs_stability": spearman(stability.reshape(-1), q_halt.reshape(-1)), + "within_problem_pearson_mean": mean_within_problem_corr(stability, q_halt, rank=False), + "within_problem_spearman_mean": mean_within_problem_corr(stability, q_halt, rank=True), + "q_success_mean": float(q_halt[exact].mean()) if exact.any() else float("nan"), + "q_fail_mean": float(q_halt[~exact].mean()) if (~exact).any() else float("nan"), + "lambda_success_mean": float(lyap[exact].mean()) if exact.any() else float("nan"), + "lambda_fail_mean": float(lyap[~exact].mean()) if (~exact).any() else float("nan"), + } + out.update(mixed_summary) + return out + + +def scatter_panel( + ax: plt.Axes, + stability_2d: np.ndarray, + q_2d: np.ndarray, + exact_2d: np.ndarray, + title: str, +) -> None: + stability = stability_2d.reshape(-1) + q_halt = q_2d.reshape(-1) + exact = exact_2d.reshape(-1) + finite = np.isfinite(stability) & np.isfinite(q_halt) + exact = exact[finite] + q_halt = q_halt[finite] + stability = stability[finite] + if len(stability) == 0: + ax.set_title(title + "\n(no points)") + return + + xlo, xhi = np.quantile(stability, [0.005, 0.995]) + ylo, yhi = np.quantile(q_halt, [0.005, 0.995]) + visible = (stability >= xlo) & (stability <= xhi) & (q_halt >= ylo) & (q_halt <= yhi) + + ax.scatter( + stability[visible & ~exact], + q_halt[visible & ~exact], + s=8, + alpha=0.22, + color="#dc2626", + linewidths=0, + label="incorrect rollout", + ) + ax.scatter( + stability[visible & exact], + q_halt[visible & exact], + s=8, + alpha=0.17, + color="#2563eb", + linewidths=0, + label="correct rollout", + ) + + fit = visible + if int(fit.sum()) >= 3: + slope, intercept = np.polyfit(stability[fit], q_halt[fit], 1) + xs = np.linspace(xlo, xhi, 100) + ax.plot(xs, slope * xs + intercept, color="black", linewidth=1.8, alpha=0.75) + + ax.set_title(title) + ax.set_xlim(xlo, xhi) + ax.set_ylim(ylo, yhi) + ax.grid(alpha=0.22) + + +def plot() -> list[dict[str, float | str]]: + missing = [path for _name, path in RUNS if not path.exists()] + if missing: + raise SystemExit("missing input files:\n" + "\n".join(str(p) for p in missing)) + + OUT.mkdir(parents=True, exist_ok=True) + summaries = [summarize(name, path) for name, path in RUNS] + + fig, axes = plt.subplots(2, len(RUNS), figsize=(12.0, 8.2), sharey="row") + if len(RUNS) == 1: + axes = np.asarray(axes).reshape(2, 1) + + for col, ((name, path), summary) in enumerate(zip(RUNS, summaries)): + data = np.load(path) + exact = data["exact"].astype(bool) + q_halt = data["q_halt"].astype(float) + stability = -data["lyap"].astype(float) + correct_count = exact.sum(axis=1) + mixed = (correct_count > 0) & (correct_count < exact.shape[1]) + + scatter_panel( + axes[0, col], + stability, + q_halt, + exact, + f"{name}: all rollouts\n" + f"r={summary['global_pearson_q_vs_stability']:.2f}, " + f"rho={summary['global_spearman_q_vs_stability']:.2f}, " + f"Q exact={summary['q_max_exact']:.3f}", + ) + scatter_panel( + axes[1, col], + stability[mixed], + q_halt[mixed], + exact[mixed], + f"mixed problems only (n={int(summary['mixed_problem_count'])})\n" + f"r={summary['mixed_global_pearson_q_vs_stability']:.2f}, " + f"Q={summary['mixed_q_max_exact']:.3f}, " + f"lambda-min={summary['mixed_lambda_min_exact']:.3f}", + ) + + for ax in axes[1, :]: + ax.set_xlabel("stability proxy = -lambda_1") + axes[0, 0].set_ylabel("Q-head halt logit") + axes[1, 0].set_ylabel("Q-head halt logit") + axes[0, 0].legend(frameon=False, loc="lower right", fontsize=8) + fig.suptitle("PTRM Q-head score contains a stability signal; mixed problems reveal selector behavior") + fig.tight_layout() + fig.savefig(OUT / "fig5_qhead_vs_lambda1_ptrm.png", dpi=240) + plt.close(fig) + + out_csv = OUT / "fig5_qhead_vs_lambda1_ptrm_summary.csv" + with out_csv.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=list(summaries[0].keys())) + writer.writeheader() + writer.writerows(summaries) + return summaries + + +def main() -> None: + summaries = plot() + for row in summaries: + print( + f"{row['name']}: " + f"q_exact={row['q_max_exact']:.4f} " + f"lambda_min_exact={row['lambda_min_exact']:.4f} " + f"oracle={row['oracle_pass_exact']:.4f} " + f"pearson={row['global_pearson_q_vs_stability']:.4f} " + f"spearman={row['global_spearman_q_vs_stability']:.4f} " + f"within_spearman={row['within_problem_spearman_mean']:.4f} " + f"mixed_pearson={row['mixed_global_pearson_q_vs_stability']:.4f}" + ) + print(f"wrote {OUT / 'fig5_qhead_vs_lambda1_ptrm.png'}") + print(f"wrote {OUT / 'fig5_qhead_vs_lambda1_ptrm_summary.csv'}") + + +if __name__ == "__main__": + main() |
