summaryrefslogtreecommitdiff
path: root/ep_run/verify_aep_manual.py
blob: 6c4b403ee6372d7458440cc13945d2f1048fd918 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import math, time, torch
import lt_ep_train as LT, holo_ep as H
from test_aselect_deepdive import (manual_nc_jvp_vjp_thick, make_manual_step,
                                   make_tf_step, run_loop_from_step, cosine)
import torch.func as tf

def cosd(ga, gb, ps):
    num=da=db=0.0
    for p in ps:
        a=ga.get(id(p)); b=gb.get(id(p))
        if a is None or b is None: continue
        num+=float((a*b).sum()); da+=float((a*a).sum()); db+=float((b*b).sum())
    return num/(math.sqrt(da*db)+1e-30)

def grad_from_a(blk, zs, idx, a):
    with torch.enable_grad():
        xin=blk.embed(idx)
        f=blk.force(zs.detach(), xin, cg=True)
        gs=torch.autograd.grad((a.detach()*f).sum(), blk.block, allow_unused=True)
    return {id(p):g for p,g in zip(blk.block,gs) if g is not None}

dev='cpu'
blk=LT.EQBlock(512,16,256,256,c=1.0,attn_mode='thick')
blk.qknorm=True; blk.track=True; blk.navg=1; blk.li_avg=0
ck=torch.load('runs/ep_resreg_warm.pt', map_location=dev)
with torch.no_grad():
    for p,s in zip(blk.allp, ck['allp']): p.copy_(s.to(dev))
print(f"loaded step={ck.get('step')} best={ck.get('best')}", flush=True)

# 1) Verify direct compile(step over torch.func) FAILS (independent confirmation)
torch.manual_seed(0)
idx,y=LT.get_batch('train',1,256); xin=blk.embed(idx).detach()
zs=LT.relax(blk,xin.clone(),xin,1,0.1); B=zs.size(0); Z0=torch.cat([zs,zs],0)
tf_step=make_tf_step(blk,zs,xin,y,0.02,0.1)
try:
    c=torch.compile(tf_step, fullgraph=True); c(Z0)
    print("COMPILE compile∘func: UNEXPECTEDLY OK", flush=True)
except Exception as e:
    print("COMPILE compile∘func: FAILS ->", type(e).__name__, str(e)[:90], flush=True)

# 2) one-step manual vs torch.func identity + J accuracy
v=torch.randn_like(Z0)*1e-3
zbar=0.5*(Z0[:B]+Z0[B:]); zb2=torch.cat([zbar,zbar],0)
_,Jv_ref=tf.jvp(lambda zz: blk.nc_force(zz),(zb2,),(v,))
JTv_ref=tf.vjp(lambda zz: blk.nc_force(zz),zb2)[1](v)[0]
Jv_m,JTv_m=manual_nc_jvp_vjp_thick(blk,zb2,v)
print(f"manual Jv  cos={cosine(Jv_ref,Jv_m):.8f}  JTv cos={cosine(JTv_ref,JTv_m):.8f}", flush=True)
with torch.no_grad():
    Zt=tf_step(Z0); Zm=make_manual_step(blk,zs,xin,y,0.02,0.1)(Z0)
print(f"one-step manual vs tf max_abs={float((Zt-Zm).abs().max()):.2e}", flush=True)

# 3) full a-select + gradient cosine, several configs
for seed,T2,K in [(1,80,10),(2,80,10),(3,160,10)]:
    torch.manual_seed(seed)
    idx,y=LT.get_batch('train',1,256); xin=blk.embed(idx).detach()
    zs=LT.relax(blk,xin.clone(),xin,1,0.1)
    t=time.time(); a0,tb=H.holo_a_track(blk,zs,xin,y,0.02,T2,0.1,K=K); sb=time.time()-t
    t=time.time(); a1,tm=run_loop_from_step(make_manual_step(blk,zs,xin,y,0.02,0.1),zs,0.02,T2,K=K); sm=time.time()-t
    g0=grad_from_a(blk,zs,idx,a0); g1=grad_from_a(blk,zs,idx,a1)
    print(f"seed{seed} T2={T2} K={K}: a_cos={cosine(a0,a1):.7f} grad_cos={cosd(g0,g1,blk.block):.7f} "
          f"tbest={tb}/{tm} base={sb:.2f}s man={sm:.2f}s spd={sb/sm:.2f}x", flush=True)
print("DONE", flush=True)