diff options
Diffstat (limited to 'research/flossing/make_trm_multi4_meeting_artifacts.py')
| -rw-r--r-- | research/flossing/make_trm_multi4_meeting_artifacts.py | 133 |
1 files changed, 133 insertions, 0 deletions
diff --git a/research/flossing/make_trm_multi4_meeting_artifacts.py b/research/flossing/make_trm_multi4_meeting_artifacts.py new file mode 100644 index 0000000..5f964ed --- /dev/null +++ b/research/flossing/make_trm_multi4_meeting_artifacts.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +from pathlib import Path +import csv + +import matplotlib.pyplot as plt +import numpy as np + + +ROOT = Path("research/flossing") +SPECTRUM_DIR = ROOT / "official_gbs768_spectrum" +PTRM_DIR = ROOT / "ptrm_same_subset" +PLOT_DIR = ROOT / "meeting_artifacts" + + +def _read_csv_rows(path: Path) -> list[dict[str, str]]: + with path.open(newline="") as fh: + return list(csv.DictReader(fh)) + + +def _load_spectrum_rows() -> list[tuple[str, int, np.lib.npyio.NpzFile]]: + return [ + ( + "TRM baseline best", + 58590, + np.load(SPECTRUM_DIR / "trm_gbs768_base_step58590_n512_k8_seed20260602.npz"), + ), + ( + "multi4 best", + 35805, + np.load(SPECTRUM_DIR / "trm_gbs768_multi4_step35805_n512_k8_seed20260602.npz"), + ), + ( + "multi4 final", + 65100, + np.load(SPECTRUM_DIR / "trm_gbs768_multi4_step65100_n512_k8_seed20260602.npz"), + ), + ] + + +def plot_spectra() -> None: + rows = _load_spectrum_rows() + fig, axes = plt.subplots(1, 3, figsize=(13.5, 4.2), sharey=True) + for ax, (name, step, data) in zip(axes, rows): + spec = data["lyap_spec"].astype(float) + exact = data["exact_correct"].astype(bool) + x = np.arange(1, spec.shape[1] + 1) + for mask, color, label in [ + (exact, "#1f77b4", "success"), + (~exact, "#d62728", "failure"), + ]: + mean = spec[mask].mean(axis=0) + se = spec[mask].std(axis=0) / max(mask.sum(), 1) ** 0.5 + ax.plot(x, mean, marker="o", color=color, label=label) + ax.fill_between(x, mean - se, mean + se, color=color, alpha=0.16, linewidth=0) + ax.axhline(0, color="black", linewidth=0.8, alpha=0.55) + ax.set_title(f"{name}\nstep {step}, acc={exact.mean():.3f}") + ax.set_xlabel("Lyapunov spectrum rank") + ax.grid(alpha=0.22) + axes[0].set_ylabel("finite-time exponent") + axes[0].legend(frameon=False) + fig.suptitle("TRM official GBS768: spectrum separates success/failure; multi4 best suppresses successful spectrum") + fig.tight_layout() + fig.savefig(PLOT_DIR / "trm_gbs768_spectrum_success_failure_n512.png", dpi=220) + plt.close(fig) + + +def plot_headline() -> None: + rows = _read_csv_rows(SPECTRUM_DIR / "headline_trm_multi4_dynamics_table.csv") + labels = ["baseline\nbest", "multi4\nbest", "multi4\nfinal"] + fig, axes = plt.subplots(1, 3, figsize=(12, 3.8)) + metrics = [ + ("full_exact", "Full test exact acc", "#2f6f4e"), + ("mean8_all", "Mean top-8 exponent", "#8a4b2a"), + ("pos_count_all", "Positive exponents / 8", "#4d5f8a"), + ] + for ax, (col, title, color) in zip(axes, metrics): + values = np.asarray([float(row[col]) for row in rows]) + ax.bar(labels, values, color=color, alpha=0.86) + ax.set_title(title) + ax.grid(axis="y", alpha=0.22) + for i, v in enumerate(values): + ax.text(i, v, f"{v:.3f}" if col != "pos_count_all" else f"{v:.2f}", ha="center", va="bottom", fontsize=9) + fig.suptitle("Accuracy improves at multi4 best while chaotic volume drops; final collapse restores positive spectrum") + fig.tight_layout() + fig.savefig(PLOT_DIR / "trm_gbs768_headline_dynamics_bars.png", dpi=220) + plt.close(fig) + + +def plot_ptrm() -> None: + row = _read_csv_rows(PTRM_DIR / "paired_ptrm_k100_n1000_seed0_summary.csv")[0] + labels = ["deterministic", "mean rollout", "Q-selected", "oracle"] + base = [ + float(row["base_det"]), + float(row["base_mean_rollout"]), + float(row["base_qmax"]), + float(row["base_oracle"]), + ] + multi = [ + float(row["multi4_det"]), + float(row["multi4_mean_rollout"]), + float(row["multi4_qmax"]), + float(row["multi4_oracle"]), + ] + x = np.arange(len(labels)) + width = 0.36 + fig, ax = plt.subplots(figsize=(8.4, 4.2)) + ax.bar(x - width / 2, base, width, label="baseline best", color="#6b7280") + ax.bar(x + width / 2, multi, width, label="multi4 best", color="#2563eb") + ax.set_ylim(0.86, 1.0) + ax.set_xticks(x, labels) + ax.set_ylabel("exact accuracy") + ax.set_title("PTRM same-subset comparison (n=1000, K=100, D=64, sigma=0.3, L-only)") + ax.grid(axis="y", alpha=0.22) + ax.legend(frameon=False) + for xpos, vals in [(x - width / 2, base), (x + width / 2, multi)]: + for xi, v in zip(xpos, vals): + ax.text(xi, v + 0.002, f"{v:.3f}", ha="center", va="bottom", fontsize=8) + fig.tight_layout() + fig.savefig(PLOT_DIR / "ptrm_same_subset_k100_n1000_comparison.png", dpi=220) + plt.close(fig) + + +def main() -> None: + PLOT_DIR.mkdir(parents=True, exist_ok=True) + plot_spectra() + plot_headline() + plot_ptrm() + print(f"wrote plots to {PLOT_DIR}") + + +if __name__ == "__main__": + main() |
