from __future__ import annotations from pathlib import Path import csv import matplotlib.pyplot as plt import numpy as np ROOT = Path("research/flossing") SPECTRUM_DIR = ROOT / "official_gbs768_spectrum" PTRM_DIR = ROOT / "ptrm_same_subset" PLOT_DIR = ROOT / "meeting_artifacts" def _read_csv_rows(path: Path) -> list[dict[str, str]]: with path.open(newline="") as fh: return list(csv.DictReader(fh)) def _load_spectrum_rows() -> list[tuple[str, int, np.lib.npyio.NpzFile]]: return [ ( "TRM baseline best", 58590, np.load(SPECTRUM_DIR / "trm_gbs768_base_step58590_n512_k8_seed20260602.npz"), ), ( "multi4 best", 35805, np.load(SPECTRUM_DIR / "trm_gbs768_multi4_step35805_n512_k8_seed20260602.npz"), ), ( "multi4 final", 65100, np.load(SPECTRUM_DIR / "trm_gbs768_multi4_step65100_n512_k8_seed20260602.npz"), ), ] def plot_spectra() -> None: rows = _load_spectrum_rows() fig, axes = plt.subplots(1, 3, figsize=(13.5, 4.2), sharey=True) for ax, (name, step, data) in zip(axes, rows): spec = data["lyap_spec"].astype(float) exact = data["exact_correct"].astype(bool) x = np.arange(1, spec.shape[1] + 1) for mask, color, label in [ (exact, "#1f77b4", "success"), (~exact, "#d62728", "failure"), ]: mean = spec[mask].mean(axis=0) se = spec[mask].std(axis=0) / max(mask.sum(), 1) ** 0.5 ax.plot(x, mean, marker="o", color=color, label=label) ax.fill_between(x, mean - se, mean + se, color=color, alpha=0.16, linewidth=0) ax.axhline(0, color="black", linewidth=0.8, alpha=0.55) ax.set_title(f"{name}\nstep {step}, acc={exact.mean():.3f}") ax.set_xlabel("Lyapunov spectrum rank") ax.grid(alpha=0.22) axes[0].set_ylabel("finite-time exponent") axes[0].legend(frameon=False) fig.suptitle("TRM official GBS768: spectrum separates success/failure; multi4 best suppresses successful spectrum") fig.tight_layout() fig.savefig(PLOT_DIR / "trm_gbs768_spectrum_success_failure_n512.png", dpi=220) plt.close(fig) def plot_headline() -> None: rows = _read_csv_rows(SPECTRUM_DIR / "headline_trm_multi4_dynamics_table.csv") labels = ["baseline\nbest", "multi4\nbest", "multi4\nfinal"] fig, axes = plt.subplots(1, 3, figsize=(12, 3.8)) metrics = [ ("full_exact", "Full test exact acc", "#2f6f4e"), ("mean8_all", "Mean top-8 exponent", "#8a4b2a"), ("pos_count_all", "Positive exponents / 8", "#4d5f8a"), ] for ax, (col, title, color) in zip(axes, metrics): values = np.asarray([float(row[col]) for row in rows]) ax.bar(labels, values, color=color, alpha=0.86) ax.set_title(title) ax.grid(axis="y", alpha=0.22) for i, v in enumerate(values): ax.text(i, v, f"{v:.3f}" if col != "pos_count_all" else f"{v:.2f}", ha="center", va="bottom", fontsize=9) fig.suptitle("Accuracy improves at multi4 best while chaotic volume drops; final collapse restores positive spectrum") fig.tight_layout() fig.savefig(PLOT_DIR / "trm_gbs768_headline_dynamics_bars.png", dpi=220) plt.close(fig) def plot_ptrm() -> None: row = _read_csv_rows(PTRM_DIR / "paired_ptrm_k100_n1000_seed0_summary.csv")[0] labels = ["deterministic", "mean rollout", "Q-selected", "oracle"] base = [ float(row["base_det"]), float(row["base_mean_rollout"]), float(row["base_qmax"]), float(row["base_oracle"]), ] multi = [ float(row["multi4_det"]), float(row["multi4_mean_rollout"]), float(row["multi4_qmax"]), float(row["multi4_oracle"]), ] x = np.arange(len(labels)) width = 0.36 fig, ax = plt.subplots(figsize=(8.4, 4.2)) ax.bar(x - width / 2, base, width, label="baseline best", color="#6b7280") ax.bar(x + width / 2, multi, width, label="multi4 best", color="#2563eb") ax.set_ylim(0.86, 1.0) ax.set_xticks(x, labels) ax.set_ylabel("exact accuracy") ax.set_title("PTRM same-subset comparison (n=1000, K=100, D=64, sigma=0.3, L-only)") ax.grid(axis="y", alpha=0.22) ax.legend(frameon=False) for xpos, vals in [(x - width / 2, base), (x + width / 2, multi)]: for xi, v in zip(xpos, vals): ax.text(xi, v + 0.002, f"{v:.3f}", ha="center", va="bottom", fontsize=8) fig.tight_layout() fig.savefig(PLOT_DIR / "ptrm_same_subset_k100_n1000_comparison.png", dpi=220) plt.close(fig) def main() -> None: PLOT_DIR.mkdir(parents=True, exist_ok=True) plot_spectra() plot_headline() plot_ptrm() print(f"wrote plots to {PLOT_DIR}") if __name__ == "__main__": main()