"""EP on Synthetic ladder. Same EP algorithm as ep_baseline.py but for StudentNet.""" import os, sys, json, argparse, numpy as np, torch, torch.nn as nn, torch.nn.functional as F import torch.optim as optim sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation class StudentBlock(nn.Module): def __init__(self, d, alpha=1.0): super().__init__() self.ln=nn.LayerNorm(d);self.w=nn.Linear(d,d,bias=False) nn.init.normal_(self.w.weight,std=0.01);self.alpha=alpha def forward(self, h): return self.w(((1-self.alpha)*self.ln(h)+self.alpha*torch.tanh(self.ln(h)))) class StudentNet(nn.Module): def __init__(self, d, C, L, alpha=1.0): super().__init__() self.blocks=nn.ModuleList([StudentBlock(d,alpha) for _ in range(L)]) self.out_head=nn.Linear(d,C);self.num_blocks=L;self.d_hidden=d def forward(self, x, return_hidden=False): h=x;hi=[h] if return_hidden else None for b in self.blocks: h=h+b(h) if return_hidden:hi.append(h) lo=self.out_head(h) return (lo,hi) if return_hidden else lo class TeacherNet(nn.Module): def __init__(self, d, C, L, alpha=1.0, seed=0): super().__init__() self.alpha=alpha;rng=torch.Generator().manual_seed(seed) self.Ws=nn.ParameterList() for _ in range(L): W=torch.randn(d,d,generator=rng)*0.3/(d**0.5) U,S,Vh=torch.linalg.svd(W,full_matrices=False) self.Ws.append(nn.Parameter(U@torch.diag(S.clamp(max=0.3))@Vh,requires_grad=False)) self.U=nn.Parameter(torch.randn(C,d,generator=rng)/(d**0.5),requires_grad=False) def forward(self, x): h=x for W in self.Ws:h=h+((1-self.alpha)*h+self.alpha*torch.tanh(h))@W.T return h@self.U.T def train_ep_synth(model, teacher, dev, d, C, L, epochs=80, steps=50, bs=256, lr=1e-3, wd=0.01, beta=0.5, T_nudge=20, alpha_nudge=0.1): block_opts=[optim.AdamW(b.parameters(),lr=lr,weight_decay=wd) for b in model.blocks] head_opt=optim.AdamW(model.out_head.parameters(),lr=lr,weight_decay=wd) for ep in range(1,epochs+1): model.train() for _ in range(steps): x=torch.randn(bs,d,device=dev) with torch.no_grad():y=teacher(x).argmax(-1) # Free phase with torch.no_grad():_,hiddens_free=model(x,return_hidden=True) h_free=[h.detach().clone() for h in hiddens_free] # Nudged phase h_nudge=[h.clone().requires_grad_(True) for h in h_free] for t in range(T_nudge): E=0.0 for l in range(L): res=h_nudge[l+1]-(h_nudge[l].detach()+model.blocks[l](h_nudge[l].detach())) E+=0.5*(res**2).sum(-1).mean() lo=model.out_head(h_nudge[-1]) cost=F.cross_entropy(lo,y) total=E+beta*cost grads=torch.autograd.grad(total,[h_nudge[l] for l in range(1,L+1)],create_graph=False) with torch.no_grad(): for l in range(L): h_nudge[l+1]=h_nudge[l+1]-alpha_nudge*grads[l] h_nudge[l+1].requires_grad_(True) # Weight updates for l in range(L): res_free=h_free[l+1]-(h_free[l]+model.blocks[l](h_free[l])) res_nudge=h_nudge[l+1].detach()-(h_nudge[l].detach()+model.blocks[l](h_nudge[l].detach())) loss_free=(-res_free.detach()*model.blocks[l](h_free[l])).sum(-1).mean() loss_nudge=(-res_nudge.detach()*model.blocks[l](h_nudge[l].detach())).sum(-1).mean() block_loss=(loss_nudge-loss_free)/beta block_opts[l].zero_grad();block_loss.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0) block_opts[l].step() lo_n=model.out_head(h_nudge[-1].detach()) head_loss=F.cross_entropy(lo_n,y) head_opt.zero_grad();head_loss.backward();head_opt.step() if ep%20==0: model.eval() with torch.no_grad(): xt=torch.randn(512,d,device=dev);yt=teacher(xt).argmax(-1) acc=(model(xt).argmax(1)==yt).float().mean().item() print(f" Ep {ep}: acc={acc:.4f}",flush=True) def compute_diagnostics(model, teacher, dev, d, C, L, beta=0.5, T_nudge=20, alpha_nudge=0.1): model.eval() x=torch.randn(512,d,device=dev) with torch.no_grad():y=teacher(x).argmax(-1) # BP grads hs=[x.detach().requires_grad_(True)] for b in model.blocks:hs.append(hs[-1]+b(hs[-1])) lo=model.out_head(hs[-1]);loss=F.cross_entropy(lo,y) gs=torch.autograd.grad(loss,hs);bp={l:gs[l].detach() for l in range(L+1)} # EP credit with torch.no_grad():_,hf=model(x,return_hidden=True) h_free=[h.detach().clone() for h in hf] h_nudge=[h.clone().requires_grad_(True) for h in h_free] for t in range(T_nudge): E=0.0 for l in range(L): res=h_nudge[l+1]-(h_nudge[l].detach()+model.blocks[l](h_nudge[l].detach())) E+=0.5*(res**2).sum(-1).mean() lo2=model.out_head(h_nudge[-1]);cost=F.cross_entropy(lo2,y) total=E+beta*cost grads=torch.autograd.grad(total,[h_nudge[l] for l in range(1,L+1)],create_graph=False) with torch.no_grad(): for l in range(L): h_nudge[l+1]=h_nudge[l+1]-alpha_nudge*grads[l] h_nudge[l+1].requires_grad_(True) gammas,rhos=[],[] with torch.no_grad():_,hi=model(x,return_hidden=True) for l in range(L): # EP nudge moves h toward lower loss, so (h_nudge - h_free) points opposite to BP grad. # Negate to align with BP gradient convention (pointing toward loss increase). a_ep=-(h_nudge[l+1].detach()-h_free[l+1])/beta gammas.append(cosine_similarity_batch(a_ep,bp[l+1])) def mk(sl): def f(h): with torch.no_grad(): c=h for i in range(sl,L):c=c+model.blocks[i](c) return F.cross_entropy(model.out_head(c),y,reduction='none') return f rhos.append(perturbation_correlation(hi[l+1].detach(),a_ep,mk(l+1),epsilon=1e-3,M=16)) acc=(model(x).argmax(1)==y).float().mean().item() return {'Gamma':float(np.mean(gammas)),'rho':float(np.mean(rhos)),'acc':acc} def main(): p=argparse.ArgumentParser() p.add_argument('--alpha',type=float,required=True) p.add_argument('--seed',type=int,required=True) p.add_argument('--gpu',type=int,default=0) p.add_argument('--depth',type=int,default=4) p.add_argument('--output_dir',type=str,default='results/ep_synthetic') args=p.parse_args() os.makedirs(args.output_dir,exist_ok=True) dev=torch.device(f'cuda:{args.gpu}') d,C,L=128,10,args.depth torch.manual_seed(args.seed);np.random.seed(args.seed);torch.cuda.manual_seed_all(args.seed) teacher=TeacherNet(d,C,L,args.alpha,seed=0).to(dev) model=StudentNet(d,C,L,args.alpha).to(dev) print(f"[EP synth a={args.alpha} L={L} s={args.seed}] Training...",flush=True) train_ep_synth(model,teacher,dev,d,C,L) diag=compute_diagnostics(model,teacher,dev,d,C,L) result={'method':'ep','alpha':args.alpha,'depth':L,'seed':args.seed, 'acc':diag['acc'],'Gamma':diag['Gamma'],'rho':diag['rho']} torch.save(model.state_dict(),os.path.join(args.output_dir,f'ep_a{args.alpha}_L{L}_s{args.seed}.pt')) out=os.path.join(args.output_dir,f'ep_a{args.alpha}_L{L}_s{args.seed}.json') with open(out,'w') as f:json.dump(result,f,indent=2,default=float) print(f" acc={diag['acc']:.4f} Gamma={diag['Gamma']:.4f} rho={diag['rho']:.4f}",flush=True) if __name__=='__main__':main()