diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
| commit | 66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch) | |
| tree | c29cba61124018755a19b02c9d33e3ad5f2e05cc /research/flossing/report_bundle_20260603/scripts | |
Curated export for clone-and-run Maze training (2x A6000) + diagnostics.
trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible).
Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'research/flossing/report_bundle_20260603/scripts')
| -rw-r--r-- | research/flossing/report_bundle_20260603/scripts/make_meeting_artifacts_v2.py | 364 | ||||
| -rw-r--r-- | research/flossing/report_bundle_20260603/scripts/make_q_lambda_scatter.py | 263 |
2 files changed, 627 insertions, 0 deletions
diff --git a/research/flossing/report_bundle_20260603/scripts/make_meeting_artifacts_v2.py b/research/flossing/report_bundle_20260603/scripts/make_meeting_artifacts_v2.py new file mode 100644 index 0000000..4c88412 --- /dev/null +++ b/research/flossing/report_bundle_20260603/scripts/make_meeting_artifacts_v2.py @@ -0,0 +1,364 @@ +from __future__ import annotations + +import csv +import re +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + + +ROOT = Path("research/flossing") +OUT = ROOT / "meeting_artifacts_v2" + + +def read_csv(path: Path) -> list[dict[str, str]]: + with path.open(newline="") as f: + return list(csv.DictReader(f)) + + +def f(row: dict[str, str], key: str, default: float = float("nan")) -> float: + try: + val = row.get(key, "") + return float(val) if val != "" else default + except Exception: + return default + + +def row_for(rows: list[dict[str, str]], step: int) -> dict[str, str]: + for row in rows: + if int(float(row["step"])) == step: + return row + raise KeyError(step) + + +def spectrum_metrics(npz_path: Path) -> dict[str, float]: + data = np.load(npz_path) + spec = data["lyap_spec"].astype(float) + exact = data["exact_correct"].astype(bool) + mean8 = spec.mean(axis=1) + pos_count = (spec > 0).sum(axis=1) + return { + "sample_exact": float(exact.mean()), + "lambda1_all": float(spec[:, 0].mean()), + "mean8_all": float(mean8.mean()), + "pos_count_all": float(pos_count.mean()), + "lambda1_success": float(spec[exact, 0].mean()), + "lambda1_fail": float(spec[~exact, 0].mean()), + "mean8_success": float(mean8[exact].mean()), + "mean8_fail": float(mean8[~exact].mean()), + "pos_count_success": float(pos_count[exact].mean()), + "pos_count_fail": float(pos_count[~exact].mean()), + } + + +def spectrum_arrays(npz_path: Path) -> tuple[np.ndarray, np.ndarray]: + data = np.load(npz_path) + return data["lyap_spec"].astype(float), data["exact_correct"].astype(bool) + + +def auc_failure_score(score: np.ndarray, exact: np.ndarray) -> float: + """AUROC where higher score predicts failure.""" + fail = ~exact + n_fail = int(fail.sum()) + n_succ = int(exact.sum()) + if n_fail == 0 or n_succ == 0: + return float("nan") + order = np.argsort(score) + ranks = np.empty_like(order, dtype=float) + ranks[order] = np.arange(1, len(score) + 1) + return float((ranks[fail].sum() - n_fail * (n_fail + 1) / 2) / (n_fail * n_succ)) + + +def get_data() -> dict[str, object]: + eval_dir = ROOT / "multi4_eval_compare" + spec_dir = ROOT / "official_gbs768_spectrum" + + hrm_base = read_csv(eval_dir / "hrm_baseline_eval.csv") + hrm_multi = read_csv(eval_dir / "hrm_multi4_complete_eval.csv") + trm_base = read_csv(eval_dir / "trm_official_gbs768_eval.csv") + trm_multi = read_csv(eval_dir / "trm_official_gbs768_multi4_eval.csv") + + hrm_points = [ + { + "label": "baseline best", + "family": "baseline", + "step": 26040, + "full_exact": f(row_for(hrm_base, 26040), "all/exact_accuracy"), + **spectrum_metrics(ROOT / "diag_hrm_step_26040_512.npz"), + }, + { + "label": "multi4 best", + "family": "multi4", + "step": 23436, + "full_exact": f(row_for(hrm_multi, 23436), "exact"), + **spectrum_metrics(ROOT / "diag_hrm_multi4_step_23436_512.npz"), + }, + { + "label": "multi4 final", + "family": "multi4-final", + "step": 26040, + "full_exact": f(row_for(hrm_multi, 26040), "exact"), + **spectrum_metrics(ROOT / "diag_hrm_multi4_step_26040_512.npz"), + }, + ] + + trm_points = [ + { + "label": "baseline best", + "family": "baseline", + "step": 58590, + "full_exact": f(row_for(trm_base, 58590), "all/exact_accuracy"), + **spectrum_metrics(spec_dir / "trm_gbs768_base_step58590_n512_k8_seed20260602.npz"), + }, + { + "label": "multi4 best", + "family": "multi4", + "step": 35805, + "full_exact": f(row_for(trm_multi, 35805), "all/exact_accuracy"), + **spectrum_metrics(spec_dir / "trm_gbs768_multi4_step35805_n512_k8_seed20260602.npz"), + }, + { + "label": "multi4 final", + "family": "multi4-final", + "step": 65100, + "full_exact": f(row_for(trm_multi, 65100), "all/exact_accuracy"), + **spectrum_metrics(spec_dir / "trm_gbs768_multi4_step65100_n512_k8_seed20260602.npz"), + }, + ] + + return { + "hrm_base": hrm_base, + "hrm_multi": hrm_multi, + "trm_base": trm_base, + "trm_multi": trm_multi, + "hrm_points": hrm_points, + "trm_points": trm_points, + } + + +def metric_series(rows: list[dict[str, str]], metric: str) -> tuple[np.ndarray, np.ndarray]: + xs, ys = [], [] + for row in rows: + val = f(row, metric) + if np.isfinite(val): + xs.append(int(float(row["step"]))) + ys.append(val) + order = np.argsort(xs) + return np.asarray(xs)[order], np.asarray(ys)[order] + + +def plot_training_curves(data: dict[str, object]) -> None: + fig, axes = plt.subplots(1, 2, figsize=(11.8, 4.2)) + colors = {"baseline": "#475569", "multi4": "#2563eb"} + + panels = [ + ("HRM", data["hrm_base"], data["hrm_multi"], "all/exact_accuracy", "exact", 0.70), + ("TRM official GBS768", data["trm_base"], data["trm_multi"], "all/exact_accuracy", "all/exact_accuracy", 0.94), + ] + for ax, (title, base_rows, multi_rows, base_metric, multi_metric, ymax) in zip(axes, panels): + bx, by = metric_series(base_rows, base_metric) # type: ignore[arg-type] + mx, my = metric_series(multi_rows, multi_metric) # type: ignore[arg-type] + ax.plot(bx / 1000, by, marker="o", color=colors["baseline"], label="baseline", linewidth=2) + ax.plot(mx / 1000, my, marker="o", color=colors["multi4"], label="multi4", linewidth=2) + bi, mi = int(np.argmax(by)), int(np.argmax(my)) + ax.scatter([bx[bi] / 1000], [by[bi]], s=95, color=colors["baseline"], edgecolor="white", zorder=5) + ax.scatter([mx[mi] / 1000], [my[mi]], s=95, color=colors["multi4"], edgecolor="white", zorder=5) + ax.annotate(f"best {by[bi]:.3f}", (bx[bi] / 1000, by[bi]), xytext=(6, -18), textcoords="offset points", fontsize=8) + ax.annotate(f"best {my[mi]:.3f}", (mx[mi] / 1000, my[mi]), xytext=(6, 8), textcoords="offset points", fontsize=8) + ax.set_title(title) + ax.set_xlabel("optimizer step (k)") + ax.set_ylabel("full test exact accuracy") + ax.set_ylim(0, ymax) + ax.grid(alpha=0.22) + ax.legend(frameon=False, loc="lower right") + + fig.suptitle("Trajectory perturbation helps both HRM and TRM, but late checkpoints can collapse") + fig.tight_layout() + fig.savefig(OUT / "fig1_hrm_trm_training_curves.png", dpi=220) + plt.close(fig) + + +def plot_lambda1_motivation() -> None: + entries = [ + ("HRM baseline best\nstep 26040", ROOT / "diag_hrm_step_26040_512.npz"), + ( + "TRM baseline best\nstep 58590", + ROOT / "official_gbs768_spectrum/trm_gbs768_base_step58590_n512_k8_seed20260602.npz", + ), + ] + fig, axes = plt.subplots(1, 2, figsize=(11.6, 4.2), sharey=False) + for ax, (title, path) in zip(axes, entries): + spec, exact = spectrum_arrays(path) + lam1 = spec[:, 0] + succ = lam1[exact] + fail = lam1[~exact] + lo, hi = np.quantile(lam1, [0.01, 0.99]) + pad = 0.08 * (hi - lo) + bins = np.linspace(lo - pad, hi + pad, 34) + ax.hist(succ, bins=bins, density=True, alpha=0.62, color="#2563eb", label=f"success (n={len(succ)})") + ax.hist(fail, bins=bins, density=True, alpha=0.58, color="#dc2626", label=f"failure (n={len(fail)})") + ax.axvline(succ.mean(), color="#1d4ed8", linewidth=2) + ax.axvline(fail.mean(), color="#b91c1c", linewidth=2) + auc = auc_failure_score(lam1, exact) + ax.set_title(f"{title}\nacc={exact.mean():.3f}, AUROC={auc:.3f}") + ax.set_xlabel("top Lyapunov exponent λ1") + ax.set_ylabel("density") + ax.grid(alpha=0.2) + ax.legend(frameon=False, fontsize=8) + fig.suptitle("Motivation: λ1 is a strong success/failure detector in both HRM and TRM") + fig.tight_layout() + fig.savefig(OUT / "fig0_motivation_lambda1_success_failure_hrm_trm.png", dpi=220) + plt.close(fig) + + +def plot_phase(data: dict[str, object]) -> None: + fig, axes = plt.subplots(1, 2, figsize=(11.8, 4.4), sharex=False) + panels = [("HRM", data["hrm_points"]), ("TRM official GBS768", data["trm_points"])] + style = { + "baseline": ("#475569", "o"), + "multi4": ("#2563eb", "o"), + "multi4-final": ("#dc2626", "X"), + } + for ax, (title, points) in zip(axes, panels): + for point in points: # type: ignore[assignment] + color, marker = style[point["family"]] + ax.scatter(point["mean8_all"], point["full_exact"], s=150, marker=marker, color=color, edgecolor="white", zorder=4) + ax.annotate( + f"{point['label']}\nstep {point['step']}", + (point["mean8_all"], point["full_exact"]), + xytext=(8, 5), + textcoords="offset points", + fontsize=8, + ) + ax.set_title(title) + ax.set_xlabel("mean top-8 Lyapunov exponent (lower = more stable)") + ax.set_ylabel("full test exact accuracy") + ax.grid(alpha=0.22) + fig.suptitle("Accuracy-vs-chaotic-volume phase view: best multi4 moves up-left; collapse moves down-right") + fig.tight_layout() + fig.savefig(OUT / "fig2_accuracy_vs_chaotic_volume_phase.png", dpi=220) + plt.close(fig) + + +def plot_spectrum_grid() -> None: + entries = [ + ("HRM baseline best", ROOT / "diag_hrm_step_26040_512.npz"), + ("HRM multi4 best", ROOT / "diag_hrm_multi4_step_23436_512.npz"), + ("TRM baseline best", ROOT / "official_gbs768_spectrum/trm_gbs768_base_step58590_n512_k8_seed20260602.npz"), + ("TRM multi4 best", ROOT / "official_gbs768_spectrum/trm_gbs768_multi4_step35805_n512_k8_seed20260602.npz"), + ] + fig, axes = plt.subplots(2, 2, figsize=(10.8, 7.2), sharey=False) + for ax, (title, path) in zip(axes.flat, entries): + spec, exact = spectrum_arrays(path) + x = np.arange(1, spec.shape[1] + 1) + for mask, color, label in [(exact, "#2563eb", "success"), (~exact, "#dc2626", "failure")]: + mean = spec[mask].mean(axis=0) + se = spec[mask].std(axis=0) / max(mask.sum(), 1) ** 0.5 + ax.plot(x, mean, marker="o", color=color, label=label) + ax.fill_between(x, mean - se, mean + se, color=color, alpha=0.16, linewidth=0) + ax.axhline(0, color="black", linewidth=0.8, alpha=0.55) + ax.set_title(f"{title}\nN={len(exact)}, acc={exact.mean():.3f}") + ax.set_xlabel("spectrum rank") + ax.set_ylabel("finite-time exponent") + ax.grid(alpha=0.22) + axes.flat[0].legend(frameon=False) + fig.suptitle("Success/failure separation is present in both HRM and TRM; multi4 mainly stabilizes successful trajectories") + fig.tight_layout() + fig.savefig(OUT / "fig3_hrm_trm_success_failure_spectra.png", dpi=220) + plt.close(fig) + + +def plot_ptrm() -> None: + summary = read_csv(ROOT / "ptrm_same_subset/paired_ptrm_k100_n1000_seed0_summary.csv")[0] + labels = ["deterministic", "mean rollout", "Q-selected", "oracle"] + base = [float(summary[k]) for k in ["base_det", "base_mean_rollout", "base_qmax", "base_oracle"]] + multi = [float(summary[k]) for k in ["multi4_det", "multi4_mean_rollout", "multi4_qmax", "multi4_oracle"]] + x = np.arange(len(labels)) + width = 0.36 + fig, ax = plt.subplots(figsize=(8.7, 4.3)) + ax.bar(x - width / 2, base, width, color="#64748b", label="baseline best") + ax.bar(x + width / 2, multi, width, color="#2563eb", label="multi4 best") + ax.set_xticks(x, labels) + ax.set_ylim(0.86, 1.0) + ax.set_ylabel("exact accuracy") + ax.set_title("PTRM same-subset comparison (n=1000, K=100, D=64, sigma=0.3, L-only)") + ax.grid(axis="y", alpha=0.22) + ax.legend(frameon=False) + for xpos, vals in [(x - width / 2, base), (x + width / 2, multi)]: + for xi, val in zip(xpos, vals): + ax.text(xi, val + 0.002, f"{val:.3f}", ha="center", va="bottom", fontsize=8) + fig.tight_layout() + fig.savefig(OUT / "fig4_ptrm_same_subset_comparison.png", dpi=220) + plt.close(fig) + + +def write_summary(data: dict[str, object]) -> None: + rows = [] + for model, points in [("HRM", data["hrm_points"]), ("TRM", data["trm_points"])]: + for point in points: # type: ignore[assignment] + rows.append( + { + "model": model, + "label": point["label"], + "step": point["step"], + "full_exact": point["full_exact"], + "sample_exact": point["sample_exact"], + "lambda1_all": point["lambda1_all"], + "mean8_all": point["mean8_all"], + "pos_count_all": point["pos_count_all"], + "lambda1_success": point["lambda1_success"], + "lambda1_fail": point["lambda1_fail"], + "mean8_success": point["mean8_success"], + "mean8_fail": point["mean8_fail"], + "pos_count_success": point["pos_count_success"], + "pos_count_fail": point["pos_count_fail"], + } + ) + with (OUT / "hrm_trm_redesigned_summary.csv").open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=list(rows[0].keys())) + writer.writeheader() + writer.writerows(rows) + + ptrm = read_csv(ROOT / "ptrm_same_subset/paired_ptrm_k100_n1000_seed0_summary.csv")[0] + report = f"""# Meeting Figures v2 + +## Figure Strategy + +0. `fig0_motivation_lambda1_success_failure_hrm_trm.png`: first-exponent success/failure distribution in HRM and TRM. This motivates chaos as a detector before introducing the method. +1. `fig1_hrm_trm_training_curves.png`: performance over training for HRM and TRM. This answers whether the method improves accuracy and where best/final are. +2. `fig2_accuracy_vs_chaotic_volume_phase.png`: phase view, with accuracy versus mean top-8 Lyapunov exponent. This answers whether better checkpoints are dynamically more stable. +3. `fig3_hrm_trm_success_failure_spectra.png`: full success/failure spectrum separation for HRM and TRM best checkpoints. This extends Fig0 beyond λ1. +4. `fig4_ptrm_same_subset_comparison.png`: PTRM same-subset result. This is a secondary inference-time story. + +## Key Numbers + +- HRM baseline best: 0.5265 exact. HRM multi4 best: 0.6443 exact. HRM multi4 final: 0.4624 exact. +- TRM baseline best: 0.8686 exact. TRM multi4 best: 0.8965 exact. TRM multi4 final: 0.8351 exact. +- HRM multi4 best dynamics sample: mean top-8 exponent {rows[1]['mean8_all']:+.4f}; final {rows[2]['mean8_all']:+.4f}. +- TRM multi4 best dynamics sample: mean top-8 exponent {rows[4]['mean8_all']:+.4f}; final {rows[5]['mean8_all']:+.4f}. +- PTRM same subset, K=100: Q-selected {float(ptrm['base_qmax']):.3f} -> {float(ptrm['multi4_qmax']):.3f}; mean rollout {float(ptrm['base_mean_rollout']):.3f} -> {float(ptrm['multi4_mean_rollout']):.3f}. + +## Caveats + +- Dynamics spectra use N=512 diagnostic samples, not the full test set. +- PTRM numbers use a fixed N=1000 subset; do not mix its deterministic subset accuracy with full-test W&B exact accuracy. +- Final checkpoints are collapse diagnostics, not the method's reported performance. +""" + (OUT / "meeting_figures_v2_report.md").write_text(report) + + +def main() -> None: + OUT.mkdir(parents=True, exist_ok=True) + data = get_data() + plot_lambda1_motivation() + plot_training_curves(data) + plot_phase(data) + plot_spectrum_grid() + plot_ptrm() + write_summary(data) + print(f"wrote redesigned artifacts to {OUT}") + + +if __name__ == "__main__": + main() diff --git a/research/flossing/report_bundle_20260603/scripts/make_q_lambda_scatter.py b/research/flossing/report_bundle_20260603/scripts/make_q_lambda_scatter.py new file mode 100644 index 0000000..54ff8d2 --- /dev/null +++ b/research/flossing/report_bundle_20260603/scripts/make_q_lambda_scatter.py @@ -0,0 +1,263 @@ +from __future__ import annotations + +import csv +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + + +ROOT = Path("research/flossing") +IN_DIR = ROOT / "q_lambda_scatter" +OUT = ROOT / "meeting_artifacts_v2" + +RUNS = [ + ( + "TRM baseline + PTRM rollouts", + IN_DIR / "base58590_k25_d64_sigma03_Lonly_fdlyap_n512_seed20260602.npz", + ), + ( + "TRM multi4 + PTRM rollouts", + IN_DIR / "multi4_35805_k25_d64_sigma03_Lonly_fdlyap_n512_seed20260602.npz", + ), +] + + +def rank_average(values: np.ndarray) -> np.ndarray: + order = np.argsort(values, kind="mergesort") + sorted_values = values[order] + ranks = np.empty(len(values), dtype=float) + i = 0 + while i < len(values): + j = i + 1 + while j < len(values) and sorted_values[j] == sorted_values[i]: + j += 1 + ranks[order[i:j]] = 0.5 * (i + j - 1) + 1.0 + i = j + return ranks + + +def corr(x: np.ndarray, y: np.ndarray) -> float: + mask = np.isfinite(x) & np.isfinite(y) + x = x[mask] + y = y[mask] + if len(x) < 3: + return float("nan") + x = x - x.mean() + y = y - y.mean() + denom = np.sqrt(np.square(x).sum() * np.square(y).sum()) + if denom <= 0: + return float("nan") + return float(np.dot(x, y) / denom) + + +def spearman(x: np.ndarray, y: np.ndarray) -> float: + mask = np.isfinite(x) & np.isfinite(y) + x = x[mask] + y = y[mask] + if len(x) < 3: + return float("nan") + return corr(rank_average(x), rank_average(y)) + + +def mean_within_problem_corr(stability: np.ndarray, q_halt: np.ndarray, rank: bool) -> float: + vals: list[float] = [] + for x, y in zip(stability, q_halt): + val = spearman(x, y) if rank else corr(x, y) + if np.isfinite(val): + vals.append(val) + return float(np.mean(vals)) if vals else float("nan") + + +def summarize(name: str, path: Path) -> dict[str, float | str]: + data = np.load(path) + exact = data["exact"].astype(bool) + q_halt = data["q_halt"].astype(float) + lyap = data["lyap"].astype(float) + if lyap.size == 0: + raise ValueError(f"{path} does not contain lyap") + stability = -lyap + + arange = np.arange(exact.shape[0]) + q_idx = q_halt.argmax(axis=1) + lyap_idx = lyap.argmin(axis=1) + correct_count = exact.sum(axis=1) + mixed = (correct_count > 0) & (correct_count < exact.shape[1]) + + mixed_summary: dict[str, float] = { + "mixed_problem_count": float(mixed.sum()), + "zero_success_problem_count": float((correct_count == 0).sum()), + "full_success_problem_count": float((correct_count == exact.shape[1]).sum()), + "mixed_global_pearson_q_vs_stability": float("nan"), + "mixed_q_max_exact": float("nan"), + "mixed_lambda_min_exact": float("nan"), + "mixed_oracle_exact": float("nan"), + } + if mixed.any(): + m_arange = np.arange(int(mixed.sum())) + m_exact = exact[mixed] + m_q = q_halt[mixed] + m_lyap = lyap[mixed] + mixed_summary.update( + { + "mixed_global_pearson_q_vs_stability": corr((-m_lyap).reshape(-1), m_q.reshape(-1)), + "mixed_q_max_exact": float(m_exact[m_arange, m_q.argmax(axis=1)].mean()), + "mixed_lambda_min_exact": float(m_exact[m_arange, m_lyap.argmin(axis=1)].mean()), + "mixed_oracle_exact": float(m_exact.any(axis=1).mean()), + } + ) + + out: dict[str, float | str] = { + "name": name, + "path": str(path), + "n_samples": float(exact.shape[0]), + "rollouts": float(exact.shape[1]), + "mean_rollout_exact": float(exact.mean()), + "q_max_exact": float(exact[arange, q_idx].mean()), + "lambda_min_exact": float(exact[arange, lyap_idx].mean()), + "oracle_pass_exact": float(exact.any(axis=1).mean()), + "q_lambda_same_argmax_frac": float((q_idx == lyap_idx).mean()), + "global_pearson_q_vs_stability": corr(stability.reshape(-1), q_halt.reshape(-1)), + "global_spearman_q_vs_stability": spearman(stability.reshape(-1), q_halt.reshape(-1)), + "within_problem_pearson_mean": mean_within_problem_corr(stability, q_halt, rank=False), + "within_problem_spearman_mean": mean_within_problem_corr(stability, q_halt, rank=True), + "q_success_mean": float(q_halt[exact].mean()) if exact.any() else float("nan"), + "q_fail_mean": float(q_halt[~exact].mean()) if (~exact).any() else float("nan"), + "lambda_success_mean": float(lyap[exact].mean()) if exact.any() else float("nan"), + "lambda_fail_mean": float(lyap[~exact].mean()) if (~exact).any() else float("nan"), + } + out.update(mixed_summary) + return out + + +def scatter_panel( + ax: plt.Axes, + stability_2d: np.ndarray, + q_2d: np.ndarray, + exact_2d: np.ndarray, + title: str, +) -> None: + stability = stability_2d.reshape(-1) + q_halt = q_2d.reshape(-1) + exact = exact_2d.reshape(-1) + finite = np.isfinite(stability) & np.isfinite(q_halt) + exact = exact[finite] + q_halt = q_halt[finite] + stability = stability[finite] + if len(stability) == 0: + ax.set_title(title + "\n(no points)") + return + + xlo, xhi = np.quantile(stability, [0.005, 0.995]) + ylo, yhi = np.quantile(q_halt, [0.005, 0.995]) + visible = (stability >= xlo) & (stability <= xhi) & (q_halt >= ylo) & (q_halt <= yhi) + + ax.scatter( + stability[visible & ~exact], + q_halt[visible & ~exact], + s=8, + alpha=0.22, + color="#dc2626", + linewidths=0, + label="incorrect rollout", + ) + ax.scatter( + stability[visible & exact], + q_halt[visible & exact], + s=8, + alpha=0.17, + color="#2563eb", + linewidths=0, + label="correct rollout", + ) + + fit = visible + if int(fit.sum()) >= 3: + slope, intercept = np.polyfit(stability[fit], q_halt[fit], 1) + xs = np.linspace(xlo, xhi, 100) + ax.plot(xs, slope * xs + intercept, color="black", linewidth=1.8, alpha=0.75) + + ax.set_title(title) + ax.set_xlim(xlo, xhi) + ax.set_ylim(ylo, yhi) + ax.grid(alpha=0.22) + + +def plot() -> list[dict[str, float | str]]: + missing = [path for _name, path in RUNS if not path.exists()] + if missing: + raise SystemExit("missing input files:\n" + "\n".join(str(p) for p in missing)) + + OUT.mkdir(parents=True, exist_ok=True) + summaries = [summarize(name, path) for name, path in RUNS] + + fig, axes = plt.subplots(2, len(RUNS), figsize=(12.0, 8.2), sharey="row") + if len(RUNS) == 1: + axes = np.asarray(axes).reshape(2, 1) + + for col, ((name, path), summary) in enumerate(zip(RUNS, summaries)): + data = np.load(path) + exact = data["exact"].astype(bool) + q_halt = data["q_halt"].astype(float) + stability = -data["lyap"].astype(float) + correct_count = exact.sum(axis=1) + mixed = (correct_count > 0) & (correct_count < exact.shape[1]) + + scatter_panel( + axes[0, col], + stability, + q_halt, + exact, + f"{name}: all rollouts\n" + f"r={summary['global_pearson_q_vs_stability']:.2f}, " + f"rho={summary['global_spearman_q_vs_stability']:.2f}, " + f"Q exact={summary['q_max_exact']:.3f}", + ) + scatter_panel( + axes[1, col], + stability[mixed], + q_halt[mixed], + exact[mixed], + f"mixed problems only (n={int(summary['mixed_problem_count'])})\n" + f"r={summary['mixed_global_pearson_q_vs_stability']:.2f}, " + f"Q={summary['mixed_q_max_exact']:.3f}, " + f"lambda-min={summary['mixed_lambda_min_exact']:.3f}", + ) + + for ax in axes[1, :]: + ax.set_xlabel("stability proxy = -lambda_1") + axes[0, 0].set_ylabel("Q-head halt logit") + axes[1, 0].set_ylabel("Q-head halt logit") + axes[0, 0].legend(frameon=False, loc="lower right", fontsize=8) + fig.suptitle("PTRM Q-head score contains a stability signal; mixed problems reveal selector behavior") + fig.tight_layout() + fig.savefig(OUT / "fig5_qhead_vs_lambda1_ptrm.png", dpi=240) + plt.close(fig) + + out_csv = OUT / "fig5_qhead_vs_lambda1_ptrm_summary.csv" + with out_csv.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=list(summaries[0].keys())) + writer.writeheader() + writer.writerows(summaries) + return summaries + + +def main() -> None: + summaries = plot() + for row in summaries: + print( + f"{row['name']}: " + f"q_exact={row['q_max_exact']:.4f} " + f"lambda_min_exact={row['lambda_min_exact']:.4f} " + f"oracle={row['oracle_pass_exact']:.4f} " + f"pearson={row['global_pearson_q_vs_stability']:.4f} " + f"spearman={row['global_spearman_q_vs_stability']:.4f} " + f"within_spearman={row['within_problem_spearman_mean']:.4f} " + f"mixed_pearson={row['mixed_global_pearson_q_vs_stability']:.4f}" + ) + print(f"wrote {OUT / 'fig5_qhead_vs_lambda1_ptrm.png'}") + print(f"wrote {OUT / 'fig5_qhead_vs_lambda1_ptrm_summary.csv'}") + + +if __name__ == "__main__": + main() |
