diff options
Diffstat (limited to 'ep_run/redx_trajprobe.py')
| -rw-r--r-- | ep_run/redx_trajprobe.py | 47 |
1 files changed, 47 insertions, 0 deletions
diff --git a/ep_run/redx_trajprobe.py b/ep_run/redx_trajprobe.py new file mode 100644 index 0000000..ee88098 --- /dev/null +++ b/ep_run/redx_trajprobe.py @@ -0,0 +1,47 @@ +"""Live trajectory prober for ep_redx: probe each frozen ckpt for cos(g_EP,g_transpose), +overlay with val/res from the training log -> the step|res|cos ordering through the divergence. +Probes live (shares GPU0, slows the run ~1.4x); finishes when ep_redx exits (diverges).""" +import time, os, re, subprocess, glob +os.chdir("/home/yurenh2/ept/ep_run") +LOG, OUT, PID = "runs/ep_redx.log", "runs/redx_traj.log", 2497442 +def alive(): + try: os.kill(PID, 0); return True + except Exception: return False +def resmap(): + M = {} + try: + for l in open(LOG): + if l.startswith("step"): + m = re.search(r"step (\d+)/.*val CE ([\d.eE+-]+).*res=([\d.eE+-]+)", l) + if m: M[int(m.group(1))] = (m.group(2), m.group(3)) + except Exception: pass + return M +def probe(ck, done): + step = int(re.search(r"s(\d+)", ck).group(1)) + if step in done: return + done.add(step) + val, res = resmap().get(step, ("?", "?")) + cosv, zres = "?", "?" + env = dict(os.environ, CUDA_VISIBLE_DEVICES="0", PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True") + try: + r = subprocess.run(["python3", "asym_probe.py", "--ckpt", ck, "--B", "8"], + env=env, capture_output=True, text=True, timeout=400) + out = r.stdout + r.stderr + m = re.search(r"cos\(g_EP, ?g_transpose\)=([+-][0-9.]+)", out); cosv = m.group(1) if m else "?" + z = re.search(r"step_res=([0-9.eE+-]+)", out); zres = z.group(1) if z else "?" + except Exception: cosv = "err" + line = f" {step:5d} | val {val} | res(log) {res} | cos {cosv} | z*res {zres}" + open(OUT, "a").write(line + "\n"); print(line, flush=True) +def cks(): + return sorted(glob.glob("runs/redx_traj/s*.pt"), key=lambda p: int(re.search(r"s(\d+)", p).group(1))) +open(OUT, "a").write("# step | val | res(log) | cos(g_EP,g_transpose) | z*res(probe)\n") +done = set(); t0 = time.time() +while time.time() - t0 < 6 * 3600: + time.sleep(30) + for ck in cks(): probe(ck, done) + if not alive(): + time.sleep(45) # let freezer catch the final ckpts + for ck in cks(): probe(ck, done) + break +print("=== TRAJPROBE DONE — full trajectory ===") +for l in open(OUT): print(l.rstrip()) |
