diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-03-26 16:27:53 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-03-26 16:27:53 -0500 |
| commit | 610e1169e19378cccd2d9b92a588c24dca7f3df7 (patch) | |
| tree | 532f8dc2fda6c68ab1409b20d7431b76d8d6f378 /experiments | |
| parent | ef4aed70130e2212b4ed1cb7212e2ea6c7c7adb2 (diff) | |
Add Phase 10A.5: blend gain is implicit regularization, not learned credit
Dissection of 6 branches from same DFA checkpoint:
- blend_random_frozen: 12.6% (CATASTROPHIC — frozen noise destroys training)
- blend_random_trainable: 32.2% (+1.2% — trainable network helps)
- blend_shuffled_trainable: 32.5% (+1.4% — even wrong targets work!)
- blend_gaussian_noise: 30.8% (neutral)
- scaled_DFA_norm_match: 31.0% (neutral)
The gain comes from implicit regularization through a co-optimized auxiliary
network, NOT from learned credit quality. Phase 9A's +1.5% was an optimization
dynamics effect, not evidence of useful credit assignment.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
| -rw-r--r-- | experiments/blend_mechanism_dissection.py | 466 |
1 files changed, 466 insertions, 0 deletions
diff --git a/experiments/blend_mechanism_dissection.py b/experiments/blend_mechanism_dissection.py new file mode 100644 index 0000000..a193524 --- /dev/null +++ b/experiments/blend_mechanism_dissection.py @@ -0,0 +1,466 @@ +""" +Phase 10A.5: Blend Mechanism Dissection. + +Core question: Phase 10A's gain from blend(random Vec, DFA) — is it learned +correction or blend/diversification/norm effect? + +6 branches from the same DFA checkpoint: +1. continue_DFA — baseline +2. blend_random_frozen — random Vec frozen, no training +3. blend_random_trainable — random Vec, trained online after handoff +4. blend_shuffled_trainable — Vec trained but targets permuted (no semantics) +5. blend_gaussian_noise_matched — no Vec, just RMS-matched Gaussian noise +6. scaled_DFA_match_norm — DFA only but effective norm matched to blend +""" +import os +import sys +import json +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision +import torchvision.transforms as transforms +import copy + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.residual_mlp import ResidualMLP +from models.value_net import SinusoidalTimeEmbed +from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation + + +class VectorCreditNet(nn.Module): + def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3): + super().__init__() + self.ln = nn.LayerNorm(d_hidden) + self.time_embed = SinusoidalTimeEmbed(time_embed_dim) + input_dim = d_hidden + time_embed_dim + s_dim + layers = [] + for i in range(num_layers): + in_d = input_dim if i == 0 else hidden_dim + layers.append(nn.Linear(in_d, hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, d_hidden)) + self.net = nn.Sequential(*layers) + + def forward(self, h, t, s): + return self.net(torch.cat([self.ln(h), self.time_embed(t), s], dim=-1)) + + +def get_cifar10(batch_size=128): + transform_train = transforms.Compose([ + transforms.RandomCrop(32,padding=4),transforms.RandomHorizontalFlip(), + transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))]) + transform_test = transforms.Compose([ + transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))]) + trainset = torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform_train) + testset = torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform_test) + return (DataLoader(trainset,batch_size=batch_size,shuffle=True,num_workers=4,pin_memory=True), + DataLoader(testset,batch_size=batch_size,shuffle=False,num_workers=4,pin_memory=True)) + + +def evaluate(model, test_loader, device): + model.eval();c,t=0,0 + with torch.no_grad(): + for x,y in test_loader: + x=x.view(x.size(0),-1).to(device);y=y.to(device) + c+=(model(x).argmax(1)==y).sum().item();t+=x.size(0) + return c/t + + +def train_dfa_get_checkpoint(model, train_loader, test_loader, device, total_epochs, t0, lr, wd): + d=model.d_hidden;L=model.num_blocks + Bs=[torch.randn(d,10,device=device)/np.sqrt(10) for _ in range(L)] + block_opts=[optim.AdamW(b.parameters(),lr=lr,weight_decay=wd) for b in model.blocks] + embed_opt=optim.AdamW(model.embed.parameters(),lr=lr,weight_decay=wd) + head_opt=optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()),lr=lr,weight_decay=wd) + scheds=[optim.lr_scheduler.CosineAnnealingLR(o,T_max=total_epochs) for o in block_opts]+\ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt,T_max=total_epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt,T_max=total_epochs)] + ckpt=None + for epoch in range(1,total_epochs+1): + model.train();tl,c,t=0,0,0 + for x,y in train_loader: + x=x.view(x.size(0),-1).to(device);y=y.to(device);b=x.size(0) + with torch.no_grad(): + lo,hi=model(x,return_hidden=True);lv=F.cross_entropy(lo,y) + eT=lo.softmax(-1);eT[torch.arange(b),y]-=1 + hL=hi[-1].detach() + lo2=F.cross_entropy(model.out_head(model.out_ln(hL)),y) + head_opt.zero_grad();lo2.backward();head_opt.step() + for l in range(L): + a=(eT@Bs[l].T).detach();rm=(a**2).mean(-1,keepdim=True).sqrt()+1e-6 + f=model.blocks[l](hi[l].detach());ll=(f*(a/rm)).sum(-1).mean() + block_opts[l].zero_grad();ll.backward() + torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0);block_opts[l].step() + a0=(eT@Bs[0].T).detach();r0=(a0**2).mean(-1,keepdim=True).sqrt()+1e-6 + el=(model.embed(x)*(a0/r0)).sum(-1).mean() + embed_opt.zero_grad();el.backward();embed_opt.step() + tl+=lv.item()*b;c+=(lo.argmax(1)==y).sum().item();t+=b + for s in scheds:s.step() + if epoch==t0: + acc=evaluate(model,test_loader,device) + ckpt={'model':copy.deepcopy(model.state_dict()),'Bs':[B.clone() for B in Bs],'acc':acc} + print(f" [DFA] Checkpoint at epoch {t0}: acc={acc:.4f}") + if epoch%20==0:print(f" [DFA] Epoch {epoch}: acc={evaluate(model,test_loader,device):.4f}") + return Bs,ckpt + + +def estimate_blend_norm_ratio(model, Bs, vec_net, train_loader, device, alpha, n_batches=50): + """Estimate ratio of blend update norm to DFA-only update norm (per block).""" + model.eval();L=model.num_blocks + dfa_norms=[0.0]*L;blend_norms=[0.0]*L;n=0 + for i,(x,y) in enumerate(train_loader): + if i>=n_batches:break + x=x.view(x.size(0),-1).to(device);y=y.to(device);batch=x.size(0) + with torch.no_grad(): + lo,hi=model(x,return_hidden=True) + eT=lo.softmax(-1);eT[torch.arange(batch),y]-=1;s=eT.detach() + for l in range(L): + h_l=hi[l].detach();t_l=torch.full((batch,),l/L,device=device) + a_dfa=(eT@Bs[l].T).detach() + a_vec=vec_net(h_l,t_l,s).detach() if vec_net else torch.zeros_like(a_dfa) + rms_d=(a_dfa**2).mean(-1,keepdim=True).sqrt()+1e-6 + rms_v=(a_vec**2).mean(-1,keepdim=True).sqrt()+1e-6 + a_blend=alpha*a_vec/rms_v+(1-alpha)*a_dfa/rms_d + a_dfa_norm=a_dfa/rms_d + # Compute block update norms (proxy: just use credit norm) + dfa_norms[l]+=(a_dfa_norm**2).mean().item() + blend_norms[l]+=(a_blend**2).mean().item() + n+=1 + ratios=[((blend_norms[l]/n)/(dfa_norms[l]/n+1e-12))**0.5 for l in range(L)] + return float(np.mean(ratios)) + + +def compute_diagnostics(model, vec_net, Bs, test_loader, device, credit_mode, alpha=0.75): + """Compute Gamma, rho for current credit source.""" + model.eval() + if vec_net:vec_net.eval() + L=model.num_blocks + for x,y in test_loader: + x=x.view(x.size(0),-1).to(device);y=y.to(device);break + batch=x.size(0) + was_frozen=not next(model.parameters()).requires_grad + if was_frozen: + for p in model.parameters():p.requires_grad_(True) + model.zero_grad() + lo,hbp=model(x,return_hidden=True) + for l in range(L+1):hbp[l].retain_grad() + F.cross_entropy(lo,y).backward() + bp={l:hbp[l].grad.detach().clone() for l in range(L+1)} + if was_frozen: + for p in model.parameters():p.requires_grad_(False) + with torch.no_grad(): + lo2,hi=model(x,return_hidden=True) + eT=lo2.softmax(-1);eT[torch.arange(batch),y]-=1;s=eT.detach() + gammas,rhos=[],[] + for l in range(L): + h_l=hi[l].detach();t_l=torch.full((batch,),l/L,device=device) + if credit_mode=='dfa': + a_l=(s@Bs[l].T).detach() + elif credit_mode=='vec' and vec_net: + a_l=vec_net(h_l,t_l,s).detach() + elif credit_mode=='blend' and vec_net: + a_dfa=(s@Bs[l].T).detach();a_vec=vec_net(h_l,t_l,s).detach() + rv=(a_vec**2).mean(-1,keepdim=True).sqrt()+1e-6;rd=(a_dfa**2).mean(-1,keepdim=True).sqrt()+1e-6 + a_l=alpha*a_vec/rv+(1-alpha)*a_dfa/rd + else: + a_l=(s@Bs[l].T).detach() + gammas.append(cosine_similarity_batch(a_l,bp[l])) + def make_fwd(sl): + def f(h): + with torch.no_grad(): + c=h + for i in range(sl,L):c=c+model.blocks[i](c) + return F.cross_entropy(model.out_head(model.out_ln(c)),y,reduction='none') + return f + rhos.append(perturbation_correlation(h_l,a_l,make_fwd(l),epsilon=1e-3,M=16)) + return float(np.mean(gammas)),float(np.mean(rhos)) + + +def run_branch(model, vec_net, Bs, train_loader, test_loader, device, t0, total_epochs, + branch_type, alpha, lr, lr_fb, wd, M, beta_scale=1.0, branch_name=''): + """ + Run a training branch from checkpoint. + branch_type: 'dfa', 'blend_frozen', 'blend_trainable', 'blend_shuffled', + 'blend_gaussian', 'scaled_dfa' + """ + d=model.d_hidden;L=model.num_blocks;eps_pert=1e-3 + block_opts=[optim.AdamW(b.parameters(),lr=lr,weight_decay=wd) for b in model.blocks] + embed_opt=optim.AdamW(model.embed.parameters(),lr=lr,weight_decay=wd) + head_opt=optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()),lr=lr,weight_decay=wd) + vec_opt=optim.Adam(vec_net.parameters(),lr=lr_fb) if (vec_net and branch_type in ['blend_trainable','blend_shuffled']) else None + scheds=[optim.lr_scheduler.CosineAnnealingLR(o,T_max=total_epochs) for o in block_opts]+\ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt,T_max=total_epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt,T_max=total_epochs)] + for _ in range(t0): + for s in scheds:s.step() + + # Get reference RMS for gaussian noise matching + ref_rms=[0.01]*L # default + if branch_type=='blend_gaussian' and vec_net: + # Estimate RMS from frozen random Vec + model.eval() + for x,y in train_loader: + x=x.view(x.size(0),-1).to(device);y=y.to(device);batch=x.size(0) + with torch.no_grad(): + lo,hi=model(x,return_hidden=True) + eT=lo.softmax(-1);eT[torch.arange(batch),y]-=1;s=eT.detach() + for l in range(L): + h_l=hi[l].detach();t_l=torch.full((batch,),l/L,device=device) + a_v=vec_net(h_l,t_l,s).detach() + ref_rms[l]=(a_v**2).mean().sqrt().item() + break + + log={'test_acc':[],'train_loss':[],'gamma':[],'rho':[],'alpha_eff':[]} + diag_epochs=set(list(range(t0+1,min(t0+6,total_epochs+1)))+[t0+8,t0+10,t0+15,t0+20]+ + list(range(t0+10,total_epochs+1,10))+[total_epochs]) + + for epoch in range(t0+1,total_epochs+1): + model.train() + if vec_net and vec_opt:vec_net.train() + tl,c,t=0,0,0 + epoch_aux_norms,epoch_dfa_norms=[],[] + + for x,y in train_loader: + x=x.view(x.size(0),-1).to(device);y=y.to(device);batch=x.size(0) + with torch.no_grad(): + lo,hi=model(x,return_hidden=True);lv=F.cross_entropy(lo,y) + eT=lo.softmax(-1);eT[torch.arange(batch),y]-=1;s=eT.detach() + hL=hi[-1].detach() + + # Train Vec if applicable + if vec_opt and branch_type in ['blend_trainable','blend_shuffled']: + t_L=torch.ones(batch,device=device) + a_term=vec_net(hL,t_L,s) + hL_req=hL.clone().requires_grad_(True) + ce=F.cross_entropy(model.out_head(model.out_ln(hL_req)),y,reduction='sum') + dL=torch.autograd.grad(ce,hL_req,create_graph=False)[0].detach() + loss_term=((a_term-dL)**2).sum(-1).mean() + lt=np.random.randint(0,L);h_l=hi[lt].detach();t_l=torch.full((batch,),lt/L,device=device) + a_l=vec_net(h_l,t_l,s);lp2=torch.tensor(0.0,device=device) + for _ in range(M): + v=torch.randn_like(h_l);v=v/(v.norm(-1,keepdim=True)+1e-8) + with torch.no_grad(): + lp=F.cross_entropy(model.forward_from_layer(h_l+eps_pert*v,lt),y,reduction='none') + lm=F.cross_entropy(model.forward_from_layer(h_l-eps_pert*v,lt),y,reduction='none') + gj=(lp-lm)/(2*eps_pert) + if branch_type=='blend_shuffled': + gj=gj[torch.randperm(batch,device=device)] # Shuffle targets + lp2=lp2+(((a_l*v).sum(-1)-gj.detach())**2).mean() + lp2/=M + vl=loss_term+lp2;vec_opt.zero_grad();vl.backward() + torch.nn.utils.clip_grad_norm_(vec_net.parameters(),1.0);vec_opt.step() + + # Compute credits per block + dfa_credits=[(eT@Bs[l].T).detach() for l in range(L)] + credits=[] + for l in range(L): + a_dfa=dfa_credits[l] + rms_d=(a_dfa**2).mean(-1,keepdim=True).sqrt()+1e-6 + + if branch_type=='dfa': + credits.append(a_dfa/rms_d) + elif branch_type=='scaled_dfa': + credits.append(beta_scale*a_dfa/rms_d) + elif branch_type=='blend_gaussian': + xi=torch.randn_like(a_dfa) + rms_xi=(xi**2).mean(-1,keepdim=True).sqrt()+1e-6 + xi_matched=xi/rms_xi*ref_rms[l] # match to Vec RMS + rms_xi2=(xi_matched**2).mean(-1,keepdim=True).sqrt()+1e-6 + credits.append(alpha*xi_matched/rms_xi2+(1-alpha)*a_dfa/rms_d) + else: + # blend with Vec (frozen or trainable or shuffled) + h_l=hi[l].detach();t_l=torch.full((batch,),l/L,device=device) + with torch.no_grad(): + a_vec=vec_net(h_l,t_l,s).detach() + rms_v=(a_vec**2).mean(-1,keepdim=True).sqrt()+1e-6 + credits.append(alpha*a_vec/rms_v+(1-alpha)*a_dfa/rms_d) + + # Track update geometry + a_blend=credits[-1] + a_dfa_n=a_dfa/rms_d + aux_norm=(alpha*a_blend).norm().item() if branch_type!='dfa' else 0 + dfa_norm=((1-alpha)*a_dfa_n).norm().item() if branch_type not in ['dfa','scaled_dfa'] else a_blend.norm().item() + epoch_aux_norms.append(aux_norm) + epoch_dfa_norms.append(dfa_norm) + + # Update head + lo2=F.cross_entropy(model.out_head(model.out_ln(hL)),y) + head_opt.zero_grad();lo2.backward();head_opt.step() + + # Update blocks + for l in range(L): + a=credits[l];rm=(a**2).mean(-1,keepdim=True).sqrt()+1e-6 + f=model.blocks[l](hi[l].detach());ll=(f*(a/rm)).sum(-1).mean() + block_opts[l].zero_grad();ll.backward() + torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0);block_opts[l].step() + + a0=credits[0];r0=(a0**2).mean(-1,keepdim=True).sqrt()+1e-6 + el=(model.embed(x)*(a0/r0)).sum(-1).mean() + embed_opt.zero_grad();el.backward();embed_opt.step() + + tl+=lv.item()*batch;c+=(lo.argmax(1)==y).sum().item();t+=batch + + for sch in scheds:sch.step() + ta=evaluate(model,test_loader,device) + log['test_acc'].append(ta);log['train_loss'].append(tl/t) + + mean_aux=np.mean(epoch_aux_norms) if epoch_aux_norms else 0 + mean_dfa=np.mean(epoch_dfa_norms) if epoch_dfa_norms else 1 + aeff=mean_aux/(mean_aux+mean_dfa+1e-12) + log['alpha_eff'].append((epoch,aeff)) + + if epoch in diag_epochs: + cm='vec' if branch_type in ['blend_trainable','blend_shuffled'] and vec_net else 'dfa' + gamma,rho=compute_diagnostics(model,vec_net,Bs,test_loader,device,cm) + log['gamma'].append((epoch,gamma));log['rho'].append((epoch,rho)) + if epoch<=t0+15 or epoch%20==0 or epoch==total_epochs: + print(f" [{branch_name}] Ep {epoch}: acc={ta:.4f}, G={gamma:.4f}, r={rho:.4f}, aeff={aeff:.3f}") + elif epoch%20==0 or epoch==total_epochs: + print(f" [{branch_name}] Ep {epoch}: acc={ta:.4f}") + + return log + + +def run_experiment(args): + device=torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + os.makedirs(args.output_dir,exist_ok=True) + torch.manual_seed(args.seed);np.random.seed(args.seed);torch.cuda.manual_seed_all(args.seed) + train_loader,test_loader=get_cifar10(args.batch_size) + input_dim=32*32*3;L=args.num_blocks;d=args.d_hidden + + # Train DFA baseline + print(f"\n{'='*60}\nTraining DFA baseline\n{'='*60}") + model_dfa=ResidualMLP(input_dim,d,10,L).to(device) + Bs,ckpt=train_dfa_get_checkpoint(model_dfa,train_loader,test_loader,device,args.epochs,args.t0,args.lr,args.wd) + print(f" Checkpoint acc: {ckpt['acc']:.4f}") + + # Create frozen random Vec for reference + torch.manual_seed(args.seed+7777) + vec_frozen_ref=VectorCreditNet(d_hidden=d,s_dim=10).to(device) + vec_frozen_ref.eval() + + # Estimate norm ratio for scaled_DFA + model_ref=ResidualMLP(input_dim,d,10,L).to(device) + model_ref.load_state_dict(ckpt['model']);model_ref.eval() + for p in model_ref.parameters():p.requires_grad_(False) + beta=estimate_blend_norm_ratio(model_ref,ckpt['Bs'],vec_frozen_ref,train_loader,device,args.alpha) + print(f" Estimated beta (norm ratio): {beta:.4f}") + del model_ref + + # Define branches + branches=[ + ('continue_DFA','dfa',False,None), + ('blend_random_frozen','blend_frozen',False,'frozen'), + ('blend_random_trainable','blend_trainable',True,'trainable'), + ('blend_shuffled_trainable','blend_shuffled',True,'shuffled'), + ('blend_gaussian_noise','blend_gaussian',False,'gaussian'), + ('scaled_DFA_norm_match','scaled_dfa',False,None), + ] + + all_results={} + for bname,btype,needs_trainable_vec,vec_mode in branches: + print(f"\n{'='*60}\n{bname}\n{'='*60}") + model_b=ResidualMLP(input_dim,d,10,L).to(device) + model_b.load_state_dict(ckpt['model']) + + # Prepare Vec + if btype in ['blend_frozen','blend_gaussian']: + torch.manual_seed(args.seed+7777) # same random Vec as reference + vec_b=VectorCreditNet(d_hidden=d,s_dim=10).to(device) + vec_b.eval() + for p in vec_b.parameters():p.requires_grad_(False) + elif needs_trainable_vec: + torch.manual_seed(args.seed+7777) # same init + vec_b=VectorCreditNet(d_hidden=d,s_dim=10).to(device) + else: + vec_b=None + + log=run_branch(model_b,vec_b,ckpt['Bs'],train_loader,test_loader,device, + args.t0,args.epochs,btype,args.alpha,args.lr,args.lr_fb,args.wd,args.M, + beta_scale=beta,branch_name=bname) + all_results[bname]=log + print(f" {bname} final: {log['test_acc'][-1]:.4f}") + + # Summary + dfa_final=all_results['continue_DFA']['test_acc'][-1] + dfa_acc20=all_results['continue_DFA']['test_acc'][20-args.t0-1] if len(all_results['continue_DFA']['test_acc'])>20-args.t0-1 else dfa_final + + print(f"\n{'='*90}") + print("SUMMARY") + print(f"{'='*90}") + print(f"{'Branch':<30} {'acc@20':>7} {'final':>7} {'diff':>7} {'mG_5:15':>8} {'mr_5:15':>8} {'aeff':>7}") + print("-"*78) + + for bname,log in all_results.items(): + accs=log['test_acc'] + acc20=accs[20-args.t0-1] if len(accs)>20-args.t0-1 else accs[-1] + final=accs[-1] + diff=final-dfa_final + + gammas_early=[g for e,g in log['gamma'] if args.t0<e<=args.t0+15] + rhos_early=[r for e,r in log['rho'] if args.t0<e<=args.t0+15] + aeffs_early=[a for e,a in log['alpha_eff'] if args.t0<e<=args.t0+15] + mg=np.mean(gammas_early) if gammas_early else float('nan') + mr=np.mean(rhos_early) if rhos_early else float('nan') + mae=np.mean(aeffs_early) if aeffs_early else float('nan') + + print(f"{bname:<30} {acc20:>7.4f} {final:>7.4f} {diff:>+7.4f} {mg:>8.4f} {mr:>8.4f} {mae:>7.3f}") + + # Save + save_data={'beta':beta} + for bname,log in all_results.items(): + save_data[bname]={'test_acc':log['test_acc'],'train_loss':log['train_loss'], + 'gamma':log['gamma'],'rho':log['rho'],'alpha_eff':log['alpha_eff']} + out_path=os.path.join(args.output_dir,f'dissection_t{args.t0}_s{args.seed}.json') + with open(out_path,'w') as f:json.dump(save_data,f,indent=2,default=float) + print(f"\nSaved to {out_path}") + + # Judgment + print(f"\n{'='*60}\nJUDGMENT\n{'='*60}") + rf=all_results.get('blend_random_frozen',{}).get('test_acc',[-1])[-1] + rt=all_results.get('blend_random_trainable',{}).get('test_acc',[-1])[-1] + sh=all_results.get('blend_shuffled_trainable',{}).get('test_acc',[-1])[-1] + gn=all_results.get('blend_gaussian_noise',{}).get('test_acc',[-1])[-1] + sd=all_results.get('scaled_DFA_norm_match',{}).get('test_acc',[-1])[-1] + + print(f" DFA={dfa_final:.4f}, rf={rf:.4f}, rt={rt:.4f}, sh={sh:.4f}, gn={gn:.4f}, sd={sd:.4f}") + + if abs(rf-rt)<0.005 and abs(rf-sh)<0.005 and abs(rf-gn)<0.005: + if abs(rf-sd)<0.005: + print(" -> ALL SIMILAR including scaled_DFA: gain is primarily norm/step-size effect") + else: + print(" -> rf≈rt≈sh≈gn but ≠ scaled_DFA: gain is diversification, not just norm") + elif rt>rf+0.005: + print(" -> random_trainable > random_frozen: online learning contributes") + else: + print(" -> Mixed signal, see table for details") + + +def main(): + parser=argparse.ArgumentParser(description='Phase 10A.5: Blend Mechanism Dissection') + parser.add_argument('--num_blocks',type=int,default=4) + parser.add_argument('--d_hidden',type=int,default=256) + parser.add_argument('--batch_size',type=int,default=128) + parser.add_argument('--epochs',type=int,default=100) + parser.add_argument('--t0',type=int,default=5) + parser.add_argument('--alpha',type=float,default=0.75) + parser.add_argument('--lr',type=float,default=1e-3) + parser.add_argument('--lr_fb',type=float,default=1e-3) + parser.add_argument('--wd',type=float,default=0.01) + parser.add_argument('--M',type=int,default=4) + parser.add_argument('--seed',type=int,default=42) + parser.add_argument('--gpu',type=int,default=2) + parser.add_argument('--output_dir',type=str,default='results/blend_dissection') + args=parser.parse_args() + run_experiment(args) + + +if __name__=='__main__': + main() |
