summaryrefslogtreecommitdiff
path: root/research/flossing/analyze_ptrm_rollout_cache.py
blob: b7413e88584adc7ccf546883522df21aa9419b50 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""Analyze cached PTRM rollout/spectrum files without rerunning the model."""
from __future__ import annotations

import argparse
import csv
from pathlib import Path

import numpy as np


def _safe_corr(a: np.ndarray, b: np.ndarray) -> float:
    a = np.asarray(a, dtype=np.float64).reshape(-1)
    b = np.asarray(b, dtype=np.float64).reshape(-1)
    if a.size == 0 or b.size == 0 or np.std(a) == 0 or np.std(b) == 0:
        return float("nan")
    return float(np.corrcoef(a, b)[0, 1])


def _take(values: np.ndarray, idx: np.ndarray) -> np.ndarray:
    return values[np.arange(values.shape[0]), idx]


def _selector_metrics(exact: np.ndarray, token_acc: np.ndarray, idx: np.ndarray, prefix: str) -> dict[str, float]:
    return {
        f"{prefix}/exact": float(_take(exact, idx).mean()),
        f"{prefix}/token_acc": float(_take(token_acc, idx).mean()),
    }


def _feature_table(data: np.lib.npyio.NpzFile) -> dict[str, np.ndarray]:
    features: dict[str, np.ndarray] = {}
    lyap = data["lyap"] if "lyap" in data.files else np.asarray([])
    if lyap.size:
        features["lambda1"] = lyap

    spec = data["lyap_spec"] if "lyap_spec" in data.files else np.asarray([])
    if spec.size:
        features["spec_lambda1"] = spec[..., 0]
        features["spec_pos_mass"] = np.maximum(spec, 0).sum(axis=-1)
        features["spec_pos_l2"] = np.sqrt((np.maximum(spec, 0) ** 2).mean(axis=-1))
        features["spec_mean"] = spec.mean(axis=-1)
        features["spec_count_pos"] = (spec > 0).sum(axis=-1).astype(np.float32)
        features["spec_spread"] = spec[..., 0] - spec[..., -1]
    return features


def analyze(path: Path) -> dict[str, float | str]:
    data = np.load(path)
    exact = data["exact"].astype(bool)
    token_acc = data["token_acc"].astype(np.float32)
    q_halt = data["q_halt"].astype(np.float32)
    det_exact = data["det_exact"].astype(bool) if "det_exact" in data.files else np.asarray([])
    correct_count = exact.sum(axis=1).astype(np.float32)

    out: dict[str, float | str] = {
        "file": str(path),
        "n_samples": float(exact.shape[0]),
        "rollouts": float(exact.shape[1]),
        "deterministic/exact": float(det_exact.mean()) if det_exact.size else float("nan"),
        "mean_rollout/exact": float(exact.mean()),
        "oracle_pass/exact": float(exact.any(axis=1).mean()),
        "correct_count/mean": float(correct_count.mean()),
        "correct_count/std": float(correct_count.std()),
        "correct_count/median": float(np.median(correct_count)),
        "correct_count/q10": float(np.quantile(correct_count, 0.10)),
        "correct_count/q25": float(np.quantile(correct_count, 0.25)),
        "correct_count/q75": float(np.quantile(correct_count, 0.75)),
        "correct_count/q90": float(np.quantile(correct_count, 0.90)),
        "correct_count/zero_frac": float((correct_count == 0).mean()),
        "correct_count/full_frac": float((correct_count == exact.shape[1]).mean()),
    }
    for threshold in (1, 5, 10, 25, 50, 75, 90):
        if threshold <= exact.shape[1]:
            out[f"correct_count/ge_{threshold}_frac"] = float((correct_count >= threshold).mean())

    q_idx = q_halt.argmax(axis=1)
    out.update(_selector_metrics(exact, token_acc, q_idx, "q_max"))
    out["corr_q_correct"] = _safe_corr(q_halt, exact.astype(np.float32))
    q_selected = _take(exact, q_idx).astype(bool)

    if det_exact.size:
        det_success = det_exact.astype(bool)
        det_fail = ~det_success
        if det_success.any():
            out["correct_count/det_success_mean"] = float(correct_count[det_success].mean())
            out["oracle_pass/det_success_frac"] = float(exact.any(axis=1)[det_success].mean())
            out["q_max/det_success_frac"] = float(q_selected[det_success].mean())
        if det_fail.any():
            out["correct_count/det_fail_mean"] = float(correct_count[det_fail].mean())
            out["oracle_pass/det_fail_frac"] = float(exact.any(axis=1)[det_fail].mean())
            out["q_max/det_fail_frac"] = float(q_selected[det_fail].mean())

    for name, feature in _feature_table(data).items():
        idx = feature.argmin(axis=1)
        out.update(_selector_metrics(exact, token_acc, idx, f"{name}_min"))
        out[f"corr_neg_{name}_correct"] = _safe_corr(-feature, exact.astype(np.float32))
        out[f"corr_q_{name}"] = _safe_corr(q_halt, feature)

        q_sel = _take(exact, q_idx).astype(bool)
        f_sel = _take(exact, idx).astype(bool)
        out[f"{name}_q_wins"] = float((q_sel & ~f_sel).sum())
        out[f"{name}_feature_wins"] = float((~q_sel & f_sel).sum())
        out[f"{name}_both_fail"] = float((~q_sel & ~f_sel).sum())

    return out


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("npz", nargs="+", help="Cached .npz files from ptrm_rollout_selection.py")
    parser.add_argument("--out-csv", default=None)
    args = parser.parse_args()

    rows = [analyze(Path(p)) for p in args.npz]
    fieldnames = sorted({k for row in rows for k in row})

    if args.out_csv:
        out = Path(args.out_csv)
        out.parent.mkdir(parents=True, exist_ok=True)
        with out.open("w", newline="") as f:
            writer = csv.DictWriter(f, fieldnames=fieldnames)
            writer.writeheader()
            writer.writerows(rows)
        print(f"saved {out}")

    for row in rows:
        print(f"\n{row['file']}")
        for key in fieldnames:
            if key == "file" or key not in row:
                continue
            print(f"{key}: {row[key]}")


if __name__ == "__main__":
    main()