diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
| commit | 66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch) | |
| tree | c29cba61124018755a19b02c9d33e3ad5f2e05cc /research/flossing/plot_initial_perturb_robustness.py | |
Curated export for clone-and-run Maze training (2x A6000) + diagnostics.
trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible).
Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'research/flossing/plot_initial_perturb_robustness.py')
| -rw-r--r-- | research/flossing/plot_initial_perturb_robustness.py | 123 |
1 files changed, 123 insertions, 0 deletions
diff --git a/research/flossing/plot_initial_perturb_robustness.py b/research/flossing/plot_initial_perturb_robustness.py new file mode 100644 index 0000000..0c8e8c8 --- /dev/null +++ b/research/flossing/plot_initial_perturb_robustness.py @@ -0,0 +1,123 @@ +"""Plot initial recurrent-state perturbation robustness curves.""" +from __future__ import annotations + +import argparse +import csv +from pathlib import Path + +import matplotlib.pyplot as plt + + +def read_rows(paths: list[Path]) -> list[dict[str, str]]: + rows: list[dict[str, str]] = [] + for path in paths: + with path.open() as f: + rows.extend(csv.DictReader(f)) + return rows + + +def f(row: dict[str, str], key: str) -> float: + return float(row[key]) + + +def write_combined(path: Path, rows: list[dict[str, str]]) -> None: + keys: list[str] = [] + for row in rows: + for key in row: + if key not in keys: + keys.append(key) + with path.open("w", newline="") as out: + writer = csv.DictWriter(out, fieldnames=keys) + writer.writeheader() + writer.writerows(rows) + + +def plot_metric(rows: list[dict[str, str]], metric: str, ylabel: str, out: Path) -> None: + labels = [] + for row in rows: + label = row["label"] + if label not in labels: + labels.append(label) + + colors = { + "trm_baseline_best": "#334155", + "trm_multi4_best": "#0f766e", + "trm_multi4_final": "#dc2626", + } + markers = { + "trm_baseline_best": "o", + "trm_multi4_best": "s", + "trm_multi4_final": "X", + } + + fig, ax = plt.subplots(figsize=(8.2, 5.0)) + for label in labels: + lr = [r for r in rows if r["label"] == label] + lr.sort(key=lambda r: f(r, "sigma")) + xs = [f(r, "sigma") for r in lr] + ys = [f(r, metric) for r in lr] + ax.plot( + xs, + ys, + marker=markers.get(label, "o"), + linewidth=2.2, + markersize=6, + color=colors.get(label), + label=label.replace("trm_", "").replace("_", " "), + ) + ax.set_xscale("symlog", linthresh=3e-5) + ax.set_xlabel("Initial recurrent-state perturbation σ") + ax.set_ylabel(ylabel) + ax.set_ylim(-0.02, 1.02) + ax.grid(alpha=0.24) + ax.legend(frameon=False, loc="best") + ax.set_title("TRM robustness to initial latent trajectory perturbations") + ax.text( + 0.0, + -0.20, + "Perturbation is applied once to z_H/z_L after reset, then the model unrolls deterministically. " + "Mean rollout exact is per-trajectory accuracy over K=8 perturbed rollouts.", + transform=ax.transAxes, + ha="left", + va="top", + fontsize=9.2, + color="#475569", + ) + fig.tight_layout() + fig.savefig(out, dpi=220, bbox_inches="tight") + plt.close(fig) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--summaries", nargs="+", required=True) + parser.add_argument("--out-dir", required=True) + args = parser.parse_args() + + out_dir = Path(args.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + rows = read_rows([Path(p) for p in args.summaries]) + write_combined(out_dir / "initial_perturb_robustness_combined.csv", rows) + plot_metric( + rows, + "mean_rollout_exact", + "Mean perturbed-rollout exact accuracy", + out_dir / "initial_perturb_robustness_mean_rollout_exact.png", + ) + plot_metric( + rows, + "pass_at_k", + "Pass@K exact accuracy", + out_dir / "initial_perturb_robustness_pass_at_k.png", + ) + plot_metric( + rows, + "all_k", + "All-K exact accuracy", + out_dir / "initial_perturb_robustness_all_k.png", + ) + print(f"wrote {out_dir}") + + +if __name__ == "__main__": + main() |
