summaryrefslogtreecommitdiff
path: root/research/flossing/analyze_spectrum_microscope.py
diff options
context:
space:
mode:
Diffstat (limited to 'research/flossing/analyze_spectrum_microscope.py')
-rw-r--r--research/flossing/analyze_spectrum_microscope.py360
1 files changed, 360 insertions, 0 deletions
diff --git a/research/flossing/analyze_spectrum_microscope.py b/research/flossing/analyze_spectrum_microscope.py
new file mode 100644
index 0000000..d2b4d82
--- /dev/null
+++ b/research/flossing/analyze_spectrum_microscope.py
@@ -0,0 +1,360 @@
+"""Full-spectrum microscope for HRM/TRM joint Lyapunov diagnostics.
+
+This script intentionally treats the finite-time QR columns as an unordered
+top-k estimate per sample and analyzes both the raw column-0 value and the
+sorted spectrum. The sorted features are safer for per-sample questions like
+"how many unstable directions does this trajectory have?".
+"""
+from __future__ import annotations
+
+import argparse
+import csv
+import json
+import math
+import re
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+
+ROOT = Path("/home/yurenh2/rrm/research/flossing")
+
+
+def auc_score(score: np.ndarray, label: np.ndarray) -> float:
+ """AUROC for score predicting label=True, with average ranks for ties."""
+ score = np.asarray(score, dtype=np.float64)
+ label = np.asarray(label, dtype=bool)
+ n_pos = int(label.sum())
+ n_neg = int((~label).sum())
+ if n_pos == 0 or n_neg == 0:
+ return float("nan")
+
+ order = np.argsort(score)
+ ranks = np.empty_like(order, dtype=np.float64)
+ ranks[order] = np.arange(1, len(score) + 1, dtype=np.float64)
+
+ sorted_score = score[order]
+ i = 0
+ while i < len(sorted_score):
+ j = i + 1
+ while j < len(sorted_score) and sorted_score[j] == sorted_score[i]:
+ j += 1
+ if j - i > 1:
+ ranks[order[i:j]] = (i + 1 + j) / 2.0
+ i = j
+
+ rank_sum = ranks[label].sum()
+ return float((rank_sum - n_pos * (n_pos + 1) / 2.0) / (n_pos * n_neg))
+
+
+def cohen_d(success: np.ndarray, failure: np.ndarray) -> float:
+ success = np.asarray(success, dtype=np.float64)
+ failure = np.asarray(failure, dtype=np.float64)
+ if len(success) < 2 or len(failure) < 2:
+ return float("nan")
+ pooled = (
+ (len(success) - 1) * success.var(ddof=1)
+ + (len(failure) - 1) * failure.var(ddof=1)
+ ) / (len(success) + len(failure) - 2)
+ if pooled <= 0:
+ return float("nan")
+ return float((failure.mean() - success.mean()) / math.sqrt(pooled))
+
+
+def safe_mean(x: np.ndarray) -> float:
+ return float(np.mean(x)) if len(x) else float("nan")
+
+
+def safe_std(x: np.ndarray) -> float:
+ return float(np.std(x)) if len(x) else float("nan")
+
+
+def parse_step(path: Path, kind: str) -> int:
+ if kind == "HRM":
+ return int(re.search(r"step_(\d+)_512", path.name).group(1))
+ return int(re.search(r"step(\d+)_512", path.name).group(1))
+
+
+def discover(kind: str) -> list[Path]:
+ if kind == "HRM":
+ files = sorted(
+ ROOT.glob("diag_hrm_step_*_512.npz"),
+ key=lambda p: parse_step(p, kind),
+ )
+ else:
+ files = sorted(
+ ROOT.glob("diag_trm_singleGPU_step*_512.npz"),
+ key=lambda p: parse_step(p, kind),
+ )
+ if not files:
+ raise FileNotFoundError(f"No {kind} diagnostic files found under {ROOT}")
+ return files
+
+
+def feature_dict(raw_lam: np.ndarray) -> dict[str, np.ndarray]:
+ sorted_lam = np.sort(raw_lam, axis=1)[:, ::-1]
+ positive = np.clip(sorted_lam, 0.0, None)
+ k = sorted_lam.shape[1]
+ x = np.arange(k, dtype=np.float64)
+ x = x - x.mean()
+ denom = float((x**2).sum())
+ slope = (sorted_lam @ x) / denom
+ return {
+ "raw_col0": raw_lam[:, 0],
+ "lambda_max": sorted_lam[:, 0],
+ "lambda_2": sorted_lam[:, 1],
+ "lambda_min8": sorted_lam[:, -1],
+ "mean8": sorted_lam.mean(axis=1),
+ "sum8": sorted_lam.sum(axis=1),
+ "tail_mean_5_8": sorted_lam[:, 4:].mean(axis=1),
+ "positive_sum": positive.sum(axis=1),
+ "positive_count": (sorted_lam > 0).sum(axis=1),
+ "spread": sorted_lam[:, 0] - sorted_lam[:, -1],
+ "gap12": sorted_lam[:, 0] - sorted_lam[:, 1],
+ "std8": sorted_lam.std(axis=1),
+ "linear_slope": slope,
+ }
+
+
+def summarize_file(kind: str, path: Path) -> tuple[dict, list[dict]]:
+ data = np.load(path)
+ raw_lam = np.asarray(data["lyap_spec"], dtype=np.float64)
+ sorted_lam = np.sort(raw_lam, axis=1)[:, ::-1]
+ exact = data["exact_correct"].astype(bool)
+ fail = ~exact
+ token_acc = np.asarray(data["token_acc"], dtype=np.float64)
+ features = feature_dict(raw_lam)
+
+ monotone_frac = float((np.diff(raw_lam, axis=1) <= 1e-5).mean())
+ summary = {
+ "kind": kind,
+ "file": str(path),
+ "step": parse_step(path, kind),
+ "n": int(len(exact)),
+ "k": int(raw_lam.shape[1]),
+ "acc": float(exact.mean()),
+ "n_success": int(exact.sum()),
+ "n_failure": int(fail.sum()),
+ "raw_monotone_adjacent_fraction": monotone_frac,
+ "raw_col0_is_sample_max_fraction": float((raw_lam[:, 0] >= sorted_lam[:, 0] - 1e-7).mean()),
+ }
+
+ # Group means for the most interpretable features.
+ for name in [
+ "raw_col0",
+ "lambda_max",
+ "mean8",
+ "tail_mean_5_8",
+ "positive_sum",
+ "positive_count",
+ "spread",
+ "gap12",
+ ]:
+ arr = features[name]
+ summary[f"{name}_success_mean"] = safe_mean(arr[exact])
+ summary[f"{name}_failure_mean"] = safe_mean(arr[fail])
+ summary[f"{name}_delta_failure_minus_success"] = (
+ summary[f"{name}_failure_mean"] - summary[f"{name}_success_mean"]
+ )
+ summary[f"{name}_auc_failure"] = auc_score(arr, fail)
+
+ # Continuous token-accuracy correlations catch near-misses, not only exact success.
+ for name in ["lambda_max", "mean8", "tail_mean_5_8", "positive_sum", "positive_count"]:
+ arr = features[name]
+ if arr.std() > 0 and token_acc.std() > 0:
+ summary[f"{name}_corr_token_acc"] = float(np.corrcoef(arr, token_acc)[0, 1])
+ else:
+ summary[f"{name}_corr_token_acc"] = float("nan")
+
+ feature_rows = []
+ for name, arr in features.items():
+ feature_rows.append(
+ {
+ "kind": kind,
+ "step": summary["step"],
+ "feature": name,
+ "success_mean": safe_mean(arr[exact]),
+ "failure_mean": safe_mean(arr[fail]),
+ "success_std": safe_std(arr[exact]),
+ "failure_std": safe_std(arr[fail]),
+ "delta_failure_minus_success": safe_mean(arr[fail]) - safe_mean(arr[exact]),
+ "cohen_d_failure_minus_success": cohen_d(arr[exact], arr[fail]),
+ "auc_failure": auc_score(arr, fail),
+ }
+ )
+
+ spectrum_rows = []
+ for i in range(raw_lam.shape[1]):
+ spectrum_rows.append(
+ {
+ "kind": kind,
+ "step": summary["step"],
+ "rank": i + 1,
+ "success_mean": safe_mean(sorted_lam[exact, i]),
+ "failure_mean": safe_mean(sorted_lam[fail, i]),
+ "delta_failure_minus_success": safe_mean(sorted_lam[fail, i])
+ - safe_mean(sorted_lam[exact, i]),
+ "auc_failure": auc_score(sorted_lam[:, i], fail),
+ }
+ )
+
+ return summary, feature_rows + spectrum_rows
+
+
+def write_csv(path: Path, rows: list[dict]) -> None:
+ if not rows:
+ return
+ keys: list[str] = []
+ for row in rows:
+ for key in row:
+ if key not in keys:
+ keys.append(key)
+ with path.open("w", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=keys)
+ writer.writeheader()
+ writer.writerows(rows)
+
+
+def plot_series(kind: str, summaries: list[dict], out_dir: Path) -> None:
+ steps = np.array([r["step"] for r in summaries])
+ acc = np.array([r["acc"] for r in summaries])
+
+ fig, axes = plt.subplots(2, 2, figsize=(12, 8))
+ ax = axes[0, 0]
+ ax.plot(steps, acc, "ko-", label="exact acc")
+ ax.set_title(f"{kind}: accuracy")
+ ax.set_xlabel("checkpoint step")
+ ax.set_ylabel("accuracy")
+ ax.grid(alpha=0.3)
+
+ ax = axes[0, 1]
+ ax.plot(steps, [r["lambda_max_success_mean"] for r in summaries], "C0-o", label="success λmax")
+ ax.plot(steps, [r["lambda_max_failure_mean"] for r in summaries], "C3-o", label="failure λmax")
+ ax.axhline(0, color="k", lw=1, alpha=0.35)
+ ax.set_title("Most unstable measured mode")
+ ax.set_xlabel("checkpoint step")
+ ax.set_ylabel("sorted λmax")
+ ax.legend()
+ ax.grid(alpha=0.3)
+
+ ax = axes[1, 0]
+ ax.plot(steps, [r["mean8_success_mean"] for r in summaries], "C0-o", label="success mean top-8")
+ ax.plot(steps, [r["mean8_failure_mean"] for r in summaries], "C3-o", label="failure mean top-8")
+ ax.axhline(0, color="k", lw=1, alpha=0.35)
+ ax.set_title("Top-8 volume proxy")
+ ax.set_xlabel("checkpoint step")
+ ax.set_ylabel("mean sorted λ1..λ8")
+ ax.legend()
+ ax.grid(alpha=0.3)
+
+ ax = axes[1, 1]
+ ax.plot(steps, [r["positive_count_success_mean"] for r in summaries], "C0-o", label="success")
+ ax.plot(steps, [r["positive_count_failure_mean"] for r in summaries], "C3-o", label="failure")
+ ax.set_title("Dimensionality of expansion")
+ ax.set_xlabel("checkpoint step")
+ ax.set_ylabel("# positive exponents among top 8")
+ ax.legend()
+ ax.grid(alpha=0.3)
+
+ fig.tight_layout()
+ fig.savefig(out_dir / f"{kind.lower()}_checkpoint_spectrum_features.png", dpi=150)
+ plt.close(fig)
+
+
+def plot_spectra(kind: str, files: list[Path], out_dir: Path) -> None:
+ n = len(files)
+ cols = min(5, n)
+ rows = int(math.ceil(n / cols))
+ fig, axes = plt.subplots(rows, cols, figsize=(3.4 * cols, 2.7 * rows), squeeze=False)
+ for ax, path in zip(axes.flat, files):
+ data = np.load(path)
+ raw_lam = np.asarray(data["lyap_spec"], dtype=np.float64)
+ sorted_lam = np.sort(raw_lam, axis=1)[:, ::-1]
+ exact = data["exact_correct"].astype(bool)
+ fail = ~exact
+ x = np.arange(1, sorted_lam.shape[1] + 1)
+ if exact.any():
+ ax.plot(x, sorted_lam[exact].mean(axis=0), "C0-o", lw=1.6, ms=3, label="success")
+ if fail.any():
+ ax.plot(x, sorted_lam[fail].mean(axis=0), "C3-o", lw=1.6, ms=3, label="failure")
+ ax.axhline(0, color="k", lw=0.8, alpha=0.35)
+ ax.set_title(f"step {parse_step(path, kind)} acc={exact.mean():.2f}")
+ ax.set_xlabel("sorted rank")
+ ax.set_ylabel("λ")
+ ax.grid(alpha=0.25)
+ for ax in axes.flat[len(files) :]:
+ ax.axis("off")
+ handles, labels = axes.flat[0].get_legend_handles_labels()
+ if handles:
+ fig.legend(handles, labels, loc="upper center", ncol=2)
+ fig.tight_layout(rect=[0, 0, 1, 0.96])
+ fig.savefig(out_dir / f"{kind.lower()}_mean_sorted_spectra_grid.png", dpi=150)
+ plt.close(fig)
+
+
+def plot_feature_auc(kind: str, feature_rows: list[dict], out_dir: Path) -> None:
+ selected = [
+ "raw_col0",
+ "lambda_max",
+ "mean8",
+ "tail_mean_5_8",
+ "positive_sum",
+ "positive_count",
+ "spread",
+ "gap12",
+ ]
+ fig, ax = plt.subplots(figsize=(12, 5))
+ for feature in selected:
+ rows = [r for r in feature_rows if r.get("feature") == feature and r["kind"] == kind]
+ rows = sorted(rows, key=lambda r: r["step"])
+ ax.plot([r["step"] for r in rows], [r["auc_failure"] for r in rows], marker="o", label=feature)
+ ax.axhline(0.5, color="k", lw=1, alpha=0.35)
+ ax.set_title(f"{kind}: feature AUROC for predicting exact failure")
+ ax.set_xlabel("checkpoint step")
+ ax.set_ylabel("AUROC")
+ ax.set_ylim(0.0, 1.02)
+ ax.legend(ncol=4, fontsize=8)
+ ax.grid(alpha=0.3)
+ fig.tight_layout()
+ fig.savefig(out_dir / f"{kind.lower()}_feature_auc.png", dpi=150)
+ plt.close(fig)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--out-dir", default=str(ROOT / "spectrum_microscope"))
+ args = parser.parse_args()
+
+ out_dir = Path(args.out_dir)
+ out_dir.mkdir(parents=True, exist_ok=True)
+
+ all_summaries: list[dict] = []
+ all_feature_rows: list[dict] = []
+ for kind in ["HRM", "TRM"]:
+ files = discover(kind)
+ summaries = []
+ feature_rows = []
+ for path in files:
+ summary, rows = summarize_file(kind, path)
+ summaries.append(summary)
+ feature_rows.extend(rows)
+ summaries = sorted(summaries, key=lambda r: r["step"])
+ all_summaries.extend(summaries)
+ all_feature_rows.extend(feature_rows)
+
+ plot_series(kind, summaries, out_dir)
+ plot_spectra(kind, files, out_dir)
+ plot_feature_auc(kind, feature_rows, out_dir)
+
+ write_csv(out_dir / "checkpoint_summary.csv", all_summaries)
+ write_csv(out_dir / "feature_and_rank_summary.csv", all_feature_rows)
+ (out_dir / "checkpoint_summary.json").write_text(json.dumps(all_summaries, indent=2))
+
+ print(f"Wrote {out_dir}")
+ print(f" {out_dir / 'checkpoint_summary.csv'}")
+ print(f" {out_dir / 'feature_and_rank_summary.csv'}")
+
+
+if __name__ == "__main__":
+ main()