summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-02 18:11:23 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-02 18:11:23 -0500
commit1a917bca503c3a53610f363929f24e44a77fcb48 (patch)
tree8f3d61610189b56eecf4da0bc69c143d5e100768 /experiments
parent01f6b347194caefa1b2cb023755ac5088577f66f (diff)
Add EP synthetic ladder script
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
-rw-r--r--experiments/ep_synthetic.py154
1 files changed, 154 insertions, 0 deletions
diff --git a/experiments/ep_synthetic.py b/experiments/ep_synthetic.py
new file mode 100644
index 0000000..bb7f4af
--- /dev/null
+++ b/experiments/ep_synthetic.py
@@ -0,0 +1,154 @@
+"""EP on Synthetic ladder. Same EP algorithm as ep_baseline.py but for StudentNet."""
+import os, sys, json, argparse, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
+import torch.optim as optim
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation
+
+class StudentBlock(nn.Module):
+ def __init__(self, d, alpha=1.0):
+ super().__init__()
+ self.ln=nn.LayerNorm(d);self.w=nn.Linear(d,d,bias=False)
+ nn.init.normal_(self.w.weight,std=0.01);self.alpha=alpha
+ def forward(self, h):
+ return self.w(((1-self.alpha)*self.ln(h)+self.alpha*torch.tanh(self.ln(h))))
+
+class StudentNet(nn.Module):
+ def __init__(self, d, C, L, alpha=1.0):
+ super().__init__()
+ self.blocks=nn.ModuleList([StudentBlock(d,alpha) for _ in range(L)])
+ self.out_head=nn.Linear(d,C);self.num_blocks=L;self.d_hidden=d
+ def forward(self, x, return_hidden=False):
+ h=x;hi=[h] if return_hidden else None
+ for b in self.blocks:
+ h=h+b(h)
+ if return_hidden:hi.append(h)
+ lo=self.out_head(h)
+ return (lo,hi) if return_hidden else lo
+
+class TeacherNet(nn.Module):
+ def __init__(self, d, C, L, alpha=1.0, seed=0):
+ super().__init__()
+ self.alpha=alpha;rng=torch.Generator().manual_seed(seed)
+ self.Ws=nn.ParameterList()
+ for _ in range(L):
+ W=torch.randn(d,d,generator=rng)*0.3/(d**0.5)
+ U,S,Vh=torch.linalg.svd(W,full_matrices=False)
+ self.Ws.append(nn.Parameter(U@torch.diag(S.clamp(max=0.3))@Vh,requires_grad=False))
+ self.U=nn.Parameter(torch.randn(C,d,generator=rng)/(d**0.5),requires_grad=False)
+ def forward(self, x):
+ h=x
+ for W in self.Ws:h=h+((1-self.alpha)*h+self.alpha*torch.tanh(h))@W.T
+ return h@self.U.T
+
+def train_ep_synth(model, teacher, dev, d, C, L, epochs=80, steps=50, bs=256,
+ lr=1e-3, wd=0.01, beta=0.5, T_nudge=20, alpha_nudge=0.1):
+ block_opts=[optim.AdamW(b.parameters(),lr=lr,weight_decay=wd) for b in model.blocks]
+ head_opt=optim.AdamW(model.out_head.parameters(),lr=lr,weight_decay=wd)
+ for ep in range(1,epochs+1):
+ model.train()
+ for _ in range(steps):
+ x=torch.randn(bs,d,device=dev)
+ with torch.no_grad():y=teacher(x).argmax(-1)
+ # Free phase
+ with torch.no_grad():_,hiddens_free=model(x,return_hidden=True)
+ h_free=[h.detach().clone() for h in hiddens_free]
+ # Nudged phase
+ h_nudge=[h.clone().requires_grad_(True) for h in h_free]
+ for t in range(T_nudge):
+ E=0.0
+ for l in range(L):
+ res=h_nudge[l+1]-(h_nudge[l].detach()+model.blocks[l](h_nudge[l].detach()))
+ E+=0.5*(res**2).sum(-1).mean()
+ lo=model.out_head(h_nudge[-1])
+ cost=F.cross_entropy(lo,y)
+ total=E+beta*cost
+ grads=torch.autograd.grad(total,[h_nudge[l] for l in range(1,L+1)],create_graph=False)
+ with torch.no_grad():
+ for l in range(L):
+ h_nudge[l+1]=h_nudge[l+1]-alpha_nudge*grads[l]
+ h_nudge[l+1].requires_grad_(True)
+ # Weight updates
+ for l in range(L):
+ res_free=h_free[l+1]-(h_free[l]+model.blocks[l](h_free[l]))
+ res_nudge=h_nudge[l+1].detach()-(h_nudge[l].detach()+model.blocks[l](h_nudge[l].detach()))
+ loss_free=(-res_free.detach()*model.blocks[l](h_free[l])).sum(-1).mean()
+ loss_nudge=(-res_nudge.detach()*model.blocks[l](h_nudge[l].detach())).sum(-1).mean()
+ block_loss=(loss_nudge-loss_free)/beta
+ block_opts[l].zero_grad();block_loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0)
+ block_opts[l].step()
+ lo_n=model.out_head(h_nudge[-1].detach())
+ head_loss=F.cross_entropy(lo_n,y)
+ head_opt.zero_grad();head_loss.backward();head_opt.step()
+ if ep%20==0:
+ model.eval()
+ with torch.no_grad():
+ xt=torch.randn(512,d,device=dev);yt=teacher(xt).argmax(-1)
+ acc=(model(xt).argmax(1)==yt).float().mean().item()
+ print(f" Ep {ep}: acc={acc:.4f}",flush=True)
+
+def compute_diagnostics(model, teacher, dev, d, C, L, beta=0.5, T_nudge=20, alpha_nudge=0.1):
+ model.eval()
+ x=torch.randn(512,d,device=dev)
+ with torch.no_grad():y=teacher(x).argmax(-1)
+ # BP grads
+ hs=[x.detach().requires_grad_(True)]
+ for b in model.blocks:hs.append(hs[-1]+b(hs[-1]))
+ lo=model.out_head(hs[-1]);loss=F.cross_entropy(lo,y)
+ gs=torch.autograd.grad(loss,hs);bp={l:gs[l].detach() for l in range(L)}
+ # EP credit
+ with torch.no_grad():_,hf=model(x,return_hidden=True)
+ h_free=[h.detach().clone() for h in hf]
+ h_nudge=[h.clone().requires_grad_(True) for h in h_free]
+ for t in range(T_nudge):
+ E=0.0
+ for l in range(L):
+ res=h_nudge[l+1]-(h_nudge[l].detach()+model.blocks[l](h_nudge[l].detach()))
+ E+=0.5*(res**2).sum(-1).mean()
+ lo2=model.out_head(h_nudge[-1]);cost=F.cross_entropy(lo2,y)
+ total=E+beta*cost
+ grads=torch.autograd.grad(total,[h_nudge[l] for l in range(1,L+1)],create_graph=False)
+ with torch.no_grad():
+ for l in range(L):
+ h_nudge[l+1]=h_nudge[l+1]-alpha_nudge*grads[l]
+ h_nudge[l+1].requires_grad_(True)
+ gammas,rhos=[],[]
+ with torch.no_grad():_,hi=model(x,return_hidden=True)
+ for l in range(L):
+ a_ep=(h_nudge[l+1].detach()-h_free[l+1])/beta
+ gammas.append(cosine_similarity_batch(a_ep,bp[l+1]))
+ def mk(sl):
+ def f(h):
+ with torch.no_grad():
+ c=h
+ for i in range(sl,L):c=c+model.blocks[i](c)
+ return F.cross_entropy(model.out_head(c),y,reduction='none')
+ return f
+ rhos.append(perturbation_correlation(hi[l+1].detach(),a_ep,mk(l+1),epsilon=1e-3,M=16))
+ acc=(model(x).argmax(1)==y).float().mean().item()
+ return {'Gamma':float(np.mean(gammas)),'rho':float(np.mean(rhos)),'acc':acc}
+
+def main():
+ p=argparse.ArgumentParser()
+ p.add_argument('--alpha',type=float,required=True)
+ p.add_argument('--seed',type=int,required=True)
+ p.add_argument('--gpu',type=int,default=0)
+ p.add_argument('--depth',type=int,default=4)
+ p.add_argument('--output_dir',type=str,default='results/ep_synthetic')
+ args=p.parse_args()
+ os.makedirs(args.output_dir,exist_ok=True)
+ dev=torch.device(f'cuda:{args.gpu}')
+ d,C,L=128,10,args.depth
+ torch.manual_seed(args.seed);np.random.seed(args.seed);torch.cuda.manual_seed_all(args.seed)
+ teacher=TeacherNet(d,C,L,args.alpha,seed=0).to(dev)
+ model=StudentNet(d,C,L,args.alpha).to(dev)
+ print(f"[EP synth a={args.alpha} L={L} s={args.seed}] Training...",flush=True)
+ train_ep_synth(model,teacher,dev,d,C,L)
+ diag=compute_diagnostics(model,teacher,dev,d,C,L)
+ result={'method':'ep','alpha':args.alpha,'depth':L,'seed':args.seed,
+ 'acc':diag['acc'],'Gamma':diag['Gamma'],'rho':diag['rho']}
+ out=os.path.join(args.output_dir,f'ep_a{args.alpha}_L{L}_s{args.seed}.json')
+ with open(out,'w') as f:json.dump(result,f,indent=2,default=float)
+ print(f" acc={diag['acc']:.4f} Gamma={diag['Gamma']:.4f} rho={diag['rho']:.4f}",flush=True)
+
+if __name__=='__main__':main()