diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
| commit | 66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch) | |
| tree | c29cba61124018755a19b02c9d33e3ad5f2e05cc /research/flossing/plot_directional_lyap_perturb.py | |
Curated export for clone-and-run Maze training (2x A6000) + diagnostics.
trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible).
Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'research/flossing/plot_directional_lyap_perturb.py')
| -rw-r--r-- | research/flossing/plot_directional_lyap_perturb.py | 136 |
1 files changed, 136 insertions, 0 deletions
diff --git a/research/flossing/plot_directional_lyap_perturb.py b/research/flossing/plot_directional_lyap_perturb.py new file mode 100644 index 0000000..218c912 --- /dev/null +++ b/research/flossing/plot_directional_lyap_perturb.py @@ -0,0 +1,136 @@ +"""Plot directional finite-difference Lyapunov perturbation robustness.""" +from __future__ import annotations + +import argparse +import csv +from pathlib import Path + +import matplotlib.pyplot as plt + + +COLORS = { + "trm_baseline_best": "#334155", + "trm_multi4_best": "#0f766e", + "trm_multi4_final": "#dc2626", +} +MARKERS = { + "trm_baseline_best": "o", + "trm_multi4_best": "s", + "trm_multi4_final": "X", +} + + +def read_rows(paths: list[Path]) -> list[dict[str, str]]: + rows: list[dict[str, str]] = [] + for path in paths: + with path.open() as f: + rows.extend(csv.DictReader(f)) + return rows + + +def f(row: dict[str, str], key: str) -> float: + return float(row[key]) + + +def write_combined(path: Path, rows: list[dict[str, str]]) -> None: + keys: list[str] = [] + for row in rows: + for key in row: + if key not in keys: + keys.append(key) + with path.open("w", newline="") as out: + writer = csv.DictWriter(out, fieldnames=keys) + writer.writeheader() + writer.writerows(rows) + + +def plot_metric_grid(rows: list[dict[str, str]], metric: str, ylabel: str, out: Path) -> None: + labels = sorted({r["label"] for r in rows}) + afters = sorted({int(float(r["perturb_after"])) for r in rows}) + fig, axes = plt.subplots(1, len(afters), figsize=(4.1 * len(afters), 4.25), sharey=True) + if len(afters) == 1: + axes = [axes] + for ax, after in zip(axes, afters): + for label in labels: + lr = [r for r in rows if r["label"] == label and int(float(r["perturb_after"])) == after] + lr.sort(key=lambda r: f(r, "sigma")) + ax.plot( + [f(r, "sigma") for r in lr], + [f(r, metric) for r in lr], + marker=MARKERS.get(label, "o"), + linewidth=2.0, + markersize=5, + color=COLORS.get(label), + label=label.replace("trm_", "").replace("_", " "), + ) + ax.set_xscale("symlog", linthresh=1e-3) + ax.set_title(f"after {after}") + ax.set_xlabel("σ along selected direction") + ax.set_ylim(-0.02, 1.02) + ax.grid(alpha=0.23) + axes[0].set_ylabel(ylabel) + axes[-1].legend(frameon=False, fontsize=8, loc="best") + fig.suptitle(f"Finite-difference Lyapunov-direction perturbation: {ylabel}") + fig.tight_layout() + fig.savefig(out, dpi=220, bbox_inches="tight") + plt.close(fig) + + +def plot_sigma_slice(rows: list[dict[str, str]], sigma: float, out: Path) -> None: + labels = sorted({r["label"] for r in rows}) + metrics = [ + ("retain_worst_on_clean_success", "Clean-success robust retention"), + ("rescue_best_on_clean_fail", "Clean-fail best-sign rescue"), + ] + fig, axes = plt.subplots(1, 2, figsize=(11.2, 4.6), sharex=True) + for ax, (metric, title) in zip(axes, metrics): + for label in labels: + lr = [ + r + for r in rows + if r["label"] == label and abs(float(r["sigma"]) - sigma) <= max(1e-12, sigma * 1e-6) + ] + lr.sort(key=lambda r: int(float(r["perturb_after"]))) + ax.plot( + [int(float(r["perturb_after"])) for r in lr], + [f(r, metric) for r in lr], + marker=MARKERS.get(label, "o"), + linewidth=2.1, + markersize=6, + color=COLORS.get(label), + label=label.replace("trm_", "").replace("_", " "), + ) + ax.set_title(title) + ax.set_xlabel("Perturb after ACT step") + ax.set_ylim(-0.02, 1.02) + ax.grid(alpha=0.23) + axes[0].set_ylabel("Conditional probability") + axes[1].legend(frameon=False, loc="best") + fig.suptitle(f"Lyapunov-direction conditional behavior at σ={sigma:g}") + fig.tight_layout() + fig.savefig(out, dpi=220, bbox_inches="tight") + plt.close(fig) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--summaries", nargs="+", required=True) + parser.add_argument("--out-dir", required=True) + parser.add_argument("--slice-sigma", type=float, default=0.03) + args = parser.parse_args() + + out_dir = Path(args.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + rows = read_rows([Path(p) for p in args.summaries]) + write_combined(out_dir / "directional_lyap_perturb_combined.csv", rows) + plot_metric_grid(rows, "mean_sign_exact", "Mean ± sign exact", out_dir / "directional_mean_sign_exact_grid.png") + plot_metric_grid(rows, "worst_sign_exact", "Worst-sign exact", out_dir / "directional_worst_sign_exact_grid.png") + plot_metric_grid(rows, "retain_worst_on_clean_success", "Clean-success worst-sign retention", out_dir / "directional_retention_worst_grid.png") + plot_metric_grid(rows, "rescue_best_on_clean_fail", "Clean-fail best-sign rescue", out_dir / "directional_rescue_best_grid.png") + plot_metric_grid(rows, "selected_growth_mean", "Selected direction finite-time growth", out_dir / "directional_selected_growth_grid.png") + plot_sigma_slice(rows, args.slice_sigma, out_dir / f"directional_retention_rescue_sigma{args.slice_sigma:g}.png") + print(f"wrote {out_dir}") + + +if __name__ == "__main__": + main() |
