summaryrefslogtreecommitdiff
path: root/research/flossing/analysis_2x2/early_window_pairing.py
blob: 6277df2e3d26b1dc8f60e1cdd7eaf3176938043a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""Early-window (first 4 ACT steps) vs full-window pairing.

Question: does the early-window FTLE forecast FINAL failure before convergence?
- join full/short npz per example via idx (same seed/n => same sampling)
- target label = final exact_correct from the FULL-window run
- report: AUC(-lam1_early -> final correct), same for early drift and q_halt@4,
  and the conditional version restricted to examples NOT yet correct at step 4
  (short run's own exact_correct == 0) — the practically relevant population for
  early-exit / compute-reallocation decisions.

Observational only. Usage: python early_window_pairing.py FULL.npz SHORT.npz TAG
"""
from __future__ import annotations

import sys
from pathlib import Path

import numpy as np

HERE = Path(__file__).resolve().parent


def auc_rank(score: np.ndarray, label: np.ndarray) -> float:
    pos, neg = score[label == 1], score[label == 0]
    if len(pos) == 0 or len(neg) == 0:
        return float("nan")
    allv = np.concatenate([pos, neg])
    order = np.argsort(allv, kind="mergesort")
    ranks = np.empty(len(allv)); ranks[order] = np.arange(1, len(allv) + 1)
    sv = allv[order]; i = 0
    while i < len(sv):
        j = i
        while j + 1 < len(sv) and sv[j + 1] == sv[i]:
            j += 1
        if j > i:
            ranks[order[i:j + 1]] = ranks[order[i:j + 1]].mean()
        i = j + 1
    return float((ranks[:len(pos)].sum() - len(pos) * (len(pos) + 1) / 2) / (len(pos) * len(neg)))


def main(full_path: str, short_path: str, tag: str) -> None:
    f = np.load(full_path)
    s = np.load(short_path)

    fi, si = f["idx"], s["idx"]
    common, f_pos, s_pos = np.intersect1d(fi, si, return_indices=True)
    print(f"[{tag}] paired {len(common)} examples (full n={len(fi)}, short n={len(si)})")

    y_final = f["exact_correct"].astype(int)[f_pos]          # FINAL outcome (step 16)
    y_early = s["exact_correct"].astype(int)[s_pos]          # correct already at step 4
    l1_early = s["lyap_spec"][s_pos, 0]
    l1_full = f["lyap_spec"][f_pos, 0]
    drift4 = np.log10(np.clip(s["drift_zH"][s_pos, -1], 1e-12, None))   # drift at ACT step 4
    q4 = s["q_halt"][s_pos, -1]

    lines = [f"# Early-window pairing — {tag}",
             f"- paired n={len(common)}; final acc={y_final.mean():.4f}; already-correct@step4={y_early.mean():.4f}",
             f"- of final-correct, fraction already correct@4: {y_early[y_final==1].mean():.4f}",
             f"- early-window lam1: final-correct med {np.median(l1_early[y_final==1]):+.4f}, "
             f"final-wrong med {np.median(l1_early[y_final==0]):+.4f}",
             "",
             "## Forecasting FINAL outcome from the first 4 ACT steps",
             f"- AUC(-lam1_early -> final correct) = {auc_rank(-l1_early, y_final):.3f}",
             f"- AUC(-drift@4    -> final correct) = {auc_rank(-drift4, y_final):.3f}",
             f"- AUC(q_halt@4    -> final correct) = {auc_rank(q4, y_final):.3f}",
             f"- reference: AUC(-lam1_full -> final correct) = {auc_rank(-l1_full, y_final):.3f}",
             "",
             "## Restricted to examples NOT yet correct at step 4 (the decision-relevant set)"]
    m = y_early == 0
    n_m = int(m.sum()); n_pos = int(y_final[m].sum())
    lines += [f"- n={n_m}, of which eventually correct: {n_pos} ({n_pos/max(n_m,1):.3f})",
              f"- AUC(-lam1_early -> eventually correct) = {auc_rank(-l1_early[m], y_final[m]):.3f}",
              f"- AUC(-drift@4    -> eventually correct) = {auc_rank(-drift4[m], y_final[m]):.3f}",
              f"- AUC(q_halt@4    -> eventually correct) = {auc_rank(q4[m], y_final[m]):.3f}",
              f"- early lam1 med: eventually-correct {np.median(l1_early[m & (y_final==1)]):+.4f} vs "
              f"never-correct {np.median(l1_early[m & (y_final==0)]):+.4f}"]

    out = HERE / f"early_pairing_{tag}.md"
    out.write_text("\n".join(lines))
    print("\n".join(lines))
    print("wrote", out)


if __name__ == "__main__":
    main(sys.argv[1], sys.argv[2], sys.argv[3])