diff options
Diffstat (limited to 'research/flossing/make_meeting_artifacts_v2.py')
| -rw-r--r-- | research/flossing/make_meeting_artifacts_v2.py | 364 |
1 files changed, 364 insertions, 0 deletions
diff --git a/research/flossing/make_meeting_artifacts_v2.py b/research/flossing/make_meeting_artifacts_v2.py new file mode 100644 index 0000000..4c88412 --- /dev/null +++ b/research/flossing/make_meeting_artifacts_v2.py @@ -0,0 +1,364 @@ +from __future__ import annotations + +import csv +import re +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + + +ROOT = Path("research/flossing") +OUT = ROOT / "meeting_artifacts_v2" + + +def read_csv(path: Path) -> list[dict[str, str]]: + with path.open(newline="") as f: + return list(csv.DictReader(f)) + + +def f(row: dict[str, str], key: str, default: float = float("nan")) -> float: + try: + val = row.get(key, "") + return float(val) if val != "" else default + except Exception: + return default + + +def row_for(rows: list[dict[str, str]], step: int) -> dict[str, str]: + for row in rows: + if int(float(row["step"])) == step: + return row + raise KeyError(step) + + +def spectrum_metrics(npz_path: Path) -> dict[str, float]: + data = np.load(npz_path) + spec = data["lyap_spec"].astype(float) + exact = data["exact_correct"].astype(bool) + mean8 = spec.mean(axis=1) + pos_count = (spec > 0).sum(axis=1) + return { + "sample_exact": float(exact.mean()), + "lambda1_all": float(spec[:, 0].mean()), + "mean8_all": float(mean8.mean()), + "pos_count_all": float(pos_count.mean()), + "lambda1_success": float(spec[exact, 0].mean()), + "lambda1_fail": float(spec[~exact, 0].mean()), + "mean8_success": float(mean8[exact].mean()), + "mean8_fail": float(mean8[~exact].mean()), + "pos_count_success": float(pos_count[exact].mean()), + "pos_count_fail": float(pos_count[~exact].mean()), + } + + +def spectrum_arrays(npz_path: Path) -> tuple[np.ndarray, np.ndarray]: + data = np.load(npz_path) + return data["lyap_spec"].astype(float), data["exact_correct"].astype(bool) + + +def auc_failure_score(score: np.ndarray, exact: np.ndarray) -> float: + """AUROC where higher score predicts failure.""" + fail = ~exact + n_fail = int(fail.sum()) + n_succ = int(exact.sum()) + if n_fail == 0 or n_succ == 0: + return float("nan") + order = np.argsort(score) + ranks = np.empty_like(order, dtype=float) + ranks[order] = np.arange(1, len(score) + 1) + return float((ranks[fail].sum() - n_fail * (n_fail + 1) / 2) / (n_fail * n_succ)) + + +def get_data() -> dict[str, object]: + eval_dir = ROOT / "multi4_eval_compare" + spec_dir = ROOT / "official_gbs768_spectrum" + + hrm_base = read_csv(eval_dir / "hrm_baseline_eval.csv") + hrm_multi = read_csv(eval_dir / "hrm_multi4_complete_eval.csv") + trm_base = read_csv(eval_dir / "trm_official_gbs768_eval.csv") + trm_multi = read_csv(eval_dir / "trm_official_gbs768_multi4_eval.csv") + + hrm_points = [ + { + "label": "baseline best", + "family": "baseline", + "step": 26040, + "full_exact": f(row_for(hrm_base, 26040), "all/exact_accuracy"), + **spectrum_metrics(ROOT / "diag_hrm_step_26040_512.npz"), + }, + { + "label": "multi4 best", + "family": "multi4", + "step": 23436, + "full_exact": f(row_for(hrm_multi, 23436), "exact"), + **spectrum_metrics(ROOT / "diag_hrm_multi4_step_23436_512.npz"), + }, + { + "label": "multi4 final", + "family": "multi4-final", + "step": 26040, + "full_exact": f(row_for(hrm_multi, 26040), "exact"), + **spectrum_metrics(ROOT / "diag_hrm_multi4_step_26040_512.npz"), + }, + ] + + trm_points = [ + { + "label": "baseline best", + "family": "baseline", + "step": 58590, + "full_exact": f(row_for(trm_base, 58590), "all/exact_accuracy"), + **spectrum_metrics(spec_dir / "trm_gbs768_base_step58590_n512_k8_seed20260602.npz"), + }, + { + "label": "multi4 best", + "family": "multi4", + "step": 35805, + "full_exact": f(row_for(trm_multi, 35805), "all/exact_accuracy"), + **spectrum_metrics(spec_dir / "trm_gbs768_multi4_step35805_n512_k8_seed20260602.npz"), + }, + { + "label": "multi4 final", + "family": "multi4-final", + "step": 65100, + "full_exact": f(row_for(trm_multi, 65100), "all/exact_accuracy"), + **spectrum_metrics(spec_dir / "trm_gbs768_multi4_step65100_n512_k8_seed20260602.npz"), + }, + ] + + return { + "hrm_base": hrm_base, + "hrm_multi": hrm_multi, + "trm_base": trm_base, + "trm_multi": trm_multi, + "hrm_points": hrm_points, + "trm_points": trm_points, + } + + +def metric_series(rows: list[dict[str, str]], metric: str) -> tuple[np.ndarray, np.ndarray]: + xs, ys = [], [] + for row in rows: + val = f(row, metric) + if np.isfinite(val): + xs.append(int(float(row["step"]))) + ys.append(val) + order = np.argsort(xs) + return np.asarray(xs)[order], np.asarray(ys)[order] + + +def plot_training_curves(data: dict[str, object]) -> None: + fig, axes = plt.subplots(1, 2, figsize=(11.8, 4.2)) + colors = {"baseline": "#475569", "multi4": "#2563eb"} + + panels = [ + ("HRM", data["hrm_base"], data["hrm_multi"], "all/exact_accuracy", "exact", 0.70), + ("TRM official GBS768", data["trm_base"], data["trm_multi"], "all/exact_accuracy", "all/exact_accuracy", 0.94), + ] + for ax, (title, base_rows, multi_rows, base_metric, multi_metric, ymax) in zip(axes, panels): + bx, by = metric_series(base_rows, base_metric) # type: ignore[arg-type] + mx, my = metric_series(multi_rows, multi_metric) # type: ignore[arg-type] + ax.plot(bx / 1000, by, marker="o", color=colors["baseline"], label="baseline", linewidth=2) + ax.plot(mx / 1000, my, marker="o", color=colors["multi4"], label="multi4", linewidth=2) + bi, mi = int(np.argmax(by)), int(np.argmax(my)) + ax.scatter([bx[bi] / 1000], [by[bi]], s=95, color=colors["baseline"], edgecolor="white", zorder=5) + ax.scatter([mx[mi] / 1000], [my[mi]], s=95, color=colors["multi4"], edgecolor="white", zorder=5) + ax.annotate(f"best {by[bi]:.3f}", (bx[bi] / 1000, by[bi]), xytext=(6, -18), textcoords="offset points", fontsize=8) + ax.annotate(f"best {my[mi]:.3f}", (mx[mi] / 1000, my[mi]), xytext=(6, 8), textcoords="offset points", fontsize=8) + ax.set_title(title) + ax.set_xlabel("optimizer step (k)") + ax.set_ylabel("full test exact accuracy") + ax.set_ylim(0, ymax) + ax.grid(alpha=0.22) + ax.legend(frameon=False, loc="lower right") + + fig.suptitle("Trajectory perturbation helps both HRM and TRM, but late checkpoints can collapse") + fig.tight_layout() + fig.savefig(OUT / "fig1_hrm_trm_training_curves.png", dpi=220) + plt.close(fig) + + +def plot_lambda1_motivation() -> None: + entries = [ + ("HRM baseline best\nstep 26040", ROOT / "diag_hrm_step_26040_512.npz"), + ( + "TRM baseline best\nstep 58590", + ROOT / "official_gbs768_spectrum/trm_gbs768_base_step58590_n512_k8_seed20260602.npz", + ), + ] + fig, axes = plt.subplots(1, 2, figsize=(11.6, 4.2), sharey=False) + for ax, (title, path) in zip(axes, entries): + spec, exact = spectrum_arrays(path) + lam1 = spec[:, 0] + succ = lam1[exact] + fail = lam1[~exact] + lo, hi = np.quantile(lam1, [0.01, 0.99]) + pad = 0.08 * (hi - lo) + bins = np.linspace(lo - pad, hi + pad, 34) + ax.hist(succ, bins=bins, density=True, alpha=0.62, color="#2563eb", label=f"success (n={len(succ)})") + ax.hist(fail, bins=bins, density=True, alpha=0.58, color="#dc2626", label=f"failure (n={len(fail)})") + ax.axvline(succ.mean(), color="#1d4ed8", linewidth=2) + ax.axvline(fail.mean(), color="#b91c1c", linewidth=2) + auc = auc_failure_score(lam1, exact) + ax.set_title(f"{title}\nacc={exact.mean():.3f}, AUROC={auc:.3f}") + ax.set_xlabel("top Lyapunov exponent λ1") + ax.set_ylabel("density") + ax.grid(alpha=0.2) + ax.legend(frameon=False, fontsize=8) + fig.suptitle("Motivation: λ1 is a strong success/failure detector in both HRM and TRM") + fig.tight_layout() + fig.savefig(OUT / "fig0_motivation_lambda1_success_failure_hrm_trm.png", dpi=220) + plt.close(fig) + + +def plot_phase(data: dict[str, object]) -> None: + fig, axes = plt.subplots(1, 2, figsize=(11.8, 4.4), sharex=False) + panels = [("HRM", data["hrm_points"]), ("TRM official GBS768", data["trm_points"])] + style = { + "baseline": ("#475569", "o"), + "multi4": ("#2563eb", "o"), + "multi4-final": ("#dc2626", "X"), + } + for ax, (title, points) in zip(axes, panels): + for point in points: # type: ignore[assignment] + color, marker = style[point["family"]] + ax.scatter(point["mean8_all"], point["full_exact"], s=150, marker=marker, color=color, edgecolor="white", zorder=4) + ax.annotate( + f"{point['label']}\nstep {point['step']}", + (point["mean8_all"], point["full_exact"]), + xytext=(8, 5), + textcoords="offset points", + fontsize=8, + ) + ax.set_title(title) + ax.set_xlabel("mean top-8 Lyapunov exponent (lower = more stable)") + ax.set_ylabel("full test exact accuracy") + ax.grid(alpha=0.22) + fig.suptitle("Accuracy-vs-chaotic-volume phase view: best multi4 moves up-left; collapse moves down-right") + fig.tight_layout() + fig.savefig(OUT / "fig2_accuracy_vs_chaotic_volume_phase.png", dpi=220) + plt.close(fig) + + +def plot_spectrum_grid() -> None: + entries = [ + ("HRM baseline best", ROOT / "diag_hrm_step_26040_512.npz"), + ("HRM multi4 best", ROOT / "diag_hrm_multi4_step_23436_512.npz"), + ("TRM baseline best", ROOT / "official_gbs768_spectrum/trm_gbs768_base_step58590_n512_k8_seed20260602.npz"), + ("TRM multi4 best", ROOT / "official_gbs768_spectrum/trm_gbs768_multi4_step35805_n512_k8_seed20260602.npz"), + ] + fig, axes = plt.subplots(2, 2, figsize=(10.8, 7.2), sharey=False) + for ax, (title, path) in zip(axes.flat, entries): + spec, exact = spectrum_arrays(path) + x = np.arange(1, spec.shape[1] + 1) + for mask, color, label in [(exact, "#2563eb", "success"), (~exact, "#dc2626", "failure")]: + mean = spec[mask].mean(axis=0) + se = spec[mask].std(axis=0) / max(mask.sum(), 1) ** 0.5 + ax.plot(x, mean, marker="o", color=color, label=label) + ax.fill_between(x, mean - se, mean + se, color=color, alpha=0.16, linewidth=0) + ax.axhline(0, color="black", linewidth=0.8, alpha=0.55) + ax.set_title(f"{title}\nN={len(exact)}, acc={exact.mean():.3f}") + ax.set_xlabel("spectrum rank") + ax.set_ylabel("finite-time exponent") + ax.grid(alpha=0.22) + axes.flat[0].legend(frameon=False) + fig.suptitle("Success/failure separation is present in both HRM and TRM; multi4 mainly stabilizes successful trajectories") + fig.tight_layout() + fig.savefig(OUT / "fig3_hrm_trm_success_failure_spectra.png", dpi=220) + plt.close(fig) + + +def plot_ptrm() -> None: + summary = read_csv(ROOT / "ptrm_same_subset/paired_ptrm_k100_n1000_seed0_summary.csv")[0] + labels = ["deterministic", "mean rollout", "Q-selected", "oracle"] + base = [float(summary[k]) for k in ["base_det", "base_mean_rollout", "base_qmax", "base_oracle"]] + multi = [float(summary[k]) for k in ["multi4_det", "multi4_mean_rollout", "multi4_qmax", "multi4_oracle"]] + x = np.arange(len(labels)) + width = 0.36 + fig, ax = plt.subplots(figsize=(8.7, 4.3)) + ax.bar(x - width / 2, base, width, color="#64748b", label="baseline best") + ax.bar(x + width / 2, multi, width, color="#2563eb", label="multi4 best") + ax.set_xticks(x, labels) + ax.set_ylim(0.86, 1.0) + ax.set_ylabel("exact accuracy") + ax.set_title("PTRM same-subset comparison (n=1000, K=100, D=64, sigma=0.3, L-only)") + ax.grid(axis="y", alpha=0.22) + ax.legend(frameon=False) + for xpos, vals in [(x - width / 2, base), (x + width / 2, multi)]: + for xi, val in zip(xpos, vals): + ax.text(xi, val + 0.002, f"{val:.3f}", ha="center", va="bottom", fontsize=8) + fig.tight_layout() + fig.savefig(OUT / "fig4_ptrm_same_subset_comparison.png", dpi=220) + plt.close(fig) + + +def write_summary(data: dict[str, object]) -> None: + rows = [] + for model, points in [("HRM", data["hrm_points"]), ("TRM", data["trm_points"])]: + for point in points: # type: ignore[assignment] + rows.append( + { + "model": model, + "label": point["label"], + "step": point["step"], + "full_exact": point["full_exact"], + "sample_exact": point["sample_exact"], + "lambda1_all": point["lambda1_all"], + "mean8_all": point["mean8_all"], + "pos_count_all": point["pos_count_all"], + "lambda1_success": point["lambda1_success"], + "lambda1_fail": point["lambda1_fail"], + "mean8_success": point["mean8_success"], + "mean8_fail": point["mean8_fail"], + "pos_count_success": point["pos_count_success"], + "pos_count_fail": point["pos_count_fail"], + } + ) + with (OUT / "hrm_trm_redesigned_summary.csv").open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=list(rows[0].keys())) + writer.writeheader() + writer.writerows(rows) + + ptrm = read_csv(ROOT / "ptrm_same_subset/paired_ptrm_k100_n1000_seed0_summary.csv")[0] + report = f"""# Meeting Figures v2 + +## Figure Strategy + +0. `fig0_motivation_lambda1_success_failure_hrm_trm.png`: first-exponent success/failure distribution in HRM and TRM. This motivates chaos as a detector before introducing the method. +1. `fig1_hrm_trm_training_curves.png`: performance over training for HRM and TRM. This answers whether the method improves accuracy and where best/final are. +2. `fig2_accuracy_vs_chaotic_volume_phase.png`: phase view, with accuracy versus mean top-8 Lyapunov exponent. This answers whether better checkpoints are dynamically more stable. +3. `fig3_hrm_trm_success_failure_spectra.png`: full success/failure spectrum separation for HRM and TRM best checkpoints. This extends Fig0 beyond λ1. +4. `fig4_ptrm_same_subset_comparison.png`: PTRM same-subset result. This is a secondary inference-time story. + +## Key Numbers + +- HRM baseline best: 0.5265 exact. HRM multi4 best: 0.6443 exact. HRM multi4 final: 0.4624 exact. +- TRM baseline best: 0.8686 exact. TRM multi4 best: 0.8965 exact. TRM multi4 final: 0.8351 exact. +- HRM multi4 best dynamics sample: mean top-8 exponent {rows[1]['mean8_all']:+.4f}; final {rows[2]['mean8_all']:+.4f}. +- TRM multi4 best dynamics sample: mean top-8 exponent {rows[4]['mean8_all']:+.4f}; final {rows[5]['mean8_all']:+.4f}. +- PTRM same subset, K=100: Q-selected {float(ptrm['base_qmax']):.3f} -> {float(ptrm['multi4_qmax']):.3f}; mean rollout {float(ptrm['base_mean_rollout']):.3f} -> {float(ptrm['multi4_mean_rollout']):.3f}. + +## Caveats + +- Dynamics spectra use N=512 diagnostic samples, not the full test set. +- PTRM numbers use a fixed N=1000 subset; do not mix its deterministic subset accuracy with full-test W&B exact accuracy. +- Final checkpoints are collapse diagnostics, not the method's reported performance. +""" + (OUT / "meeting_figures_v2_report.md").write_text(report) + + +def main() -> None: + OUT.mkdir(parents=True, exist_ok=True) + data = get_data() + plot_lambda1_motivation() + plot_training_curves(data) + plot_phase(data) + plot_spectrum_grid() + plot_ptrm() + write_summary(data) + print(f"wrote redesigned artifacts to {OUT}") + + +if __name__ == "__main__": + main() |
