summaryrefslogtreecommitdiff
path: root/research/flossing/plot_directional_lyap_perturb.py
diff options
context:
space:
mode:
Diffstat (limited to 'research/flossing/plot_directional_lyap_perturb.py')
-rw-r--r--research/flossing/plot_directional_lyap_perturb.py136
1 files changed, 136 insertions, 0 deletions
diff --git a/research/flossing/plot_directional_lyap_perturb.py b/research/flossing/plot_directional_lyap_perturb.py
new file mode 100644
index 0000000..218c912
--- /dev/null
+++ b/research/flossing/plot_directional_lyap_perturb.py
@@ -0,0 +1,136 @@
+"""Plot directional finite-difference Lyapunov perturbation robustness."""
+from __future__ import annotations
+
+import argparse
+import csv
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+
+
+COLORS = {
+ "trm_baseline_best": "#334155",
+ "trm_multi4_best": "#0f766e",
+ "trm_multi4_final": "#dc2626",
+}
+MARKERS = {
+ "trm_baseline_best": "o",
+ "trm_multi4_best": "s",
+ "trm_multi4_final": "X",
+}
+
+
+def read_rows(paths: list[Path]) -> list[dict[str, str]]:
+ rows: list[dict[str, str]] = []
+ for path in paths:
+ with path.open() as f:
+ rows.extend(csv.DictReader(f))
+ return rows
+
+
+def f(row: dict[str, str], key: str) -> float:
+ return float(row[key])
+
+
+def write_combined(path: Path, rows: list[dict[str, str]]) -> None:
+ keys: list[str] = []
+ for row in rows:
+ for key in row:
+ if key not in keys:
+ keys.append(key)
+ with path.open("w", newline="") as out:
+ writer = csv.DictWriter(out, fieldnames=keys)
+ writer.writeheader()
+ writer.writerows(rows)
+
+
+def plot_metric_grid(rows: list[dict[str, str]], metric: str, ylabel: str, out: Path) -> None:
+ labels = sorted({r["label"] for r in rows})
+ afters = sorted({int(float(r["perturb_after"])) for r in rows})
+ fig, axes = plt.subplots(1, len(afters), figsize=(4.1 * len(afters), 4.25), sharey=True)
+ if len(afters) == 1:
+ axes = [axes]
+ for ax, after in zip(axes, afters):
+ for label in labels:
+ lr = [r for r in rows if r["label"] == label and int(float(r["perturb_after"])) == after]
+ lr.sort(key=lambda r: f(r, "sigma"))
+ ax.plot(
+ [f(r, "sigma") for r in lr],
+ [f(r, metric) for r in lr],
+ marker=MARKERS.get(label, "o"),
+ linewidth=2.0,
+ markersize=5,
+ color=COLORS.get(label),
+ label=label.replace("trm_", "").replace("_", " "),
+ )
+ ax.set_xscale("symlog", linthresh=1e-3)
+ ax.set_title(f"after {after}")
+ ax.set_xlabel("σ along selected direction")
+ ax.set_ylim(-0.02, 1.02)
+ ax.grid(alpha=0.23)
+ axes[0].set_ylabel(ylabel)
+ axes[-1].legend(frameon=False, fontsize=8, loc="best")
+ fig.suptitle(f"Finite-difference Lyapunov-direction perturbation: {ylabel}")
+ fig.tight_layout()
+ fig.savefig(out, dpi=220, bbox_inches="tight")
+ plt.close(fig)
+
+
+def plot_sigma_slice(rows: list[dict[str, str]], sigma: float, out: Path) -> None:
+ labels = sorted({r["label"] for r in rows})
+ metrics = [
+ ("retain_worst_on_clean_success", "Clean-success robust retention"),
+ ("rescue_best_on_clean_fail", "Clean-fail best-sign rescue"),
+ ]
+ fig, axes = plt.subplots(1, 2, figsize=(11.2, 4.6), sharex=True)
+ for ax, (metric, title) in zip(axes, metrics):
+ for label in labels:
+ lr = [
+ r
+ for r in rows
+ if r["label"] == label and abs(float(r["sigma"]) - sigma) <= max(1e-12, sigma * 1e-6)
+ ]
+ lr.sort(key=lambda r: int(float(r["perturb_after"])))
+ ax.plot(
+ [int(float(r["perturb_after"])) for r in lr],
+ [f(r, metric) for r in lr],
+ marker=MARKERS.get(label, "o"),
+ linewidth=2.1,
+ markersize=6,
+ color=COLORS.get(label),
+ label=label.replace("trm_", "").replace("_", " "),
+ )
+ ax.set_title(title)
+ ax.set_xlabel("Perturb after ACT step")
+ ax.set_ylim(-0.02, 1.02)
+ ax.grid(alpha=0.23)
+ axes[0].set_ylabel("Conditional probability")
+ axes[1].legend(frameon=False, loc="best")
+ fig.suptitle(f"Lyapunov-direction conditional behavior at σ={sigma:g}")
+ fig.tight_layout()
+ fig.savefig(out, dpi=220, bbox_inches="tight")
+ plt.close(fig)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--summaries", nargs="+", required=True)
+ parser.add_argument("--out-dir", required=True)
+ parser.add_argument("--slice-sigma", type=float, default=0.03)
+ args = parser.parse_args()
+
+ out_dir = Path(args.out_dir)
+ out_dir.mkdir(parents=True, exist_ok=True)
+ rows = read_rows([Path(p) for p in args.summaries])
+ write_combined(out_dir / "directional_lyap_perturb_combined.csv", rows)
+ plot_metric_grid(rows, "mean_sign_exact", "Mean ± sign exact", out_dir / "directional_mean_sign_exact_grid.png")
+ plot_metric_grid(rows, "worst_sign_exact", "Worst-sign exact", out_dir / "directional_worst_sign_exact_grid.png")
+ plot_metric_grid(rows, "retain_worst_on_clean_success", "Clean-success worst-sign retention", out_dir / "directional_retention_worst_grid.png")
+ plot_metric_grid(rows, "rescue_best_on_clean_fail", "Clean-fail best-sign rescue", out_dir / "directional_rescue_best_grid.png")
+ plot_metric_grid(rows, "selected_growth_mean", "Selected direction finite-time growth", out_dir / "directional_selected_growth_grid.png")
+ plot_sigma_slice(rows, args.slice_sigma, out_dir / f"directional_retention_rescue_sigma{args.slice_sigma:g}.png")
+ print(f"wrote {out_dir}")
+
+
+if __name__ == "__main__":
+ main()