1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
|
"""Phase 10A.8C: 3-seed minimal scaffold replication."""
import os, sys, json, argparse, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
import torch.optim as optim, copy
from torch.utils.data import DataLoader
import torchvision, torchvision.transforms as transforms
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.residual_mlp import ResidualMLP
from models.value_net import SinusoidalTimeEmbed
class VectorCreditNet(nn.Module):
def __init__(self, d, s_dim):
super().__init__()
self.ln=nn.LayerNorm(d);self.te=SinusoidalTimeEmbed(32)
layers=[];ind=d+32+s_dim
for i in range(3):layers+=[nn.Linear(ind if i==0 else 256,256),nn.GELU()]
layers.append(nn.Linear(256,d));self.net=nn.Sequential(*layers)
def forward(self,h,t,s):return self.net(torch.cat([self.ln(h),self.te(t),s],-1))
class PerLayerVector(nn.Module):
def __init__(self,d,L):
super().__init__()
self.vecs=nn.ParameterList([nn.Parameter(torch.randn(d)*0.01) for _ in range(L)])
self._b=0
def set_block(self,l):self._b=l
def forward(self,h,t,s):return self.vecs[self._b].unsqueeze(0).expand(h.size(0),-1)
def get_cifar10(bs=128):
tt=transforms.Compose([transforms.RandomCrop(32,4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))])
tv=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))])
return (DataLoader(torchvision.datasets.CIFAR10('./data',True,download=True,transform=tt),bs,True,num_workers=4,pin_memory=True),
DataLoader(torchvision.datasets.CIFAR10('./data',False,download=True,transform=tv),bs,False,num_workers=4,pin_memory=True))
def evaluate(m,tl,dev):
m.eval();c,t=0,0
with torch.no_grad():
for x,y in tl:x=x.view(x.size(0),-1).to(dev);y=y.to(dev);c+=(m(x).argmax(1)==y).sum().item();t+=x.size(0)
return c/t
def run_branch(seed,branch,gpu,epochs=100,t0=5,alpha=0.75,lr=1e-3,wd=0.01,lr_fb=1e-3,M=4):
dev=torch.device(f'cuda:{gpu}')
torch.manual_seed(seed);np.random.seed(seed);torch.cuda.manual_seed_all(seed)
trl,tel=get_cifar10()
L,d=4,256;inp=3072
model=ResidualMLP(inp,d,10,L).to(dev)
Bs=[torch.randn(d,10,device=dev)/np.sqrt(10) for _ in range(L)]
bops=[optim.AdamW(b.parameters(),lr=lr,weight_decay=wd) for b in model.blocks]
eop=optim.AdamW(model.embed.parameters(),lr=lr,weight_decay=wd)
hop=optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()),lr=lr,weight_decay=wd)
schs=[optim.lr_scheduler.CosineAnnealingLR(o,T_max=epochs) for o in bops]+[optim.lr_scheduler.CosineAnnealingLR(eop,T_max=epochs),optim.lr_scheduler.CosineAnnealingLR(hop,T_max=epochs)]
# DFA warmup to t0
for ep in range(1,t0+1):
model.train()
for x,y in trl:
x=x.view(x.size(0),-1).to(dev);y=y.to(dev);b=x.size(0)
with torch.no_grad():lo,hi=model(x,return_hidden=True);eT=lo.softmax(-1);eT[torch.arange(b),y]-=1
hL=hi[-1].detach()
F.cross_entropy(model.out_head(model.out_ln(hL)),y).backward();hop.step();hop.zero_grad()
for l in range(L):
a=(eT@Bs[l].T).detach();rm=(a**2).mean(-1,keepdim=True).sqrt()+1e-6
ll=(model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean()
bops[l].zero_grad();ll.backward();torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0);bops[l].step()
a0=(eT@Bs[0].T).detach();r0=(a0**2).mean(-1,keepdim=True).sqrt()+1e-6
el=(model.embed(x)*(a0/r0)).sum(-1).mean();eop.zero_grad();el.backward();eop.step()
for s in schs:s.step()
ckpt_state=copy.deepcopy(model.state_dict())
ckpt_Bs=[B.clone() for B in Bs]
# Reset and reload for branch
model.load_state_dict(ckpt_state);Bs=[B.clone() for B in ckpt_Bs]
bops=[optim.AdamW(b.parameters(),lr=lr,weight_decay=wd) for b in model.blocks]
eop=optim.AdamW(model.embed.parameters(),lr=lr,weight_decay=wd)
hop=optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()),lr=lr,weight_decay=wd)
schs=[optim.lr_scheduler.CosineAnnealingLR(o,T_max=epochs) for o in bops]+[optim.lr_scheduler.CosineAnnealingLR(eop,T_max=epochs),optim.lr_scheduler.CosineAnnealingLR(hop,T_max=epochs)]
for _ in range(t0):
for s in schs:s.step()
aux=None;aop=None
if branch=='perlayer':
aux=PerLayerVector(d,L).to(dev);aop=optim.Adam(aux.parameters(),lr=lr_fb)
elif branch=='vec':
aux=VectorCreditNet(d,10).to(dev);aop=optim.Adam(aux.parameters(),lr=lr_fb)
eps=1e-3;accs=[]
for ep in range(t0+1,epochs+1):
model.train()
if aux:aux.train()
for x,y in trl:
x=x.view(x.size(0),-1).to(dev);y=y.to(dev);batch=x.size(0)
with torch.no_grad():lo,hi=model(x,return_hidden=True);eT=lo.softmax(-1);eT[torch.arange(batch),y]-=1;s=eT.detach()
hL=hi[-1].detach()
if aux and aop:
tL=torch.ones(batch,device=dev)
if branch=='perlayer':aux.set_block(L-1)
at=aux(hL,tL,s)
hLr=hL.clone().requires_grad_(True)
dL=torch.autograd.grad(F.cross_entropy(model.out_head(model.out_ln(hLr)),y,reduction='sum'),hLr,create_graph=False)[0].detach()
lt=((at-dL)**2).sum(-1).mean()
ll=np.random.randint(0,L);hl=hi[ll].detach();tl=torch.full((batch,),ll/L,device=dev)
if branch=='perlayer':aux.set_block(ll)
al=aux(hl,tl,s);lp=torch.tensor(0.0,device=dev)
for _ in range(M):
v=torch.randn_like(hl);v=v/(v.norm(-1,keepdim=True)+1e-8)
with torch.no_grad():
gj=(F.cross_entropy(model.forward_from_layer(hl+eps*v,ll),y,reduction='none')-F.cross_entropy(model.forward_from_layer(hl-eps*v,ll),y,reduction='none'))/(2*eps)
lp=lp+(((al*v).sum(-1)-gj.detach())**2).mean()
lp/=M;vl=lt+lp;aop.zero_grad();vl.backward();torch.nn.utils.clip_grad_norm_(aux.parameters(),1.0);aop.step()
dfa_c=[(eT@Bs[l].T).detach() for l in range(L)]
creds=[]
for l in range(L):
if aux and alpha>0:
if branch=='perlayer':aux.set_block(l)
with torch.no_grad():av=aux(hi[l].detach(),torch.full((batch,),l/L,device=dev),s).detach()
rv=(av**2).mean(-1,keepdim=True).sqrt()+1e-6;rd=(dfa_c[l]**2).mean(-1,keepdim=True).sqrt()+1e-6
creds.append(alpha*av/rv+(1-alpha)*dfa_c[l]/rd)
else:creds.append(dfa_c[l])
lo2=F.cross_entropy(model.out_head(model.out_ln(hL)),y);hop.zero_grad();lo2.backward();hop.step()
for l in range(L):
a=creds[l];rm=(a**2).mean(-1,keepdim=True).sqrt()+1e-6
ll2=(model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean()
bops[l].zero_grad();ll2.backward();torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0);bops[l].step()
a0=creds[0];r0=(a0**2).mean(-1,keepdim=True).sqrt()+1e-6
el=(model.embed(x)*(a0/r0)).sum(-1).mean();eop.zero_grad();el.backward();eop.step()
for sc in schs:sc.step()
acc=evaluate(model,tel,dev);accs.append(acc)
return accs
def main():
p=argparse.ArgumentParser()
p.add_argument('--gpu',type=int,default=2)
p.add_argument('--seeds',type=int,nargs='+',default=[42,123,456])
p.add_argument('--output_dir',type=str,default='results/scaffold_replication')
a=p.parse_args()
os.makedirs(a.output_dir,exist_ok=True)
results={}
for branch in ['dfa','perlayer','vec']:
results[branch]={'final':[],'acc20':[]}
for seed in a.seeds:
alpha=0.75 if branch!='dfa' else 0.0
print(f" Running {branch} seed={seed}...")
accs=run_branch(seed,branch,a.gpu,alpha=alpha)
final=accs[-1];acc20=accs[14] if len(accs)>14 else accs[-1] # epoch 20 = index 14 (t0=5, so epoch 20 is step 14)
results[branch]['final'].append(final)
results[branch]['acc20'].append(acc20)
print(f" {branch} s={seed}: acc@20={acc20:.4f}, final={final:.4f}")
print(f"\n{'='*60}\nSUMMARY (3-seed)\n{'='*60}")
print(f"{'Method':<25} {'acc@20':>15} {'final':>15} {'diff vs DFA':>12}")
print("-"*70)
dfa_mean=np.mean(results['dfa']['final'])
for b in ['dfa','perlayer','vec']:
m20=np.mean(results[b]['acc20']);s20=np.std(results[b]['acc20'])
mf=np.mean(results[b]['final']);sf=np.std(results[b]['final'])
diff=mf-dfa_mean
print(f"{b:<25} {m20:.4f}±{s20:.4f} {mf:.4f}±{sf:.4f} {diff:+.4f}")
with open(os.path.join(a.output_dir,'replication.json'),'w') as f:json.dump(results,f,indent=2,default=float)
print(f"\nSaved to {a.output_dir}/replication.json")
if __name__=='__main__':main()
|