summaryrefslogtreecommitdiff
path: root/ep_run/redx_trajprobe.py
blob: ee88098f90b571c6c7786e73bf34cf557c6d7589 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
"""Live trajectory prober for ep_redx: probe each frozen ckpt for cos(g_EP,g_transpose),
overlay with val/res from the training log -> the step|res|cos ordering through the divergence.
Probes live (shares GPU0, slows the run ~1.4x); finishes when ep_redx exits (diverges)."""
import time, os, re, subprocess, glob
os.chdir("/home/yurenh2/ept/ep_run")
LOG, OUT, PID = "runs/ep_redx.log", "runs/redx_traj.log", 2497442
def alive():
    try: os.kill(PID, 0); return True
    except Exception: return False
def resmap():
    M = {}
    try:
        for l in open(LOG):
            if l.startswith("step"):
                m = re.search(r"step (\d+)/.*val CE ([\d.eE+-]+).*res=([\d.eE+-]+)", l)
                if m: M[int(m.group(1))] = (m.group(2), m.group(3))
    except Exception: pass
    return M
def probe(ck, done):
    step = int(re.search(r"s(\d+)", ck).group(1))
    if step in done: return
    done.add(step)
    val, res = resmap().get(step, ("?", "?"))
    cosv, zres = "?", "?"
    env = dict(os.environ, CUDA_VISIBLE_DEVICES="0", PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True")
    try:
        r = subprocess.run(["python3", "asym_probe.py", "--ckpt", ck, "--B", "8"],
                           env=env, capture_output=True, text=True, timeout=400)
        out = r.stdout + r.stderr
        m = re.search(r"cos\(g_EP, ?g_transpose\)=([+-][0-9.]+)", out); cosv = m.group(1) if m else "?"
        z = re.search(r"step_res=([0-9.eE+-]+)", out); zres = z.group(1) if z else "?"
    except Exception: cosv = "err"
    line = f"  {step:5d} | val {val} | res(log) {res} | cos {cosv} | z*res {zres}"
    open(OUT, "a").write(line + "\n"); print(line, flush=True)
def cks():
    return sorted(glob.glob("runs/redx_traj/s*.pt"), key=lambda p: int(re.search(r"s(\d+)", p).group(1)))
open(OUT, "a").write("# step | val | res(log) | cos(g_EP,g_transpose) | z*res(probe)\n")
done = set(); t0 = time.time()
while time.time() - t0 < 6 * 3600:
    time.sleep(30)
    for ck in cks(): probe(ck, done)
    if not alive():
        time.sleep(45)                      # let freezer catch the final ckpts
        for ck in cks(): probe(ck, done)
        break
print("=== TRAJPROBE DONE — full trajectory ===")
for l in open(OUT): print(l.rstrip())