from __future__ import annotations import csv import re from pathlib import Path import matplotlib.pyplot as plt import numpy as np ROOT = Path("research/flossing") OUT = ROOT / "meeting_artifacts_v2" def read_csv(path: Path) -> list[dict[str, str]]: with path.open(newline="") as f: return list(csv.DictReader(f)) def f(row: dict[str, str], key: str, default: float = float("nan")) -> float: try: val = row.get(key, "") return float(val) if val != "" else default except Exception: return default def row_for(rows: list[dict[str, str]], step: int) -> dict[str, str]: for row in rows: if int(float(row["step"])) == step: return row raise KeyError(step) def spectrum_metrics(npz_path: Path) -> dict[str, float]: data = np.load(npz_path) spec = data["lyap_spec"].astype(float) exact = data["exact_correct"].astype(bool) mean8 = spec.mean(axis=1) pos_count = (spec > 0).sum(axis=1) return { "sample_exact": float(exact.mean()), "lambda1_all": float(spec[:, 0].mean()), "mean8_all": float(mean8.mean()), "pos_count_all": float(pos_count.mean()), "lambda1_success": float(spec[exact, 0].mean()), "lambda1_fail": float(spec[~exact, 0].mean()), "mean8_success": float(mean8[exact].mean()), "mean8_fail": float(mean8[~exact].mean()), "pos_count_success": float(pos_count[exact].mean()), "pos_count_fail": float(pos_count[~exact].mean()), } def spectrum_arrays(npz_path: Path) -> tuple[np.ndarray, np.ndarray]: data = np.load(npz_path) return data["lyap_spec"].astype(float), data["exact_correct"].astype(bool) def auc_failure_score(score: np.ndarray, exact: np.ndarray) -> float: """AUROC where higher score predicts failure.""" fail = ~exact n_fail = int(fail.sum()) n_succ = int(exact.sum()) if n_fail == 0 or n_succ == 0: return float("nan") order = np.argsort(score) ranks = np.empty_like(order, dtype=float) ranks[order] = np.arange(1, len(score) + 1) return float((ranks[fail].sum() - n_fail * (n_fail + 1) / 2) / (n_fail * n_succ)) def get_data() -> dict[str, object]: eval_dir = ROOT / "multi4_eval_compare" spec_dir = ROOT / "official_gbs768_spectrum" hrm_base = read_csv(eval_dir / "hrm_baseline_eval.csv") hrm_multi = read_csv(eval_dir / "hrm_multi4_complete_eval.csv") trm_base = read_csv(eval_dir / "trm_official_gbs768_eval.csv") trm_multi = read_csv(eval_dir / "trm_official_gbs768_multi4_eval.csv") hrm_points = [ { "label": "baseline best", "family": "baseline", "step": 26040, "full_exact": f(row_for(hrm_base, 26040), "all/exact_accuracy"), **spectrum_metrics(ROOT / "diag_hrm_step_26040_512.npz"), }, { "label": "multi4 best", "family": "multi4", "step": 23436, "full_exact": f(row_for(hrm_multi, 23436), "exact"), **spectrum_metrics(ROOT / "diag_hrm_multi4_step_23436_512.npz"), }, { "label": "multi4 final", "family": "multi4-final", "step": 26040, "full_exact": f(row_for(hrm_multi, 26040), "exact"), **spectrum_metrics(ROOT / "diag_hrm_multi4_step_26040_512.npz"), }, ] trm_points = [ { "label": "baseline best", "family": "baseline", "step": 58590, "full_exact": f(row_for(trm_base, 58590), "all/exact_accuracy"), **spectrum_metrics(spec_dir / "trm_gbs768_base_step58590_n512_k8_seed20260602.npz"), }, { "label": "multi4 best", "family": "multi4", "step": 35805, "full_exact": f(row_for(trm_multi, 35805), "all/exact_accuracy"), **spectrum_metrics(spec_dir / "trm_gbs768_multi4_step35805_n512_k8_seed20260602.npz"), }, { "label": "multi4 final", "family": "multi4-final", "step": 65100, "full_exact": f(row_for(trm_multi, 65100), "all/exact_accuracy"), **spectrum_metrics(spec_dir / "trm_gbs768_multi4_step65100_n512_k8_seed20260602.npz"), }, ] return { "hrm_base": hrm_base, "hrm_multi": hrm_multi, "trm_base": trm_base, "trm_multi": trm_multi, "hrm_points": hrm_points, "trm_points": trm_points, } def metric_series(rows: list[dict[str, str]], metric: str) -> tuple[np.ndarray, np.ndarray]: xs, ys = [], [] for row in rows: val = f(row, metric) if np.isfinite(val): xs.append(int(float(row["step"]))) ys.append(val) order = np.argsort(xs) return np.asarray(xs)[order], np.asarray(ys)[order] def plot_training_curves(data: dict[str, object]) -> None: fig, axes = plt.subplots(1, 2, figsize=(11.8, 4.2)) colors = {"baseline": "#475569", "multi4": "#2563eb"} panels = [ ("HRM", data["hrm_base"], data["hrm_multi"], "all/exact_accuracy", "exact", 0.70), ("TRM official GBS768", data["trm_base"], data["trm_multi"], "all/exact_accuracy", "all/exact_accuracy", 0.94), ] for ax, (title, base_rows, multi_rows, base_metric, multi_metric, ymax) in zip(axes, panels): bx, by = metric_series(base_rows, base_metric) # type: ignore[arg-type] mx, my = metric_series(multi_rows, multi_metric) # type: ignore[arg-type] ax.plot(bx / 1000, by, marker="o", color=colors["baseline"], label="baseline", linewidth=2) ax.plot(mx / 1000, my, marker="o", color=colors["multi4"], label="multi4", linewidth=2) bi, mi = int(np.argmax(by)), int(np.argmax(my)) ax.scatter([bx[bi] / 1000], [by[bi]], s=95, color=colors["baseline"], edgecolor="white", zorder=5) ax.scatter([mx[mi] / 1000], [my[mi]], s=95, color=colors["multi4"], edgecolor="white", zorder=5) ax.annotate(f"best {by[bi]:.3f}", (bx[bi] / 1000, by[bi]), xytext=(6, -18), textcoords="offset points", fontsize=8) ax.annotate(f"best {my[mi]:.3f}", (mx[mi] / 1000, my[mi]), xytext=(6, 8), textcoords="offset points", fontsize=8) ax.set_title(title) ax.set_xlabel("optimizer step (k)") ax.set_ylabel("full test exact accuracy") ax.set_ylim(0, ymax) ax.grid(alpha=0.22) ax.legend(frameon=False, loc="lower right") fig.suptitle("Trajectory perturbation helps both HRM and TRM, but late checkpoints can collapse") fig.tight_layout() fig.savefig(OUT / "fig1_hrm_trm_training_curves.png", dpi=220) plt.close(fig) def plot_lambda1_motivation() -> None: entries = [ ("HRM baseline best\nstep 26040", ROOT / "diag_hrm_step_26040_512.npz"), ( "TRM baseline best\nstep 58590", ROOT / "official_gbs768_spectrum/trm_gbs768_base_step58590_n512_k8_seed20260602.npz", ), ] fig, axes = plt.subplots(1, 2, figsize=(11.6, 4.2), sharey=False) for ax, (title, path) in zip(axes, entries): spec, exact = spectrum_arrays(path) lam1 = spec[:, 0] succ = lam1[exact] fail = lam1[~exact] lo, hi = np.quantile(lam1, [0.01, 0.99]) pad = 0.08 * (hi - lo) bins = np.linspace(lo - pad, hi + pad, 34) ax.hist(succ, bins=bins, density=True, alpha=0.62, color="#2563eb", label=f"success (n={len(succ)})") ax.hist(fail, bins=bins, density=True, alpha=0.58, color="#dc2626", label=f"failure (n={len(fail)})") ax.axvline(succ.mean(), color="#1d4ed8", linewidth=2) ax.axvline(fail.mean(), color="#b91c1c", linewidth=2) auc = auc_failure_score(lam1, exact) ax.set_title(f"{title}\nacc={exact.mean():.3f}, AUROC={auc:.3f}") ax.set_xlabel("top Lyapunov exponent λ1") ax.set_ylabel("density") ax.grid(alpha=0.2) ax.legend(frameon=False, fontsize=8) fig.suptitle("Motivation: λ1 is a strong success/failure detector in both HRM and TRM") fig.tight_layout() fig.savefig(OUT / "fig0_motivation_lambda1_success_failure_hrm_trm.png", dpi=220) plt.close(fig) def plot_phase(data: dict[str, object]) -> None: fig, axes = plt.subplots(1, 2, figsize=(11.8, 4.4), sharex=False) panels = [("HRM", data["hrm_points"]), ("TRM official GBS768", data["trm_points"])] style = { "baseline": ("#475569", "o"), "multi4": ("#2563eb", "o"), "multi4-final": ("#dc2626", "X"), } for ax, (title, points) in zip(axes, panels): for point in points: # type: ignore[assignment] color, marker = style[point["family"]] ax.scatter(point["mean8_all"], point["full_exact"], s=150, marker=marker, color=color, edgecolor="white", zorder=4) ax.annotate( f"{point['label']}\nstep {point['step']}", (point["mean8_all"], point["full_exact"]), xytext=(8, 5), textcoords="offset points", fontsize=8, ) ax.set_title(title) ax.set_xlabel("mean top-8 Lyapunov exponent (lower = more stable)") ax.set_ylabel("full test exact accuracy") ax.grid(alpha=0.22) fig.suptitle("Accuracy-vs-chaotic-volume phase view: best multi4 moves up-left; collapse moves down-right") fig.tight_layout() fig.savefig(OUT / "fig2_accuracy_vs_chaotic_volume_phase.png", dpi=220) plt.close(fig) def plot_spectrum_grid() -> None: entries = [ ("HRM baseline best", ROOT / "diag_hrm_step_26040_512.npz"), ("HRM multi4 best", ROOT / "diag_hrm_multi4_step_23436_512.npz"), ("TRM baseline best", ROOT / "official_gbs768_spectrum/trm_gbs768_base_step58590_n512_k8_seed20260602.npz"), ("TRM multi4 best", ROOT / "official_gbs768_spectrum/trm_gbs768_multi4_step35805_n512_k8_seed20260602.npz"), ] fig, axes = plt.subplots(2, 2, figsize=(10.8, 7.2), sharey=False) for ax, (title, path) in zip(axes.flat, entries): spec, exact = spectrum_arrays(path) x = np.arange(1, spec.shape[1] + 1) for mask, color, label in [(exact, "#2563eb", "success"), (~exact, "#dc2626", "failure")]: mean = spec[mask].mean(axis=0) se = spec[mask].std(axis=0) / max(mask.sum(), 1) ** 0.5 ax.plot(x, mean, marker="o", color=color, label=label) ax.fill_between(x, mean - se, mean + se, color=color, alpha=0.16, linewidth=0) ax.axhline(0, color="black", linewidth=0.8, alpha=0.55) ax.set_title(f"{title}\nN={len(exact)}, acc={exact.mean():.3f}") ax.set_xlabel("spectrum rank") ax.set_ylabel("finite-time exponent") ax.grid(alpha=0.22) axes.flat[0].legend(frameon=False) fig.suptitle("Success/failure separation is present in both HRM and TRM; multi4 mainly stabilizes successful trajectories") fig.tight_layout() fig.savefig(OUT / "fig3_hrm_trm_success_failure_spectra.png", dpi=220) plt.close(fig) def plot_ptrm() -> None: summary = read_csv(ROOT / "ptrm_same_subset/paired_ptrm_k100_n1000_seed0_summary.csv")[0] labels = ["deterministic", "mean rollout", "Q-selected", "oracle"] base = [float(summary[k]) for k in ["base_det", "base_mean_rollout", "base_qmax", "base_oracle"]] multi = [float(summary[k]) for k in ["multi4_det", "multi4_mean_rollout", "multi4_qmax", "multi4_oracle"]] x = np.arange(len(labels)) width = 0.36 fig, ax = plt.subplots(figsize=(8.7, 4.3)) ax.bar(x - width / 2, base, width, color="#64748b", label="baseline best") ax.bar(x + width / 2, multi, width, color="#2563eb", label="multi4 best") ax.set_xticks(x, labels) ax.set_ylim(0.86, 1.0) ax.set_ylabel("exact accuracy") ax.set_title("PTRM same-subset comparison (n=1000, K=100, D=64, sigma=0.3, L-only)") ax.grid(axis="y", alpha=0.22) ax.legend(frameon=False) for xpos, vals in [(x - width / 2, base), (x + width / 2, multi)]: for xi, val in zip(xpos, vals): ax.text(xi, val + 0.002, f"{val:.3f}", ha="center", va="bottom", fontsize=8) fig.tight_layout() fig.savefig(OUT / "fig4_ptrm_same_subset_comparison.png", dpi=220) plt.close(fig) def write_summary(data: dict[str, object]) -> None: rows = [] for model, points in [("HRM", data["hrm_points"]), ("TRM", data["trm_points"])]: for point in points: # type: ignore[assignment] rows.append( { "model": model, "label": point["label"], "step": point["step"], "full_exact": point["full_exact"], "sample_exact": point["sample_exact"], "lambda1_all": point["lambda1_all"], "mean8_all": point["mean8_all"], "pos_count_all": point["pos_count_all"], "lambda1_success": point["lambda1_success"], "lambda1_fail": point["lambda1_fail"], "mean8_success": point["mean8_success"], "mean8_fail": point["mean8_fail"], "pos_count_success": point["pos_count_success"], "pos_count_fail": point["pos_count_fail"], } ) with (OUT / "hrm_trm_redesigned_summary.csv").open("w", newline="") as f: writer = csv.DictWriter(f, fieldnames=list(rows[0].keys())) writer.writeheader() writer.writerows(rows) ptrm = read_csv(ROOT / "ptrm_same_subset/paired_ptrm_k100_n1000_seed0_summary.csv")[0] report = f"""# Meeting Figures v2 ## Figure Strategy 0. `fig0_motivation_lambda1_success_failure_hrm_trm.png`: first-exponent success/failure distribution in HRM and TRM. This motivates chaos as a detector before introducing the method. 1. `fig1_hrm_trm_training_curves.png`: performance over training for HRM and TRM. This answers whether the method improves accuracy and where best/final are. 2. `fig2_accuracy_vs_chaotic_volume_phase.png`: phase view, with accuracy versus mean top-8 Lyapunov exponent. This answers whether better checkpoints are dynamically more stable. 3. `fig3_hrm_trm_success_failure_spectra.png`: full success/failure spectrum separation for HRM and TRM best checkpoints. This extends Fig0 beyond λ1. 4. `fig4_ptrm_same_subset_comparison.png`: PTRM same-subset result. This is a secondary inference-time story. ## Key Numbers - HRM baseline best: 0.5265 exact. HRM multi4 best: 0.6443 exact. HRM multi4 final: 0.4624 exact. - TRM baseline best: 0.8686 exact. TRM multi4 best: 0.8965 exact. TRM multi4 final: 0.8351 exact. - HRM multi4 best dynamics sample: mean top-8 exponent {rows[1]['mean8_all']:+.4f}; final {rows[2]['mean8_all']:+.4f}. - TRM multi4 best dynamics sample: mean top-8 exponent {rows[4]['mean8_all']:+.4f}; final {rows[5]['mean8_all']:+.4f}. - PTRM same subset, K=100: Q-selected {float(ptrm['base_qmax']):.3f} -> {float(ptrm['multi4_qmax']):.3f}; mean rollout {float(ptrm['base_mean_rollout']):.3f} -> {float(ptrm['multi4_mean_rollout']):.3f}. ## Caveats - Dynamics spectra use N=512 diagnostic samples, not the full test set. - PTRM numbers use a fixed N=1000 subset; do not mix its deterministic subset accuracy with full-test W&B exact accuracy. - Final checkpoints are collapse diagnostics, not the method's reported performance. """ (OUT / "meeting_figures_v2_report.md").write_text(report) def main() -> None: OUT.mkdir(parents=True, exist_ok=True) data = get_data() plot_lambda1_motivation() plot_training_curves(data) plot_phase(data) plot_spectrum_grid() plot_ptrm() write_summary(data) print(f"wrote redesigned artifacts to {OUT}") if __name__ == "__main__": main()