diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-02 18:11:23 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-02 18:11:23 -0500 |
| commit | 1a917bca503c3a53610f363929f24e44a77fcb48 (patch) | |
| tree | 8f3d61610189b56eecf4da0bc69c143d5e100768 /experiments | |
| parent | 01f6b347194caefa1b2cb023755ac5088577f66f (diff) | |
Add EP synthetic ladder script
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
| -rw-r--r-- | experiments/ep_synthetic.py | 154 |
1 files changed, 154 insertions, 0 deletions
diff --git a/experiments/ep_synthetic.py b/experiments/ep_synthetic.py new file mode 100644 index 0000000..bb7f4af --- /dev/null +++ b/experiments/ep_synthetic.py @@ -0,0 +1,154 @@ +"""EP on Synthetic ladder. Same EP algorithm as ep_baseline.py but for StudentNet.""" +import os, sys, json, argparse, numpy as np, torch, torch.nn as nn, torch.nn.functional as F +import torch.optim as optim +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation + +class StudentBlock(nn.Module): + def __init__(self, d, alpha=1.0): + super().__init__() + self.ln=nn.LayerNorm(d);self.w=nn.Linear(d,d,bias=False) + nn.init.normal_(self.w.weight,std=0.01);self.alpha=alpha + def forward(self, h): + return self.w(((1-self.alpha)*self.ln(h)+self.alpha*torch.tanh(self.ln(h)))) + +class StudentNet(nn.Module): + def __init__(self, d, C, L, alpha=1.0): + super().__init__() + self.blocks=nn.ModuleList([StudentBlock(d,alpha) for _ in range(L)]) + self.out_head=nn.Linear(d,C);self.num_blocks=L;self.d_hidden=d + def forward(self, x, return_hidden=False): + h=x;hi=[h] if return_hidden else None + for b in self.blocks: + h=h+b(h) + if return_hidden:hi.append(h) + lo=self.out_head(h) + return (lo,hi) if return_hidden else lo + +class TeacherNet(nn.Module): + def __init__(self, d, C, L, alpha=1.0, seed=0): + super().__init__() + self.alpha=alpha;rng=torch.Generator().manual_seed(seed) + self.Ws=nn.ParameterList() + for _ in range(L): + W=torch.randn(d,d,generator=rng)*0.3/(d**0.5) + U,S,Vh=torch.linalg.svd(W,full_matrices=False) + self.Ws.append(nn.Parameter(U@torch.diag(S.clamp(max=0.3))@Vh,requires_grad=False)) + self.U=nn.Parameter(torch.randn(C,d,generator=rng)/(d**0.5),requires_grad=False) + def forward(self, x): + h=x + for W in self.Ws:h=h+((1-self.alpha)*h+self.alpha*torch.tanh(h))@W.T + return h@self.U.T + +def train_ep_synth(model, teacher, dev, d, C, L, epochs=80, steps=50, bs=256, + lr=1e-3, wd=0.01, beta=0.5, T_nudge=20, alpha_nudge=0.1): + block_opts=[optim.AdamW(b.parameters(),lr=lr,weight_decay=wd) for b in model.blocks] + head_opt=optim.AdamW(model.out_head.parameters(),lr=lr,weight_decay=wd) + for ep in range(1,epochs+1): + model.train() + for _ in range(steps): + x=torch.randn(bs,d,device=dev) + with torch.no_grad():y=teacher(x).argmax(-1) + # Free phase + with torch.no_grad():_,hiddens_free=model(x,return_hidden=True) + h_free=[h.detach().clone() for h in hiddens_free] + # Nudged phase + h_nudge=[h.clone().requires_grad_(True) for h in h_free] + for t in range(T_nudge): + E=0.0 + for l in range(L): + res=h_nudge[l+1]-(h_nudge[l].detach()+model.blocks[l](h_nudge[l].detach())) + E+=0.5*(res**2).sum(-1).mean() + lo=model.out_head(h_nudge[-1]) + cost=F.cross_entropy(lo,y) + total=E+beta*cost + grads=torch.autograd.grad(total,[h_nudge[l] for l in range(1,L+1)],create_graph=False) + with torch.no_grad(): + for l in range(L): + h_nudge[l+1]=h_nudge[l+1]-alpha_nudge*grads[l] + h_nudge[l+1].requires_grad_(True) + # Weight updates + for l in range(L): + res_free=h_free[l+1]-(h_free[l]+model.blocks[l](h_free[l])) + res_nudge=h_nudge[l+1].detach()-(h_nudge[l].detach()+model.blocks[l](h_nudge[l].detach())) + loss_free=(-res_free.detach()*model.blocks[l](h_free[l])).sum(-1).mean() + loss_nudge=(-res_nudge.detach()*model.blocks[l](h_nudge[l].detach())).sum(-1).mean() + block_loss=(loss_nudge-loss_free)/beta + block_opts[l].zero_grad();block_loss.backward() + torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0) + block_opts[l].step() + lo_n=model.out_head(h_nudge[-1].detach()) + head_loss=F.cross_entropy(lo_n,y) + head_opt.zero_grad();head_loss.backward();head_opt.step() + if ep%20==0: + model.eval() + with torch.no_grad(): + xt=torch.randn(512,d,device=dev);yt=teacher(xt).argmax(-1) + acc=(model(xt).argmax(1)==yt).float().mean().item() + print(f" Ep {ep}: acc={acc:.4f}",flush=True) + +def compute_diagnostics(model, teacher, dev, d, C, L, beta=0.5, T_nudge=20, alpha_nudge=0.1): + model.eval() + x=torch.randn(512,d,device=dev) + with torch.no_grad():y=teacher(x).argmax(-1) + # BP grads + hs=[x.detach().requires_grad_(True)] + for b in model.blocks:hs.append(hs[-1]+b(hs[-1])) + lo=model.out_head(hs[-1]);loss=F.cross_entropy(lo,y) + gs=torch.autograd.grad(loss,hs);bp={l:gs[l].detach() for l in range(L)} + # EP credit + with torch.no_grad():_,hf=model(x,return_hidden=True) + h_free=[h.detach().clone() for h in hf] + h_nudge=[h.clone().requires_grad_(True) for h in h_free] + for t in range(T_nudge): + E=0.0 + for l in range(L): + res=h_nudge[l+1]-(h_nudge[l].detach()+model.blocks[l](h_nudge[l].detach())) + E+=0.5*(res**2).sum(-1).mean() + lo2=model.out_head(h_nudge[-1]);cost=F.cross_entropy(lo2,y) + total=E+beta*cost + grads=torch.autograd.grad(total,[h_nudge[l] for l in range(1,L+1)],create_graph=False) + with torch.no_grad(): + for l in range(L): + h_nudge[l+1]=h_nudge[l+1]-alpha_nudge*grads[l] + h_nudge[l+1].requires_grad_(True) + gammas,rhos=[],[] + with torch.no_grad():_,hi=model(x,return_hidden=True) + for l in range(L): + a_ep=(h_nudge[l+1].detach()-h_free[l+1])/beta + gammas.append(cosine_similarity_batch(a_ep,bp[l+1])) + def mk(sl): + def f(h): + with torch.no_grad(): + c=h + for i in range(sl,L):c=c+model.blocks[i](c) + return F.cross_entropy(model.out_head(c),y,reduction='none') + return f + rhos.append(perturbation_correlation(hi[l+1].detach(),a_ep,mk(l+1),epsilon=1e-3,M=16)) + acc=(model(x).argmax(1)==y).float().mean().item() + return {'Gamma':float(np.mean(gammas)),'rho':float(np.mean(rhos)),'acc':acc} + +def main(): + p=argparse.ArgumentParser() + p.add_argument('--alpha',type=float,required=True) + p.add_argument('--seed',type=int,required=True) + p.add_argument('--gpu',type=int,default=0) + p.add_argument('--depth',type=int,default=4) + p.add_argument('--output_dir',type=str,default='results/ep_synthetic') + args=p.parse_args() + os.makedirs(args.output_dir,exist_ok=True) + dev=torch.device(f'cuda:{args.gpu}') + d,C,L=128,10,args.depth + torch.manual_seed(args.seed);np.random.seed(args.seed);torch.cuda.manual_seed_all(args.seed) + teacher=TeacherNet(d,C,L,args.alpha,seed=0).to(dev) + model=StudentNet(d,C,L,args.alpha).to(dev) + print(f"[EP synth a={args.alpha} L={L} s={args.seed}] Training...",flush=True) + train_ep_synth(model,teacher,dev,d,C,L) + diag=compute_diagnostics(model,teacher,dev,d,C,L) + result={'method':'ep','alpha':args.alpha,'depth':L,'seed':args.seed, + 'acc':diag['acc'],'Gamma':diag['Gamma'],'rho':diag['rho']} + out=os.path.join(args.output_dir,f'ep_a{args.alpha}_L{L}_s{args.seed}.json') + with open(out,'w') as f:json.dump(result,f,indent=2,default=float) + print(f" acc={diag['acc']:.4f} Gamma={diag['Gamma']:.4f} rho={diag['rho']:.4f}",flush=True) + +if __name__=='__main__':main() |
