summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-03-27 18:07:58 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-03-27 18:07:58 -0500
commit2a230acd5ee3fa6605892d524badf281ba7e9cfd (patch)
tree9b3eadb60966b895a0349cbb457e6dab2004af47 /experiments
parent4d6e689fe6bfffef6db7a4650aec210cd3eeed5c (diff)
Add Phase 10A.8C: 3-seed replication — scaffold gains are marginal
3-seed results (mean±std): - DFA: 0.306±0.006 - perlayer_vector α=0.75: 0.304±0.006 (-0.2%, not significant) - random_trainable α=0.75: 0.313±0.007 (+0.7%, marginal, error bars overlap) Single-seed gains (+1.1% perlayer, +0.8% vec) do not robustly replicate. The scaffold mechanism provides at best a marginal, statistically uncertain benefit. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
-rw-r--r--experiments/minimal_scaffold_replication.py157
1 files changed, 157 insertions, 0 deletions
diff --git a/experiments/minimal_scaffold_replication.py b/experiments/minimal_scaffold_replication.py
new file mode 100644
index 0000000..c2dbb0c
--- /dev/null
+++ b/experiments/minimal_scaffold_replication.py
@@ -0,0 +1,157 @@
+"""Phase 10A.8C: 3-seed minimal scaffold replication."""
+import os, sys, json, argparse, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
+import torch.optim as optim, copy
+from torch.utils.data import DataLoader
+import torchvision, torchvision.transforms as transforms
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from models.residual_mlp import ResidualMLP
+from models.value_net import SinusoidalTimeEmbed
+
+class VectorCreditNet(nn.Module):
+ def __init__(self, d, s_dim):
+ super().__init__()
+ self.ln=nn.LayerNorm(d);self.te=SinusoidalTimeEmbed(32)
+ layers=[];ind=d+32+s_dim
+ for i in range(3):layers+=[nn.Linear(ind if i==0 else 256,256),nn.GELU()]
+ layers.append(nn.Linear(256,d));self.net=nn.Sequential(*layers)
+ def forward(self,h,t,s):return self.net(torch.cat([self.ln(h),self.te(t),s],-1))
+
+class PerLayerVector(nn.Module):
+ def __init__(self,d,L):
+ super().__init__()
+ self.vecs=nn.ParameterList([nn.Parameter(torch.randn(d)*0.01) for _ in range(L)])
+ self._b=0
+ def set_block(self,l):self._b=l
+ def forward(self,h,t,s):return self.vecs[self._b].unsqueeze(0).expand(h.size(0),-1)
+
+def get_cifar10(bs=128):
+ tt=transforms.Compose([transforms.RandomCrop(32,4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))])
+ tv=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))])
+ return (DataLoader(torchvision.datasets.CIFAR10('./data',True,download=True,transform=tt),bs,True,num_workers=4,pin_memory=True),
+ DataLoader(torchvision.datasets.CIFAR10('./data',False,download=True,transform=tv),bs,False,num_workers=4,pin_memory=True))
+
+def evaluate(m,tl,dev):
+ m.eval();c,t=0,0
+ with torch.no_grad():
+ for x,y in tl:x=x.view(x.size(0),-1).to(dev);y=y.to(dev);c+=(m(x).argmax(1)==y).sum().item();t+=x.size(0)
+ return c/t
+
+def run_branch(seed,branch,gpu,epochs=100,t0=5,alpha=0.75,lr=1e-3,wd=0.01,lr_fb=1e-3,M=4):
+ dev=torch.device(f'cuda:{gpu}')
+ torch.manual_seed(seed);np.random.seed(seed);torch.cuda.manual_seed_all(seed)
+ trl,tel=get_cifar10()
+ L,d=4,256;inp=3072
+ model=ResidualMLP(inp,d,10,L).to(dev)
+ Bs=[torch.randn(d,10,device=dev)/np.sqrt(10) for _ in range(L)]
+ bops=[optim.AdamW(b.parameters(),lr=lr,weight_decay=wd) for b in model.blocks]
+ eop=optim.AdamW(model.embed.parameters(),lr=lr,weight_decay=wd)
+ hop=optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()),lr=lr,weight_decay=wd)
+ schs=[optim.lr_scheduler.CosineAnnealingLR(o,T_max=epochs) for o in bops]+[optim.lr_scheduler.CosineAnnealingLR(eop,T_max=epochs),optim.lr_scheduler.CosineAnnealingLR(hop,T_max=epochs)]
+ # DFA warmup to t0
+ for ep in range(1,t0+1):
+ model.train()
+ for x,y in trl:
+ x=x.view(x.size(0),-1).to(dev);y=y.to(dev);b=x.size(0)
+ with torch.no_grad():lo,hi=model(x,return_hidden=True);eT=lo.softmax(-1);eT[torch.arange(b),y]-=1
+ hL=hi[-1].detach()
+ F.cross_entropy(model.out_head(model.out_ln(hL)),y).backward();hop.step();hop.zero_grad()
+ for l in range(L):
+ a=(eT@Bs[l].T).detach();rm=(a**2).mean(-1,keepdim=True).sqrt()+1e-6
+ ll=(model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean()
+ bops[l].zero_grad();ll.backward();torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0);bops[l].step()
+ a0=(eT@Bs[0].T).detach();r0=(a0**2).mean(-1,keepdim=True).sqrt()+1e-6
+ el=(model.embed(x)*(a0/r0)).sum(-1).mean();eop.zero_grad();el.backward();eop.step()
+ for s in schs:s.step()
+ ckpt_state=copy.deepcopy(model.state_dict())
+ ckpt_Bs=[B.clone() for B in Bs]
+
+ # Reset and reload for branch
+ model.load_state_dict(ckpt_state);Bs=[B.clone() for B in ckpt_Bs]
+ bops=[optim.AdamW(b.parameters(),lr=lr,weight_decay=wd) for b in model.blocks]
+ eop=optim.AdamW(model.embed.parameters(),lr=lr,weight_decay=wd)
+ hop=optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()),lr=lr,weight_decay=wd)
+ schs=[optim.lr_scheduler.CosineAnnealingLR(o,T_max=epochs) for o in bops]+[optim.lr_scheduler.CosineAnnealingLR(eop,T_max=epochs),optim.lr_scheduler.CosineAnnealingLR(hop,T_max=epochs)]
+ for _ in range(t0):
+ for s in schs:s.step()
+
+ aux=None;aop=None
+ if branch=='perlayer':
+ aux=PerLayerVector(d,L).to(dev);aop=optim.Adam(aux.parameters(),lr=lr_fb)
+ elif branch=='vec':
+ aux=VectorCreditNet(d,10).to(dev);aop=optim.Adam(aux.parameters(),lr=lr_fb)
+
+ eps=1e-3;accs=[]
+ for ep in range(t0+1,epochs+1):
+ model.train()
+ if aux:aux.train()
+ for x,y in trl:
+ x=x.view(x.size(0),-1).to(dev);y=y.to(dev);batch=x.size(0)
+ with torch.no_grad():lo,hi=model(x,return_hidden=True);eT=lo.softmax(-1);eT[torch.arange(batch),y]-=1;s=eT.detach()
+ hL=hi[-1].detach()
+ if aux and aop:
+ tL=torch.ones(batch,device=dev)
+ if branch=='perlayer':aux.set_block(L-1)
+ at=aux(hL,tL,s)
+ hLr=hL.clone().requires_grad_(True)
+ dL=torch.autograd.grad(F.cross_entropy(model.out_head(model.out_ln(hLr)),y,reduction='sum'),hLr,create_graph=False)[0].detach()
+ lt=((at-dL)**2).sum(-1).mean()
+ ll=np.random.randint(0,L);hl=hi[ll].detach();tl=torch.full((batch,),ll/L,device=dev)
+ if branch=='perlayer':aux.set_block(ll)
+ al=aux(hl,tl,s);lp=torch.tensor(0.0,device=dev)
+ for _ in range(M):
+ v=torch.randn_like(hl);v=v/(v.norm(-1,keepdim=True)+1e-8)
+ with torch.no_grad():
+ gj=(F.cross_entropy(model.forward_from_layer(hl+eps*v,ll),y,reduction='none')-F.cross_entropy(model.forward_from_layer(hl-eps*v,ll),y,reduction='none'))/(2*eps)
+ lp=lp+(((al*v).sum(-1)-gj.detach())**2).mean()
+ lp/=M;vl=lt+lp;aop.zero_grad();vl.backward();torch.nn.utils.clip_grad_norm_(aux.parameters(),1.0);aop.step()
+ dfa_c=[(eT@Bs[l].T).detach() for l in range(L)]
+ creds=[]
+ for l in range(L):
+ if aux and alpha>0:
+ if branch=='perlayer':aux.set_block(l)
+ with torch.no_grad():av=aux(hi[l].detach(),torch.full((batch,),l/L,device=dev),s).detach()
+ rv=(av**2).mean(-1,keepdim=True).sqrt()+1e-6;rd=(dfa_c[l]**2).mean(-1,keepdim=True).sqrt()+1e-6
+ creds.append(alpha*av/rv+(1-alpha)*dfa_c[l]/rd)
+ else:creds.append(dfa_c[l])
+ lo2=F.cross_entropy(model.out_head(model.out_ln(hL)),y);hop.zero_grad();lo2.backward();hop.step()
+ for l in range(L):
+ a=creds[l];rm=(a**2).mean(-1,keepdim=True).sqrt()+1e-6
+ ll2=(model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean()
+ bops[l].zero_grad();ll2.backward();torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0);bops[l].step()
+ a0=creds[0];r0=(a0**2).mean(-1,keepdim=True).sqrt()+1e-6
+ el=(model.embed(x)*(a0/r0)).sum(-1).mean();eop.zero_grad();el.backward();eop.step()
+ for sc in schs:sc.step()
+ acc=evaluate(model,tel,dev);accs.append(acc)
+ return accs
+
+def main():
+ p=argparse.ArgumentParser()
+ p.add_argument('--gpu',type=int,default=2)
+ p.add_argument('--seeds',type=int,nargs='+',default=[42,123,456])
+ p.add_argument('--output_dir',type=str,default='results/scaffold_replication')
+ a=p.parse_args()
+ os.makedirs(a.output_dir,exist_ok=True)
+ results={}
+ for branch in ['dfa','perlayer','vec']:
+ results[branch]={'final':[],'acc20':[]}
+ for seed in a.seeds:
+ alpha=0.75 if branch!='dfa' else 0.0
+ print(f" Running {branch} seed={seed}...")
+ accs=run_branch(seed,branch,a.gpu,alpha=alpha)
+ final=accs[-1];acc20=accs[14] if len(accs)>14 else accs[-1] # epoch 20 = index 14 (t0=5, so epoch 20 is step 14)
+ results[branch]['final'].append(final)
+ results[branch]['acc20'].append(acc20)
+ print(f" {branch} s={seed}: acc@20={acc20:.4f}, final={final:.4f}")
+ print(f"\n{'='*60}\nSUMMARY (3-seed)\n{'='*60}")
+ print(f"{'Method':<25} {'acc@20':>15} {'final':>15} {'diff vs DFA':>12}")
+ print("-"*70)
+ dfa_mean=np.mean(results['dfa']['final'])
+ for b in ['dfa','perlayer','vec']:
+ m20=np.mean(results[b]['acc20']);s20=np.std(results[b]['acc20'])
+ mf=np.mean(results[b]['final']);sf=np.std(results[b]['final'])
+ diff=mf-dfa_mean
+ print(f"{b:<25} {m20:.4f}±{s20:.4f} {mf:.4f}±{sf:.4f} {diff:+.4f}")
+ with open(os.path.join(a.output_dir,'replication.json'),'w') as f:json.dump(results,f,indent=2,default=float)
+ print(f"\nSaved to {a.output_dir}/replication.json")
+
+if __name__=='__main__':main()