1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
|
"""Early-window (first 4 ACT steps) vs full-window pairing.
Question: does the early-window FTLE forecast FINAL failure before convergence?
- join full/short npz per example via idx (same seed/n => same sampling)
- target label = final exact_correct from the FULL-window run
- report: AUC(-lam1_early -> final correct), same for early drift and q_halt@4,
and the conditional version restricted to examples NOT yet correct at step 4
(short run's own exact_correct == 0) — the practically relevant population for
early-exit / compute-reallocation decisions.
Observational only. Usage: python early_window_pairing.py FULL.npz SHORT.npz TAG
"""
from __future__ import annotations
import sys
from pathlib import Path
import numpy as np
HERE = Path(__file__).resolve().parent
def auc_rank(score: np.ndarray, label: np.ndarray) -> float:
pos, neg = score[label == 1], score[label == 0]
if len(pos) == 0 or len(neg) == 0:
return float("nan")
allv = np.concatenate([pos, neg])
order = np.argsort(allv, kind="mergesort")
ranks = np.empty(len(allv)); ranks[order] = np.arange(1, len(allv) + 1)
sv = allv[order]; i = 0
while i < len(sv):
j = i
while j + 1 < len(sv) and sv[j + 1] == sv[i]:
j += 1
if j > i:
ranks[order[i:j + 1]] = ranks[order[i:j + 1]].mean()
i = j + 1
return float((ranks[:len(pos)].sum() - len(pos) * (len(pos) + 1) / 2) / (len(pos) * len(neg)))
def main(full_path: str, short_path: str, tag: str) -> None:
f = np.load(full_path)
s = np.load(short_path)
fi, si = f["idx"], s["idx"]
common, f_pos, s_pos = np.intersect1d(fi, si, return_indices=True)
print(f"[{tag}] paired {len(common)} examples (full n={len(fi)}, short n={len(si)})")
y_final = f["exact_correct"].astype(int)[f_pos] # FINAL outcome (step 16)
y_early = s["exact_correct"].astype(int)[s_pos] # correct already at step 4
l1_early = s["lyap_spec"][s_pos, 0]
l1_full = f["lyap_spec"][f_pos, 0]
drift4 = np.log10(np.clip(s["drift_zH"][s_pos, -1], 1e-12, None)) # drift at ACT step 4
q4 = s["q_halt"][s_pos, -1]
lines = [f"# Early-window pairing — {tag}",
f"- paired n={len(common)}; final acc={y_final.mean():.4f}; already-correct@step4={y_early.mean():.4f}",
f"- of final-correct, fraction already correct@4: {y_early[y_final==1].mean():.4f}",
f"- early-window lam1: final-correct med {np.median(l1_early[y_final==1]):+.4f}, "
f"final-wrong med {np.median(l1_early[y_final==0]):+.4f}",
"",
"## Forecasting FINAL outcome from the first 4 ACT steps",
f"- AUC(-lam1_early -> final correct) = {auc_rank(-l1_early, y_final):.3f}",
f"- AUC(-drift@4 -> final correct) = {auc_rank(-drift4, y_final):.3f}",
f"- AUC(q_halt@4 -> final correct) = {auc_rank(q4, y_final):.3f}",
f"- reference: AUC(-lam1_full -> final correct) = {auc_rank(-l1_full, y_final):.3f}",
"",
"## Restricted to examples NOT yet correct at step 4 (the decision-relevant set)"]
m = y_early == 0
n_m = int(m.sum()); n_pos = int(y_final[m].sum())
lines += [f"- n={n_m}, of which eventually correct: {n_pos} ({n_pos/max(n_m,1):.3f})",
f"- AUC(-lam1_early -> eventually correct) = {auc_rank(-l1_early[m], y_final[m]):.3f}",
f"- AUC(-drift@4 -> eventually correct) = {auc_rank(-drift4[m], y_final[m]):.3f}",
f"- AUC(q_halt@4 -> eventually correct) = {auc_rank(q4[m], y_final[m]):.3f}",
f"- early lam1 med: eventually-correct {np.median(l1_early[m & (y_final==1)]):+.4f} vs "
f"never-correct {np.median(l1_early[m & (y_final==0)]):+.4f}"]
out = HERE / f"early_pairing_{tag}.md"
out.write_text("\n".join(lines))
print("\n".join(lines))
print("wrote", out)
if __name__ == "__main__":
main(sys.argv[1], sys.argv[2], sys.argv[3])
|