summaryrefslogtreecommitdiff
path: root/research/flossing/report_bundle_20260603/scripts
diff options
context:
space:
mode:
Diffstat (limited to 'research/flossing/report_bundle_20260603/scripts')
-rw-r--r--research/flossing/report_bundle_20260603/scripts/make_meeting_artifacts_v2.py364
-rw-r--r--research/flossing/report_bundle_20260603/scripts/make_q_lambda_scatter.py263
2 files changed, 627 insertions, 0 deletions
diff --git a/research/flossing/report_bundle_20260603/scripts/make_meeting_artifacts_v2.py b/research/flossing/report_bundle_20260603/scripts/make_meeting_artifacts_v2.py
new file mode 100644
index 0000000..4c88412
--- /dev/null
+++ b/research/flossing/report_bundle_20260603/scripts/make_meeting_artifacts_v2.py
@@ -0,0 +1,364 @@
+from __future__ import annotations
+
+import csv
+import re
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+
+ROOT = Path("research/flossing")
+OUT = ROOT / "meeting_artifacts_v2"
+
+
+def read_csv(path: Path) -> list[dict[str, str]]:
+ with path.open(newline="") as f:
+ return list(csv.DictReader(f))
+
+
+def f(row: dict[str, str], key: str, default: float = float("nan")) -> float:
+ try:
+ val = row.get(key, "")
+ return float(val) if val != "" else default
+ except Exception:
+ return default
+
+
+def row_for(rows: list[dict[str, str]], step: int) -> dict[str, str]:
+ for row in rows:
+ if int(float(row["step"])) == step:
+ return row
+ raise KeyError(step)
+
+
+def spectrum_metrics(npz_path: Path) -> dict[str, float]:
+ data = np.load(npz_path)
+ spec = data["lyap_spec"].astype(float)
+ exact = data["exact_correct"].astype(bool)
+ mean8 = spec.mean(axis=1)
+ pos_count = (spec > 0).sum(axis=1)
+ return {
+ "sample_exact": float(exact.mean()),
+ "lambda1_all": float(spec[:, 0].mean()),
+ "mean8_all": float(mean8.mean()),
+ "pos_count_all": float(pos_count.mean()),
+ "lambda1_success": float(spec[exact, 0].mean()),
+ "lambda1_fail": float(spec[~exact, 0].mean()),
+ "mean8_success": float(mean8[exact].mean()),
+ "mean8_fail": float(mean8[~exact].mean()),
+ "pos_count_success": float(pos_count[exact].mean()),
+ "pos_count_fail": float(pos_count[~exact].mean()),
+ }
+
+
+def spectrum_arrays(npz_path: Path) -> tuple[np.ndarray, np.ndarray]:
+ data = np.load(npz_path)
+ return data["lyap_spec"].astype(float), data["exact_correct"].astype(bool)
+
+
+def auc_failure_score(score: np.ndarray, exact: np.ndarray) -> float:
+ """AUROC where higher score predicts failure."""
+ fail = ~exact
+ n_fail = int(fail.sum())
+ n_succ = int(exact.sum())
+ if n_fail == 0 or n_succ == 0:
+ return float("nan")
+ order = np.argsort(score)
+ ranks = np.empty_like(order, dtype=float)
+ ranks[order] = np.arange(1, len(score) + 1)
+ return float((ranks[fail].sum() - n_fail * (n_fail + 1) / 2) / (n_fail * n_succ))
+
+
+def get_data() -> dict[str, object]:
+ eval_dir = ROOT / "multi4_eval_compare"
+ spec_dir = ROOT / "official_gbs768_spectrum"
+
+ hrm_base = read_csv(eval_dir / "hrm_baseline_eval.csv")
+ hrm_multi = read_csv(eval_dir / "hrm_multi4_complete_eval.csv")
+ trm_base = read_csv(eval_dir / "trm_official_gbs768_eval.csv")
+ trm_multi = read_csv(eval_dir / "trm_official_gbs768_multi4_eval.csv")
+
+ hrm_points = [
+ {
+ "label": "baseline best",
+ "family": "baseline",
+ "step": 26040,
+ "full_exact": f(row_for(hrm_base, 26040), "all/exact_accuracy"),
+ **spectrum_metrics(ROOT / "diag_hrm_step_26040_512.npz"),
+ },
+ {
+ "label": "multi4 best",
+ "family": "multi4",
+ "step": 23436,
+ "full_exact": f(row_for(hrm_multi, 23436), "exact"),
+ **spectrum_metrics(ROOT / "diag_hrm_multi4_step_23436_512.npz"),
+ },
+ {
+ "label": "multi4 final",
+ "family": "multi4-final",
+ "step": 26040,
+ "full_exact": f(row_for(hrm_multi, 26040), "exact"),
+ **spectrum_metrics(ROOT / "diag_hrm_multi4_step_26040_512.npz"),
+ },
+ ]
+
+ trm_points = [
+ {
+ "label": "baseline best",
+ "family": "baseline",
+ "step": 58590,
+ "full_exact": f(row_for(trm_base, 58590), "all/exact_accuracy"),
+ **spectrum_metrics(spec_dir / "trm_gbs768_base_step58590_n512_k8_seed20260602.npz"),
+ },
+ {
+ "label": "multi4 best",
+ "family": "multi4",
+ "step": 35805,
+ "full_exact": f(row_for(trm_multi, 35805), "all/exact_accuracy"),
+ **spectrum_metrics(spec_dir / "trm_gbs768_multi4_step35805_n512_k8_seed20260602.npz"),
+ },
+ {
+ "label": "multi4 final",
+ "family": "multi4-final",
+ "step": 65100,
+ "full_exact": f(row_for(trm_multi, 65100), "all/exact_accuracy"),
+ **spectrum_metrics(spec_dir / "trm_gbs768_multi4_step65100_n512_k8_seed20260602.npz"),
+ },
+ ]
+
+ return {
+ "hrm_base": hrm_base,
+ "hrm_multi": hrm_multi,
+ "trm_base": trm_base,
+ "trm_multi": trm_multi,
+ "hrm_points": hrm_points,
+ "trm_points": trm_points,
+ }
+
+
+def metric_series(rows: list[dict[str, str]], metric: str) -> tuple[np.ndarray, np.ndarray]:
+ xs, ys = [], []
+ for row in rows:
+ val = f(row, metric)
+ if np.isfinite(val):
+ xs.append(int(float(row["step"])))
+ ys.append(val)
+ order = np.argsort(xs)
+ return np.asarray(xs)[order], np.asarray(ys)[order]
+
+
+def plot_training_curves(data: dict[str, object]) -> None:
+ fig, axes = plt.subplots(1, 2, figsize=(11.8, 4.2))
+ colors = {"baseline": "#475569", "multi4": "#2563eb"}
+
+ panels = [
+ ("HRM", data["hrm_base"], data["hrm_multi"], "all/exact_accuracy", "exact", 0.70),
+ ("TRM official GBS768", data["trm_base"], data["trm_multi"], "all/exact_accuracy", "all/exact_accuracy", 0.94),
+ ]
+ for ax, (title, base_rows, multi_rows, base_metric, multi_metric, ymax) in zip(axes, panels):
+ bx, by = metric_series(base_rows, base_metric) # type: ignore[arg-type]
+ mx, my = metric_series(multi_rows, multi_metric) # type: ignore[arg-type]
+ ax.plot(bx / 1000, by, marker="o", color=colors["baseline"], label="baseline", linewidth=2)
+ ax.plot(mx / 1000, my, marker="o", color=colors["multi4"], label="multi4", linewidth=2)
+ bi, mi = int(np.argmax(by)), int(np.argmax(my))
+ ax.scatter([bx[bi] / 1000], [by[bi]], s=95, color=colors["baseline"], edgecolor="white", zorder=5)
+ ax.scatter([mx[mi] / 1000], [my[mi]], s=95, color=colors["multi4"], edgecolor="white", zorder=5)
+ ax.annotate(f"best {by[bi]:.3f}", (bx[bi] / 1000, by[bi]), xytext=(6, -18), textcoords="offset points", fontsize=8)
+ ax.annotate(f"best {my[mi]:.3f}", (mx[mi] / 1000, my[mi]), xytext=(6, 8), textcoords="offset points", fontsize=8)
+ ax.set_title(title)
+ ax.set_xlabel("optimizer step (k)")
+ ax.set_ylabel("full test exact accuracy")
+ ax.set_ylim(0, ymax)
+ ax.grid(alpha=0.22)
+ ax.legend(frameon=False, loc="lower right")
+
+ fig.suptitle("Trajectory perturbation helps both HRM and TRM, but late checkpoints can collapse")
+ fig.tight_layout()
+ fig.savefig(OUT / "fig1_hrm_trm_training_curves.png", dpi=220)
+ plt.close(fig)
+
+
+def plot_lambda1_motivation() -> None:
+ entries = [
+ ("HRM baseline best\nstep 26040", ROOT / "diag_hrm_step_26040_512.npz"),
+ (
+ "TRM baseline best\nstep 58590",
+ ROOT / "official_gbs768_spectrum/trm_gbs768_base_step58590_n512_k8_seed20260602.npz",
+ ),
+ ]
+ fig, axes = plt.subplots(1, 2, figsize=(11.6, 4.2), sharey=False)
+ for ax, (title, path) in zip(axes, entries):
+ spec, exact = spectrum_arrays(path)
+ lam1 = spec[:, 0]
+ succ = lam1[exact]
+ fail = lam1[~exact]
+ lo, hi = np.quantile(lam1, [0.01, 0.99])
+ pad = 0.08 * (hi - lo)
+ bins = np.linspace(lo - pad, hi + pad, 34)
+ ax.hist(succ, bins=bins, density=True, alpha=0.62, color="#2563eb", label=f"success (n={len(succ)})")
+ ax.hist(fail, bins=bins, density=True, alpha=0.58, color="#dc2626", label=f"failure (n={len(fail)})")
+ ax.axvline(succ.mean(), color="#1d4ed8", linewidth=2)
+ ax.axvline(fail.mean(), color="#b91c1c", linewidth=2)
+ auc = auc_failure_score(lam1, exact)
+ ax.set_title(f"{title}\nacc={exact.mean():.3f}, AUROC={auc:.3f}")
+ ax.set_xlabel("top Lyapunov exponent λ1")
+ ax.set_ylabel("density")
+ ax.grid(alpha=0.2)
+ ax.legend(frameon=False, fontsize=8)
+ fig.suptitle("Motivation: λ1 is a strong success/failure detector in both HRM and TRM")
+ fig.tight_layout()
+ fig.savefig(OUT / "fig0_motivation_lambda1_success_failure_hrm_trm.png", dpi=220)
+ plt.close(fig)
+
+
+def plot_phase(data: dict[str, object]) -> None:
+ fig, axes = plt.subplots(1, 2, figsize=(11.8, 4.4), sharex=False)
+ panels = [("HRM", data["hrm_points"]), ("TRM official GBS768", data["trm_points"])]
+ style = {
+ "baseline": ("#475569", "o"),
+ "multi4": ("#2563eb", "o"),
+ "multi4-final": ("#dc2626", "X"),
+ }
+ for ax, (title, points) in zip(axes, panels):
+ for point in points: # type: ignore[assignment]
+ color, marker = style[point["family"]]
+ ax.scatter(point["mean8_all"], point["full_exact"], s=150, marker=marker, color=color, edgecolor="white", zorder=4)
+ ax.annotate(
+ f"{point['label']}\nstep {point['step']}",
+ (point["mean8_all"], point["full_exact"]),
+ xytext=(8, 5),
+ textcoords="offset points",
+ fontsize=8,
+ )
+ ax.set_title(title)
+ ax.set_xlabel("mean top-8 Lyapunov exponent (lower = more stable)")
+ ax.set_ylabel("full test exact accuracy")
+ ax.grid(alpha=0.22)
+ fig.suptitle("Accuracy-vs-chaotic-volume phase view: best multi4 moves up-left; collapse moves down-right")
+ fig.tight_layout()
+ fig.savefig(OUT / "fig2_accuracy_vs_chaotic_volume_phase.png", dpi=220)
+ plt.close(fig)
+
+
+def plot_spectrum_grid() -> None:
+ entries = [
+ ("HRM baseline best", ROOT / "diag_hrm_step_26040_512.npz"),
+ ("HRM multi4 best", ROOT / "diag_hrm_multi4_step_23436_512.npz"),
+ ("TRM baseline best", ROOT / "official_gbs768_spectrum/trm_gbs768_base_step58590_n512_k8_seed20260602.npz"),
+ ("TRM multi4 best", ROOT / "official_gbs768_spectrum/trm_gbs768_multi4_step35805_n512_k8_seed20260602.npz"),
+ ]
+ fig, axes = plt.subplots(2, 2, figsize=(10.8, 7.2), sharey=False)
+ for ax, (title, path) in zip(axes.flat, entries):
+ spec, exact = spectrum_arrays(path)
+ x = np.arange(1, spec.shape[1] + 1)
+ for mask, color, label in [(exact, "#2563eb", "success"), (~exact, "#dc2626", "failure")]:
+ mean = spec[mask].mean(axis=0)
+ se = spec[mask].std(axis=0) / max(mask.sum(), 1) ** 0.5
+ ax.plot(x, mean, marker="o", color=color, label=label)
+ ax.fill_between(x, mean - se, mean + se, color=color, alpha=0.16, linewidth=0)
+ ax.axhline(0, color="black", linewidth=0.8, alpha=0.55)
+ ax.set_title(f"{title}\nN={len(exact)}, acc={exact.mean():.3f}")
+ ax.set_xlabel("spectrum rank")
+ ax.set_ylabel("finite-time exponent")
+ ax.grid(alpha=0.22)
+ axes.flat[0].legend(frameon=False)
+ fig.suptitle("Success/failure separation is present in both HRM and TRM; multi4 mainly stabilizes successful trajectories")
+ fig.tight_layout()
+ fig.savefig(OUT / "fig3_hrm_trm_success_failure_spectra.png", dpi=220)
+ plt.close(fig)
+
+
+def plot_ptrm() -> None:
+ summary = read_csv(ROOT / "ptrm_same_subset/paired_ptrm_k100_n1000_seed0_summary.csv")[0]
+ labels = ["deterministic", "mean rollout", "Q-selected", "oracle"]
+ base = [float(summary[k]) for k in ["base_det", "base_mean_rollout", "base_qmax", "base_oracle"]]
+ multi = [float(summary[k]) for k in ["multi4_det", "multi4_mean_rollout", "multi4_qmax", "multi4_oracle"]]
+ x = np.arange(len(labels))
+ width = 0.36
+ fig, ax = plt.subplots(figsize=(8.7, 4.3))
+ ax.bar(x - width / 2, base, width, color="#64748b", label="baseline best")
+ ax.bar(x + width / 2, multi, width, color="#2563eb", label="multi4 best")
+ ax.set_xticks(x, labels)
+ ax.set_ylim(0.86, 1.0)
+ ax.set_ylabel("exact accuracy")
+ ax.set_title("PTRM same-subset comparison (n=1000, K=100, D=64, sigma=0.3, L-only)")
+ ax.grid(axis="y", alpha=0.22)
+ ax.legend(frameon=False)
+ for xpos, vals in [(x - width / 2, base), (x + width / 2, multi)]:
+ for xi, val in zip(xpos, vals):
+ ax.text(xi, val + 0.002, f"{val:.3f}", ha="center", va="bottom", fontsize=8)
+ fig.tight_layout()
+ fig.savefig(OUT / "fig4_ptrm_same_subset_comparison.png", dpi=220)
+ plt.close(fig)
+
+
+def write_summary(data: dict[str, object]) -> None:
+ rows = []
+ for model, points in [("HRM", data["hrm_points"]), ("TRM", data["trm_points"])]:
+ for point in points: # type: ignore[assignment]
+ rows.append(
+ {
+ "model": model,
+ "label": point["label"],
+ "step": point["step"],
+ "full_exact": point["full_exact"],
+ "sample_exact": point["sample_exact"],
+ "lambda1_all": point["lambda1_all"],
+ "mean8_all": point["mean8_all"],
+ "pos_count_all": point["pos_count_all"],
+ "lambda1_success": point["lambda1_success"],
+ "lambda1_fail": point["lambda1_fail"],
+ "mean8_success": point["mean8_success"],
+ "mean8_fail": point["mean8_fail"],
+ "pos_count_success": point["pos_count_success"],
+ "pos_count_fail": point["pos_count_fail"],
+ }
+ )
+ with (OUT / "hrm_trm_redesigned_summary.csv").open("w", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
+ writer.writeheader()
+ writer.writerows(rows)
+
+ ptrm = read_csv(ROOT / "ptrm_same_subset/paired_ptrm_k100_n1000_seed0_summary.csv")[0]
+ report = f"""# Meeting Figures v2
+
+## Figure Strategy
+
+0. `fig0_motivation_lambda1_success_failure_hrm_trm.png`: first-exponent success/failure distribution in HRM and TRM. This motivates chaos as a detector before introducing the method.
+1. `fig1_hrm_trm_training_curves.png`: performance over training for HRM and TRM. This answers whether the method improves accuracy and where best/final are.
+2. `fig2_accuracy_vs_chaotic_volume_phase.png`: phase view, with accuracy versus mean top-8 Lyapunov exponent. This answers whether better checkpoints are dynamically more stable.
+3. `fig3_hrm_trm_success_failure_spectra.png`: full success/failure spectrum separation for HRM and TRM best checkpoints. This extends Fig0 beyond λ1.
+4. `fig4_ptrm_same_subset_comparison.png`: PTRM same-subset result. This is a secondary inference-time story.
+
+## Key Numbers
+
+- HRM baseline best: 0.5265 exact. HRM multi4 best: 0.6443 exact. HRM multi4 final: 0.4624 exact.
+- TRM baseline best: 0.8686 exact. TRM multi4 best: 0.8965 exact. TRM multi4 final: 0.8351 exact.
+- HRM multi4 best dynamics sample: mean top-8 exponent {rows[1]['mean8_all']:+.4f}; final {rows[2]['mean8_all']:+.4f}.
+- TRM multi4 best dynamics sample: mean top-8 exponent {rows[4]['mean8_all']:+.4f}; final {rows[5]['mean8_all']:+.4f}.
+- PTRM same subset, K=100: Q-selected {float(ptrm['base_qmax']):.3f} -> {float(ptrm['multi4_qmax']):.3f}; mean rollout {float(ptrm['base_mean_rollout']):.3f} -> {float(ptrm['multi4_mean_rollout']):.3f}.
+
+## Caveats
+
+- Dynamics spectra use N=512 diagnostic samples, not the full test set.
+- PTRM numbers use a fixed N=1000 subset; do not mix its deterministic subset accuracy with full-test W&B exact accuracy.
+- Final checkpoints are collapse diagnostics, not the method's reported performance.
+"""
+ (OUT / "meeting_figures_v2_report.md").write_text(report)
+
+
+def main() -> None:
+ OUT.mkdir(parents=True, exist_ok=True)
+ data = get_data()
+ plot_lambda1_motivation()
+ plot_training_curves(data)
+ plot_phase(data)
+ plot_spectrum_grid()
+ plot_ptrm()
+ write_summary(data)
+ print(f"wrote redesigned artifacts to {OUT}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/research/flossing/report_bundle_20260603/scripts/make_q_lambda_scatter.py b/research/flossing/report_bundle_20260603/scripts/make_q_lambda_scatter.py
new file mode 100644
index 0000000..54ff8d2
--- /dev/null
+++ b/research/flossing/report_bundle_20260603/scripts/make_q_lambda_scatter.py
@@ -0,0 +1,263 @@
+from __future__ import annotations
+
+import csv
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+
+ROOT = Path("research/flossing")
+IN_DIR = ROOT / "q_lambda_scatter"
+OUT = ROOT / "meeting_artifacts_v2"
+
+RUNS = [
+ (
+ "TRM baseline + PTRM rollouts",
+ IN_DIR / "base58590_k25_d64_sigma03_Lonly_fdlyap_n512_seed20260602.npz",
+ ),
+ (
+ "TRM multi4 + PTRM rollouts",
+ IN_DIR / "multi4_35805_k25_d64_sigma03_Lonly_fdlyap_n512_seed20260602.npz",
+ ),
+]
+
+
+def rank_average(values: np.ndarray) -> np.ndarray:
+ order = np.argsort(values, kind="mergesort")
+ sorted_values = values[order]
+ ranks = np.empty(len(values), dtype=float)
+ i = 0
+ while i < len(values):
+ j = i + 1
+ while j < len(values) and sorted_values[j] == sorted_values[i]:
+ j += 1
+ ranks[order[i:j]] = 0.5 * (i + j - 1) + 1.0
+ i = j
+ return ranks
+
+
+def corr(x: np.ndarray, y: np.ndarray) -> float:
+ mask = np.isfinite(x) & np.isfinite(y)
+ x = x[mask]
+ y = y[mask]
+ if len(x) < 3:
+ return float("nan")
+ x = x - x.mean()
+ y = y - y.mean()
+ denom = np.sqrt(np.square(x).sum() * np.square(y).sum())
+ if denom <= 0:
+ return float("nan")
+ return float(np.dot(x, y) / denom)
+
+
+def spearman(x: np.ndarray, y: np.ndarray) -> float:
+ mask = np.isfinite(x) & np.isfinite(y)
+ x = x[mask]
+ y = y[mask]
+ if len(x) < 3:
+ return float("nan")
+ return corr(rank_average(x), rank_average(y))
+
+
+def mean_within_problem_corr(stability: np.ndarray, q_halt: np.ndarray, rank: bool) -> float:
+ vals: list[float] = []
+ for x, y in zip(stability, q_halt):
+ val = spearman(x, y) if rank else corr(x, y)
+ if np.isfinite(val):
+ vals.append(val)
+ return float(np.mean(vals)) if vals else float("nan")
+
+
+def summarize(name: str, path: Path) -> dict[str, float | str]:
+ data = np.load(path)
+ exact = data["exact"].astype(bool)
+ q_halt = data["q_halt"].astype(float)
+ lyap = data["lyap"].astype(float)
+ if lyap.size == 0:
+ raise ValueError(f"{path} does not contain lyap")
+ stability = -lyap
+
+ arange = np.arange(exact.shape[0])
+ q_idx = q_halt.argmax(axis=1)
+ lyap_idx = lyap.argmin(axis=1)
+ correct_count = exact.sum(axis=1)
+ mixed = (correct_count > 0) & (correct_count < exact.shape[1])
+
+ mixed_summary: dict[str, float] = {
+ "mixed_problem_count": float(mixed.sum()),
+ "zero_success_problem_count": float((correct_count == 0).sum()),
+ "full_success_problem_count": float((correct_count == exact.shape[1]).sum()),
+ "mixed_global_pearson_q_vs_stability": float("nan"),
+ "mixed_q_max_exact": float("nan"),
+ "mixed_lambda_min_exact": float("nan"),
+ "mixed_oracle_exact": float("nan"),
+ }
+ if mixed.any():
+ m_arange = np.arange(int(mixed.sum()))
+ m_exact = exact[mixed]
+ m_q = q_halt[mixed]
+ m_lyap = lyap[mixed]
+ mixed_summary.update(
+ {
+ "mixed_global_pearson_q_vs_stability": corr((-m_lyap).reshape(-1), m_q.reshape(-1)),
+ "mixed_q_max_exact": float(m_exact[m_arange, m_q.argmax(axis=1)].mean()),
+ "mixed_lambda_min_exact": float(m_exact[m_arange, m_lyap.argmin(axis=1)].mean()),
+ "mixed_oracle_exact": float(m_exact.any(axis=1).mean()),
+ }
+ )
+
+ out: dict[str, float | str] = {
+ "name": name,
+ "path": str(path),
+ "n_samples": float(exact.shape[0]),
+ "rollouts": float(exact.shape[1]),
+ "mean_rollout_exact": float(exact.mean()),
+ "q_max_exact": float(exact[arange, q_idx].mean()),
+ "lambda_min_exact": float(exact[arange, lyap_idx].mean()),
+ "oracle_pass_exact": float(exact.any(axis=1).mean()),
+ "q_lambda_same_argmax_frac": float((q_idx == lyap_idx).mean()),
+ "global_pearson_q_vs_stability": corr(stability.reshape(-1), q_halt.reshape(-1)),
+ "global_spearman_q_vs_stability": spearman(stability.reshape(-1), q_halt.reshape(-1)),
+ "within_problem_pearson_mean": mean_within_problem_corr(stability, q_halt, rank=False),
+ "within_problem_spearman_mean": mean_within_problem_corr(stability, q_halt, rank=True),
+ "q_success_mean": float(q_halt[exact].mean()) if exact.any() else float("nan"),
+ "q_fail_mean": float(q_halt[~exact].mean()) if (~exact).any() else float("nan"),
+ "lambda_success_mean": float(lyap[exact].mean()) if exact.any() else float("nan"),
+ "lambda_fail_mean": float(lyap[~exact].mean()) if (~exact).any() else float("nan"),
+ }
+ out.update(mixed_summary)
+ return out
+
+
+def scatter_panel(
+ ax: plt.Axes,
+ stability_2d: np.ndarray,
+ q_2d: np.ndarray,
+ exact_2d: np.ndarray,
+ title: str,
+) -> None:
+ stability = stability_2d.reshape(-1)
+ q_halt = q_2d.reshape(-1)
+ exact = exact_2d.reshape(-1)
+ finite = np.isfinite(stability) & np.isfinite(q_halt)
+ exact = exact[finite]
+ q_halt = q_halt[finite]
+ stability = stability[finite]
+ if len(stability) == 0:
+ ax.set_title(title + "\n(no points)")
+ return
+
+ xlo, xhi = np.quantile(stability, [0.005, 0.995])
+ ylo, yhi = np.quantile(q_halt, [0.005, 0.995])
+ visible = (stability >= xlo) & (stability <= xhi) & (q_halt >= ylo) & (q_halt <= yhi)
+
+ ax.scatter(
+ stability[visible & ~exact],
+ q_halt[visible & ~exact],
+ s=8,
+ alpha=0.22,
+ color="#dc2626",
+ linewidths=0,
+ label="incorrect rollout",
+ )
+ ax.scatter(
+ stability[visible & exact],
+ q_halt[visible & exact],
+ s=8,
+ alpha=0.17,
+ color="#2563eb",
+ linewidths=0,
+ label="correct rollout",
+ )
+
+ fit = visible
+ if int(fit.sum()) >= 3:
+ slope, intercept = np.polyfit(stability[fit], q_halt[fit], 1)
+ xs = np.linspace(xlo, xhi, 100)
+ ax.plot(xs, slope * xs + intercept, color="black", linewidth=1.8, alpha=0.75)
+
+ ax.set_title(title)
+ ax.set_xlim(xlo, xhi)
+ ax.set_ylim(ylo, yhi)
+ ax.grid(alpha=0.22)
+
+
+def plot() -> list[dict[str, float | str]]:
+ missing = [path for _name, path in RUNS if not path.exists()]
+ if missing:
+ raise SystemExit("missing input files:\n" + "\n".join(str(p) for p in missing))
+
+ OUT.mkdir(parents=True, exist_ok=True)
+ summaries = [summarize(name, path) for name, path in RUNS]
+
+ fig, axes = plt.subplots(2, len(RUNS), figsize=(12.0, 8.2), sharey="row")
+ if len(RUNS) == 1:
+ axes = np.asarray(axes).reshape(2, 1)
+
+ for col, ((name, path), summary) in enumerate(zip(RUNS, summaries)):
+ data = np.load(path)
+ exact = data["exact"].astype(bool)
+ q_halt = data["q_halt"].astype(float)
+ stability = -data["lyap"].astype(float)
+ correct_count = exact.sum(axis=1)
+ mixed = (correct_count > 0) & (correct_count < exact.shape[1])
+
+ scatter_panel(
+ axes[0, col],
+ stability,
+ q_halt,
+ exact,
+ f"{name}: all rollouts\n"
+ f"r={summary['global_pearson_q_vs_stability']:.2f}, "
+ f"rho={summary['global_spearman_q_vs_stability']:.2f}, "
+ f"Q exact={summary['q_max_exact']:.3f}",
+ )
+ scatter_panel(
+ axes[1, col],
+ stability[mixed],
+ q_halt[mixed],
+ exact[mixed],
+ f"mixed problems only (n={int(summary['mixed_problem_count'])})\n"
+ f"r={summary['mixed_global_pearson_q_vs_stability']:.2f}, "
+ f"Q={summary['mixed_q_max_exact']:.3f}, "
+ f"lambda-min={summary['mixed_lambda_min_exact']:.3f}",
+ )
+
+ for ax in axes[1, :]:
+ ax.set_xlabel("stability proxy = -lambda_1")
+ axes[0, 0].set_ylabel("Q-head halt logit")
+ axes[1, 0].set_ylabel("Q-head halt logit")
+ axes[0, 0].legend(frameon=False, loc="lower right", fontsize=8)
+ fig.suptitle("PTRM Q-head score contains a stability signal; mixed problems reveal selector behavior")
+ fig.tight_layout()
+ fig.savefig(OUT / "fig5_qhead_vs_lambda1_ptrm.png", dpi=240)
+ plt.close(fig)
+
+ out_csv = OUT / "fig5_qhead_vs_lambda1_ptrm_summary.csv"
+ with out_csv.open("w", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=list(summaries[0].keys()))
+ writer.writeheader()
+ writer.writerows(summaries)
+ return summaries
+
+
+def main() -> None:
+ summaries = plot()
+ for row in summaries:
+ print(
+ f"{row['name']}: "
+ f"q_exact={row['q_max_exact']:.4f} "
+ f"lambda_min_exact={row['lambda_min_exact']:.4f} "
+ f"oracle={row['oracle_pass_exact']:.4f} "
+ f"pearson={row['global_pearson_q_vs_stability']:.4f} "
+ f"spearman={row['global_spearman_q_vs_stability']:.4f} "
+ f"within_spearman={row['within_problem_spearman_mean']:.4f} "
+ f"mixed_pearson={row['mixed_global_pearson_q_vs_stability']:.4f}"
+ )
+ print(f"wrote {OUT / 'fig5_qhead_vs_lambda1_ptrm.png'}")
+ print(f"wrote {OUT / 'fig5_qhead_vs_lambda1_ptrm_summary.csv'}")
+
+
+if __name__ == "__main__":
+ main()