summaryrefslogtreecommitdiff
path: root/research/flossing/analysis_2x2/early_window_pairing.py
diff options
context:
space:
mode:
Diffstat (limited to 'research/flossing/analysis_2x2/early_window_pairing.py')
-rw-r--r--research/flossing/analysis_2x2/early_window_pairing.py85
1 files changed, 85 insertions, 0 deletions
diff --git a/research/flossing/analysis_2x2/early_window_pairing.py b/research/flossing/analysis_2x2/early_window_pairing.py
new file mode 100644
index 0000000..6277df2
--- /dev/null
+++ b/research/flossing/analysis_2x2/early_window_pairing.py
@@ -0,0 +1,85 @@
+"""Early-window (first 4 ACT steps) vs full-window pairing.
+
+Question: does the early-window FTLE forecast FINAL failure before convergence?
+- join full/short npz per example via idx (same seed/n => same sampling)
+- target label = final exact_correct from the FULL-window run
+- report: AUC(-lam1_early -> final correct), same for early drift and q_halt@4,
+ and the conditional version restricted to examples NOT yet correct at step 4
+ (short run's own exact_correct == 0) — the practically relevant population for
+ early-exit / compute-reallocation decisions.
+
+Observational only. Usage: python early_window_pairing.py FULL.npz SHORT.npz TAG
+"""
+from __future__ import annotations
+
+import sys
+from pathlib import Path
+
+import numpy as np
+
+HERE = Path(__file__).resolve().parent
+
+
+def auc_rank(score: np.ndarray, label: np.ndarray) -> float:
+ pos, neg = score[label == 1], score[label == 0]
+ if len(pos) == 0 or len(neg) == 0:
+ return float("nan")
+ allv = np.concatenate([pos, neg])
+ order = np.argsort(allv, kind="mergesort")
+ ranks = np.empty(len(allv)); ranks[order] = np.arange(1, len(allv) + 1)
+ sv = allv[order]; i = 0
+ while i < len(sv):
+ j = i
+ while j + 1 < len(sv) and sv[j + 1] == sv[i]:
+ j += 1
+ if j > i:
+ ranks[order[i:j + 1]] = ranks[order[i:j + 1]].mean()
+ i = j + 1
+ return float((ranks[:len(pos)].sum() - len(pos) * (len(pos) + 1) / 2) / (len(pos) * len(neg)))
+
+
+def main(full_path: str, short_path: str, tag: str) -> None:
+ f = np.load(full_path)
+ s = np.load(short_path)
+
+ fi, si = f["idx"], s["idx"]
+ common, f_pos, s_pos = np.intersect1d(fi, si, return_indices=True)
+ print(f"[{tag}] paired {len(common)} examples (full n={len(fi)}, short n={len(si)})")
+
+ y_final = f["exact_correct"].astype(int)[f_pos] # FINAL outcome (step 16)
+ y_early = s["exact_correct"].astype(int)[s_pos] # correct already at step 4
+ l1_early = s["lyap_spec"][s_pos, 0]
+ l1_full = f["lyap_spec"][f_pos, 0]
+ drift4 = np.log10(np.clip(s["drift_zH"][s_pos, -1], 1e-12, None)) # drift at ACT step 4
+ q4 = s["q_halt"][s_pos, -1]
+
+ lines = [f"# Early-window pairing — {tag}",
+ f"- paired n={len(common)}; final acc={y_final.mean():.4f}; already-correct@step4={y_early.mean():.4f}",
+ f"- of final-correct, fraction already correct@4: {y_early[y_final==1].mean():.4f}",
+ f"- early-window lam1: final-correct med {np.median(l1_early[y_final==1]):+.4f}, "
+ f"final-wrong med {np.median(l1_early[y_final==0]):+.4f}",
+ "",
+ "## Forecasting FINAL outcome from the first 4 ACT steps",
+ f"- AUC(-lam1_early -> final correct) = {auc_rank(-l1_early, y_final):.3f}",
+ f"- AUC(-drift@4 -> final correct) = {auc_rank(-drift4, y_final):.3f}",
+ f"- AUC(q_halt@4 -> final correct) = {auc_rank(q4, y_final):.3f}",
+ f"- reference: AUC(-lam1_full -> final correct) = {auc_rank(-l1_full, y_final):.3f}",
+ "",
+ "## Restricted to examples NOT yet correct at step 4 (the decision-relevant set)"]
+ m = y_early == 0
+ n_m = int(m.sum()); n_pos = int(y_final[m].sum())
+ lines += [f"- n={n_m}, of which eventually correct: {n_pos} ({n_pos/max(n_m,1):.3f})",
+ f"- AUC(-lam1_early -> eventually correct) = {auc_rank(-l1_early[m], y_final[m]):.3f}",
+ f"- AUC(-drift@4 -> eventually correct) = {auc_rank(-drift4[m], y_final[m]):.3f}",
+ f"- AUC(q_halt@4 -> eventually correct) = {auc_rank(q4[m], y_final[m]):.3f}",
+ f"- early lam1 med: eventually-correct {np.median(l1_early[m & (y_final==1)]):+.4f} vs "
+ f"never-correct {np.median(l1_early[m & (y_final==0)]):+.4f}"]
+
+ out = HERE / f"early_pairing_{tag}.md"
+ out.write_text("\n".join(lines))
+ print("\n".join(lines))
+ print("wrote", out)
+
+
+if __name__ == "__main__":
+ main(sys.argv[1], sys.argv[2], sys.argv[3])