"""Full-spectrum microscope for HRM/TRM joint Lyapunov diagnostics. This script intentionally treats the finite-time QR columns as an unordered top-k estimate per sample and analyzes both the raw column-0 value and the sorted spectrum. The sorted features are safer for per-sample questions like "how many unstable directions does this trajectory have?". """ from __future__ import annotations import argparse import csv import json import math import re from pathlib import Path import matplotlib.pyplot as plt import numpy as np ROOT = Path("/home/yurenh2/rrm/research/flossing") def auc_score(score: np.ndarray, label: np.ndarray) -> float: """AUROC for score predicting label=True, with average ranks for ties.""" score = np.asarray(score, dtype=np.float64) label = np.asarray(label, dtype=bool) n_pos = int(label.sum()) n_neg = int((~label).sum()) if n_pos == 0 or n_neg == 0: return float("nan") order = np.argsort(score) ranks = np.empty_like(order, dtype=np.float64) ranks[order] = np.arange(1, len(score) + 1, dtype=np.float64) sorted_score = score[order] i = 0 while i < len(sorted_score): j = i + 1 while j < len(sorted_score) and sorted_score[j] == sorted_score[i]: j += 1 if j - i > 1: ranks[order[i:j]] = (i + 1 + j) / 2.0 i = j rank_sum = ranks[label].sum() return float((rank_sum - n_pos * (n_pos + 1) / 2.0) / (n_pos * n_neg)) def cohen_d(success: np.ndarray, failure: np.ndarray) -> float: success = np.asarray(success, dtype=np.float64) failure = np.asarray(failure, dtype=np.float64) if len(success) < 2 or len(failure) < 2: return float("nan") pooled = ( (len(success) - 1) * success.var(ddof=1) + (len(failure) - 1) * failure.var(ddof=1) ) / (len(success) + len(failure) - 2) if pooled <= 0: return float("nan") return float((failure.mean() - success.mean()) / math.sqrt(pooled)) def safe_mean(x: np.ndarray) -> float: return float(np.mean(x)) if len(x) else float("nan") def safe_std(x: np.ndarray) -> float: return float(np.std(x)) if len(x) else float("nan") def parse_step(path: Path, kind: str) -> int: if kind == "HRM": return int(re.search(r"step_(\d+)_512", path.name).group(1)) return int(re.search(r"step(\d+)_512", path.name).group(1)) def discover(kind: str) -> list[Path]: if kind == "HRM": files = sorted( ROOT.glob("diag_hrm_step_*_512.npz"), key=lambda p: parse_step(p, kind), ) else: files = sorted( ROOT.glob("diag_trm_singleGPU_step*_512.npz"), key=lambda p: parse_step(p, kind), ) if not files: raise FileNotFoundError(f"No {kind} diagnostic files found under {ROOT}") return files def feature_dict(raw_lam: np.ndarray) -> dict[str, np.ndarray]: sorted_lam = np.sort(raw_lam, axis=1)[:, ::-1] positive = np.clip(sorted_lam, 0.0, None) k = sorted_lam.shape[1] x = np.arange(k, dtype=np.float64) x = x - x.mean() denom = float((x**2).sum()) slope = (sorted_lam @ x) / denom return { "raw_col0": raw_lam[:, 0], "lambda_max": sorted_lam[:, 0], "lambda_2": sorted_lam[:, 1], "lambda_min8": sorted_lam[:, -1], "mean8": sorted_lam.mean(axis=1), "sum8": sorted_lam.sum(axis=1), "tail_mean_5_8": sorted_lam[:, 4:].mean(axis=1), "positive_sum": positive.sum(axis=1), "positive_count": (sorted_lam > 0).sum(axis=1), "spread": sorted_lam[:, 0] - sorted_lam[:, -1], "gap12": sorted_lam[:, 0] - sorted_lam[:, 1], "std8": sorted_lam.std(axis=1), "linear_slope": slope, } def summarize_file(kind: str, path: Path) -> tuple[dict, list[dict]]: data = np.load(path) raw_lam = np.asarray(data["lyap_spec"], dtype=np.float64) sorted_lam = np.sort(raw_lam, axis=1)[:, ::-1] exact = data["exact_correct"].astype(bool) fail = ~exact token_acc = np.asarray(data["token_acc"], dtype=np.float64) features = feature_dict(raw_lam) monotone_frac = float((np.diff(raw_lam, axis=1) <= 1e-5).mean()) summary = { "kind": kind, "file": str(path), "step": parse_step(path, kind), "n": int(len(exact)), "k": int(raw_lam.shape[1]), "acc": float(exact.mean()), "n_success": int(exact.sum()), "n_failure": int(fail.sum()), "raw_monotone_adjacent_fraction": monotone_frac, "raw_col0_is_sample_max_fraction": float((raw_lam[:, 0] >= sorted_lam[:, 0] - 1e-7).mean()), } # Group means for the most interpretable features. for name in [ "raw_col0", "lambda_max", "mean8", "tail_mean_5_8", "positive_sum", "positive_count", "spread", "gap12", ]: arr = features[name] summary[f"{name}_success_mean"] = safe_mean(arr[exact]) summary[f"{name}_failure_mean"] = safe_mean(arr[fail]) summary[f"{name}_delta_failure_minus_success"] = ( summary[f"{name}_failure_mean"] - summary[f"{name}_success_mean"] ) summary[f"{name}_auc_failure"] = auc_score(arr, fail) # Continuous token-accuracy correlations catch near-misses, not only exact success. for name in ["lambda_max", "mean8", "tail_mean_5_8", "positive_sum", "positive_count"]: arr = features[name] if arr.std() > 0 and token_acc.std() > 0: summary[f"{name}_corr_token_acc"] = float(np.corrcoef(arr, token_acc)[0, 1]) else: summary[f"{name}_corr_token_acc"] = float("nan") feature_rows = [] for name, arr in features.items(): feature_rows.append( { "kind": kind, "step": summary["step"], "feature": name, "success_mean": safe_mean(arr[exact]), "failure_mean": safe_mean(arr[fail]), "success_std": safe_std(arr[exact]), "failure_std": safe_std(arr[fail]), "delta_failure_minus_success": safe_mean(arr[fail]) - safe_mean(arr[exact]), "cohen_d_failure_minus_success": cohen_d(arr[exact], arr[fail]), "auc_failure": auc_score(arr, fail), } ) spectrum_rows = [] for i in range(raw_lam.shape[1]): spectrum_rows.append( { "kind": kind, "step": summary["step"], "rank": i + 1, "success_mean": safe_mean(sorted_lam[exact, i]), "failure_mean": safe_mean(sorted_lam[fail, i]), "delta_failure_minus_success": safe_mean(sorted_lam[fail, i]) - safe_mean(sorted_lam[exact, i]), "auc_failure": auc_score(sorted_lam[:, i], fail), } ) return summary, feature_rows + spectrum_rows def write_csv(path: Path, rows: list[dict]) -> None: if not rows: return keys: list[str] = [] for row in rows: for key in row: if key not in keys: keys.append(key) with path.open("w", newline="") as f: writer = csv.DictWriter(f, fieldnames=keys) writer.writeheader() writer.writerows(rows) def plot_series(kind: str, summaries: list[dict], out_dir: Path) -> None: steps = np.array([r["step"] for r in summaries]) acc = np.array([r["acc"] for r in summaries]) fig, axes = plt.subplots(2, 2, figsize=(12, 8)) ax = axes[0, 0] ax.plot(steps, acc, "ko-", label="exact acc") ax.set_title(f"{kind}: accuracy") ax.set_xlabel("checkpoint step") ax.set_ylabel("accuracy") ax.grid(alpha=0.3) ax = axes[0, 1] ax.plot(steps, [r["lambda_max_success_mean"] for r in summaries], "C0-o", label="success λmax") ax.plot(steps, [r["lambda_max_failure_mean"] for r in summaries], "C3-o", label="failure λmax") ax.axhline(0, color="k", lw=1, alpha=0.35) ax.set_title("Most unstable measured mode") ax.set_xlabel("checkpoint step") ax.set_ylabel("sorted λmax") ax.legend() ax.grid(alpha=0.3) ax = axes[1, 0] ax.plot(steps, [r["mean8_success_mean"] for r in summaries], "C0-o", label="success mean top-8") ax.plot(steps, [r["mean8_failure_mean"] for r in summaries], "C3-o", label="failure mean top-8") ax.axhline(0, color="k", lw=1, alpha=0.35) ax.set_title("Top-8 volume proxy") ax.set_xlabel("checkpoint step") ax.set_ylabel("mean sorted λ1..λ8") ax.legend() ax.grid(alpha=0.3) ax = axes[1, 1] ax.plot(steps, [r["positive_count_success_mean"] for r in summaries], "C0-o", label="success") ax.plot(steps, [r["positive_count_failure_mean"] for r in summaries], "C3-o", label="failure") ax.set_title("Dimensionality of expansion") ax.set_xlabel("checkpoint step") ax.set_ylabel("# positive exponents among top 8") ax.legend() ax.grid(alpha=0.3) fig.tight_layout() fig.savefig(out_dir / f"{kind.lower()}_checkpoint_spectrum_features.png", dpi=150) plt.close(fig) def plot_spectra(kind: str, files: list[Path], out_dir: Path) -> None: n = len(files) cols = min(5, n) rows = int(math.ceil(n / cols)) fig, axes = plt.subplots(rows, cols, figsize=(3.4 * cols, 2.7 * rows), squeeze=False) for ax, path in zip(axes.flat, files): data = np.load(path) raw_lam = np.asarray(data["lyap_spec"], dtype=np.float64) sorted_lam = np.sort(raw_lam, axis=1)[:, ::-1] exact = data["exact_correct"].astype(bool) fail = ~exact x = np.arange(1, sorted_lam.shape[1] + 1) if exact.any(): ax.plot(x, sorted_lam[exact].mean(axis=0), "C0-o", lw=1.6, ms=3, label="success") if fail.any(): ax.plot(x, sorted_lam[fail].mean(axis=0), "C3-o", lw=1.6, ms=3, label="failure") ax.axhline(0, color="k", lw=0.8, alpha=0.35) ax.set_title(f"step {parse_step(path, kind)} acc={exact.mean():.2f}") ax.set_xlabel("sorted rank") ax.set_ylabel("λ") ax.grid(alpha=0.25) for ax in axes.flat[len(files) :]: ax.axis("off") handles, labels = axes.flat[0].get_legend_handles_labels() if handles: fig.legend(handles, labels, loc="upper center", ncol=2) fig.tight_layout(rect=[0, 0, 1, 0.96]) fig.savefig(out_dir / f"{kind.lower()}_mean_sorted_spectra_grid.png", dpi=150) plt.close(fig) def plot_feature_auc(kind: str, feature_rows: list[dict], out_dir: Path) -> None: selected = [ "raw_col0", "lambda_max", "mean8", "tail_mean_5_8", "positive_sum", "positive_count", "spread", "gap12", ] fig, ax = plt.subplots(figsize=(12, 5)) for feature in selected: rows = [r for r in feature_rows if r.get("feature") == feature and r["kind"] == kind] rows = sorted(rows, key=lambda r: r["step"]) ax.plot([r["step"] for r in rows], [r["auc_failure"] for r in rows], marker="o", label=feature) ax.axhline(0.5, color="k", lw=1, alpha=0.35) ax.set_title(f"{kind}: feature AUROC for predicting exact failure") ax.set_xlabel("checkpoint step") ax.set_ylabel("AUROC") ax.set_ylim(0.0, 1.02) ax.legend(ncol=4, fontsize=8) ax.grid(alpha=0.3) fig.tight_layout() fig.savefig(out_dir / f"{kind.lower()}_feature_auc.png", dpi=150) plt.close(fig) def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--out-dir", default=str(ROOT / "spectrum_microscope")) args = parser.parse_args() out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) all_summaries: list[dict] = [] all_feature_rows: list[dict] = [] for kind in ["HRM", "TRM"]: files = discover(kind) summaries = [] feature_rows = [] for path in files: summary, rows = summarize_file(kind, path) summaries.append(summary) feature_rows.extend(rows) summaries = sorted(summaries, key=lambda r: r["step"]) all_summaries.extend(summaries) all_feature_rows.extend(feature_rows) plot_series(kind, summaries, out_dir) plot_spectra(kind, files, out_dir) plot_feature_auc(kind, feature_rows, out_dir) write_csv(out_dir / "checkpoint_summary.csv", all_summaries) write_csv(out_dir / "feature_and_rank_summary.csv", all_feature_rows) (out_dir / "checkpoint_summary.json").write_text(json.dumps(all_summaries, indent=2)) print(f"Wrote {out_dir}") print(f" {out_dir / 'checkpoint_summary.csv'}") print(f" {out_dir / 'feature_and_rank_summary.csv'}") if __name__ == "__main__": main()