diff options
Diffstat (limited to 'research/flossing/analyze_ptrm_rollout_cache.py')
| -rw-r--r-- | research/flossing/analyze_ptrm_rollout_cache.py | 135 |
1 files changed, 135 insertions, 0 deletions
diff --git a/research/flossing/analyze_ptrm_rollout_cache.py b/research/flossing/analyze_ptrm_rollout_cache.py new file mode 100644 index 0000000..b7413e8 --- /dev/null +++ b/research/flossing/analyze_ptrm_rollout_cache.py @@ -0,0 +1,135 @@ +"""Analyze cached PTRM rollout/spectrum files without rerunning the model.""" +from __future__ import annotations + +import argparse +import csv +from pathlib import Path + +import numpy as np + + +def _safe_corr(a: np.ndarray, b: np.ndarray) -> float: + a = np.asarray(a, dtype=np.float64).reshape(-1) + b = np.asarray(b, dtype=np.float64).reshape(-1) + if a.size == 0 or b.size == 0 or np.std(a) == 0 or np.std(b) == 0: + return float("nan") + return float(np.corrcoef(a, b)[0, 1]) + + +def _take(values: np.ndarray, idx: np.ndarray) -> np.ndarray: + return values[np.arange(values.shape[0]), idx] + + +def _selector_metrics(exact: np.ndarray, token_acc: np.ndarray, idx: np.ndarray, prefix: str) -> dict[str, float]: + return { + f"{prefix}/exact": float(_take(exact, idx).mean()), + f"{prefix}/token_acc": float(_take(token_acc, idx).mean()), + } + + +def _feature_table(data: np.lib.npyio.NpzFile) -> dict[str, np.ndarray]: + features: dict[str, np.ndarray] = {} + lyap = data["lyap"] if "lyap" in data.files else np.asarray([]) + if lyap.size: + features["lambda1"] = lyap + + spec = data["lyap_spec"] if "lyap_spec" in data.files else np.asarray([]) + if spec.size: + features["spec_lambda1"] = spec[..., 0] + features["spec_pos_mass"] = np.maximum(spec, 0).sum(axis=-1) + features["spec_pos_l2"] = np.sqrt((np.maximum(spec, 0) ** 2).mean(axis=-1)) + features["spec_mean"] = spec.mean(axis=-1) + features["spec_count_pos"] = (spec > 0).sum(axis=-1).astype(np.float32) + features["spec_spread"] = spec[..., 0] - spec[..., -1] + return features + + +def analyze(path: Path) -> dict[str, float | str]: + data = np.load(path) + exact = data["exact"].astype(bool) + token_acc = data["token_acc"].astype(np.float32) + q_halt = data["q_halt"].astype(np.float32) + det_exact = data["det_exact"].astype(bool) if "det_exact" in data.files else np.asarray([]) + correct_count = exact.sum(axis=1).astype(np.float32) + + out: dict[str, float | str] = { + "file": str(path), + "n_samples": float(exact.shape[0]), + "rollouts": float(exact.shape[1]), + "deterministic/exact": float(det_exact.mean()) if det_exact.size else float("nan"), + "mean_rollout/exact": float(exact.mean()), + "oracle_pass/exact": float(exact.any(axis=1).mean()), + "correct_count/mean": float(correct_count.mean()), + "correct_count/std": float(correct_count.std()), + "correct_count/median": float(np.median(correct_count)), + "correct_count/q10": float(np.quantile(correct_count, 0.10)), + "correct_count/q25": float(np.quantile(correct_count, 0.25)), + "correct_count/q75": float(np.quantile(correct_count, 0.75)), + "correct_count/q90": float(np.quantile(correct_count, 0.90)), + "correct_count/zero_frac": float((correct_count == 0).mean()), + "correct_count/full_frac": float((correct_count == exact.shape[1]).mean()), + } + for threshold in (1, 5, 10, 25, 50, 75, 90): + if threshold <= exact.shape[1]: + out[f"correct_count/ge_{threshold}_frac"] = float((correct_count >= threshold).mean()) + + q_idx = q_halt.argmax(axis=1) + out.update(_selector_metrics(exact, token_acc, q_idx, "q_max")) + out["corr_q_correct"] = _safe_corr(q_halt, exact.astype(np.float32)) + q_selected = _take(exact, q_idx).astype(bool) + + if det_exact.size: + det_success = det_exact.astype(bool) + det_fail = ~det_success + if det_success.any(): + out["correct_count/det_success_mean"] = float(correct_count[det_success].mean()) + out["oracle_pass/det_success_frac"] = float(exact.any(axis=1)[det_success].mean()) + out["q_max/det_success_frac"] = float(q_selected[det_success].mean()) + if det_fail.any(): + out["correct_count/det_fail_mean"] = float(correct_count[det_fail].mean()) + out["oracle_pass/det_fail_frac"] = float(exact.any(axis=1)[det_fail].mean()) + out["q_max/det_fail_frac"] = float(q_selected[det_fail].mean()) + + for name, feature in _feature_table(data).items(): + idx = feature.argmin(axis=1) + out.update(_selector_metrics(exact, token_acc, idx, f"{name}_min")) + out[f"corr_neg_{name}_correct"] = _safe_corr(-feature, exact.astype(np.float32)) + out[f"corr_q_{name}"] = _safe_corr(q_halt, feature) + + q_sel = _take(exact, q_idx).astype(bool) + f_sel = _take(exact, idx).astype(bool) + out[f"{name}_q_wins"] = float((q_sel & ~f_sel).sum()) + out[f"{name}_feature_wins"] = float((~q_sel & f_sel).sum()) + out[f"{name}_both_fail"] = float((~q_sel & ~f_sel).sum()) + + return out + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("npz", nargs="+", help="Cached .npz files from ptrm_rollout_selection.py") + parser.add_argument("--out-csv", default=None) + args = parser.parse_args() + + rows = [analyze(Path(p)) for p in args.npz] + fieldnames = sorted({k for row in rows for k in row}) + + if args.out_csv: + out = Path(args.out_csv) + out.parent.mkdir(parents=True, exist_ok=True) + with out.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(rows) + print(f"saved {out}") + + for row in rows: + print(f"\n{row['file']}") + for key in fieldnames: + if key == "file" or key not in row: + continue + print(f"{key}: {row[key]}") + + +if __name__ == "__main__": + main() |
