summaryrefslogtreecommitdiff
path: root/ep_run/auto_probe.py
blob: 90b969ea540759ebbdd49394f7c191b213f6ebad (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
"""Wait for the first converged clean-EP ckpt, run the fixed oracle-adjoint probe on it, report g_EP vs g_transpose."""
import time, os, subprocess, shutil
WD = "/home/yurenh2/ept/ep_run"; os.chdir(WD)
CK, FROZEN = "runs/ep_clean.pt", "runs/ep_clean_probe.pt"
got = False
for _ in range(45):                       # up to ~67 min
    time.sleep(90)
    if os.path.exists(CK) and os.path.getsize(CK) > 1_000_000:
        got = True; break
if not got:
    print("=== AUTO-PROBE: ep_clean.pt never appeared in ~67min ==="); raise SystemExit
shutil.copy2(CK, FROZEN)                   # freeze (avoid write race)
env = dict(os.environ, CUDA_VISIBLE_DEVICES="0", PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True")
try:
    r = subprocess.run(["python3", "asym_probe.py", "--ckpt", FROZEN, "--B", "8"],
                       env=env, capture_output=True, text=True, timeout=1200)
    txt = r.stdout + "\n" + r.stderr
except Exception as e:
    txt = f"probe run error: {e}"
print("=== AUTO ORACLE PROBE on first clean-EP ckpt ===")
KEEP = ("z* res", "GMRES", "resid", "cos(", "g_transpose", "g_EP", "g_BPTT", "interpret", "->", "exact", "AsymEP", "# ckpt", "step ")
DROP = ("UserWarning", "cuBLAS", "warnings.warn", "Triggered", "return Variable", "FutureWarning")
for line in txt.splitlines():
    if any(k in line for k in KEEP) and not any(b in line for b in DROP):
        print(line)