summaryrefslogtreecommitdiff
path: root/research/flossing/plot_initial_perturb_robustness.py
diff options
context:
space:
mode:
Diffstat (limited to 'research/flossing/plot_initial_perturb_robustness.py')
-rw-r--r--research/flossing/plot_initial_perturb_robustness.py123
1 files changed, 123 insertions, 0 deletions
diff --git a/research/flossing/plot_initial_perturb_robustness.py b/research/flossing/plot_initial_perturb_robustness.py
new file mode 100644
index 0000000..0c8e8c8
--- /dev/null
+++ b/research/flossing/plot_initial_perturb_robustness.py
@@ -0,0 +1,123 @@
+"""Plot initial recurrent-state perturbation robustness curves."""
+from __future__ import annotations
+
+import argparse
+import csv
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+
+
+def read_rows(paths: list[Path]) -> list[dict[str, str]]:
+ rows: list[dict[str, str]] = []
+ for path in paths:
+ with path.open() as f:
+ rows.extend(csv.DictReader(f))
+ return rows
+
+
+def f(row: dict[str, str], key: str) -> float:
+ return float(row[key])
+
+
+def write_combined(path: Path, rows: list[dict[str, str]]) -> None:
+ keys: list[str] = []
+ for row in rows:
+ for key in row:
+ if key not in keys:
+ keys.append(key)
+ with path.open("w", newline="") as out:
+ writer = csv.DictWriter(out, fieldnames=keys)
+ writer.writeheader()
+ writer.writerows(rows)
+
+
+def plot_metric(rows: list[dict[str, str]], metric: str, ylabel: str, out: Path) -> None:
+ labels = []
+ for row in rows:
+ label = row["label"]
+ if label not in labels:
+ labels.append(label)
+
+ colors = {
+ "trm_baseline_best": "#334155",
+ "trm_multi4_best": "#0f766e",
+ "trm_multi4_final": "#dc2626",
+ }
+ markers = {
+ "trm_baseline_best": "o",
+ "trm_multi4_best": "s",
+ "trm_multi4_final": "X",
+ }
+
+ fig, ax = plt.subplots(figsize=(8.2, 5.0))
+ for label in labels:
+ lr = [r for r in rows if r["label"] == label]
+ lr.sort(key=lambda r: f(r, "sigma"))
+ xs = [f(r, "sigma") for r in lr]
+ ys = [f(r, metric) for r in lr]
+ ax.plot(
+ xs,
+ ys,
+ marker=markers.get(label, "o"),
+ linewidth=2.2,
+ markersize=6,
+ color=colors.get(label),
+ label=label.replace("trm_", "").replace("_", " "),
+ )
+ ax.set_xscale("symlog", linthresh=3e-5)
+ ax.set_xlabel("Initial recurrent-state perturbation σ")
+ ax.set_ylabel(ylabel)
+ ax.set_ylim(-0.02, 1.02)
+ ax.grid(alpha=0.24)
+ ax.legend(frameon=False, loc="best")
+ ax.set_title("TRM robustness to initial latent trajectory perturbations")
+ ax.text(
+ 0.0,
+ -0.20,
+ "Perturbation is applied once to z_H/z_L after reset, then the model unrolls deterministically. "
+ "Mean rollout exact is per-trajectory accuracy over K=8 perturbed rollouts.",
+ transform=ax.transAxes,
+ ha="left",
+ va="top",
+ fontsize=9.2,
+ color="#475569",
+ )
+ fig.tight_layout()
+ fig.savefig(out, dpi=220, bbox_inches="tight")
+ plt.close(fig)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--summaries", nargs="+", required=True)
+ parser.add_argument("--out-dir", required=True)
+ args = parser.parse_args()
+
+ out_dir = Path(args.out_dir)
+ out_dir.mkdir(parents=True, exist_ok=True)
+ rows = read_rows([Path(p) for p in args.summaries])
+ write_combined(out_dir / "initial_perturb_robustness_combined.csv", rows)
+ plot_metric(
+ rows,
+ "mean_rollout_exact",
+ "Mean perturbed-rollout exact accuracy",
+ out_dir / "initial_perturb_robustness_mean_rollout_exact.png",
+ )
+ plot_metric(
+ rows,
+ "pass_at_k",
+ "Pass@K exact accuracy",
+ out_dir / "initial_perturb_robustness_pass_at_k.png",
+ )
+ plot_metric(
+ rows,
+ "all_k",
+ "All-K exact accuracy",
+ out_dir / "initial_perturb_robustness_all_k.png",
+ )
+ print(f"wrote {out_dir}")
+
+
+if __name__ == "__main__":
+ main()