summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-02 11:29:51 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-02 11:29:51 -0500
commit100ae0eb429774ae9cd4f3085de7c87bf9d56d45 (patch)
tree924cb02ebf69c8f1b7765e41bca1744fb440c98a /experiments
parentef80d52840a1c6fb7f9a22985784ce311edc59a4 (diff)
Add GELU/ReLU ablation script for CIFAR MLP
Note: existing ResidualMLP already uses GELU. This adds ResidualMLPReLU variant. Ablation compares ReLU vs GELU for BP/DFA/SB/CB. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
-rw-r--r--experiments/gelu_ablation.py254
1 files changed, 254 insertions, 0 deletions
diff --git a/experiments/gelu_ablation.py b/experiments/gelu_ablation.py
new file mode 100644
index 0000000..bef35b1
--- /dev/null
+++ b/experiments/gelu_ablation.py
@@ -0,0 +1,254 @@
+"""
+GELU activation ablation: replace ReLU with GELU in ResidualMLP.
+Run BP/DFA/SB/CB on CIFAR-10, L=4, d=256, independent process per method+seed.
+Usage: python gelu_ablation.py --method bp --seed 42 --gpu 0
+"""
+import os, sys, json, argparse, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+# NOTE: ResidualMLP already uses GELU! Check:
+from models.residual_mlp import ResidualBlock, ResidualMLP
+# ResidualBlock.forward: z = self.w2(F.gelu(self.w1(self.ln(h))))
+# So the default architecture IS GELU. The ablation should test ReLU instead.
+
+class ResidualBlockReLU(nn.Module):
+ """ResidualBlock with ReLU instead of GELU."""
+ def __init__(self, d_hidden):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.w1 = nn.Linear(d_hidden, d_hidden)
+ self.w2 = nn.Linear(d_hidden, d_hidden)
+ nn.init.normal_(self.w2.weight, std=0.01)
+ nn.init.zeros_(self.w2.bias)
+ def forward(self, h):
+ z = self.ln(h)
+ z = self.w1(z)
+ z = F.relu(z) # ReLU instead of GELU
+ z = self.w2(z)
+ return z
+
+class ResidualMLPReLU(nn.Module):
+ """ResidualMLP with ReLU blocks."""
+ def __init__(self, input_dim, d_hidden, num_classes, num_blocks):
+ super().__init__()
+ self.embed = nn.Linear(input_dim, d_hidden)
+ self.blocks = nn.ModuleList([ResidualBlockReLU(d_hidden) for _ in range(num_blocks)])
+ self.out_ln = nn.LayerNorm(d_hidden)
+ self.out_head = nn.Linear(d_hidden, num_classes)
+ self.num_blocks = num_blocks
+ self.d_hidden = d_hidden
+ def forward(self, x, return_hidden=False):
+ h = self.embed(x)
+ hiddens = [h] if return_hidden else None
+ for block in self.blocks:
+ h = h + block(h)
+ if return_hidden: hiddens.append(h)
+ logits = self.out_head(self.out_ln(h))
+ return (logits, hiddens) if return_hidden else logits
+ def forward_from_layer(self, h, start_layer):
+ for i in range(start_layer, self.num_blocks):
+ h = h + self.blocks[i](h)
+ return self.out_head(self.out_ln(h))
+
+# Import training functions from cifar_d512_confirmatory and adapt
+from models.value_net import ValueNet, create_ema_model, update_ema
+from models.state_bridge import StateBridgeNet
+from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation
+import torchvision, torchvision.transforms as transforms
+
+def get_cifar10(bs=128):
+ tt=transforms.Compose([transforms.RandomCrop(32,4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))])
+ tv=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))])
+ return (DataLoader(torchvision.datasets.CIFAR10('./data',True,download=True,transform=tt),bs,True,num_workers=4,pin_memory=True),
+ DataLoader(torchvision.datasets.CIFAR10('./data',False,download=True,transform=tv),bs,False,num_workers=4,pin_memory=True))
+
+def evaluate(m,tl,dev):
+ m.eval();c,t=0,0
+ with torch.no_grad():
+ for x,y in tl:x=x.view(x.size(0),-1).to(dev);y=y.to(dev);c+=(m(x).argmax(1)==y).sum().item();t+=x.size(0)
+ return c/t
+
+# Reuse exact training functions from cifar_d512_confirmatory but with model_cls parameter
+def train_bp(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01):
+ opt=optim.AdamW(model.parameters(),lr=lr,weight_decay=wd)
+ sch=optim.lr_scheduler.CosineAnnealingLR(opt,T_max=epochs)
+ for ep in range(1,epochs+1):
+ model.train()
+ for x,y in trl:
+ x=x.view(x.size(0),-1).to(dev);y=y.to(dev)
+ F.cross_entropy(model(x),y).backward();opt.step();opt.zero_grad()
+ sch.step()
+ if ep%20==0:print(f" Ep {ep}: acc={evaluate(model,tel,dev):.4f}",flush=True)
+
+def train_dfa(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01):
+ d=model.d_hidden;L=model.num_blocks;C=10
+ Bs=[torch.randn(d,C,device=dev)/np.sqrt(C) for _ in range(L)]
+ bops=[optim.AdamW(b.parameters(),lr=lr,weight_decay=wd) for b in model.blocks]
+ eop=optim.AdamW(model.embed.parameters(),lr=lr,weight_decay=wd)
+ hop=optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()),lr=lr,weight_decay=wd)
+ schs=[optim.lr_scheduler.CosineAnnealingLR(o,T_max=epochs) for o in bops]+[optim.lr_scheduler.CosineAnnealingLR(eop,T_max=epochs),optim.lr_scheduler.CosineAnnealingLR(hop,T_max=epochs)]
+ for ep in range(1,epochs+1):
+ model.train()
+ for x,y in trl:
+ x=x.view(x.size(0),-1).to(dev);y=y.to(dev);b=x.size(0)
+ with torch.no_grad():lo,hi=model(x,return_hidden=True);eT=lo.softmax(-1);eT[torch.arange(b),y]-=1
+ hL=hi[-1].detach()
+ F.cross_entropy(model.out_head(model.out_ln(hL)),y).backward();hop.step();hop.zero_grad()
+ for l in range(L):
+ a=(eT@Bs[l].T).detach();rm=(a**2).mean(-1,keepdim=True).sqrt()+1e-6
+ ll=(model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean()
+ bops[l].zero_grad();ll.backward();torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0);bops[l].step()
+ a0=(eT@Bs[0].T).detach();r0=(a0**2).mean(-1,keepdim=True).sqrt()+1e-6
+ (model.embed(x)*(a0/r0)).sum(-1).mean().backward();eop.step();eop.zero_grad()
+ for s in schs:s.step()
+ if ep%20==0:print(f" Ep {ep}: acc={evaluate(model,tel,dev):.4f}",flush=True)
+
+def train_state_bridge(model, trl, tel, dev, epochs=100, lr=1e-3, lr_fb=1e-3, wd=0.01):
+ d=model.d_hidden;L=model.num_blocks;C=10
+ sp=StateBridgeNet(d_hidden=d,s_dim=C).to(dev)
+ bops=[optim.AdamW(b.parameters(),lr=lr,weight_decay=wd) for b in model.blocks]
+ eop=optim.AdamW(model.embed.parameters(),lr=lr,weight_decay=wd)
+ hop=optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()),lr=lr,weight_decay=wd)
+ sop=optim.Adam(sp.parameters(),lr=lr_fb)
+ schs=[optim.lr_scheduler.CosineAnnealingLR(o,T_max=epochs) for o in bops]+[optim.lr_scheduler.CosineAnnealingLR(eop,T_max=epochs),optim.lr_scheduler.CosineAnnealingLR(hop,T_max=epochs)]
+ se_final=0
+ for ep in range(1,epochs+1):
+ model.train();sp.train();se_ep=0;n_ep=0
+ for x,y in trl:
+ x=x.view(x.size(0),-1).to(dev);y=y.to(dev);b=x.size(0)
+ with torch.no_grad():lo,hi=model(x,return_hidden=True);eT=lo.softmax(-1);eT[torch.arange(b),y]-=1;s=eT.detach()
+ hL=hi[-1].detach()
+ sl=0.0
+ for l in range(L):
+ tl=torch.full((b,),l/L,device=dev);pred=sp(hi[l].detach(),tl,s)
+ tn=hL.norm(-1,keepdim=True).clamp(min=1.0);sl+=(((pred-hL)/tn)**2).sum(-1).mean()
+ sl/=L;sop.zero_grad();sl.backward();sop.step();se_ep+=sl.item()*b;n_ep+=b
+ credits=[]
+ for l in range(L):
+ hl=hi[l].detach().requires_grad_(True);tl=torch.full((b,),l/L,device=dev)
+ pl=F.cross_entropy(model.out_head(model.out_ln(sp(hl,tl,s))),y,reduction='sum')
+ credits.append(torch.autograd.grad(pl,hl,create_graph=False)[0].detach())
+ F.cross_entropy(model.out_head(model.out_ln(hL)),y).backward();hop.step();hop.zero_grad()
+ for l in range(L):
+ a=credits[l];rm=(a**2).mean(-1,keepdim=True).sqrt()+1e-6
+ ll=(model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean()
+ bops[l].zero_grad();ll.backward();torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0);bops[l].step()
+ a0=credits[0];r0=(a0**2).mean(-1,keepdim=True).sqrt()+1e-6
+ (model.embed(x)*(a0/r0)).sum(-1).mean().backward();eop.step();eop.zero_grad()
+ for s in schs:s.step()
+ se_final=se_ep/n_ep
+ if ep%20==0:print(f" Ep {ep}: acc={evaluate(model,tel,dev):.4f} se={se_final:.6f}",flush=True)
+ return se_final
+
+def train_credit_bridge(model, trl, tel, dev, epochs=100, lr=1e-3, lr_fb=1e-3, wd=0.01):
+ d=model.d_hidden;L=model.num_blocks;C=10
+ vn=ValueNet(d_hidden=d,s_dim=C).to(dev);ve=create_ema_model(vn)
+ Bs=[torch.randn(d,C,device=dev)/np.sqrt(C) for _ in range(L)]
+ bops=[optim.AdamW(b.parameters(),lr=lr,weight_decay=wd) for b in model.blocks]
+ eop=optim.AdamW(model.embed.parameters(),lr=lr,weight_decay=wd)
+ hop=optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()),lr=lr,weight_decay=wd)
+ vop=optim.Adam(vn.parameters(),lr=lr_fb)
+ schs=[optim.lr_scheduler.CosineAnnealingLR(o,T_max=epochs) for o in bops]+[optim.lr_scheduler.CosineAnnealingLR(eop,T_max=epochs),optim.lr_scheduler.CosineAnnealingLR(hop,T_max=epochs)]
+ warmup=max(1,epochs//5)
+ for ep in range(1,epochs+1):
+ model.train();vn.train()
+ blend=0.0 if ep<=warmup else min(1.0,(ep-warmup)/max(1,warmup))
+ for x,y in trl:
+ x=x.view(x.size(0),-1).to(dev);y=y.to(dev);b=x.size(0)
+ with torch.no_grad():lo,hi=model(x,return_hidden=True);eT=lo.softmax(-1);eT[torch.arange(b),y]-=1;s=eT.detach();tlv=F.cross_entropy(lo,y,reduction='none').detach()
+ hL=hi[-1].detach();t_L=torch.ones(b,device=dev)
+ lt=((vn(hL,t_L,s)-tlv)**2).mean()
+ hLr=hL.clone().requires_grad_(True);VL=vn(hLr,t_L,s);gV=torch.autograd.grad(VL.sum(),hLr,create_graph=True)[0]
+ hLr2=hL.clone().requires_grad_(True);ce=F.cross_entropy(model.out_head(model.out_ln(hLr2)),y,reduction='sum')
+ aLe=torch.autograd.grad(ce,hLr2,create_graph=False)[0].detach()
+ ltg=((gV-aLe)**2).sum(-1).mean()
+ lb=0.0
+ for l in range(L):
+ hl=hi[l].detach();tl=torch.full((b,),l/L,device=dev);tn=torch.full((b,),(l+1)/L,device=dev)
+ Vl=vn(hl,tl,s)
+ with torch.no_grad():
+ hn=hi[l+1].detach();lts=[]
+ for k in range(4):lts.append(-ve(hn+0.05*torch.randn_like(hn),tn,s)/0.1)
+ Vt=-0.1*(torch.logsumexp(torch.stack(lts,-1),-1)-np.log(4))
+ lb+=((Vl-Vt.detach())**2).mean()
+ lb/=L;vl=lt+lb+1.0*ltg
+ vop.zero_grad();vl.backward();torch.nn.utils.clip_grad_norm_(vn.parameters(),1.0);vop.step()
+ update_ema(vn,ve,0.995)
+ cbc=[]
+ for l in range(L):
+ hl=hi[l].detach().requires_grad_(True);tl=torch.full((b,),l/L,device=dev)
+ Vl=vn(hl,tl,s);cbc.append(torch.autograd.grad(Vl.sum(),hl,create_graph=False)[0].detach())
+ dfac=[(eT@Bs[l].T).detach() for l in range(L)]
+ credits=[]
+ for l in range(L):
+ if blend>=1:credits.append(cbc[l])
+ elif blend<=0:credits.append(dfac[l])
+ else:
+ cr=(cbc[l]**2).mean(-1,keepdim=True).sqrt()+1e-6;dr=(dfac[l]**2).mean(-1,keepdim=True).sqrt()+1e-6
+ credits.append(blend*cbc[l]/cr+(1-blend)*dfac[l]/dr)
+ F.cross_entropy(model.out_head(model.out_ln(hL)),y).backward();hop.step();hop.zero_grad()
+ for l in range(L):
+ a=credits[l];rm=(a**2).mean(-1,keepdim=True).sqrt()+1e-6
+ ll=(model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean()
+ bops[l].zero_grad();ll.backward();torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0);bops[l].step()
+ a0=credits[0];r0=(a0**2).mean(-1,keepdim=True).sqrt()+1e-6
+ (model.embed(x)*(a0/r0)).sum(-1).mean().backward();eop.step();eop.zero_grad()
+ for s in schs:s.step()
+ if ep%20==0:print(f" Ep {ep}: acc={evaluate(model,tel,dev):.4f}",flush=True)
+
+def compute_diagnostics(model, tel, dev):
+ model.eval();L=model.num_blocks
+ for x,y in tel:x=x.view(x.size(0),-1).to(dev);y=y.to(dev);break
+ h0=model.embed(x.detach());hs=[h0.clone().requires_grad_(True)]
+ for bl in model.blocks:hs.append(hs[-1]+bl(hs[-1]))
+ lo=model.out_head(model.out_ln(hs[-1]));loss=F.cross_entropy(lo,y)
+ gs=torch.autograd.grad(loss,hs);bp={l:gs[l].detach() for l in range(L)}
+ with torch.no_grad():_,hi=model(x,return_hidden=True)
+ nse=((hi[L//2]-hi[-1]).norm(-1)/hi[-1].norm(-1).clamp(min=1e-8)).mean().item()
+ rhos=[]
+ for l in range(L):
+ h_l=hi[l].detach();a_l=bp[l]
+ def mk(sl):
+ def f(h):
+ with torch.no_grad():
+ c=h
+ for i in range(sl,L):c=c+model.blocks[i](c)
+ return F.cross_entropy(model.out_head(model.out_ln(c)),y,reduction='none')
+ return f
+ rhos.append(perturbation_correlation(h_l,a_l,mk(l),epsilon=1e-3,M=16))
+ return {'Gamma':1.0,'rho':np.mean(rhos),'naive_StateErr':nse}
+
+def main():
+ p=argparse.ArgumentParser()
+ p.add_argument('--method',type=str,required=True,choices=['bp','dfa','state_bridge','credit_bridge'])
+ p.add_argument('--activation',type=str,default='relu',choices=['relu','gelu'])
+ p.add_argument('--seed',type=int,required=True)
+ p.add_argument('--gpu',type=int,default=0)
+ p.add_argument('--output_dir',type=str,default='results/gelu_ablation')
+ args=p.parse_args()
+ os.makedirs(args.output_dir,exist_ok=True)
+ dev=torch.device(f'cuda:{args.gpu}')
+ torch.manual_seed(args.seed);np.random.seed(args.seed);torch.cuda.manual_seed_all(args.seed)
+ trl,tel=get_cifar10()
+ L,d=4,256
+ if args.activation=='relu':
+ model=ResidualMLPReLU(3072,d,10,L).to(dev)
+ else:
+ model=ResidualMLP(3072,d,10,L).to(dev)
+ print(f"[{args.activation}_{args.method} s={args.seed}] Training...",flush=True)
+ se=None
+ if args.method=='bp':train_bp(model,trl,tel,dev)
+ elif args.method=='dfa':train_dfa(model,trl,tel,dev)
+ elif args.method=='state_bridge':se=train_state_bridge(model,trl,tel,dev)
+ elif args.method=='credit_bridge':train_credit_bridge(model,trl,tel,dev)
+ acc=evaluate(model,tel,dev)
+ diag=compute_diagnostics(model,tel,dev)
+ torch.save(model.state_dict(),os.path.join(args.output_dir,f'{args.activation}_{args.method}_s{args.seed}.pt'))
+ result={'activation':args.activation,'method':args.method,'seed':args.seed,'acc':acc,
+ 'StateErr':se,'Gamma':diag['Gamma'],'rho':diag['rho'],'naive_StateErr':diag['naive_StateErr']}
+ with open(os.path.join(args.output_dir,f'{args.activation}_{args.method}_s{args.seed}.json'),'w') as f:
+ json.dump(result,f,indent=2,default=float)
+ print(f"[{args.activation}_{args.method} s={args.seed}] acc={acc:.4f} ρ={diag['rho']:.4f}",flush=True)
+
+if __name__=='__main__':main()