"""Plot late recurrent-state perturbation robustness curves.""" from __future__ import annotations import argparse import csv from pathlib import Path import matplotlib.pyplot as plt def read_rows(paths: list[Path]) -> list[dict[str, str]]: rows: list[dict[str, str]] = [] for path in paths: with path.open() as f: rows.extend(csv.DictReader(f)) return rows def f(row: dict[str, str], key: str) -> float: return float(row[key]) def write_combined(path: Path, rows: list[dict[str, str]]) -> None: keys: list[str] = [] for row in rows: for key in row: if key not in keys: keys.append(key) with path.open("w", newline="") as out: writer = csv.DictWriter(out, fieldnames=keys) writer.writeheader() writer.writerows(rows) COLORS = { "trm_baseline_best": "#334155", "trm_multi4_best": "#0f766e", "trm_multi4_final": "#dc2626", } MARKERS = { "trm_baseline_best": "o", "trm_multi4_best": "s", "trm_multi4_final": "X", } def plot_metric_grid(rows: list[dict[str, str]], metric: str, ylabel: str, out: Path) -> None: labels = sorted({r["label"] for r in rows}) afters = sorted({int(float(r["perturb_after"])) for r in rows}) fig, axes = plt.subplots(1, len(afters), figsize=(4.0 * len(afters), 4.2), sharey=True) if len(afters) == 1: axes = [axes] for ax, after in zip(axes, afters): for label in labels: lr = [r for r in rows if r["label"] == label and int(float(r["perturb_after"])) == after] lr.sort(key=lambda r: f(r, "sigma")) xs = [f(r, "sigma") for r in lr] ys = [f(r, metric) for r in lr] ax.plot( xs, ys, marker=MARKERS.get(label, "o"), linewidth=2.0, markersize=5, color=COLORS.get(label), label=label.replace("trm_", "").replace("_", " "), ) ax.set_xscale("symlog", linthresh=1e-3) ax.set_title(f"after {after}") ax.set_xlabel("σ") ax.grid(alpha=0.23) ax.set_ylim(-0.02, 1.02) axes[0].set_ylabel(ylabel) axes[-1].legend(frameon=False, fontsize=8, loc="best") fig.suptitle(f"Late-state perturbation robustness: {ylabel}") fig.tight_layout() fig.savefig(out, dpi=220, bbox_inches="tight") plt.close(fig) def plot_sigma_slice(rows: list[dict[str, str]], sigma: float, out: Path) -> None: labels = sorted({r["label"] for r in rows}) metrics = [ ("retain_mean_on_clean_success", "Clean-success retention"), ("rescue_mean_on_clean_fail", "Clean-fail rescue"), ] fig, axes = plt.subplots(1, 2, figsize=(11.2, 4.6), sharex=True) for ax, (metric, title) in zip(axes, metrics): for label in labels: lr = [ r for r in rows if r["label"] == label and abs(float(r["sigma"]) - sigma) <= max(1e-12, sigma * 1e-6) ] lr.sort(key=lambda r: int(float(r["perturb_after"]))) xs = [int(float(r["perturb_after"])) for r in lr] ys = [f(r, metric) for r in lr] ax.plot( xs, ys, marker=MARKERS.get(label, "o"), linewidth=2.1, markersize=6, color=COLORS.get(label), label=label.replace("trm_", "").replace("_", " "), ) ax.set_title(title) ax.set_xlabel("Perturb after ACT step") ax.set_ylim(-0.02, 1.02) ax.grid(alpha=0.23) axes[0].set_ylabel("Conditional probability") axes[1].legend(frameon=False, loc="best") fig.suptitle(f"Late perturbation conditional behavior at σ={sigma:g}") fig.tight_layout() fig.savefig(out, dpi=220, bbox_inches="tight") plt.close(fig) def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--summaries", nargs="+", required=True) parser.add_argument("--out-dir", required=True) parser.add_argument("--slice-sigma", type=float, default=0.1) args = parser.parse_args() out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) rows = read_rows([Path(p) for p in args.summaries]) write_combined(out_dir / "late_perturb_robustness_combined.csv", rows) plot_metric_grid(rows, "mean_rollout_exact", "Mean perturbed-rollout exact", out_dir / "late_perturb_mean_rollout_exact_grid.png") plot_metric_grid(rows, "pass_at_k", "Pass@K exact", out_dir / "late_perturb_pass_at_k_grid.png") plot_metric_grid(rows, "all_k", "All-K exact", out_dir / "late_perturb_all_k_grid.png") plot_metric_grid(rows, "retain_mean_on_clean_success", "Clean-success retention", out_dir / "late_perturb_retention_grid.png") plot_metric_grid(rows, "rescue_mean_on_clean_fail", "Clean-fail rescue", out_dir / "late_perturb_rescue_grid.png") plot_sigma_slice(rows, args.slice_sigma, out_dir / f"late_perturb_retention_rescue_sigma{args.slice_sigma:g}.png") print(f"wrote {out_dir}") if __name__ == "__main__": main()