"""Analyze cached PTRM rollout/spectrum files without rerunning the model.""" from __future__ import annotations import argparse import csv from pathlib import Path import numpy as np def _safe_corr(a: np.ndarray, b: np.ndarray) -> float: a = np.asarray(a, dtype=np.float64).reshape(-1) b = np.asarray(b, dtype=np.float64).reshape(-1) if a.size == 0 or b.size == 0 or np.std(a) == 0 or np.std(b) == 0: return float("nan") return float(np.corrcoef(a, b)[0, 1]) def _take(values: np.ndarray, idx: np.ndarray) -> np.ndarray: return values[np.arange(values.shape[0]), idx] def _selector_metrics(exact: np.ndarray, token_acc: np.ndarray, idx: np.ndarray, prefix: str) -> dict[str, float]: return { f"{prefix}/exact": float(_take(exact, idx).mean()), f"{prefix}/token_acc": float(_take(token_acc, idx).mean()), } def _feature_table(data: np.lib.npyio.NpzFile) -> dict[str, np.ndarray]: features: dict[str, np.ndarray] = {} lyap = data["lyap"] if "lyap" in data.files else np.asarray([]) if lyap.size: features["lambda1"] = lyap spec = data["lyap_spec"] if "lyap_spec" in data.files else np.asarray([]) if spec.size: features["spec_lambda1"] = spec[..., 0] features["spec_pos_mass"] = np.maximum(spec, 0).sum(axis=-1) features["spec_pos_l2"] = np.sqrt((np.maximum(spec, 0) ** 2).mean(axis=-1)) features["spec_mean"] = spec.mean(axis=-1) features["spec_count_pos"] = (spec > 0).sum(axis=-1).astype(np.float32) features["spec_spread"] = spec[..., 0] - spec[..., -1] return features def analyze(path: Path) -> dict[str, float | str]: data = np.load(path) exact = data["exact"].astype(bool) token_acc = data["token_acc"].astype(np.float32) q_halt = data["q_halt"].astype(np.float32) det_exact = data["det_exact"].astype(bool) if "det_exact" in data.files else np.asarray([]) correct_count = exact.sum(axis=1).astype(np.float32) out: dict[str, float | str] = { "file": str(path), "n_samples": float(exact.shape[0]), "rollouts": float(exact.shape[1]), "deterministic/exact": float(det_exact.mean()) if det_exact.size else float("nan"), "mean_rollout/exact": float(exact.mean()), "oracle_pass/exact": float(exact.any(axis=1).mean()), "correct_count/mean": float(correct_count.mean()), "correct_count/std": float(correct_count.std()), "correct_count/median": float(np.median(correct_count)), "correct_count/q10": float(np.quantile(correct_count, 0.10)), "correct_count/q25": float(np.quantile(correct_count, 0.25)), "correct_count/q75": float(np.quantile(correct_count, 0.75)), "correct_count/q90": float(np.quantile(correct_count, 0.90)), "correct_count/zero_frac": float((correct_count == 0).mean()), "correct_count/full_frac": float((correct_count == exact.shape[1]).mean()), } for threshold in (1, 5, 10, 25, 50, 75, 90): if threshold <= exact.shape[1]: out[f"correct_count/ge_{threshold}_frac"] = float((correct_count >= threshold).mean()) q_idx = q_halt.argmax(axis=1) out.update(_selector_metrics(exact, token_acc, q_idx, "q_max")) out["corr_q_correct"] = _safe_corr(q_halt, exact.astype(np.float32)) q_selected = _take(exact, q_idx).astype(bool) if det_exact.size: det_success = det_exact.astype(bool) det_fail = ~det_success if det_success.any(): out["correct_count/det_success_mean"] = float(correct_count[det_success].mean()) out["oracle_pass/det_success_frac"] = float(exact.any(axis=1)[det_success].mean()) out["q_max/det_success_frac"] = float(q_selected[det_success].mean()) if det_fail.any(): out["correct_count/det_fail_mean"] = float(correct_count[det_fail].mean()) out["oracle_pass/det_fail_frac"] = float(exact.any(axis=1)[det_fail].mean()) out["q_max/det_fail_frac"] = float(q_selected[det_fail].mean()) for name, feature in _feature_table(data).items(): idx = feature.argmin(axis=1) out.update(_selector_metrics(exact, token_acc, idx, f"{name}_min")) out[f"corr_neg_{name}_correct"] = _safe_corr(-feature, exact.astype(np.float32)) out[f"corr_q_{name}"] = _safe_corr(q_halt, feature) q_sel = _take(exact, q_idx).astype(bool) f_sel = _take(exact, idx).astype(bool) out[f"{name}_q_wins"] = float((q_sel & ~f_sel).sum()) out[f"{name}_feature_wins"] = float((~q_sel & f_sel).sum()) out[f"{name}_both_fail"] = float((~q_sel & ~f_sel).sum()) return out def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("npz", nargs="+", help="Cached .npz files from ptrm_rollout_selection.py") parser.add_argument("--out-csv", default=None) args = parser.parse_args() rows = [analyze(Path(p)) for p in args.npz] fieldnames = sorted({k for row in rows for k in row}) if args.out_csv: out = Path(args.out_csv) out.parent.mkdir(parents=True, exist_ok=True) with out.open("w", newline="") as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() writer.writerows(rows) print(f"saved {out}") for row in rows: print(f"\n{row['file']}") for key in fieldnames: if key == "file" or key not in row: continue print(f"{key}: {row[key]}") if __name__ == "__main__": main()