diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-02 14:50:19 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-02 14:50:19 -0500 |
| commit | c8407db32d3a159613ce5f0a567843ed4c970a27 (patch) | |
| tree | ed0f6041722147c8d836e31615bd233fd3e4c07c /experiments | |
| parent | 8fe362f3b6d792081d5be32d9d7d6972b2d1c9b2 (diff) | |
Fix gelu_ablation.py: compute method-specific Gamma instead of hardcoded 1.0
DFA now uses regenerated DFA Bs for credit; SB/CB use BP as proxy (feedback nets not saved).
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
| -rw-r--r-- | experiments/gelu_ablation.py | 30 |
1 files changed, 24 insertions, 6 deletions
diff --git a/experiments/gelu_ablation.py b/experiments/gelu_ablation.py index bef35b1..4e8d49c 100644 --- a/experiments/gelu_ablation.py +++ b/experiments/gelu_ablation.py @@ -197,18 +197,36 @@ def train_credit_bridge(model, trl, tel, dev, epochs=100, lr=1e-3, lr_fb=1e-3, w for s in schs:s.step() if ep%20==0:print(f" Ep {ep}: acc={evaluate(model,tel,dev):.4f}",flush=True) -def compute_diagnostics(model, tel, dev): - model.eval();L=model.num_blocks +def compute_diagnostics(model, tel, dev, method='bp', seed=42): + model.eval();L=model.num_blocks;d=model.d_hidden;C=10 for x,y in tel:x=x.view(x.size(0),-1).to(dev);y=y.to(dev);break + batch=x.size(0) + # BP gradients h0=model.embed(x.detach());hs=[h0.clone().requires_grad_(True)] for bl in model.blocks:hs.append(hs[-1]+bl(hs[-1])) lo=model.out_head(model.out_ln(hs[-1]));loss=F.cross_entropy(lo,y) gs=torch.autograd.grad(loss,hs);bp={l:gs[l].detach() for l in range(L)} with torch.no_grad():_,hi=model(x,return_hidden=True) nse=((hi[L//2]-hi[-1]).norm(-1)/hi[-1].norm(-1).clamp(min=1e-8)).mean().item() - rhos=[] + # Method-specific credit + with torch.no_grad(): + logits=model(x);e_T=logits.softmax(-1);e_T[torch.arange(batch),y]-=1 + if method=='dfa': + # Regenerate DFA Bs from seed + torch.manual_seed(seed) + if hasattr(model,'embed'): + _tmp=type(model)(3072,d,C,L) # consume same random state + dfa_Bs=[torch.randn(d,C,device=dev)/np.sqrt(C) for _ in range(L)] + gammas,rhos=[],[] for l in range(L): - h_l=hi[l].detach();a_l=bp[l] + h_l=hi[l].detach() + if method=='bp': + a_l=bp[l] + elif method=='dfa': + a_l=(e_T@dfa_Bs[l].T).detach() + else: + a_l=bp[l] # SB/CB: use BP as proxy (their feedback nets not saved) + gammas.append(cosine_similarity_batch(a_l,bp[l])) def mk(sl): def f(h): with torch.no_grad(): @@ -217,7 +235,7 @@ def compute_diagnostics(model, tel, dev): return F.cross_entropy(model.out_head(model.out_ln(c)),y,reduction='none') return f rhos.append(perturbation_correlation(h_l,a_l,mk(l),epsilon=1e-3,M=16)) - return {'Gamma':1.0,'rho':np.mean(rhos),'naive_StateErr':nse} + return {'Gamma':float(np.mean(gammas)),'rho':float(np.mean(rhos)),'naive_StateErr':nse} def main(): p=argparse.ArgumentParser() @@ -243,7 +261,7 @@ def main(): elif args.method=='state_bridge':se=train_state_bridge(model,trl,tel,dev) elif args.method=='credit_bridge':train_credit_bridge(model,trl,tel,dev) acc=evaluate(model,tel,dev) - diag=compute_diagnostics(model,tel,dev) + diag=compute_diagnostics(model,tel,dev,method=args.method,seed=args.seed) torch.save(model.state_dict(),os.path.join(args.output_dir,f'{args.activation}_{args.method}_s{args.seed}.pt')) result={'activation':args.activation,'method':args.method,'seed':args.seed,'acc':acc, 'StateErr':se,'Gamma':diag['Gamma'],'rho':diag['rho'],'naive_StateErr':diag['naive_StateErr']} |
