summaryrefslogtreecommitdiff
path: root/research/flossing/make_q_lambda_scatter.py
diff options
context:
space:
mode:
Diffstat (limited to 'research/flossing/make_q_lambda_scatter.py')
-rw-r--r--research/flossing/make_q_lambda_scatter.py263
1 files changed, 263 insertions, 0 deletions
diff --git a/research/flossing/make_q_lambda_scatter.py b/research/flossing/make_q_lambda_scatter.py
new file mode 100644
index 0000000..54ff8d2
--- /dev/null
+++ b/research/flossing/make_q_lambda_scatter.py
@@ -0,0 +1,263 @@
+from __future__ import annotations
+
+import csv
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+
+ROOT = Path("research/flossing")
+IN_DIR = ROOT / "q_lambda_scatter"
+OUT = ROOT / "meeting_artifacts_v2"
+
+RUNS = [
+ (
+ "TRM baseline + PTRM rollouts",
+ IN_DIR / "base58590_k25_d64_sigma03_Lonly_fdlyap_n512_seed20260602.npz",
+ ),
+ (
+ "TRM multi4 + PTRM rollouts",
+ IN_DIR / "multi4_35805_k25_d64_sigma03_Lonly_fdlyap_n512_seed20260602.npz",
+ ),
+]
+
+
+def rank_average(values: np.ndarray) -> np.ndarray:
+ order = np.argsort(values, kind="mergesort")
+ sorted_values = values[order]
+ ranks = np.empty(len(values), dtype=float)
+ i = 0
+ while i < len(values):
+ j = i + 1
+ while j < len(values) and sorted_values[j] == sorted_values[i]:
+ j += 1
+ ranks[order[i:j]] = 0.5 * (i + j - 1) + 1.0
+ i = j
+ return ranks
+
+
+def corr(x: np.ndarray, y: np.ndarray) -> float:
+ mask = np.isfinite(x) & np.isfinite(y)
+ x = x[mask]
+ y = y[mask]
+ if len(x) < 3:
+ return float("nan")
+ x = x - x.mean()
+ y = y - y.mean()
+ denom = np.sqrt(np.square(x).sum() * np.square(y).sum())
+ if denom <= 0:
+ return float("nan")
+ return float(np.dot(x, y) / denom)
+
+
+def spearman(x: np.ndarray, y: np.ndarray) -> float:
+ mask = np.isfinite(x) & np.isfinite(y)
+ x = x[mask]
+ y = y[mask]
+ if len(x) < 3:
+ return float("nan")
+ return corr(rank_average(x), rank_average(y))
+
+
+def mean_within_problem_corr(stability: np.ndarray, q_halt: np.ndarray, rank: bool) -> float:
+ vals: list[float] = []
+ for x, y in zip(stability, q_halt):
+ val = spearman(x, y) if rank else corr(x, y)
+ if np.isfinite(val):
+ vals.append(val)
+ return float(np.mean(vals)) if vals else float("nan")
+
+
+def summarize(name: str, path: Path) -> dict[str, float | str]:
+ data = np.load(path)
+ exact = data["exact"].astype(bool)
+ q_halt = data["q_halt"].astype(float)
+ lyap = data["lyap"].astype(float)
+ if lyap.size == 0:
+ raise ValueError(f"{path} does not contain lyap")
+ stability = -lyap
+
+ arange = np.arange(exact.shape[0])
+ q_idx = q_halt.argmax(axis=1)
+ lyap_idx = lyap.argmin(axis=1)
+ correct_count = exact.sum(axis=1)
+ mixed = (correct_count > 0) & (correct_count < exact.shape[1])
+
+ mixed_summary: dict[str, float] = {
+ "mixed_problem_count": float(mixed.sum()),
+ "zero_success_problem_count": float((correct_count == 0).sum()),
+ "full_success_problem_count": float((correct_count == exact.shape[1]).sum()),
+ "mixed_global_pearson_q_vs_stability": float("nan"),
+ "mixed_q_max_exact": float("nan"),
+ "mixed_lambda_min_exact": float("nan"),
+ "mixed_oracle_exact": float("nan"),
+ }
+ if mixed.any():
+ m_arange = np.arange(int(mixed.sum()))
+ m_exact = exact[mixed]
+ m_q = q_halt[mixed]
+ m_lyap = lyap[mixed]
+ mixed_summary.update(
+ {
+ "mixed_global_pearson_q_vs_stability": corr((-m_lyap).reshape(-1), m_q.reshape(-1)),
+ "mixed_q_max_exact": float(m_exact[m_arange, m_q.argmax(axis=1)].mean()),
+ "mixed_lambda_min_exact": float(m_exact[m_arange, m_lyap.argmin(axis=1)].mean()),
+ "mixed_oracle_exact": float(m_exact.any(axis=1).mean()),
+ }
+ )
+
+ out: dict[str, float | str] = {
+ "name": name,
+ "path": str(path),
+ "n_samples": float(exact.shape[0]),
+ "rollouts": float(exact.shape[1]),
+ "mean_rollout_exact": float(exact.mean()),
+ "q_max_exact": float(exact[arange, q_idx].mean()),
+ "lambda_min_exact": float(exact[arange, lyap_idx].mean()),
+ "oracle_pass_exact": float(exact.any(axis=1).mean()),
+ "q_lambda_same_argmax_frac": float((q_idx == lyap_idx).mean()),
+ "global_pearson_q_vs_stability": corr(stability.reshape(-1), q_halt.reshape(-1)),
+ "global_spearman_q_vs_stability": spearman(stability.reshape(-1), q_halt.reshape(-1)),
+ "within_problem_pearson_mean": mean_within_problem_corr(stability, q_halt, rank=False),
+ "within_problem_spearman_mean": mean_within_problem_corr(stability, q_halt, rank=True),
+ "q_success_mean": float(q_halt[exact].mean()) if exact.any() else float("nan"),
+ "q_fail_mean": float(q_halt[~exact].mean()) if (~exact).any() else float("nan"),
+ "lambda_success_mean": float(lyap[exact].mean()) if exact.any() else float("nan"),
+ "lambda_fail_mean": float(lyap[~exact].mean()) if (~exact).any() else float("nan"),
+ }
+ out.update(mixed_summary)
+ return out
+
+
+def scatter_panel(
+ ax: plt.Axes,
+ stability_2d: np.ndarray,
+ q_2d: np.ndarray,
+ exact_2d: np.ndarray,
+ title: str,
+) -> None:
+ stability = stability_2d.reshape(-1)
+ q_halt = q_2d.reshape(-1)
+ exact = exact_2d.reshape(-1)
+ finite = np.isfinite(stability) & np.isfinite(q_halt)
+ exact = exact[finite]
+ q_halt = q_halt[finite]
+ stability = stability[finite]
+ if len(stability) == 0:
+ ax.set_title(title + "\n(no points)")
+ return
+
+ xlo, xhi = np.quantile(stability, [0.005, 0.995])
+ ylo, yhi = np.quantile(q_halt, [0.005, 0.995])
+ visible = (stability >= xlo) & (stability <= xhi) & (q_halt >= ylo) & (q_halt <= yhi)
+
+ ax.scatter(
+ stability[visible & ~exact],
+ q_halt[visible & ~exact],
+ s=8,
+ alpha=0.22,
+ color="#dc2626",
+ linewidths=0,
+ label="incorrect rollout",
+ )
+ ax.scatter(
+ stability[visible & exact],
+ q_halt[visible & exact],
+ s=8,
+ alpha=0.17,
+ color="#2563eb",
+ linewidths=0,
+ label="correct rollout",
+ )
+
+ fit = visible
+ if int(fit.sum()) >= 3:
+ slope, intercept = np.polyfit(stability[fit], q_halt[fit], 1)
+ xs = np.linspace(xlo, xhi, 100)
+ ax.plot(xs, slope * xs + intercept, color="black", linewidth=1.8, alpha=0.75)
+
+ ax.set_title(title)
+ ax.set_xlim(xlo, xhi)
+ ax.set_ylim(ylo, yhi)
+ ax.grid(alpha=0.22)
+
+
+def plot() -> list[dict[str, float | str]]:
+ missing = [path for _name, path in RUNS if not path.exists()]
+ if missing:
+ raise SystemExit("missing input files:\n" + "\n".join(str(p) for p in missing))
+
+ OUT.mkdir(parents=True, exist_ok=True)
+ summaries = [summarize(name, path) for name, path in RUNS]
+
+ fig, axes = plt.subplots(2, len(RUNS), figsize=(12.0, 8.2), sharey="row")
+ if len(RUNS) == 1:
+ axes = np.asarray(axes).reshape(2, 1)
+
+ for col, ((name, path), summary) in enumerate(zip(RUNS, summaries)):
+ data = np.load(path)
+ exact = data["exact"].astype(bool)
+ q_halt = data["q_halt"].astype(float)
+ stability = -data["lyap"].astype(float)
+ correct_count = exact.sum(axis=1)
+ mixed = (correct_count > 0) & (correct_count < exact.shape[1])
+
+ scatter_panel(
+ axes[0, col],
+ stability,
+ q_halt,
+ exact,
+ f"{name}: all rollouts\n"
+ f"r={summary['global_pearson_q_vs_stability']:.2f}, "
+ f"rho={summary['global_spearman_q_vs_stability']:.2f}, "
+ f"Q exact={summary['q_max_exact']:.3f}",
+ )
+ scatter_panel(
+ axes[1, col],
+ stability[mixed],
+ q_halt[mixed],
+ exact[mixed],
+ f"mixed problems only (n={int(summary['mixed_problem_count'])})\n"
+ f"r={summary['mixed_global_pearson_q_vs_stability']:.2f}, "
+ f"Q={summary['mixed_q_max_exact']:.3f}, "
+ f"lambda-min={summary['mixed_lambda_min_exact']:.3f}",
+ )
+
+ for ax in axes[1, :]:
+ ax.set_xlabel("stability proxy = -lambda_1")
+ axes[0, 0].set_ylabel("Q-head halt logit")
+ axes[1, 0].set_ylabel("Q-head halt logit")
+ axes[0, 0].legend(frameon=False, loc="lower right", fontsize=8)
+ fig.suptitle("PTRM Q-head score contains a stability signal; mixed problems reveal selector behavior")
+ fig.tight_layout()
+ fig.savefig(OUT / "fig5_qhead_vs_lambda1_ptrm.png", dpi=240)
+ plt.close(fig)
+
+ out_csv = OUT / "fig5_qhead_vs_lambda1_ptrm_summary.csv"
+ with out_csv.open("w", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=list(summaries[0].keys()))
+ writer.writeheader()
+ writer.writerows(summaries)
+ return summaries
+
+
+def main() -> None:
+ summaries = plot()
+ for row in summaries:
+ print(
+ f"{row['name']}: "
+ f"q_exact={row['q_max_exact']:.4f} "
+ f"lambda_min_exact={row['lambda_min_exact']:.4f} "
+ f"oracle={row['oracle_pass_exact']:.4f} "
+ f"pearson={row['global_pearson_q_vs_stability']:.4f} "
+ f"spearman={row['global_spearman_q_vs_stability']:.4f} "
+ f"within_spearman={row['within_problem_spearman_mean']:.4f} "
+ f"mixed_pearson={row['mixed_global_pearson_q_vs_stability']:.4f}"
+ )
+ print(f"wrote {OUT / 'fig5_qhead_vs_lambda1_ptrm.png'}")
+ print(f"wrote {OUT / 'fig5_qhead_vs_lambda1_ptrm_summary.csv'}")
+
+
+if __name__ == "__main__":
+ main()