diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-01 17:31:27 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-01 17:31:27 -0500 |
| commit | e96cdc09a4f39473f6a71d8a8a2b7af56f1c0983 (patch) | |
| tree | 47270df2f58c8621bcc8523f6134abfcf03c6cfb /experiments | |
| parent | 32a27a62c3aefa5354c2968b3440ee3cdb32f43b (diff) | |
Add cifar_d512_confirmatory.py: L=4 d=512 with checkpoint saving
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
| -rw-r--r-- | experiments/cifar_d512_confirmatory.py | 210 |
1 files changed, 210 insertions, 0 deletions
diff --git a/experiments/cifar_d512_confirmatory.py b/experiments/cifar_d512_confirmatory.py new file mode 100644 index 0000000..1281b03 --- /dev/null +++ b/experiments/cifar_d512_confirmatory.py @@ -0,0 +1,210 @@ +""" +CIFAR-10 L=4 d=512 confirmatory: BP/DFA/State Bridge/Credit Bridge, 5 seeds. +One method+seed per invocation for clean process isolation. +Usage: python cifar_d512_confirmatory.py --method bp --seed 42 --gpu 0 +""" +import os, sys, json, csv, argparse, numpy as np, torch, torch.nn as nn, torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from models.residual_mlp import ResidualMLP +from models.value_net import ValueNet, create_ema_model, update_ema +from models.state_bridge import StateBridgeNet +from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation +import torchvision, torchvision.transforms as transforms + +def get_cifar10(bs=128): + tt=transforms.Compose([transforms.RandomCrop(32,4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))]) + tv=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))]) + return (DataLoader(torchvision.datasets.CIFAR10('./data',True,download=True,transform=tt),bs,True,num_workers=4,pin_memory=True), + DataLoader(torchvision.datasets.CIFAR10('./data',False,download=True,transform=tv),bs,False,num_workers=4,pin_memory=True)) + +def evaluate(m,tl,dev): + m.eval();c,t=0,0 + with torch.no_grad(): + for x,y in tl:x=x.view(x.size(0),-1).to(dev);y=y.to(dev);c+=(m(x).argmax(1)==y).sum().item();t+=x.size(0) + return c/t + +def train_bp(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01): + opt=optim.AdamW(model.parameters(),lr=lr,weight_decay=wd) + sch=optim.lr_scheduler.CosineAnnealingLR(opt,T_max=epochs) + for ep in range(1,epochs+1): + model.train() + for x,y in trl: + x=x.view(x.size(0),-1).to(dev);y=y.to(dev) + F.cross_entropy(model(x),y).backward();opt.step();opt.zero_grad() + sch.step() + if ep%20==0:print(f" Ep {ep}: acc={evaluate(model,tel,dev):.4f}",flush=True) + return model + +def train_dfa(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01): + d=model.d_hidden;L=model.num_blocks;C=10 + Bs=[torch.randn(d,C,device=dev)/np.sqrt(C) for _ in range(L)] + bops=[optim.AdamW(b.parameters(),lr=lr,weight_decay=wd) for b in model.blocks] + eop=optim.AdamW(model.embed.parameters(),lr=lr,weight_decay=wd) + hop=optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()),lr=lr,weight_decay=wd) + schs=[optim.lr_scheduler.CosineAnnealingLR(o,T_max=epochs) for o in bops]+[optim.lr_scheduler.CosineAnnealingLR(eop,T_max=epochs),optim.lr_scheduler.CosineAnnealingLR(hop,T_max=epochs)] + for ep in range(1,epochs+1): + model.train() + for x,y in trl: + x=x.view(x.size(0),-1).to(dev);y=y.to(dev);b=x.size(0) + with torch.no_grad():lo,hi=model(x,return_hidden=True);eT=lo.softmax(-1);eT[torch.arange(b),y]-=1 + hL=hi[-1].detach() + F.cross_entropy(model.out_head(model.out_ln(hL)),y).backward();hop.step();hop.zero_grad() + for l in range(L): + a=(eT@Bs[l].T).detach();rm=(a**2).mean(-1,keepdim=True).sqrt()+1e-6 + ll=(model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean() + bops[l].zero_grad();ll.backward();torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0);bops[l].step() + a0=(eT@Bs[0].T).detach();r0=(a0**2).mean(-1,keepdim=True).sqrt()+1e-6 + (model.embed(x)*(a0/r0)).sum(-1).mean().backward();eop.step();eop.zero_grad() + for s in schs:s.step() + if ep%20==0:print(f" Ep {ep}: acc={evaluate(model,tel,dev):.4f}",flush=True) + return model + +def train_state_bridge(model, trl, tel, dev, epochs=100, lr=1e-3, lr_fb=1e-3, wd=0.01): + d=model.d_hidden;L=model.num_blocks;C=10 + sp=StateBridgeNet(d_hidden=d,s_dim=C).to(dev) + bops=[optim.AdamW(b.parameters(),lr=lr,weight_decay=wd) for b in model.blocks] + eop=optim.AdamW(model.embed.parameters(),lr=lr,weight_decay=wd) + hop=optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()),lr=lr,weight_decay=wd) + sop=optim.Adam(sp.parameters(),lr=lr_fb) + schs=[optim.lr_scheduler.CosineAnnealingLR(o,T_max=epochs) for o in bops]+[optim.lr_scheduler.CosineAnnealingLR(eop,T_max=epochs),optim.lr_scheduler.CosineAnnealingLR(hop,T_max=epochs)] + se_sum=0 + for ep in range(1,epochs+1): + model.train();sp.train();se_ep=0;n_ep=0 + for x,y in trl: + x=x.view(x.size(0),-1).to(dev);y=y.to(dev);b=x.size(0) + with torch.no_grad():lo,hi=model(x,return_hidden=True);eT=lo.softmax(-1);eT[torch.arange(b),y]-=1;s=eT.detach() + hL=hi[-1].detach() + sl=0.0 + for l in range(L): + tl=torch.full((b,),l/L,device=dev);pred=sp(hi[l].detach(),tl,s) + tn=hL.norm(-1,keepdim=True).clamp(min=1.0);sl+=(((pred-hL)/tn)**2).sum(-1).mean() + sl/=L;sop.zero_grad();sl.backward();sop.step();se_ep+=sl.item()*b;n_ep+=b + credits=[] + for l in range(L): + hl=hi[l].detach().requires_grad_(True);tl=torch.full((b,),l/L,device=dev) + pl=F.cross_entropy(model.out_head(model.out_ln(sp(hl,tl,s))),y,reduction='sum') + credits.append(torch.autograd.grad(pl,hl,create_graph=False)[0].detach()) + F.cross_entropy(model.out_head(model.out_ln(hL)),y).backward();hop.step();hop.zero_grad() + for l in range(L): + a=credits[l];rm=(a**2).mean(-1,keepdim=True).sqrt()+1e-6 + ll=(model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean() + bops[l].zero_grad();ll.backward();torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0);bops[l].step() + a0=credits[0];r0=(a0**2).mean(-1,keepdim=True).sqrt()+1e-6 + (model.embed(x)*(a0/r0)).sum(-1).mean().backward();eop.step();eop.zero_grad() + for s in schs:s.step() + se_sum=se_ep/n_ep + if ep%20==0:print(f" Ep {ep}: acc={evaluate(model,tel,dev):.4f} se={se_sum:.6f}",flush=True) + return model, se_sum + +def train_credit_bridge(model, trl, tel, dev, epochs=100, lr=1e-3, lr_fb=1e-3, wd=0.01): + d=model.d_hidden;L=model.num_blocks;C=10 + vn=ValueNet(d_hidden=d,s_dim=C).to(dev);ve=create_ema_model(vn) + Bs=[torch.randn(d,C,device=dev)/np.sqrt(C) for _ in range(L)] + bops=[optim.AdamW(b.parameters(),lr=lr,weight_decay=wd) for b in model.blocks] + eop=optim.AdamW(model.embed.parameters(),lr=lr,weight_decay=wd) + hop=optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()),lr=lr,weight_decay=wd) + vop=optim.Adam(vn.parameters(),lr=lr_fb) + schs=[optim.lr_scheduler.CosineAnnealingLR(o,T_max=epochs) for o in bops]+[optim.lr_scheduler.CosineAnnealingLR(eop,T_max=epochs),optim.lr_scheduler.CosineAnnealingLR(hop,T_max=epochs)] + warmup=max(1,epochs//5) + for ep in range(1,epochs+1): + model.train();vn.train() + blend=0.0 if ep<=warmup else min(1.0,(ep-warmup)/max(1,warmup)) + for x,y in trl: + x=x.view(x.size(0),-1).to(dev);y=y.to(dev);b=x.size(0) + with torch.no_grad():lo,hi=model(x,return_hidden=True);eT=lo.softmax(-1);eT[torch.arange(b),y]-=1;s=eT.detach();tlv=F.cross_entropy(lo,y,reduction='none').detach() + hL=hi[-1].detach();t_L=torch.ones(b,device=dev) + lt=((vn(hL,t_L,s)-tlv)**2).mean() + hLr=hL.clone().requires_grad_(True);VL=vn(hLr,t_L,s);gV=torch.autograd.grad(VL.sum(),hLr,create_graph=True)[0] + hLr2=hL.clone().requires_grad_(True);ce=F.cross_entropy(model.out_head(model.out_ln(hLr2)),y,reduction='sum') + aLe=torch.autograd.grad(ce,hLr2,create_graph=False)[0].detach() + ltg=((gV-aLe)**2).sum(-1).mean() + lb=0.0 + for l in range(L): + hl=hi[l].detach();tl=torch.full((b,),l/L,device=dev);tn=torch.full((b,),(l+1)/L,device=dev) + Vl=vn(hl,tl,s) + with torch.no_grad(): + hn=hi[l+1].detach();lts=[] + for k in range(4):lts.append(-ve(hn+0.05*torch.randn_like(hn),tn,s)/0.1) + Vt=-0.1*(torch.logsumexp(torch.stack(lts,-1),-1)-np.log(4)) + lb+=((Vl-Vt.detach())**2).mean() + lb/=L;vl=lt+lb+1.0*ltg + vop.zero_grad();vl.backward();torch.nn.utils.clip_grad_norm_(vn.parameters(),1.0);vop.step() + update_ema(vn,ve,0.995) + cbc=[] + for l in range(L): + hl=hi[l].detach().requires_grad_(True);tl=torch.full((b,),l/L,device=dev) + Vl=vn(hl,tl,s);cbc.append(torch.autograd.grad(Vl.sum(),hl,create_graph=False)[0].detach()) + dfac=[(eT@Bs[l].T).detach() for l in range(L)] + credits=[] + for l in range(L): + if blend>=1:credits.append(cbc[l]) + elif blend<=0:credits.append(dfac[l]) + else: + cr=(cbc[l]**2).mean(-1,keepdim=True).sqrt()+1e-6;dr=(dfac[l]**2).mean(-1,keepdim=True).sqrt()+1e-6 + credits.append(blend*cbc[l]/cr+(1-blend)*dfac[l]/dr) + F.cross_entropy(model.out_head(model.out_ln(hL)),y).backward();hop.step();hop.zero_grad() + for l in range(L): + a=credits[l];rm=(a**2).mean(-1,keepdim=True).sqrt()+1e-6 + ll=(model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean() + bops[l].zero_grad();ll.backward();torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0);bops[l].step() + a0=credits[0];r0=(a0**2).mean(-1,keepdim=True).sqrt()+1e-6 + (model.embed(x)*(a0/r0)).sum(-1).mean().backward();eop.step();eop.zero_grad() + for s in schs:s.step() + if ep%20==0:print(f" Ep {ep}: acc={evaluate(model,tel,dev):.4f}",flush=True) + return model + +def compute_diagnostics(model, tel, dev): + model.eval();L=model.num_blocks;d=model.d_hidden + for x,y in tel:x=x.view(x.size(0),-1).to(dev);y=y.to(dev);break + b=x.size(0) + h0=model.embed(x.detach());hs=[h0.clone().requires_grad_(True)] + for bl in model.blocks:hs.append(hs[-1]+bl(hs[-1])) + lo=model.out_head(model.out_ln(hs[-1]));loss=F.cross_entropy(lo,y) + gs=torch.autograd.grad(loss,hs);bp={l:gs[l].detach() for l in range(L)} + with torch.no_grad():_,hi=model(x,return_hidden=True);nse=((hi[L//2]-hi[-1]).norm(-1)/hi[-1].norm(-1).clamp(min=1e-8)).mean().item() + gammas,rhos=[],[] + for l in range(L): + gammas.append(1.0) # BP self-cosine + h_l=hi[l].detach();a_l=bp[l] + def mk(sl): + def f(h): + with torch.no_grad(): + c=h + for i in range(sl,L):c=c+model.blocks[i](c) + return F.cross_entropy(model.out_head(model.out_ln(c)),y,reduction='none') + return f + rhos.append(perturbation_correlation(h_l,a_l,mk(l),epsilon=1e-3,M=16)) + return {'Gamma':np.mean(gammas),'rho':np.mean(rhos),'naive_StateErr':nse} + +def main(): + p=argparse.ArgumentParser() + p.add_argument('--method',type=str,required=True) + p.add_argument('--seed',type=int,required=True) + p.add_argument('--gpu',type=int,default=0) + p.add_argument('--output_dir',type=str,default='results/confirmatory/cifar_d512') + args=p.parse_args() + os.makedirs(args.output_dir,exist_ok=True) + dev=torch.device(f'cuda:{args.gpu}') + torch.manual_seed(args.seed);np.random.seed(args.seed);torch.cuda.manual_seed_all(args.seed) + trl,tel=get_cifar10() + L,d=4,512 + model=ResidualMLP(3072,d,10,L).to(dev) + print(f"[{args.method} s={args.seed}] Training...",flush=True) + se=None + if args.method=='bp':model=train_bp(model,trl,tel,dev) + elif args.method=='dfa':model=train_dfa(model,trl,tel,dev) + elif args.method=='state_bridge':model,se=train_state_bridge(model,trl,tel,dev) + elif args.method=='credit_bridge':model=train_credit_bridge(model,trl,tel,dev) + acc=evaluate(model,tel,dev) + diag=compute_diagnostics(model,tel,dev) + # Save checkpoint + torch.save(model.state_dict(),os.path.join(args.output_dir,f'{args.method}_s{args.seed}.pt')) + result={'method':args.method,'seed':args.seed,'acc':acc,'StateErr':se, + 'Gamma':diag['Gamma'],'rho':diag['rho'],'naive_StateErr':diag['naive_StateErr']} + with open(os.path.join(args.output_dir,f'{args.method}_s{args.seed}.json'),'w') as f: + json.dump(result,f,indent=2,default=float) + print(f"[{args.method} s={args.seed}] acc={acc:.4f} Γ={diag['Gamma']:.4f} ρ={diag['rho']:.4f} nse={diag['naive_StateErr']:.4f}",flush=True) + +if __name__=='__main__':main() |
