"""Phase 10A.8C: 3-seed minimal scaffold replication.""" import os, sys, json, argparse, numpy as np, torch, torch.nn as nn, torch.nn.functional as F import torch.optim as optim, copy from torch.utils.data import DataLoader import torchvision, torchvision.transforms as transforms sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.residual_mlp import ResidualMLP from models.value_net import SinusoidalTimeEmbed class VectorCreditNet(nn.Module): def __init__(self, d, s_dim): super().__init__() self.ln=nn.LayerNorm(d);self.te=SinusoidalTimeEmbed(32) layers=[];ind=d+32+s_dim for i in range(3):layers+=[nn.Linear(ind if i==0 else 256,256),nn.GELU()] layers.append(nn.Linear(256,d));self.net=nn.Sequential(*layers) def forward(self,h,t,s):return self.net(torch.cat([self.ln(h),self.te(t),s],-1)) class PerLayerVector(nn.Module): def __init__(self,d,L): super().__init__() self.vecs=nn.ParameterList([nn.Parameter(torch.randn(d)*0.01) for _ in range(L)]) self._b=0 def set_block(self,l):self._b=l def forward(self,h,t,s):return self.vecs[self._b].unsqueeze(0).expand(h.size(0),-1) def get_cifar10(bs=128): tt=transforms.Compose([transforms.RandomCrop(32,4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))]) tv=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))]) return (DataLoader(torchvision.datasets.CIFAR10('./data',True,download=True,transform=tt),bs,True,num_workers=4,pin_memory=True), DataLoader(torchvision.datasets.CIFAR10('./data',False,download=True,transform=tv),bs,False,num_workers=4,pin_memory=True)) def evaluate(m,tl,dev): m.eval();c,t=0,0 with torch.no_grad(): for x,y in tl:x=x.view(x.size(0),-1).to(dev);y=y.to(dev);c+=(m(x).argmax(1)==y).sum().item();t+=x.size(0) return c/t def run_branch(seed,branch,gpu,epochs=100,t0=5,alpha=0.75,lr=1e-3,wd=0.01,lr_fb=1e-3,M=4): dev=torch.device(f'cuda:{gpu}') torch.manual_seed(seed);np.random.seed(seed);torch.cuda.manual_seed_all(seed) trl,tel=get_cifar10() L,d=4,256;inp=3072 model=ResidualMLP(inp,d,10,L).to(dev) Bs=[torch.randn(d,10,device=dev)/np.sqrt(10) for _ in range(L)] bops=[optim.AdamW(b.parameters(),lr=lr,weight_decay=wd) for b in model.blocks] eop=optim.AdamW(model.embed.parameters(),lr=lr,weight_decay=wd) hop=optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()),lr=lr,weight_decay=wd) schs=[optim.lr_scheduler.CosineAnnealingLR(o,T_max=epochs) for o in bops]+[optim.lr_scheduler.CosineAnnealingLR(eop,T_max=epochs),optim.lr_scheduler.CosineAnnealingLR(hop,T_max=epochs)] # DFA warmup to t0 for ep in range(1,t0+1): model.train() for x,y in trl: x=x.view(x.size(0),-1).to(dev);y=y.to(dev);b=x.size(0) with torch.no_grad():lo,hi=model(x,return_hidden=True);eT=lo.softmax(-1);eT[torch.arange(b),y]-=1 hL=hi[-1].detach() F.cross_entropy(model.out_head(model.out_ln(hL)),y).backward();hop.step();hop.zero_grad() for l in range(L): a=(eT@Bs[l].T).detach();rm=(a**2).mean(-1,keepdim=True).sqrt()+1e-6 ll=(model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean() bops[l].zero_grad();ll.backward();torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0);bops[l].step() a0=(eT@Bs[0].T).detach();r0=(a0**2).mean(-1,keepdim=True).sqrt()+1e-6 el=(model.embed(x)*(a0/r0)).sum(-1).mean();eop.zero_grad();el.backward();eop.step() for s in schs:s.step() ckpt_state=copy.deepcopy(model.state_dict()) ckpt_Bs=[B.clone() for B in Bs] # Reset and reload for branch model.load_state_dict(ckpt_state);Bs=[B.clone() for B in ckpt_Bs] bops=[optim.AdamW(b.parameters(),lr=lr,weight_decay=wd) for b in model.blocks] eop=optim.AdamW(model.embed.parameters(),lr=lr,weight_decay=wd) hop=optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()),lr=lr,weight_decay=wd) schs=[optim.lr_scheduler.CosineAnnealingLR(o,T_max=epochs) for o in bops]+[optim.lr_scheduler.CosineAnnealingLR(eop,T_max=epochs),optim.lr_scheduler.CosineAnnealingLR(hop,T_max=epochs)] for _ in range(t0): for s in schs:s.step() aux=None;aop=None if branch=='perlayer': aux=PerLayerVector(d,L).to(dev);aop=optim.Adam(aux.parameters(),lr=lr_fb) elif branch=='vec': aux=VectorCreditNet(d,10).to(dev);aop=optim.Adam(aux.parameters(),lr=lr_fb) eps=1e-3;accs=[] for ep in range(t0+1,epochs+1): model.train() if aux:aux.train() for x,y in trl: x=x.view(x.size(0),-1).to(dev);y=y.to(dev);batch=x.size(0) with torch.no_grad():lo,hi=model(x,return_hidden=True);eT=lo.softmax(-1);eT[torch.arange(batch),y]-=1;s=eT.detach() hL=hi[-1].detach() if aux and aop: tL=torch.ones(batch,device=dev) if branch=='perlayer':aux.set_block(L-1) at=aux(hL,tL,s) hLr=hL.clone().requires_grad_(True) dL=torch.autograd.grad(F.cross_entropy(model.out_head(model.out_ln(hLr)),y,reduction='sum'),hLr,create_graph=False)[0].detach() lt=((at-dL)**2).sum(-1).mean() ll=np.random.randint(0,L);hl=hi[ll].detach();tl=torch.full((batch,),ll/L,device=dev) if branch=='perlayer':aux.set_block(ll) al=aux(hl,tl,s);lp=torch.tensor(0.0,device=dev) for _ in range(M): v=torch.randn_like(hl);v=v/(v.norm(-1,keepdim=True)+1e-8) with torch.no_grad(): gj=(F.cross_entropy(model.forward_from_layer(hl+eps*v,ll),y,reduction='none')-F.cross_entropy(model.forward_from_layer(hl-eps*v,ll),y,reduction='none'))/(2*eps) lp=lp+(((al*v).sum(-1)-gj.detach())**2).mean() lp/=M;vl=lt+lp;aop.zero_grad();vl.backward();torch.nn.utils.clip_grad_norm_(aux.parameters(),1.0);aop.step() dfa_c=[(eT@Bs[l].T).detach() for l in range(L)] creds=[] for l in range(L): if aux and alpha>0: if branch=='perlayer':aux.set_block(l) with torch.no_grad():av=aux(hi[l].detach(),torch.full((batch,),l/L,device=dev),s).detach() rv=(av**2).mean(-1,keepdim=True).sqrt()+1e-6;rd=(dfa_c[l]**2).mean(-1,keepdim=True).sqrt()+1e-6 creds.append(alpha*av/rv+(1-alpha)*dfa_c[l]/rd) else:creds.append(dfa_c[l]) lo2=F.cross_entropy(model.out_head(model.out_ln(hL)),y);hop.zero_grad();lo2.backward();hop.step() for l in range(L): a=creds[l];rm=(a**2).mean(-1,keepdim=True).sqrt()+1e-6 ll2=(model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean() bops[l].zero_grad();ll2.backward();torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0);bops[l].step() a0=creds[0];r0=(a0**2).mean(-1,keepdim=True).sqrt()+1e-6 el=(model.embed(x)*(a0/r0)).sum(-1).mean();eop.zero_grad();el.backward();eop.step() for sc in schs:sc.step() acc=evaluate(model,tel,dev);accs.append(acc) return accs def main(): p=argparse.ArgumentParser() p.add_argument('--gpu',type=int,default=2) p.add_argument('--seeds',type=int,nargs='+',default=[42,123,456]) p.add_argument('--output_dir',type=str,default='results/scaffold_replication') a=p.parse_args() os.makedirs(a.output_dir,exist_ok=True) results={} for branch in ['dfa','perlayer','vec']: results[branch]={'final':[],'acc20':[]} for seed in a.seeds: alpha=0.75 if branch!='dfa' else 0.0 print(f" Running {branch} seed={seed}...") accs=run_branch(seed,branch,a.gpu,alpha=alpha) final=accs[-1];acc20=accs[14] if len(accs)>14 else accs[-1] # epoch 20 = index 14 (t0=5, so epoch 20 is step 14) results[branch]['final'].append(final) results[branch]['acc20'].append(acc20) print(f" {branch} s={seed}: acc@20={acc20:.4f}, final={final:.4f}") print(f"\n{'='*60}\nSUMMARY (3-seed)\n{'='*60}") print(f"{'Method':<25} {'acc@20':>15} {'final':>15} {'diff vs DFA':>12}") print("-"*70) dfa_mean=np.mean(results['dfa']['final']) for b in ['dfa','perlayer','vec']: m20=np.mean(results[b]['acc20']);s20=np.std(results[b]['acc20']) mf=np.mean(results[b]['final']);sf=np.std(results[b]['final']) diff=mf-dfa_mean print(f"{b:<25} {m20:.4f}±{s20:.4f} {mf:.4f}±{sf:.4f} {diff:+.4f}") with open(os.path.join(a.output_dir,'replication.json'),'w') as f:json.dump(results,f,indent=2,default=float) print(f"\nSaved to {a.output_dir}/replication.json") if __name__=='__main__':main()