diff options
Diffstat (limited to 'ep_run/verify_aep_manual.py')
| -rw-r--r-- | ep_run/verify_aep_manual.py | 62 |
1 files changed, 62 insertions, 0 deletions
diff --git a/ep_run/verify_aep_manual.py b/ep_run/verify_aep_manual.py new file mode 100644 index 0000000..6c4b403 --- /dev/null +++ b/ep_run/verify_aep_manual.py @@ -0,0 +1,62 @@ +import math, time, torch +import lt_ep_train as LT, holo_ep as H +from test_aselect_deepdive import (manual_nc_jvp_vjp_thick, make_manual_step, + make_tf_step, run_loop_from_step, cosine) +import torch.func as tf + +def cosd(ga, gb, ps): + num=da=db=0.0 + for p in ps: + a=ga.get(id(p)); b=gb.get(id(p)) + if a is None or b is None: continue + num+=float((a*b).sum()); da+=float((a*a).sum()); db+=float((b*b).sum()) + return num/(math.sqrt(da*db)+1e-30) + +def grad_from_a(blk, zs, idx, a): + with torch.enable_grad(): + xin=blk.embed(idx) + f=blk.force(zs.detach(), xin, cg=True) + gs=torch.autograd.grad((a.detach()*f).sum(), blk.block, allow_unused=True) + return {id(p):g for p,g in zip(blk.block,gs) if g is not None} + +dev='cpu' +blk=LT.EQBlock(512,16,256,256,c=1.0,attn_mode='thick') +blk.qknorm=True; blk.track=True; blk.navg=1; blk.li_avg=0 +ck=torch.load('runs/ep_resreg_warm.pt', map_location=dev) +with torch.no_grad(): + for p,s in zip(blk.allp, ck['allp']): p.copy_(s.to(dev)) +print(f"loaded step={ck.get('step')} best={ck.get('best')}", flush=True) + +# 1) Verify direct compile(step over torch.func) FAILS (independent confirmation) +torch.manual_seed(0) +idx,y=LT.get_batch('train',1,256); xin=blk.embed(idx).detach() +zs=LT.relax(blk,xin.clone(),xin,1,0.1); B=zs.size(0); Z0=torch.cat([zs,zs],0) +tf_step=make_tf_step(blk,zs,xin,y,0.02,0.1) +try: + c=torch.compile(tf_step, fullgraph=True); c(Z0) + print("COMPILE compile∘func: UNEXPECTEDLY OK", flush=True) +except Exception as e: + print("COMPILE compile∘func: FAILS ->", type(e).__name__, str(e)[:90], flush=True) + +# 2) one-step manual vs torch.func identity + J accuracy +v=torch.randn_like(Z0)*1e-3 +zbar=0.5*(Z0[:B]+Z0[B:]); zb2=torch.cat([zbar,zbar],0) +_,Jv_ref=tf.jvp(lambda zz: blk.nc_force(zz),(zb2,),(v,)) +JTv_ref=tf.vjp(lambda zz: blk.nc_force(zz),zb2)[1](v)[0] +Jv_m,JTv_m=manual_nc_jvp_vjp_thick(blk,zb2,v) +print(f"manual Jv cos={cosine(Jv_ref,Jv_m):.8f} JTv cos={cosine(JTv_ref,JTv_m):.8f}", flush=True) +with torch.no_grad(): + Zt=tf_step(Z0); Zm=make_manual_step(blk,zs,xin,y,0.02,0.1)(Z0) +print(f"one-step manual vs tf max_abs={float((Zt-Zm).abs().max()):.2e}", flush=True) + +# 3) full a-select + gradient cosine, several configs +for seed,T2,K in [(1,80,10),(2,80,10),(3,160,10)]: + torch.manual_seed(seed) + idx,y=LT.get_batch('train',1,256); xin=blk.embed(idx).detach() + zs=LT.relax(blk,xin.clone(),xin,1,0.1) + t=time.time(); a0,tb=H.holo_a_track(blk,zs,xin,y,0.02,T2,0.1,K=K); sb=time.time()-t + t=time.time(); a1,tm=run_loop_from_step(make_manual_step(blk,zs,xin,y,0.02,0.1),zs,0.02,T2,K=K); sm=time.time()-t + g0=grad_from_a(blk,zs,idx,a0); g1=grad_from_a(blk,zs,idx,a1) + print(f"seed{seed} T2={T2} K={K}: a_cos={cosine(a0,a1):.7f} grad_cos={cosd(g0,g1,blk.block):.7f} " + f"tbest={tb}/{tm} base={sb:.2f}s man={sm:.2f}s spd={sb/sm:.2f}x", flush=True) +print("DONE", flush=True) |
