import math, time, torch import lt_ep_train as LT, holo_ep as H from test_aselect_deepdive import (manual_nc_jvp_vjp_thick, make_manual_step, make_tf_step, run_loop_from_step, cosine) import torch.func as tf def cosd(ga, gb, ps): num=da=db=0.0 for p in ps: a=ga.get(id(p)); b=gb.get(id(p)) if a is None or b is None: continue num+=float((a*b).sum()); da+=float((a*a).sum()); db+=float((b*b).sum()) return num/(math.sqrt(da*db)+1e-30) def grad_from_a(blk, zs, idx, a): with torch.enable_grad(): xin=blk.embed(idx) f=blk.force(zs.detach(), xin, cg=True) gs=torch.autograd.grad((a.detach()*f).sum(), blk.block, allow_unused=True) return {id(p):g for p,g in zip(blk.block,gs) if g is not None} dev='cpu' blk=LT.EQBlock(512,16,256,256,c=1.0,attn_mode='thick') blk.qknorm=True; blk.track=True; blk.navg=1; blk.li_avg=0 ck=torch.load('runs/ep_resreg_warm.pt', map_location=dev) with torch.no_grad(): for p,s in zip(blk.allp, ck['allp']): p.copy_(s.to(dev)) print(f"loaded step={ck.get('step')} best={ck.get('best')}", flush=True) # 1) Verify direct compile(step over torch.func) FAILS (independent confirmation) torch.manual_seed(0) idx,y=LT.get_batch('train',1,256); xin=blk.embed(idx).detach() zs=LT.relax(blk,xin.clone(),xin,1,0.1); B=zs.size(0); Z0=torch.cat([zs,zs],0) tf_step=make_tf_step(blk,zs,xin,y,0.02,0.1) try: c=torch.compile(tf_step, fullgraph=True); c(Z0) print("COMPILE compile∘func: UNEXPECTEDLY OK", flush=True) except Exception as e: print("COMPILE compile∘func: FAILS ->", type(e).__name__, str(e)[:90], flush=True) # 2) one-step manual vs torch.func identity + J accuracy v=torch.randn_like(Z0)*1e-3 zbar=0.5*(Z0[:B]+Z0[B:]); zb2=torch.cat([zbar,zbar],0) _,Jv_ref=tf.jvp(lambda zz: blk.nc_force(zz),(zb2,),(v,)) JTv_ref=tf.vjp(lambda zz: blk.nc_force(zz),zb2)[1](v)[0] Jv_m,JTv_m=manual_nc_jvp_vjp_thick(blk,zb2,v) print(f"manual Jv cos={cosine(Jv_ref,Jv_m):.8f} JTv cos={cosine(JTv_ref,JTv_m):.8f}", flush=True) with torch.no_grad(): Zt=tf_step(Z0); Zm=make_manual_step(blk,zs,xin,y,0.02,0.1)(Z0) print(f"one-step manual vs tf max_abs={float((Zt-Zm).abs().max()):.2e}", flush=True) # 3) full a-select + gradient cosine, several configs for seed,T2,K in [(1,80,10),(2,80,10),(3,160,10)]: torch.manual_seed(seed) idx,y=LT.get_batch('train',1,256); xin=blk.embed(idx).detach() zs=LT.relax(blk,xin.clone(),xin,1,0.1) t=time.time(); a0,tb=H.holo_a_track(blk,zs,xin,y,0.02,T2,0.1,K=K); sb=time.time()-t t=time.time(); a1,tm=run_loop_from_step(make_manual_step(blk,zs,xin,y,0.02,0.1),zs,0.02,T2,K=K); sm=time.time()-t g0=grad_from_a(blk,zs,idx,a0); g1=grad_from_a(blk,zs,idx,a1) print(f"seed{seed} T2={T2} K={K}: a_cos={cosine(a0,a1):.7f} grad_cos={cosd(g0,g1,blk.block):.7f} " f"tbest={tb}/{tm} base={sb:.2f}s man={sm:.2f}s spd={sb/sm:.2f}x", flush=True) print("DONE", flush=True)