summaryrefslogtreecommitdiff
path: root/ep_run/verify_aep_manual.py
diff options
context:
space:
mode:
Diffstat (limited to 'ep_run/verify_aep_manual.py')
-rw-r--r--ep_run/verify_aep_manual.py62
1 files changed, 62 insertions, 0 deletions
diff --git a/ep_run/verify_aep_manual.py b/ep_run/verify_aep_manual.py
new file mode 100644
index 0000000..6c4b403
--- /dev/null
+++ b/ep_run/verify_aep_manual.py
@@ -0,0 +1,62 @@
+import math, time, torch
+import lt_ep_train as LT, holo_ep as H
+from test_aselect_deepdive import (manual_nc_jvp_vjp_thick, make_manual_step,
+ make_tf_step, run_loop_from_step, cosine)
+import torch.func as tf
+
+def cosd(ga, gb, ps):
+ num=da=db=0.0
+ for p in ps:
+ a=ga.get(id(p)); b=gb.get(id(p))
+ if a is None or b is None: continue
+ num+=float((a*b).sum()); da+=float((a*a).sum()); db+=float((b*b).sum())
+ return num/(math.sqrt(da*db)+1e-30)
+
+def grad_from_a(blk, zs, idx, a):
+ with torch.enable_grad():
+ xin=blk.embed(idx)
+ f=blk.force(zs.detach(), xin, cg=True)
+ gs=torch.autograd.grad((a.detach()*f).sum(), blk.block, allow_unused=True)
+ return {id(p):g for p,g in zip(blk.block,gs) if g is not None}
+
+dev='cpu'
+blk=LT.EQBlock(512,16,256,256,c=1.0,attn_mode='thick')
+blk.qknorm=True; blk.track=True; blk.navg=1; blk.li_avg=0
+ck=torch.load('runs/ep_resreg_warm.pt', map_location=dev)
+with torch.no_grad():
+ for p,s in zip(blk.allp, ck['allp']): p.copy_(s.to(dev))
+print(f"loaded step={ck.get('step')} best={ck.get('best')}", flush=True)
+
+# 1) Verify direct compile(step over torch.func) FAILS (independent confirmation)
+torch.manual_seed(0)
+idx,y=LT.get_batch('train',1,256); xin=blk.embed(idx).detach()
+zs=LT.relax(blk,xin.clone(),xin,1,0.1); B=zs.size(0); Z0=torch.cat([zs,zs],0)
+tf_step=make_tf_step(blk,zs,xin,y,0.02,0.1)
+try:
+ c=torch.compile(tf_step, fullgraph=True); c(Z0)
+ print("COMPILE compile∘func: UNEXPECTEDLY OK", flush=True)
+except Exception as e:
+ print("COMPILE compile∘func: FAILS ->", type(e).__name__, str(e)[:90], flush=True)
+
+# 2) one-step manual vs torch.func identity + J accuracy
+v=torch.randn_like(Z0)*1e-3
+zbar=0.5*(Z0[:B]+Z0[B:]); zb2=torch.cat([zbar,zbar],0)
+_,Jv_ref=tf.jvp(lambda zz: blk.nc_force(zz),(zb2,),(v,))
+JTv_ref=tf.vjp(lambda zz: blk.nc_force(zz),zb2)[1](v)[0]
+Jv_m,JTv_m=manual_nc_jvp_vjp_thick(blk,zb2,v)
+print(f"manual Jv cos={cosine(Jv_ref,Jv_m):.8f} JTv cos={cosine(JTv_ref,JTv_m):.8f}", flush=True)
+with torch.no_grad():
+ Zt=tf_step(Z0); Zm=make_manual_step(blk,zs,xin,y,0.02,0.1)(Z0)
+print(f"one-step manual vs tf max_abs={float((Zt-Zm).abs().max()):.2e}", flush=True)
+
+# 3) full a-select + gradient cosine, several configs
+for seed,T2,K in [(1,80,10),(2,80,10),(3,160,10)]:
+ torch.manual_seed(seed)
+ idx,y=LT.get_batch('train',1,256); xin=blk.embed(idx).detach()
+ zs=LT.relax(blk,xin.clone(),xin,1,0.1)
+ t=time.time(); a0,tb=H.holo_a_track(blk,zs,xin,y,0.02,T2,0.1,K=K); sb=time.time()-t
+ t=time.time(); a1,tm=run_loop_from_step(make_manual_step(blk,zs,xin,y,0.02,0.1),zs,0.02,T2,K=K); sm=time.time()-t
+ g0=grad_from_a(blk,zs,idx,a0); g1=grad_from_a(blk,zs,idx,a1)
+ print(f"seed{seed} T2={T2} K={K}: a_cos={cosine(a0,a1):.7f} grad_cos={cosd(g0,g1,blk.block):.7f} "
+ f"tbest={tb}/{tm} base={sb:.2f}s man={sm:.2f}s spd={sb/sm:.2f}x", flush=True)
+print("DONE", flush=True)