summaryrefslogtreecommitdiff
path: root/research/flossing/make_meeting_artifacts_v2.py
diff options
context:
space:
mode:
Diffstat (limited to 'research/flossing/make_meeting_artifacts_v2.py')
-rw-r--r--research/flossing/make_meeting_artifacts_v2.py364
1 files changed, 364 insertions, 0 deletions
diff --git a/research/flossing/make_meeting_artifacts_v2.py b/research/flossing/make_meeting_artifacts_v2.py
new file mode 100644
index 0000000..4c88412
--- /dev/null
+++ b/research/flossing/make_meeting_artifacts_v2.py
@@ -0,0 +1,364 @@
+from __future__ import annotations
+
+import csv
+import re
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+
+ROOT = Path("research/flossing")
+OUT = ROOT / "meeting_artifacts_v2"
+
+
+def read_csv(path: Path) -> list[dict[str, str]]:
+ with path.open(newline="") as f:
+ return list(csv.DictReader(f))
+
+
+def f(row: dict[str, str], key: str, default: float = float("nan")) -> float:
+ try:
+ val = row.get(key, "")
+ return float(val) if val != "" else default
+ except Exception:
+ return default
+
+
+def row_for(rows: list[dict[str, str]], step: int) -> dict[str, str]:
+ for row in rows:
+ if int(float(row["step"])) == step:
+ return row
+ raise KeyError(step)
+
+
+def spectrum_metrics(npz_path: Path) -> dict[str, float]:
+ data = np.load(npz_path)
+ spec = data["lyap_spec"].astype(float)
+ exact = data["exact_correct"].astype(bool)
+ mean8 = spec.mean(axis=1)
+ pos_count = (spec > 0).sum(axis=1)
+ return {
+ "sample_exact": float(exact.mean()),
+ "lambda1_all": float(spec[:, 0].mean()),
+ "mean8_all": float(mean8.mean()),
+ "pos_count_all": float(pos_count.mean()),
+ "lambda1_success": float(spec[exact, 0].mean()),
+ "lambda1_fail": float(spec[~exact, 0].mean()),
+ "mean8_success": float(mean8[exact].mean()),
+ "mean8_fail": float(mean8[~exact].mean()),
+ "pos_count_success": float(pos_count[exact].mean()),
+ "pos_count_fail": float(pos_count[~exact].mean()),
+ }
+
+
+def spectrum_arrays(npz_path: Path) -> tuple[np.ndarray, np.ndarray]:
+ data = np.load(npz_path)
+ return data["lyap_spec"].astype(float), data["exact_correct"].astype(bool)
+
+
+def auc_failure_score(score: np.ndarray, exact: np.ndarray) -> float:
+ """AUROC where higher score predicts failure."""
+ fail = ~exact
+ n_fail = int(fail.sum())
+ n_succ = int(exact.sum())
+ if n_fail == 0 or n_succ == 0:
+ return float("nan")
+ order = np.argsort(score)
+ ranks = np.empty_like(order, dtype=float)
+ ranks[order] = np.arange(1, len(score) + 1)
+ return float((ranks[fail].sum() - n_fail * (n_fail + 1) / 2) / (n_fail * n_succ))
+
+
+def get_data() -> dict[str, object]:
+ eval_dir = ROOT / "multi4_eval_compare"
+ spec_dir = ROOT / "official_gbs768_spectrum"
+
+ hrm_base = read_csv(eval_dir / "hrm_baseline_eval.csv")
+ hrm_multi = read_csv(eval_dir / "hrm_multi4_complete_eval.csv")
+ trm_base = read_csv(eval_dir / "trm_official_gbs768_eval.csv")
+ trm_multi = read_csv(eval_dir / "trm_official_gbs768_multi4_eval.csv")
+
+ hrm_points = [
+ {
+ "label": "baseline best",
+ "family": "baseline",
+ "step": 26040,
+ "full_exact": f(row_for(hrm_base, 26040), "all/exact_accuracy"),
+ **spectrum_metrics(ROOT / "diag_hrm_step_26040_512.npz"),
+ },
+ {
+ "label": "multi4 best",
+ "family": "multi4",
+ "step": 23436,
+ "full_exact": f(row_for(hrm_multi, 23436), "exact"),
+ **spectrum_metrics(ROOT / "diag_hrm_multi4_step_23436_512.npz"),
+ },
+ {
+ "label": "multi4 final",
+ "family": "multi4-final",
+ "step": 26040,
+ "full_exact": f(row_for(hrm_multi, 26040), "exact"),
+ **spectrum_metrics(ROOT / "diag_hrm_multi4_step_26040_512.npz"),
+ },
+ ]
+
+ trm_points = [
+ {
+ "label": "baseline best",
+ "family": "baseline",
+ "step": 58590,
+ "full_exact": f(row_for(trm_base, 58590), "all/exact_accuracy"),
+ **spectrum_metrics(spec_dir / "trm_gbs768_base_step58590_n512_k8_seed20260602.npz"),
+ },
+ {
+ "label": "multi4 best",
+ "family": "multi4",
+ "step": 35805,
+ "full_exact": f(row_for(trm_multi, 35805), "all/exact_accuracy"),
+ **spectrum_metrics(spec_dir / "trm_gbs768_multi4_step35805_n512_k8_seed20260602.npz"),
+ },
+ {
+ "label": "multi4 final",
+ "family": "multi4-final",
+ "step": 65100,
+ "full_exact": f(row_for(trm_multi, 65100), "all/exact_accuracy"),
+ **spectrum_metrics(spec_dir / "trm_gbs768_multi4_step65100_n512_k8_seed20260602.npz"),
+ },
+ ]
+
+ return {
+ "hrm_base": hrm_base,
+ "hrm_multi": hrm_multi,
+ "trm_base": trm_base,
+ "trm_multi": trm_multi,
+ "hrm_points": hrm_points,
+ "trm_points": trm_points,
+ }
+
+
+def metric_series(rows: list[dict[str, str]], metric: str) -> tuple[np.ndarray, np.ndarray]:
+ xs, ys = [], []
+ for row in rows:
+ val = f(row, metric)
+ if np.isfinite(val):
+ xs.append(int(float(row["step"])))
+ ys.append(val)
+ order = np.argsort(xs)
+ return np.asarray(xs)[order], np.asarray(ys)[order]
+
+
+def plot_training_curves(data: dict[str, object]) -> None:
+ fig, axes = plt.subplots(1, 2, figsize=(11.8, 4.2))
+ colors = {"baseline": "#475569", "multi4": "#2563eb"}
+
+ panels = [
+ ("HRM", data["hrm_base"], data["hrm_multi"], "all/exact_accuracy", "exact", 0.70),
+ ("TRM official GBS768", data["trm_base"], data["trm_multi"], "all/exact_accuracy", "all/exact_accuracy", 0.94),
+ ]
+ for ax, (title, base_rows, multi_rows, base_metric, multi_metric, ymax) in zip(axes, panels):
+ bx, by = metric_series(base_rows, base_metric) # type: ignore[arg-type]
+ mx, my = metric_series(multi_rows, multi_metric) # type: ignore[arg-type]
+ ax.plot(bx / 1000, by, marker="o", color=colors["baseline"], label="baseline", linewidth=2)
+ ax.plot(mx / 1000, my, marker="o", color=colors["multi4"], label="multi4", linewidth=2)
+ bi, mi = int(np.argmax(by)), int(np.argmax(my))
+ ax.scatter([bx[bi] / 1000], [by[bi]], s=95, color=colors["baseline"], edgecolor="white", zorder=5)
+ ax.scatter([mx[mi] / 1000], [my[mi]], s=95, color=colors["multi4"], edgecolor="white", zorder=5)
+ ax.annotate(f"best {by[bi]:.3f}", (bx[bi] / 1000, by[bi]), xytext=(6, -18), textcoords="offset points", fontsize=8)
+ ax.annotate(f"best {my[mi]:.3f}", (mx[mi] / 1000, my[mi]), xytext=(6, 8), textcoords="offset points", fontsize=8)
+ ax.set_title(title)
+ ax.set_xlabel("optimizer step (k)")
+ ax.set_ylabel("full test exact accuracy")
+ ax.set_ylim(0, ymax)
+ ax.grid(alpha=0.22)
+ ax.legend(frameon=False, loc="lower right")
+
+ fig.suptitle("Trajectory perturbation helps both HRM and TRM, but late checkpoints can collapse")
+ fig.tight_layout()
+ fig.savefig(OUT / "fig1_hrm_trm_training_curves.png", dpi=220)
+ plt.close(fig)
+
+
+def plot_lambda1_motivation() -> None:
+ entries = [
+ ("HRM baseline best\nstep 26040", ROOT / "diag_hrm_step_26040_512.npz"),
+ (
+ "TRM baseline best\nstep 58590",
+ ROOT / "official_gbs768_spectrum/trm_gbs768_base_step58590_n512_k8_seed20260602.npz",
+ ),
+ ]
+ fig, axes = plt.subplots(1, 2, figsize=(11.6, 4.2), sharey=False)
+ for ax, (title, path) in zip(axes, entries):
+ spec, exact = spectrum_arrays(path)
+ lam1 = spec[:, 0]
+ succ = lam1[exact]
+ fail = lam1[~exact]
+ lo, hi = np.quantile(lam1, [0.01, 0.99])
+ pad = 0.08 * (hi - lo)
+ bins = np.linspace(lo - pad, hi + pad, 34)
+ ax.hist(succ, bins=bins, density=True, alpha=0.62, color="#2563eb", label=f"success (n={len(succ)})")
+ ax.hist(fail, bins=bins, density=True, alpha=0.58, color="#dc2626", label=f"failure (n={len(fail)})")
+ ax.axvline(succ.mean(), color="#1d4ed8", linewidth=2)
+ ax.axvline(fail.mean(), color="#b91c1c", linewidth=2)
+ auc = auc_failure_score(lam1, exact)
+ ax.set_title(f"{title}\nacc={exact.mean():.3f}, AUROC={auc:.3f}")
+ ax.set_xlabel("top Lyapunov exponent λ1")
+ ax.set_ylabel("density")
+ ax.grid(alpha=0.2)
+ ax.legend(frameon=False, fontsize=8)
+ fig.suptitle("Motivation: λ1 is a strong success/failure detector in both HRM and TRM")
+ fig.tight_layout()
+ fig.savefig(OUT / "fig0_motivation_lambda1_success_failure_hrm_trm.png", dpi=220)
+ plt.close(fig)
+
+
+def plot_phase(data: dict[str, object]) -> None:
+ fig, axes = plt.subplots(1, 2, figsize=(11.8, 4.4), sharex=False)
+ panels = [("HRM", data["hrm_points"]), ("TRM official GBS768", data["trm_points"])]
+ style = {
+ "baseline": ("#475569", "o"),
+ "multi4": ("#2563eb", "o"),
+ "multi4-final": ("#dc2626", "X"),
+ }
+ for ax, (title, points) in zip(axes, panels):
+ for point in points: # type: ignore[assignment]
+ color, marker = style[point["family"]]
+ ax.scatter(point["mean8_all"], point["full_exact"], s=150, marker=marker, color=color, edgecolor="white", zorder=4)
+ ax.annotate(
+ f"{point['label']}\nstep {point['step']}",
+ (point["mean8_all"], point["full_exact"]),
+ xytext=(8, 5),
+ textcoords="offset points",
+ fontsize=8,
+ )
+ ax.set_title(title)
+ ax.set_xlabel("mean top-8 Lyapunov exponent (lower = more stable)")
+ ax.set_ylabel("full test exact accuracy")
+ ax.grid(alpha=0.22)
+ fig.suptitle("Accuracy-vs-chaotic-volume phase view: best multi4 moves up-left; collapse moves down-right")
+ fig.tight_layout()
+ fig.savefig(OUT / "fig2_accuracy_vs_chaotic_volume_phase.png", dpi=220)
+ plt.close(fig)
+
+
+def plot_spectrum_grid() -> None:
+ entries = [
+ ("HRM baseline best", ROOT / "diag_hrm_step_26040_512.npz"),
+ ("HRM multi4 best", ROOT / "diag_hrm_multi4_step_23436_512.npz"),
+ ("TRM baseline best", ROOT / "official_gbs768_spectrum/trm_gbs768_base_step58590_n512_k8_seed20260602.npz"),
+ ("TRM multi4 best", ROOT / "official_gbs768_spectrum/trm_gbs768_multi4_step35805_n512_k8_seed20260602.npz"),
+ ]
+ fig, axes = plt.subplots(2, 2, figsize=(10.8, 7.2), sharey=False)
+ for ax, (title, path) in zip(axes.flat, entries):
+ spec, exact = spectrum_arrays(path)
+ x = np.arange(1, spec.shape[1] + 1)
+ for mask, color, label in [(exact, "#2563eb", "success"), (~exact, "#dc2626", "failure")]:
+ mean = spec[mask].mean(axis=0)
+ se = spec[mask].std(axis=0) / max(mask.sum(), 1) ** 0.5
+ ax.plot(x, mean, marker="o", color=color, label=label)
+ ax.fill_between(x, mean - se, mean + se, color=color, alpha=0.16, linewidth=0)
+ ax.axhline(0, color="black", linewidth=0.8, alpha=0.55)
+ ax.set_title(f"{title}\nN={len(exact)}, acc={exact.mean():.3f}")
+ ax.set_xlabel("spectrum rank")
+ ax.set_ylabel("finite-time exponent")
+ ax.grid(alpha=0.22)
+ axes.flat[0].legend(frameon=False)
+ fig.suptitle("Success/failure separation is present in both HRM and TRM; multi4 mainly stabilizes successful trajectories")
+ fig.tight_layout()
+ fig.savefig(OUT / "fig3_hrm_trm_success_failure_spectra.png", dpi=220)
+ plt.close(fig)
+
+
+def plot_ptrm() -> None:
+ summary = read_csv(ROOT / "ptrm_same_subset/paired_ptrm_k100_n1000_seed0_summary.csv")[0]
+ labels = ["deterministic", "mean rollout", "Q-selected", "oracle"]
+ base = [float(summary[k]) for k in ["base_det", "base_mean_rollout", "base_qmax", "base_oracle"]]
+ multi = [float(summary[k]) for k in ["multi4_det", "multi4_mean_rollout", "multi4_qmax", "multi4_oracle"]]
+ x = np.arange(len(labels))
+ width = 0.36
+ fig, ax = plt.subplots(figsize=(8.7, 4.3))
+ ax.bar(x - width / 2, base, width, color="#64748b", label="baseline best")
+ ax.bar(x + width / 2, multi, width, color="#2563eb", label="multi4 best")
+ ax.set_xticks(x, labels)
+ ax.set_ylim(0.86, 1.0)
+ ax.set_ylabel("exact accuracy")
+ ax.set_title("PTRM same-subset comparison (n=1000, K=100, D=64, sigma=0.3, L-only)")
+ ax.grid(axis="y", alpha=0.22)
+ ax.legend(frameon=False)
+ for xpos, vals in [(x - width / 2, base), (x + width / 2, multi)]:
+ for xi, val in zip(xpos, vals):
+ ax.text(xi, val + 0.002, f"{val:.3f}", ha="center", va="bottom", fontsize=8)
+ fig.tight_layout()
+ fig.savefig(OUT / "fig4_ptrm_same_subset_comparison.png", dpi=220)
+ plt.close(fig)
+
+
+def write_summary(data: dict[str, object]) -> None:
+ rows = []
+ for model, points in [("HRM", data["hrm_points"]), ("TRM", data["trm_points"])]:
+ for point in points: # type: ignore[assignment]
+ rows.append(
+ {
+ "model": model,
+ "label": point["label"],
+ "step": point["step"],
+ "full_exact": point["full_exact"],
+ "sample_exact": point["sample_exact"],
+ "lambda1_all": point["lambda1_all"],
+ "mean8_all": point["mean8_all"],
+ "pos_count_all": point["pos_count_all"],
+ "lambda1_success": point["lambda1_success"],
+ "lambda1_fail": point["lambda1_fail"],
+ "mean8_success": point["mean8_success"],
+ "mean8_fail": point["mean8_fail"],
+ "pos_count_success": point["pos_count_success"],
+ "pos_count_fail": point["pos_count_fail"],
+ }
+ )
+ with (OUT / "hrm_trm_redesigned_summary.csv").open("w", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
+ writer.writeheader()
+ writer.writerows(rows)
+
+ ptrm = read_csv(ROOT / "ptrm_same_subset/paired_ptrm_k100_n1000_seed0_summary.csv")[0]
+ report = f"""# Meeting Figures v2
+
+## Figure Strategy
+
+0. `fig0_motivation_lambda1_success_failure_hrm_trm.png`: first-exponent success/failure distribution in HRM and TRM. This motivates chaos as a detector before introducing the method.
+1. `fig1_hrm_trm_training_curves.png`: performance over training for HRM and TRM. This answers whether the method improves accuracy and where best/final are.
+2. `fig2_accuracy_vs_chaotic_volume_phase.png`: phase view, with accuracy versus mean top-8 Lyapunov exponent. This answers whether better checkpoints are dynamically more stable.
+3. `fig3_hrm_trm_success_failure_spectra.png`: full success/failure spectrum separation for HRM and TRM best checkpoints. This extends Fig0 beyond λ1.
+4. `fig4_ptrm_same_subset_comparison.png`: PTRM same-subset result. This is a secondary inference-time story.
+
+## Key Numbers
+
+- HRM baseline best: 0.5265 exact. HRM multi4 best: 0.6443 exact. HRM multi4 final: 0.4624 exact.
+- TRM baseline best: 0.8686 exact. TRM multi4 best: 0.8965 exact. TRM multi4 final: 0.8351 exact.
+- HRM multi4 best dynamics sample: mean top-8 exponent {rows[1]['mean8_all']:+.4f}; final {rows[2]['mean8_all']:+.4f}.
+- TRM multi4 best dynamics sample: mean top-8 exponent {rows[4]['mean8_all']:+.4f}; final {rows[5]['mean8_all']:+.4f}.
+- PTRM same subset, K=100: Q-selected {float(ptrm['base_qmax']):.3f} -> {float(ptrm['multi4_qmax']):.3f}; mean rollout {float(ptrm['base_mean_rollout']):.3f} -> {float(ptrm['multi4_mean_rollout']):.3f}.
+
+## Caveats
+
+- Dynamics spectra use N=512 diagnostic samples, not the full test set.
+- PTRM numbers use a fixed N=1000 subset; do not mix its deterministic subset accuracy with full-test W&B exact accuracy.
+- Final checkpoints are collapse diagnostics, not the method's reported performance.
+"""
+ (OUT / "meeting_figures_v2_report.md").write_text(report)
+
+
+def main() -> None:
+ OUT.mkdir(parents=True, exist_ok=True)
+ data = get_data()
+ plot_lambda1_motivation()
+ plot_training_curves(data)
+ plot_phase(data)
+ plot_spectrum_grid()
+ plot_ptrm()
+ write_summary(data)
+ print(f"wrote redesigned artifacts to {OUT}")
+
+
+if __name__ == "__main__":
+ main()