diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-03-27 18:07:58 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-03-27 18:07:58 -0500 |
| commit | 2a230acd5ee3fa6605892d524badf281ba7e9cfd (patch) | |
| tree | 9b3eadb60966b895a0349cbb457e6dab2004af47 /experiments | |
| parent | 4d6e689fe6bfffef6db7a4650aec210cd3eeed5c (diff) | |
Add Phase 10A.8C: 3-seed replication — scaffold gains are marginal
3-seed results (mean±std):
- DFA: 0.306±0.006
- perlayer_vector α=0.75: 0.304±0.006 (-0.2%, not significant)
- random_trainable α=0.75: 0.313±0.007 (+0.7%, marginal, error bars overlap)
Single-seed gains (+1.1% perlayer, +0.8% vec) do not robustly replicate.
The scaffold mechanism provides at best a marginal, statistically uncertain benefit.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
| -rw-r--r-- | experiments/minimal_scaffold_replication.py | 157 |
1 files changed, 157 insertions, 0 deletions
diff --git a/experiments/minimal_scaffold_replication.py b/experiments/minimal_scaffold_replication.py new file mode 100644 index 0000000..c2dbb0c --- /dev/null +++ b/experiments/minimal_scaffold_replication.py @@ -0,0 +1,157 @@ +"""Phase 10A.8C: 3-seed minimal scaffold replication.""" +import os, sys, json, argparse, numpy as np, torch, torch.nn as nn, torch.nn.functional as F +import torch.optim as optim, copy +from torch.utils.data import DataLoader +import torchvision, torchvision.transforms as transforms +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from models.residual_mlp import ResidualMLP +from models.value_net import SinusoidalTimeEmbed + +class VectorCreditNet(nn.Module): + def __init__(self, d, s_dim): + super().__init__() + self.ln=nn.LayerNorm(d);self.te=SinusoidalTimeEmbed(32) + layers=[];ind=d+32+s_dim + for i in range(3):layers+=[nn.Linear(ind if i==0 else 256,256),nn.GELU()] + layers.append(nn.Linear(256,d));self.net=nn.Sequential(*layers) + def forward(self,h,t,s):return self.net(torch.cat([self.ln(h),self.te(t),s],-1)) + +class PerLayerVector(nn.Module): + def __init__(self,d,L): + super().__init__() + self.vecs=nn.ParameterList([nn.Parameter(torch.randn(d)*0.01) for _ in range(L)]) + self._b=0 + def set_block(self,l):self._b=l + def forward(self,h,t,s):return self.vecs[self._b].unsqueeze(0).expand(h.size(0),-1) + +def get_cifar10(bs=128): + tt=transforms.Compose([transforms.RandomCrop(32,4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))]) + tv=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))]) + return (DataLoader(torchvision.datasets.CIFAR10('./data',True,download=True,transform=tt),bs,True,num_workers=4,pin_memory=True), + DataLoader(torchvision.datasets.CIFAR10('./data',False,download=True,transform=tv),bs,False,num_workers=4,pin_memory=True)) + +def evaluate(m,tl,dev): + m.eval();c,t=0,0 + with torch.no_grad(): + for x,y in tl:x=x.view(x.size(0),-1).to(dev);y=y.to(dev);c+=(m(x).argmax(1)==y).sum().item();t+=x.size(0) + return c/t + +def run_branch(seed,branch,gpu,epochs=100,t0=5,alpha=0.75,lr=1e-3,wd=0.01,lr_fb=1e-3,M=4): + dev=torch.device(f'cuda:{gpu}') + torch.manual_seed(seed);np.random.seed(seed);torch.cuda.manual_seed_all(seed) + trl,tel=get_cifar10() + L,d=4,256;inp=3072 + model=ResidualMLP(inp,d,10,L).to(dev) + Bs=[torch.randn(d,10,device=dev)/np.sqrt(10) for _ in range(L)] + bops=[optim.AdamW(b.parameters(),lr=lr,weight_decay=wd) for b in model.blocks] + eop=optim.AdamW(model.embed.parameters(),lr=lr,weight_decay=wd) + hop=optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()),lr=lr,weight_decay=wd) + schs=[optim.lr_scheduler.CosineAnnealingLR(o,T_max=epochs) for o in bops]+[optim.lr_scheduler.CosineAnnealingLR(eop,T_max=epochs),optim.lr_scheduler.CosineAnnealingLR(hop,T_max=epochs)] + # DFA warmup to t0 + for ep in range(1,t0+1): + model.train() + for x,y in trl: + x=x.view(x.size(0),-1).to(dev);y=y.to(dev);b=x.size(0) + with torch.no_grad():lo,hi=model(x,return_hidden=True);eT=lo.softmax(-1);eT[torch.arange(b),y]-=1 + hL=hi[-1].detach() + F.cross_entropy(model.out_head(model.out_ln(hL)),y).backward();hop.step();hop.zero_grad() + for l in range(L): + a=(eT@Bs[l].T).detach();rm=(a**2).mean(-1,keepdim=True).sqrt()+1e-6 + ll=(model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean() + bops[l].zero_grad();ll.backward();torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0);bops[l].step() + a0=(eT@Bs[0].T).detach();r0=(a0**2).mean(-1,keepdim=True).sqrt()+1e-6 + el=(model.embed(x)*(a0/r0)).sum(-1).mean();eop.zero_grad();el.backward();eop.step() + for s in schs:s.step() + ckpt_state=copy.deepcopy(model.state_dict()) + ckpt_Bs=[B.clone() for B in Bs] + + # Reset and reload for branch + model.load_state_dict(ckpt_state);Bs=[B.clone() for B in ckpt_Bs] + bops=[optim.AdamW(b.parameters(),lr=lr,weight_decay=wd) for b in model.blocks] + eop=optim.AdamW(model.embed.parameters(),lr=lr,weight_decay=wd) + hop=optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()),lr=lr,weight_decay=wd) + schs=[optim.lr_scheduler.CosineAnnealingLR(o,T_max=epochs) for o in bops]+[optim.lr_scheduler.CosineAnnealingLR(eop,T_max=epochs),optim.lr_scheduler.CosineAnnealingLR(hop,T_max=epochs)] + for _ in range(t0): + for s in schs:s.step() + + aux=None;aop=None + if branch=='perlayer': + aux=PerLayerVector(d,L).to(dev);aop=optim.Adam(aux.parameters(),lr=lr_fb) + elif branch=='vec': + aux=VectorCreditNet(d,10).to(dev);aop=optim.Adam(aux.parameters(),lr=lr_fb) + + eps=1e-3;accs=[] + for ep in range(t0+1,epochs+1): + model.train() + if aux:aux.train() + for x,y in trl: + x=x.view(x.size(0),-1).to(dev);y=y.to(dev);batch=x.size(0) + with torch.no_grad():lo,hi=model(x,return_hidden=True);eT=lo.softmax(-1);eT[torch.arange(batch),y]-=1;s=eT.detach() + hL=hi[-1].detach() + if aux and aop: + tL=torch.ones(batch,device=dev) + if branch=='perlayer':aux.set_block(L-1) + at=aux(hL,tL,s) + hLr=hL.clone().requires_grad_(True) + dL=torch.autograd.grad(F.cross_entropy(model.out_head(model.out_ln(hLr)),y,reduction='sum'),hLr,create_graph=False)[0].detach() + lt=((at-dL)**2).sum(-1).mean() + ll=np.random.randint(0,L);hl=hi[ll].detach();tl=torch.full((batch,),ll/L,device=dev) + if branch=='perlayer':aux.set_block(ll) + al=aux(hl,tl,s);lp=torch.tensor(0.0,device=dev) + for _ in range(M): + v=torch.randn_like(hl);v=v/(v.norm(-1,keepdim=True)+1e-8) + with torch.no_grad(): + gj=(F.cross_entropy(model.forward_from_layer(hl+eps*v,ll),y,reduction='none')-F.cross_entropy(model.forward_from_layer(hl-eps*v,ll),y,reduction='none'))/(2*eps) + lp=lp+(((al*v).sum(-1)-gj.detach())**2).mean() + lp/=M;vl=lt+lp;aop.zero_grad();vl.backward();torch.nn.utils.clip_grad_norm_(aux.parameters(),1.0);aop.step() + dfa_c=[(eT@Bs[l].T).detach() for l in range(L)] + creds=[] + for l in range(L): + if aux and alpha>0: + if branch=='perlayer':aux.set_block(l) + with torch.no_grad():av=aux(hi[l].detach(),torch.full((batch,),l/L,device=dev),s).detach() + rv=(av**2).mean(-1,keepdim=True).sqrt()+1e-6;rd=(dfa_c[l]**2).mean(-1,keepdim=True).sqrt()+1e-6 + creds.append(alpha*av/rv+(1-alpha)*dfa_c[l]/rd) + else:creds.append(dfa_c[l]) + lo2=F.cross_entropy(model.out_head(model.out_ln(hL)),y);hop.zero_grad();lo2.backward();hop.step() + for l in range(L): + a=creds[l];rm=(a**2).mean(-1,keepdim=True).sqrt()+1e-6 + ll2=(model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean() + bops[l].zero_grad();ll2.backward();torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0);bops[l].step() + a0=creds[0];r0=(a0**2).mean(-1,keepdim=True).sqrt()+1e-6 + el=(model.embed(x)*(a0/r0)).sum(-1).mean();eop.zero_grad();el.backward();eop.step() + for sc in schs:sc.step() + acc=evaluate(model,tel,dev);accs.append(acc) + return accs + +def main(): + p=argparse.ArgumentParser() + p.add_argument('--gpu',type=int,default=2) + p.add_argument('--seeds',type=int,nargs='+',default=[42,123,456]) + p.add_argument('--output_dir',type=str,default='results/scaffold_replication') + a=p.parse_args() + os.makedirs(a.output_dir,exist_ok=True) + results={} + for branch in ['dfa','perlayer','vec']: + results[branch]={'final':[],'acc20':[]} + for seed in a.seeds: + alpha=0.75 if branch!='dfa' else 0.0 + print(f" Running {branch} seed={seed}...") + accs=run_branch(seed,branch,a.gpu,alpha=alpha) + final=accs[-1];acc20=accs[14] if len(accs)>14 else accs[-1] # epoch 20 = index 14 (t0=5, so epoch 20 is step 14) + results[branch]['final'].append(final) + results[branch]['acc20'].append(acc20) + print(f" {branch} s={seed}: acc@20={acc20:.4f}, final={final:.4f}") + print(f"\n{'='*60}\nSUMMARY (3-seed)\n{'='*60}") + print(f"{'Method':<25} {'acc@20':>15} {'final':>15} {'diff vs DFA':>12}") + print("-"*70) + dfa_mean=np.mean(results['dfa']['final']) + for b in ['dfa','perlayer','vec']: + m20=np.mean(results[b]['acc20']);s20=np.std(results[b]['acc20']) + mf=np.mean(results[b]['final']);sf=np.std(results[b]['final']) + diff=mf-dfa_mean + print(f"{b:<25} {m20:.4f}±{s20:.4f} {mf:.4f}±{sf:.4f} {diff:+.4f}") + with open(os.path.join(a.output_dir,'replication.json'),'w') as f:json.dump(results,f,indent=2,default=float) + print(f"\nSaved to {a.output_dir}/replication.json") + +if __name__=='__main__':main() |
