summaryrefslogtreecommitdiff
path: root/research/flossing/analyze_ptrm_rollout_cache.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
committerYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
commit66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch)
treec29cba61124018755a19b02c9d33e3ad5f2e05cc /research/flossing/analyze_ptrm_rollout_cache.py
rrm workspace: TRM/HRM/SRM code, Maze dataset, dynamical-analysis pipelineHEADmain
Curated export for clone-and-run Maze training (2x A6000) + diagnostics. trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible). Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'research/flossing/analyze_ptrm_rollout_cache.py')
-rw-r--r--research/flossing/analyze_ptrm_rollout_cache.py135
1 files changed, 135 insertions, 0 deletions
diff --git a/research/flossing/analyze_ptrm_rollout_cache.py b/research/flossing/analyze_ptrm_rollout_cache.py
new file mode 100644
index 0000000..b7413e8
--- /dev/null
+++ b/research/flossing/analyze_ptrm_rollout_cache.py
@@ -0,0 +1,135 @@
+"""Analyze cached PTRM rollout/spectrum files without rerunning the model."""
+from __future__ import annotations
+
+import argparse
+import csv
+from pathlib import Path
+
+import numpy as np
+
+
+def _safe_corr(a: np.ndarray, b: np.ndarray) -> float:
+ a = np.asarray(a, dtype=np.float64).reshape(-1)
+ b = np.asarray(b, dtype=np.float64).reshape(-1)
+ if a.size == 0 or b.size == 0 or np.std(a) == 0 or np.std(b) == 0:
+ return float("nan")
+ return float(np.corrcoef(a, b)[0, 1])
+
+
+def _take(values: np.ndarray, idx: np.ndarray) -> np.ndarray:
+ return values[np.arange(values.shape[0]), idx]
+
+
+def _selector_metrics(exact: np.ndarray, token_acc: np.ndarray, idx: np.ndarray, prefix: str) -> dict[str, float]:
+ return {
+ f"{prefix}/exact": float(_take(exact, idx).mean()),
+ f"{prefix}/token_acc": float(_take(token_acc, idx).mean()),
+ }
+
+
+def _feature_table(data: np.lib.npyio.NpzFile) -> dict[str, np.ndarray]:
+ features: dict[str, np.ndarray] = {}
+ lyap = data["lyap"] if "lyap" in data.files else np.asarray([])
+ if lyap.size:
+ features["lambda1"] = lyap
+
+ spec = data["lyap_spec"] if "lyap_spec" in data.files else np.asarray([])
+ if spec.size:
+ features["spec_lambda1"] = spec[..., 0]
+ features["spec_pos_mass"] = np.maximum(spec, 0).sum(axis=-1)
+ features["spec_pos_l2"] = np.sqrt((np.maximum(spec, 0) ** 2).mean(axis=-1))
+ features["spec_mean"] = spec.mean(axis=-1)
+ features["spec_count_pos"] = (spec > 0).sum(axis=-1).astype(np.float32)
+ features["spec_spread"] = spec[..., 0] - spec[..., -1]
+ return features
+
+
+def analyze(path: Path) -> dict[str, float | str]:
+ data = np.load(path)
+ exact = data["exact"].astype(bool)
+ token_acc = data["token_acc"].astype(np.float32)
+ q_halt = data["q_halt"].astype(np.float32)
+ det_exact = data["det_exact"].astype(bool) if "det_exact" in data.files else np.asarray([])
+ correct_count = exact.sum(axis=1).astype(np.float32)
+
+ out: dict[str, float | str] = {
+ "file": str(path),
+ "n_samples": float(exact.shape[0]),
+ "rollouts": float(exact.shape[1]),
+ "deterministic/exact": float(det_exact.mean()) if det_exact.size else float("nan"),
+ "mean_rollout/exact": float(exact.mean()),
+ "oracle_pass/exact": float(exact.any(axis=1).mean()),
+ "correct_count/mean": float(correct_count.mean()),
+ "correct_count/std": float(correct_count.std()),
+ "correct_count/median": float(np.median(correct_count)),
+ "correct_count/q10": float(np.quantile(correct_count, 0.10)),
+ "correct_count/q25": float(np.quantile(correct_count, 0.25)),
+ "correct_count/q75": float(np.quantile(correct_count, 0.75)),
+ "correct_count/q90": float(np.quantile(correct_count, 0.90)),
+ "correct_count/zero_frac": float((correct_count == 0).mean()),
+ "correct_count/full_frac": float((correct_count == exact.shape[1]).mean()),
+ }
+ for threshold in (1, 5, 10, 25, 50, 75, 90):
+ if threshold <= exact.shape[1]:
+ out[f"correct_count/ge_{threshold}_frac"] = float((correct_count >= threshold).mean())
+
+ q_idx = q_halt.argmax(axis=1)
+ out.update(_selector_metrics(exact, token_acc, q_idx, "q_max"))
+ out["corr_q_correct"] = _safe_corr(q_halt, exact.astype(np.float32))
+ q_selected = _take(exact, q_idx).astype(bool)
+
+ if det_exact.size:
+ det_success = det_exact.astype(bool)
+ det_fail = ~det_success
+ if det_success.any():
+ out["correct_count/det_success_mean"] = float(correct_count[det_success].mean())
+ out["oracle_pass/det_success_frac"] = float(exact.any(axis=1)[det_success].mean())
+ out["q_max/det_success_frac"] = float(q_selected[det_success].mean())
+ if det_fail.any():
+ out["correct_count/det_fail_mean"] = float(correct_count[det_fail].mean())
+ out["oracle_pass/det_fail_frac"] = float(exact.any(axis=1)[det_fail].mean())
+ out["q_max/det_fail_frac"] = float(q_selected[det_fail].mean())
+
+ for name, feature in _feature_table(data).items():
+ idx = feature.argmin(axis=1)
+ out.update(_selector_metrics(exact, token_acc, idx, f"{name}_min"))
+ out[f"corr_neg_{name}_correct"] = _safe_corr(-feature, exact.astype(np.float32))
+ out[f"corr_q_{name}"] = _safe_corr(q_halt, feature)
+
+ q_sel = _take(exact, q_idx).astype(bool)
+ f_sel = _take(exact, idx).astype(bool)
+ out[f"{name}_q_wins"] = float((q_sel & ~f_sel).sum())
+ out[f"{name}_feature_wins"] = float((~q_sel & f_sel).sum())
+ out[f"{name}_both_fail"] = float((~q_sel & ~f_sel).sum())
+
+ return out
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser()
+ parser.add_argument("npz", nargs="+", help="Cached .npz files from ptrm_rollout_selection.py")
+ parser.add_argument("--out-csv", default=None)
+ args = parser.parse_args()
+
+ rows = [analyze(Path(p)) for p in args.npz]
+ fieldnames = sorted({k for row in rows for k in row})
+
+ if args.out_csv:
+ out = Path(args.out_csv)
+ out.parent.mkdir(parents=True, exist_ok=True)
+ with out.open("w", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
+ writer.writeheader()
+ writer.writerows(rows)
+ print(f"saved {out}")
+
+ for row in rows:
+ print(f"\n{row['file']}")
+ for key in fieldnames:
+ if key == "file" or key not in row:
+ continue
+ print(f"{key}: {row[key]}")
+
+
+if __name__ == "__main__":
+ main()