summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-02 14:50:19 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-02 14:50:19 -0500
commitc8407db32d3a159613ce5f0a567843ed4c970a27 (patch)
treeed0f6041722147c8d836e31615bd233fd3e4c07c /experiments
parent8fe362f3b6d792081d5be32d9d7d6972b2d1c9b2 (diff)
Fix gelu_ablation.py: compute method-specific Gamma instead of hardcoded 1.0
DFA now uses regenerated DFA Bs for credit; SB/CB use BP as proxy (feedback nets not saved). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
-rw-r--r--experiments/gelu_ablation.py30
1 files changed, 24 insertions, 6 deletions
diff --git a/experiments/gelu_ablation.py b/experiments/gelu_ablation.py
index bef35b1..4e8d49c 100644
--- a/experiments/gelu_ablation.py
+++ b/experiments/gelu_ablation.py
@@ -197,18 +197,36 @@ def train_credit_bridge(model, trl, tel, dev, epochs=100, lr=1e-3, lr_fb=1e-3, w
for s in schs:s.step()
if ep%20==0:print(f" Ep {ep}: acc={evaluate(model,tel,dev):.4f}",flush=True)
-def compute_diagnostics(model, tel, dev):
- model.eval();L=model.num_blocks
+def compute_diagnostics(model, tel, dev, method='bp', seed=42):
+ model.eval();L=model.num_blocks;d=model.d_hidden;C=10
for x,y in tel:x=x.view(x.size(0),-1).to(dev);y=y.to(dev);break
+ batch=x.size(0)
+ # BP gradients
h0=model.embed(x.detach());hs=[h0.clone().requires_grad_(True)]
for bl in model.blocks:hs.append(hs[-1]+bl(hs[-1]))
lo=model.out_head(model.out_ln(hs[-1]));loss=F.cross_entropy(lo,y)
gs=torch.autograd.grad(loss,hs);bp={l:gs[l].detach() for l in range(L)}
with torch.no_grad():_,hi=model(x,return_hidden=True)
nse=((hi[L//2]-hi[-1]).norm(-1)/hi[-1].norm(-1).clamp(min=1e-8)).mean().item()
- rhos=[]
+ # Method-specific credit
+ with torch.no_grad():
+ logits=model(x);e_T=logits.softmax(-1);e_T[torch.arange(batch),y]-=1
+ if method=='dfa':
+ # Regenerate DFA Bs from seed
+ torch.manual_seed(seed)
+ if hasattr(model,'embed'):
+ _tmp=type(model)(3072,d,C,L) # consume same random state
+ dfa_Bs=[torch.randn(d,C,device=dev)/np.sqrt(C) for _ in range(L)]
+ gammas,rhos=[],[]
for l in range(L):
- h_l=hi[l].detach();a_l=bp[l]
+ h_l=hi[l].detach()
+ if method=='bp':
+ a_l=bp[l]
+ elif method=='dfa':
+ a_l=(e_T@dfa_Bs[l].T).detach()
+ else:
+ a_l=bp[l] # SB/CB: use BP as proxy (their feedback nets not saved)
+ gammas.append(cosine_similarity_batch(a_l,bp[l]))
def mk(sl):
def f(h):
with torch.no_grad():
@@ -217,7 +235,7 @@ def compute_diagnostics(model, tel, dev):
return F.cross_entropy(model.out_head(model.out_ln(c)),y,reduction='none')
return f
rhos.append(perturbation_correlation(h_l,a_l,mk(l),epsilon=1e-3,M=16))
- return {'Gamma':1.0,'rho':np.mean(rhos),'naive_StateErr':nse}
+ return {'Gamma':float(np.mean(gammas)),'rho':float(np.mean(rhos)),'naive_StateErr':nse}
def main():
p=argparse.ArgumentParser()
@@ -243,7 +261,7 @@ def main():
elif args.method=='state_bridge':se=train_state_bridge(model,trl,tel,dev)
elif args.method=='credit_bridge':train_credit_bridge(model,trl,tel,dev)
acc=evaluate(model,tel,dev)
- diag=compute_diagnostics(model,tel,dev)
+ diag=compute_diagnostics(model,tel,dev,method=args.method,seed=args.seed)
torch.save(model.state_dict(),os.path.join(args.output_dir,f'{args.activation}_{args.method}_s{args.seed}.pt'))
result={'activation':args.activation,'method':args.method,'seed':args.seed,'acc':acc,
'StateErr':se,'Gamma':diag['Gamma'],'rho':diag['rho'],'naive_StateErr':diag['naive_StateErr']}