1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
|
"""Live trajectory prober for ep_redx: probe each frozen ckpt for cos(g_EP,g_transpose),
overlay with val/res from the training log -> the step|res|cos ordering through the divergence.
Probes live (shares GPU0, slows the run ~1.4x); finishes when ep_redx exits (diverges)."""
import time, os, re, subprocess, glob
os.chdir("/home/yurenh2/ept/ep_run")
LOG, OUT, PID = "runs/ep_redx.log", "runs/redx_traj.log", 2497442
def alive():
try: os.kill(PID, 0); return True
except Exception: return False
def resmap():
M = {}
try:
for l in open(LOG):
if l.startswith("step"):
m = re.search(r"step (\d+)/.*val CE ([\d.eE+-]+).*res=([\d.eE+-]+)", l)
if m: M[int(m.group(1))] = (m.group(2), m.group(3))
except Exception: pass
return M
def probe(ck, done):
step = int(re.search(r"s(\d+)", ck).group(1))
if step in done: return
done.add(step)
val, res = resmap().get(step, ("?", "?"))
cosv, zres = "?", "?"
env = dict(os.environ, CUDA_VISIBLE_DEVICES="0", PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True")
try:
r = subprocess.run(["python3", "asym_probe.py", "--ckpt", ck, "--B", "8"],
env=env, capture_output=True, text=True, timeout=400)
out = r.stdout + r.stderr
m = re.search(r"cos\(g_EP, ?g_transpose\)=([+-][0-9.]+)", out); cosv = m.group(1) if m else "?"
z = re.search(r"step_res=([0-9.eE+-]+)", out); zres = z.group(1) if z else "?"
except Exception: cosv = "err"
line = f" {step:5d} | val {val} | res(log) {res} | cos {cosv} | z*res {zres}"
open(OUT, "a").write(line + "\n"); print(line, flush=True)
def cks():
return sorted(glob.glob("runs/redx_traj/s*.pt"), key=lambda p: int(re.search(r"s(\d+)", p).group(1)))
open(OUT, "a").write("# step | val | res(log) | cos(g_EP,g_transpose) | z*res(probe)\n")
done = set(); t0 = time.time()
while time.time() - t0 < 6 * 3600:
time.sleep(30)
for ck in cks(): probe(ck, done)
if not alive():
time.sleep(45) # let freezer catch the final ckpts
for ck in cks(): probe(ck, done)
break
print("=== TRAJPROBE DONE — full trajectory ===")
for l in open(OUT): print(l.rstrip())
|