summaryrefslogtreecommitdiff
path: root/research/flossing/make_trm_multi4_meeting_artifacts.py
diff options
context:
space:
mode:
Diffstat (limited to 'research/flossing/make_trm_multi4_meeting_artifacts.py')
-rw-r--r--research/flossing/make_trm_multi4_meeting_artifacts.py133
1 files changed, 133 insertions, 0 deletions
diff --git a/research/flossing/make_trm_multi4_meeting_artifacts.py b/research/flossing/make_trm_multi4_meeting_artifacts.py
new file mode 100644
index 0000000..5f964ed
--- /dev/null
+++ b/research/flossing/make_trm_multi4_meeting_artifacts.py
@@ -0,0 +1,133 @@
+from __future__ import annotations
+
+from pathlib import Path
+import csv
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+
+ROOT = Path("research/flossing")
+SPECTRUM_DIR = ROOT / "official_gbs768_spectrum"
+PTRM_DIR = ROOT / "ptrm_same_subset"
+PLOT_DIR = ROOT / "meeting_artifacts"
+
+
+def _read_csv_rows(path: Path) -> list[dict[str, str]]:
+ with path.open(newline="") as fh:
+ return list(csv.DictReader(fh))
+
+
+def _load_spectrum_rows() -> list[tuple[str, int, np.lib.npyio.NpzFile]]:
+ return [
+ (
+ "TRM baseline best",
+ 58590,
+ np.load(SPECTRUM_DIR / "trm_gbs768_base_step58590_n512_k8_seed20260602.npz"),
+ ),
+ (
+ "multi4 best",
+ 35805,
+ np.load(SPECTRUM_DIR / "trm_gbs768_multi4_step35805_n512_k8_seed20260602.npz"),
+ ),
+ (
+ "multi4 final",
+ 65100,
+ np.load(SPECTRUM_DIR / "trm_gbs768_multi4_step65100_n512_k8_seed20260602.npz"),
+ ),
+ ]
+
+
+def plot_spectra() -> None:
+ rows = _load_spectrum_rows()
+ fig, axes = plt.subplots(1, 3, figsize=(13.5, 4.2), sharey=True)
+ for ax, (name, step, data) in zip(axes, rows):
+ spec = data["lyap_spec"].astype(float)
+ exact = data["exact_correct"].astype(bool)
+ x = np.arange(1, spec.shape[1] + 1)
+ for mask, color, label in [
+ (exact, "#1f77b4", "success"),
+ (~exact, "#d62728", "failure"),
+ ]:
+ mean = spec[mask].mean(axis=0)
+ se = spec[mask].std(axis=0) / max(mask.sum(), 1) ** 0.5
+ ax.plot(x, mean, marker="o", color=color, label=label)
+ ax.fill_between(x, mean - se, mean + se, color=color, alpha=0.16, linewidth=0)
+ ax.axhline(0, color="black", linewidth=0.8, alpha=0.55)
+ ax.set_title(f"{name}\nstep {step}, acc={exact.mean():.3f}")
+ ax.set_xlabel("Lyapunov spectrum rank")
+ ax.grid(alpha=0.22)
+ axes[0].set_ylabel("finite-time exponent")
+ axes[0].legend(frameon=False)
+ fig.suptitle("TRM official GBS768: spectrum separates success/failure; multi4 best suppresses successful spectrum")
+ fig.tight_layout()
+ fig.savefig(PLOT_DIR / "trm_gbs768_spectrum_success_failure_n512.png", dpi=220)
+ plt.close(fig)
+
+
+def plot_headline() -> None:
+ rows = _read_csv_rows(SPECTRUM_DIR / "headline_trm_multi4_dynamics_table.csv")
+ labels = ["baseline\nbest", "multi4\nbest", "multi4\nfinal"]
+ fig, axes = plt.subplots(1, 3, figsize=(12, 3.8))
+ metrics = [
+ ("full_exact", "Full test exact acc", "#2f6f4e"),
+ ("mean8_all", "Mean top-8 exponent", "#8a4b2a"),
+ ("pos_count_all", "Positive exponents / 8", "#4d5f8a"),
+ ]
+ for ax, (col, title, color) in zip(axes, metrics):
+ values = np.asarray([float(row[col]) for row in rows])
+ ax.bar(labels, values, color=color, alpha=0.86)
+ ax.set_title(title)
+ ax.grid(axis="y", alpha=0.22)
+ for i, v in enumerate(values):
+ ax.text(i, v, f"{v:.3f}" if col != "pos_count_all" else f"{v:.2f}", ha="center", va="bottom", fontsize=9)
+ fig.suptitle("Accuracy improves at multi4 best while chaotic volume drops; final collapse restores positive spectrum")
+ fig.tight_layout()
+ fig.savefig(PLOT_DIR / "trm_gbs768_headline_dynamics_bars.png", dpi=220)
+ plt.close(fig)
+
+
+def plot_ptrm() -> None:
+ row = _read_csv_rows(PTRM_DIR / "paired_ptrm_k100_n1000_seed0_summary.csv")[0]
+ labels = ["deterministic", "mean rollout", "Q-selected", "oracle"]
+ base = [
+ float(row["base_det"]),
+ float(row["base_mean_rollout"]),
+ float(row["base_qmax"]),
+ float(row["base_oracle"]),
+ ]
+ multi = [
+ float(row["multi4_det"]),
+ float(row["multi4_mean_rollout"]),
+ float(row["multi4_qmax"]),
+ float(row["multi4_oracle"]),
+ ]
+ x = np.arange(len(labels))
+ width = 0.36
+ fig, ax = plt.subplots(figsize=(8.4, 4.2))
+ ax.bar(x - width / 2, base, width, label="baseline best", color="#6b7280")
+ ax.bar(x + width / 2, multi, width, label="multi4 best", color="#2563eb")
+ ax.set_ylim(0.86, 1.0)
+ ax.set_xticks(x, labels)
+ ax.set_ylabel("exact accuracy")
+ ax.set_title("PTRM same-subset comparison (n=1000, K=100, D=64, sigma=0.3, L-only)")
+ ax.grid(axis="y", alpha=0.22)
+ ax.legend(frameon=False)
+ for xpos, vals in [(x - width / 2, base), (x + width / 2, multi)]:
+ for xi, v in zip(xpos, vals):
+ ax.text(xi, v + 0.002, f"{v:.3f}", ha="center", va="bottom", fontsize=8)
+ fig.tight_layout()
+ fig.savefig(PLOT_DIR / "ptrm_same_subset_k100_n1000_comparison.png", dpi=220)
+ plt.close(fig)
+
+
+def main() -> None:
+ PLOT_DIR.mkdir(parents=True, exist_ok=True)
+ plot_spectra()
+ plot_headline()
+ plot_ptrm()
+ print(f"wrote plots to {PLOT_DIR}")
+
+
+if __name__ == "__main__":
+ main()