summaryrefslogtreecommitdiff
path: root/research/flossing/plot_initial_perturb_robustness.py
blob: 0c8e8c8c33301e7de07f3aeb94096c1c2ed6f8a5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""Plot initial recurrent-state perturbation robustness curves."""
from __future__ import annotations

import argparse
import csv
from pathlib import Path

import matplotlib.pyplot as plt


def read_rows(paths: list[Path]) -> list[dict[str, str]]:
    rows: list[dict[str, str]] = []
    for path in paths:
        with path.open() as f:
            rows.extend(csv.DictReader(f))
    return rows


def f(row: dict[str, str], key: str) -> float:
    return float(row[key])


def write_combined(path: Path, rows: list[dict[str, str]]) -> None:
    keys: list[str] = []
    for row in rows:
        for key in row:
            if key not in keys:
                keys.append(key)
    with path.open("w", newline="") as out:
        writer = csv.DictWriter(out, fieldnames=keys)
        writer.writeheader()
        writer.writerows(rows)


def plot_metric(rows: list[dict[str, str]], metric: str, ylabel: str, out: Path) -> None:
    labels = []
    for row in rows:
        label = row["label"]
        if label not in labels:
            labels.append(label)

    colors = {
        "trm_baseline_best": "#334155",
        "trm_multi4_best": "#0f766e",
        "trm_multi4_final": "#dc2626",
    }
    markers = {
        "trm_baseline_best": "o",
        "trm_multi4_best": "s",
        "trm_multi4_final": "X",
    }

    fig, ax = plt.subplots(figsize=(8.2, 5.0))
    for label in labels:
        lr = [r for r in rows if r["label"] == label]
        lr.sort(key=lambda r: f(r, "sigma"))
        xs = [f(r, "sigma") for r in lr]
        ys = [f(r, metric) for r in lr]
        ax.plot(
            xs,
            ys,
            marker=markers.get(label, "o"),
            linewidth=2.2,
            markersize=6,
            color=colors.get(label),
            label=label.replace("trm_", "").replace("_", " "),
        )
    ax.set_xscale("symlog", linthresh=3e-5)
    ax.set_xlabel("Initial recurrent-state perturbation σ")
    ax.set_ylabel(ylabel)
    ax.set_ylim(-0.02, 1.02)
    ax.grid(alpha=0.24)
    ax.legend(frameon=False, loc="best")
    ax.set_title("TRM robustness to initial latent trajectory perturbations")
    ax.text(
        0.0,
        -0.20,
        "Perturbation is applied once to z_H/z_L after reset, then the model unrolls deterministically. "
        "Mean rollout exact is per-trajectory accuracy over K=8 perturbed rollouts.",
        transform=ax.transAxes,
        ha="left",
        va="top",
        fontsize=9.2,
        color="#475569",
    )
    fig.tight_layout()
    fig.savefig(out, dpi=220, bbox_inches="tight")
    plt.close(fig)


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--summaries", nargs="+", required=True)
    parser.add_argument("--out-dir", required=True)
    args = parser.parse_args()

    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    rows = read_rows([Path(p) for p in args.summaries])
    write_combined(out_dir / "initial_perturb_robustness_combined.csv", rows)
    plot_metric(
        rows,
        "mean_rollout_exact",
        "Mean perturbed-rollout exact accuracy",
        out_dir / "initial_perturb_robustness_mean_rollout_exact.png",
    )
    plot_metric(
        rows,
        "pass_at_k",
        "Pass@K exact accuracy",
        out_dir / "initial_perturb_robustness_pass_at_k.png",
    )
    plot_metric(
        rows,
        "all_k",
        "All-K exact accuracy",
        out_dir / "initial_perturb_robustness_all_k.png",
    )
    print(f"wrote {out_dir}")


if __name__ == "__main__":
    main()