summaryrefslogtreecommitdiff
path: root/ep_run/redx_trajprobe.py
diff options
context:
space:
mode:
Diffstat (limited to 'ep_run/redx_trajprobe.py')
-rw-r--r--ep_run/redx_trajprobe.py47
1 files changed, 47 insertions, 0 deletions
diff --git a/ep_run/redx_trajprobe.py b/ep_run/redx_trajprobe.py
new file mode 100644
index 0000000..ee88098
--- /dev/null
+++ b/ep_run/redx_trajprobe.py
@@ -0,0 +1,47 @@
+"""Live trajectory prober for ep_redx: probe each frozen ckpt for cos(g_EP,g_transpose),
+overlay with val/res from the training log -> the step|res|cos ordering through the divergence.
+Probes live (shares GPU0, slows the run ~1.4x); finishes when ep_redx exits (diverges)."""
+import time, os, re, subprocess, glob
+os.chdir("/home/yurenh2/ept/ep_run")
+LOG, OUT, PID = "runs/ep_redx.log", "runs/redx_traj.log", 2497442
+def alive():
+ try: os.kill(PID, 0); return True
+ except Exception: return False
+def resmap():
+ M = {}
+ try:
+ for l in open(LOG):
+ if l.startswith("step"):
+ m = re.search(r"step (\d+)/.*val CE ([\d.eE+-]+).*res=([\d.eE+-]+)", l)
+ if m: M[int(m.group(1))] = (m.group(2), m.group(3))
+ except Exception: pass
+ return M
+def probe(ck, done):
+ step = int(re.search(r"s(\d+)", ck).group(1))
+ if step in done: return
+ done.add(step)
+ val, res = resmap().get(step, ("?", "?"))
+ cosv, zres = "?", "?"
+ env = dict(os.environ, CUDA_VISIBLE_DEVICES="0", PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True")
+ try:
+ r = subprocess.run(["python3", "asym_probe.py", "--ckpt", ck, "--B", "8"],
+ env=env, capture_output=True, text=True, timeout=400)
+ out = r.stdout + r.stderr
+ m = re.search(r"cos\(g_EP, ?g_transpose\)=([+-][0-9.]+)", out); cosv = m.group(1) if m else "?"
+ z = re.search(r"step_res=([0-9.eE+-]+)", out); zres = z.group(1) if z else "?"
+ except Exception: cosv = "err"
+ line = f" {step:5d} | val {val} | res(log) {res} | cos {cosv} | z*res {zres}"
+ open(OUT, "a").write(line + "\n"); print(line, flush=True)
+def cks():
+ return sorted(glob.glob("runs/redx_traj/s*.pt"), key=lambda p: int(re.search(r"s(\d+)", p).group(1)))
+open(OUT, "a").write("# step | val | res(log) | cos(g_EP,g_transpose) | z*res(probe)\n")
+done = set(); t0 = time.time()
+while time.time() - t0 < 6 * 3600:
+ time.sleep(30)
+ for ck in cks(): probe(ck, done)
+ if not alive():
+ time.sleep(45) # let freezer catch the final ckpts
+ for ck in cks(): probe(ck, done)
+ break
+print("=== TRAJPROBE DONE — full trajectory ===")
+for l in open(OUT): print(l.rstrip())