diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
| commit | 66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch) | |
| tree | c29cba61124018755a19b02c9d33e3ad5f2e05cc /research/flossing/analysis_2x2/early_window_pairing.py | |
Curated export for clone-and-run Maze training (2x A6000) + diagnostics.
trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible).
Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'research/flossing/analysis_2x2/early_window_pairing.py')
| -rw-r--r-- | research/flossing/analysis_2x2/early_window_pairing.py | 85 |
1 files changed, 85 insertions, 0 deletions
diff --git a/research/flossing/analysis_2x2/early_window_pairing.py b/research/flossing/analysis_2x2/early_window_pairing.py new file mode 100644 index 0000000..6277df2 --- /dev/null +++ b/research/flossing/analysis_2x2/early_window_pairing.py @@ -0,0 +1,85 @@ +"""Early-window (first 4 ACT steps) vs full-window pairing. + +Question: does the early-window FTLE forecast FINAL failure before convergence? +- join full/short npz per example via idx (same seed/n => same sampling) +- target label = final exact_correct from the FULL-window run +- report: AUC(-lam1_early -> final correct), same for early drift and q_halt@4, + and the conditional version restricted to examples NOT yet correct at step 4 + (short run's own exact_correct == 0) — the practically relevant population for + early-exit / compute-reallocation decisions. + +Observational only. Usage: python early_window_pairing.py FULL.npz SHORT.npz TAG +""" +from __future__ import annotations + +import sys +from pathlib import Path + +import numpy as np + +HERE = Path(__file__).resolve().parent + + +def auc_rank(score: np.ndarray, label: np.ndarray) -> float: + pos, neg = score[label == 1], score[label == 0] + if len(pos) == 0 or len(neg) == 0: + return float("nan") + allv = np.concatenate([pos, neg]) + order = np.argsort(allv, kind="mergesort") + ranks = np.empty(len(allv)); ranks[order] = np.arange(1, len(allv) + 1) + sv = allv[order]; i = 0 + while i < len(sv): + j = i + while j + 1 < len(sv) and sv[j + 1] == sv[i]: + j += 1 + if j > i: + ranks[order[i:j + 1]] = ranks[order[i:j + 1]].mean() + i = j + 1 + return float((ranks[:len(pos)].sum() - len(pos) * (len(pos) + 1) / 2) / (len(pos) * len(neg))) + + +def main(full_path: str, short_path: str, tag: str) -> None: + f = np.load(full_path) + s = np.load(short_path) + + fi, si = f["idx"], s["idx"] + common, f_pos, s_pos = np.intersect1d(fi, si, return_indices=True) + print(f"[{tag}] paired {len(common)} examples (full n={len(fi)}, short n={len(si)})") + + y_final = f["exact_correct"].astype(int)[f_pos] # FINAL outcome (step 16) + y_early = s["exact_correct"].astype(int)[s_pos] # correct already at step 4 + l1_early = s["lyap_spec"][s_pos, 0] + l1_full = f["lyap_spec"][f_pos, 0] + drift4 = np.log10(np.clip(s["drift_zH"][s_pos, -1], 1e-12, None)) # drift at ACT step 4 + q4 = s["q_halt"][s_pos, -1] + + lines = [f"# Early-window pairing — {tag}", + f"- paired n={len(common)}; final acc={y_final.mean():.4f}; already-correct@step4={y_early.mean():.4f}", + f"- of final-correct, fraction already correct@4: {y_early[y_final==1].mean():.4f}", + f"- early-window lam1: final-correct med {np.median(l1_early[y_final==1]):+.4f}, " + f"final-wrong med {np.median(l1_early[y_final==0]):+.4f}", + "", + "## Forecasting FINAL outcome from the first 4 ACT steps", + f"- AUC(-lam1_early -> final correct) = {auc_rank(-l1_early, y_final):.3f}", + f"- AUC(-drift@4 -> final correct) = {auc_rank(-drift4, y_final):.3f}", + f"- AUC(q_halt@4 -> final correct) = {auc_rank(q4, y_final):.3f}", + f"- reference: AUC(-lam1_full -> final correct) = {auc_rank(-l1_full, y_final):.3f}", + "", + "## Restricted to examples NOT yet correct at step 4 (the decision-relevant set)"] + m = y_early == 0 + n_m = int(m.sum()); n_pos = int(y_final[m].sum()) + lines += [f"- n={n_m}, of which eventually correct: {n_pos} ({n_pos/max(n_m,1):.3f})", + f"- AUC(-lam1_early -> eventually correct) = {auc_rank(-l1_early[m], y_final[m]):.3f}", + f"- AUC(-drift@4 -> eventually correct) = {auc_rank(-drift4[m], y_final[m]):.3f}", + f"- AUC(q_halt@4 -> eventually correct) = {auc_rank(q4[m], y_final[m]):.3f}", + f"- early lam1 med: eventually-correct {np.median(l1_early[m & (y_final==1)]):+.4f} vs " + f"never-correct {np.median(l1_early[m & (y_final==0)]):+.4f}"] + + out = HERE / f"early_pairing_{tag}.md" + out.write_text("\n".join(lines)) + print("\n".join(lines)) + print("wrote", out) + + +if __name__ == "__main__": + main(sys.argv[1], sys.argv[2], sys.argv[3]) |
