""" Phase 10A.5: Blend Mechanism Dissection. Core question: Phase 10A's gain from blend(random Vec, DFA) — is it learned correction or blend/diversification/norm effect? 6 branches from the same DFA checkpoint: 1. continue_DFA — baseline 2. blend_random_frozen — random Vec frozen, no training 3. blend_random_trainable — random Vec, trained online after handoff 4. blend_shuffled_trainable — Vec trained but targets permuted (no semantics) 5. blend_gaussian_noise_matched — no Vec, just RMS-matched Gaussian noise 6. scaled_DFA_match_norm — DFA only but effective norm matched to blend """ import os import sys import json import argparse import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader import torchvision import torchvision.transforms as transforms import copy sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.residual_mlp import ResidualMLP from models.value_net import SinusoidalTimeEmbed from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation class VectorCreditNet(nn.Module): def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3): super().__init__() self.ln = nn.LayerNorm(d_hidden) self.time_embed = SinusoidalTimeEmbed(time_embed_dim) input_dim = d_hidden + time_embed_dim + s_dim layers = [] for i in range(num_layers): in_d = input_dim if i == 0 else hidden_dim layers.append(nn.Linear(in_d, hidden_dim)) layers.append(nn.GELU()) layers.append(nn.Linear(hidden_dim, d_hidden)) self.net = nn.Sequential(*layers) def forward(self, h, t, s): return self.net(torch.cat([self.ln(h), self.time_embed(t), s], dim=-1)) def get_cifar10(batch_size=128): transform_train = transforms.Compose([ transforms.RandomCrop(32,padding=4),transforms.RandomHorizontalFlip(), transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))]) transform_test = transforms.Compose([ transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))]) trainset = torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform_train) testset = torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform_test) return (DataLoader(trainset,batch_size=batch_size,shuffle=True,num_workers=4,pin_memory=True), DataLoader(testset,batch_size=batch_size,shuffle=False,num_workers=4,pin_memory=True)) def evaluate(model, test_loader, device): model.eval();c,t=0,0 with torch.no_grad(): for x,y in test_loader: x=x.view(x.size(0),-1).to(device);y=y.to(device) c+=(model(x).argmax(1)==y).sum().item();t+=x.size(0) return c/t def train_dfa_get_checkpoint(model, train_loader, test_loader, device, total_epochs, t0, lr, wd): d=model.d_hidden;L=model.num_blocks Bs=[torch.randn(d,10,device=device)/np.sqrt(10) for _ in range(L)] block_opts=[optim.AdamW(b.parameters(),lr=lr,weight_decay=wd) for b in model.blocks] embed_opt=optim.AdamW(model.embed.parameters(),lr=lr,weight_decay=wd) head_opt=optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()),lr=lr,weight_decay=wd) scheds=[optim.lr_scheduler.CosineAnnealingLR(o,T_max=total_epochs) for o in block_opts]+\ [optim.lr_scheduler.CosineAnnealingLR(embed_opt,T_max=total_epochs), optim.lr_scheduler.CosineAnnealingLR(head_opt,T_max=total_epochs)] ckpt=None for epoch in range(1,total_epochs+1): model.train();tl,c,t=0,0,0 for x,y in train_loader: x=x.view(x.size(0),-1).to(device);y=y.to(device);b=x.size(0) with torch.no_grad(): lo,hi=model(x,return_hidden=True);lv=F.cross_entropy(lo,y) eT=lo.softmax(-1);eT[torch.arange(b),y]-=1 hL=hi[-1].detach() lo2=F.cross_entropy(model.out_head(model.out_ln(hL)),y) head_opt.zero_grad();lo2.backward();head_opt.step() for l in range(L): a=(eT@Bs[l].T).detach();rm=(a**2).mean(-1,keepdim=True).sqrt()+1e-6 f=model.blocks[l](hi[l].detach());ll=(f*(a/rm)).sum(-1).mean() block_opts[l].zero_grad();ll.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0);block_opts[l].step() a0=(eT@Bs[0].T).detach();r0=(a0**2).mean(-1,keepdim=True).sqrt()+1e-6 el=(model.embed(x)*(a0/r0)).sum(-1).mean() embed_opt.zero_grad();el.backward();embed_opt.step() tl+=lv.item()*b;c+=(lo.argmax(1)==y).sum().item();t+=b for s in scheds:s.step() if epoch==t0: acc=evaluate(model,test_loader,device) ckpt={'model':copy.deepcopy(model.state_dict()),'Bs':[B.clone() for B in Bs],'acc':acc} print(f" [DFA] Checkpoint at epoch {t0}: acc={acc:.4f}") if epoch%20==0:print(f" [DFA] Epoch {epoch}: acc={evaluate(model,test_loader,device):.4f}") return Bs,ckpt def estimate_blend_norm_ratio(model, Bs, vec_net, train_loader, device, alpha, n_batches=50): """Estimate ratio of blend update norm to DFA-only update norm (per block).""" model.eval();L=model.num_blocks dfa_norms=[0.0]*L;blend_norms=[0.0]*L;n=0 for i,(x,y) in enumerate(train_loader): if i>=n_batches:break x=x.view(x.size(0),-1).to(device);y=y.to(device);batch=x.size(0) with torch.no_grad(): lo,hi=model(x,return_hidden=True) eT=lo.softmax(-1);eT[torch.arange(batch),y]-=1;s=eT.detach() for l in range(L): h_l=hi[l].detach();t_l=torch.full((batch,),l/L,device=device) a_dfa=(eT@Bs[l].T).detach() a_vec=vec_net(h_l,t_l,s).detach() if vec_net else torch.zeros_like(a_dfa) rms_d=(a_dfa**2).mean(-1,keepdim=True).sqrt()+1e-6 rms_v=(a_vec**2).mean(-1,keepdim=True).sqrt()+1e-6 a_blend=alpha*a_vec/rms_v+(1-alpha)*a_dfa/rms_d a_dfa_norm=a_dfa/rms_d # Compute block update norms (proxy: just use credit norm) dfa_norms[l]+=(a_dfa_norm**2).mean().item() blend_norms[l]+=(a_blend**2).mean().item() n+=1 ratios=[((blend_norms[l]/n)/(dfa_norms[l]/n+1e-12))**0.5 for l in range(L)] return float(np.mean(ratios)) def compute_diagnostics(model, vec_net, Bs, test_loader, device, credit_mode, alpha=0.75): """Compute Gamma, rho for current credit source.""" model.eval() if vec_net:vec_net.eval() L=model.num_blocks for x,y in test_loader: x=x.view(x.size(0),-1).to(device);y=y.to(device);break batch=x.size(0) was_frozen=not next(model.parameters()).requires_grad if was_frozen: for p in model.parameters():p.requires_grad_(True) model.zero_grad() lo,hbp=model(x,return_hidden=True) for l in range(L+1):hbp[l].retain_grad() F.cross_entropy(lo,y).backward() bp={l:hbp[l].grad.detach().clone() for l in range(L+1)} if was_frozen: for p in model.parameters():p.requires_grad_(False) with torch.no_grad(): lo2,hi=model(x,return_hidden=True) eT=lo2.softmax(-1);eT[torch.arange(batch),y]-=1;s=eT.detach() gammas,rhos=[],[] for l in range(L): h_l=hi[l].detach();t_l=torch.full((batch,),l/L,device=device) if credit_mode=='dfa': a_l=(s@Bs[l].T).detach() elif credit_mode=='vec' and vec_net: a_l=vec_net(h_l,t_l,s).detach() elif credit_mode=='blend' and vec_net: a_dfa=(s@Bs[l].T).detach();a_vec=vec_net(h_l,t_l,s).detach() rv=(a_vec**2).mean(-1,keepdim=True).sqrt()+1e-6;rd=(a_dfa**2).mean(-1,keepdim=True).sqrt()+1e-6 a_l=alpha*a_vec/rv+(1-alpha)*a_dfa/rd else: a_l=(s@Bs[l].T).detach() gammas.append(cosine_similarity_batch(a_l,bp[l])) def make_fwd(sl): def f(h): with torch.no_grad(): c=h for i in range(sl,L):c=c+model.blocks[i](c) return F.cross_entropy(model.out_head(model.out_ln(c)),y,reduction='none') return f rhos.append(perturbation_correlation(h_l,a_l,make_fwd(l),epsilon=1e-3,M=16)) return float(np.mean(gammas)),float(np.mean(rhos)) def run_branch(model, vec_net, Bs, train_loader, test_loader, device, t0, total_epochs, branch_type, alpha, lr, lr_fb, wd, M, beta_scale=1.0, branch_name=''): """ Run a training branch from checkpoint. branch_type: 'dfa', 'blend_frozen', 'blend_trainable', 'blend_shuffled', 'blend_gaussian', 'scaled_dfa' """ d=model.d_hidden;L=model.num_blocks;eps_pert=1e-3 block_opts=[optim.AdamW(b.parameters(),lr=lr,weight_decay=wd) for b in model.blocks] embed_opt=optim.AdamW(model.embed.parameters(),lr=lr,weight_decay=wd) head_opt=optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()),lr=lr,weight_decay=wd) vec_opt=optim.Adam(vec_net.parameters(),lr=lr_fb) if (vec_net and branch_type in ['blend_trainable','blend_shuffled']) else None scheds=[optim.lr_scheduler.CosineAnnealingLR(o,T_max=total_epochs) for o in block_opts]+\ [optim.lr_scheduler.CosineAnnealingLR(embed_opt,T_max=total_epochs), optim.lr_scheduler.CosineAnnealingLR(head_opt,T_max=total_epochs)] for _ in range(t0): for s in scheds:s.step() # Get reference RMS for gaussian noise matching ref_rms=[0.01]*L # default if branch_type=='blend_gaussian' and vec_net: # Estimate RMS from frozen random Vec model.eval() for x,y in train_loader: x=x.view(x.size(0),-1).to(device);y=y.to(device);batch=x.size(0) with torch.no_grad(): lo,hi=model(x,return_hidden=True) eT=lo.softmax(-1);eT[torch.arange(batch),y]-=1;s=eT.detach() for l in range(L): h_l=hi[l].detach();t_l=torch.full((batch,),l/L,device=device) a_v=vec_net(h_l,t_l,s).detach() ref_rms[l]=(a_v**2).mean().sqrt().item() break log={'test_acc':[],'train_loss':[],'gamma':[],'rho':[],'alpha_eff':[]} diag_epochs=set(list(range(t0+1,min(t0+6,total_epochs+1)))+[t0+8,t0+10,t0+15,t0+20]+ list(range(t0+10,total_epochs+1,10))+[total_epochs]) for epoch in range(t0+1,total_epochs+1): model.train() if vec_net and vec_opt:vec_net.train() tl,c,t=0,0,0 epoch_aux_norms,epoch_dfa_norms=[],[] for x,y in train_loader: x=x.view(x.size(0),-1).to(device);y=y.to(device);batch=x.size(0) with torch.no_grad(): lo,hi=model(x,return_hidden=True);lv=F.cross_entropy(lo,y) eT=lo.softmax(-1);eT[torch.arange(batch),y]-=1;s=eT.detach() hL=hi[-1].detach() # Train Vec if applicable if vec_opt and branch_type in ['blend_trainable','blend_shuffled']: t_L=torch.ones(batch,device=device) a_term=vec_net(hL,t_L,s) hL_req=hL.clone().requires_grad_(True) ce=F.cross_entropy(model.out_head(model.out_ln(hL_req)),y,reduction='sum') dL=torch.autograd.grad(ce,hL_req,create_graph=False)[0].detach() loss_term=((a_term-dL)**2).sum(-1).mean() lt=np.random.randint(0,L);h_l=hi[lt].detach();t_l=torch.full((batch,),lt/L,device=device) a_l=vec_net(h_l,t_l,s);lp2=torch.tensor(0.0,device=device) for _ in range(M): v=torch.randn_like(h_l);v=v/(v.norm(-1,keepdim=True)+1e-8) with torch.no_grad(): lp=F.cross_entropy(model.forward_from_layer(h_l+eps_pert*v,lt),y,reduction='none') lm=F.cross_entropy(model.forward_from_layer(h_l-eps_pert*v,lt),y,reduction='none') gj=(lp-lm)/(2*eps_pert) if branch_type=='blend_shuffled': gj=gj[torch.randperm(batch,device=device)] # Shuffle targets lp2=lp2+(((a_l*v).sum(-1)-gj.detach())**2).mean() lp2/=M vl=loss_term+lp2;vec_opt.zero_grad();vl.backward() torch.nn.utils.clip_grad_norm_(vec_net.parameters(),1.0);vec_opt.step() # Compute credits per block dfa_credits=[(eT@Bs[l].T).detach() for l in range(L)] credits=[] for l in range(L): a_dfa=dfa_credits[l] rms_d=(a_dfa**2).mean(-1,keepdim=True).sqrt()+1e-6 if branch_type=='dfa': credits.append(a_dfa/rms_d) elif branch_type=='scaled_dfa': credits.append(beta_scale*a_dfa/rms_d) elif branch_type=='blend_gaussian': xi=torch.randn_like(a_dfa) rms_xi=(xi**2).mean(-1,keepdim=True).sqrt()+1e-6 xi_matched=xi/rms_xi*ref_rms[l] # match to Vec RMS rms_xi2=(xi_matched**2).mean(-1,keepdim=True).sqrt()+1e-6 credits.append(alpha*xi_matched/rms_xi2+(1-alpha)*a_dfa/rms_d) else: # blend with Vec (frozen or trainable or shuffled) h_l=hi[l].detach();t_l=torch.full((batch,),l/L,device=device) with torch.no_grad(): a_vec=vec_net(h_l,t_l,s).detach() rms_v=(a_vec**2).mean(-1,keepdim=True).sqrt()+1e-6 credits.append(alpha*a_vec/rms_v+(1-alpha)*a_dfa/rms_d) # Track update geometry a_blend=credits[-1] a_dfa_n=a_dfa/rms_d aux_norm=(alpha*a_blend).norm().item() if branch_type!='dfa' else 0 dfa_norm=((1-alpha)*a_dfa_n).norm().item() if branch_type not in ['dfa','scaled_dfa'] else a_blend.norm().item() epoch_aux_norms.append(aux_norm) epoch_dfa_norms.append(dfa_norm) # Update head lo2=F.cross_entropy(model.out_head(model.out_ln(hL)),y) head_opt.zero_grad();lo2.backward();head_opt.step() # Update blocks for l in range(L): a=credits[l];rm=(a**2).mean(-1,keepdim=True).sqrt()+1e-6 f=model.blocks[l](hi[l].detach());ll=(f*(a/rm)).sum(-1).mean() block_opts[l].zero_grad();ll.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0);block_opts[l].step() a0=credits[0];r0=(a0**2).mean(-1,keepdim=True).sqrt()+1e-6 el=(model.embed(x)*(a0/r0)).sum(-1).mean() embed_opt.zero_grad();el.backward();embed_opt.step() tl+=lv.item()*batch;c+=(lo.argmax(1)==y).sum().item();t+=batch for sch in scheds:sch.step() ta=evaluate(model,test_loader,device) log['test_acc'].append(ta);log['train_loss'].append(tl/t) mean_aux=np.mean(epoch_aux_norms) if epoch_aux_norms else 0 mean_dfa=np.mean(epoch_dfa_norms) if epoch_dfa_norms else 1 aeff=mean_aux/(mean_aux+mean_dfa+1e-12) log['alpha_eff'].append((epoch,aeff)) if epoch in diag_epochs: cm='vec' if branch_type in ['blend_trainable','blend_shuffled'] and vec_net else 'dfa' gamma,rho=compute_diagnostics(model,vec_net,Bs,test_loader,device,cm) log['gamma'].append((epoch,gamma));log['rho'].append((epoch,rho)) if epoch<=t0+15 or epoch%20==0 or epoch==total_epochs: print(f" [{branch_name}] Ep {epoch}: acc={ta:.4f}, G={gamma:.4f}, r={rho:.4f}, aeff={aeff:.3f}") elif epoch%20==0 or epoch==total_epochs: print(f" [{branch_name}] Ep {epoch}: acc={ta:.4f}") return log def run_experiment(args): device=torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") os.makedirs(args.output_dir,exist_ok=True) torch.manual_seed(args.seed);np.random.seed(args.seed);torch.cuda.manual_seed_all(args.seed) train_loader,test_loader=get_cifar10(args.batch_size) input_dim=32*32*3;L=args.num_blocks;d=args.d_hidden # Train DFA baseline print(f"\n{'='*60}\nTraining DFA baseline\n{'='*60}") model_dfa=ResidualMLP(input_dim,d,10,L).to(device) Bs,ckpt=train_dfa_get_checkpoint(model_dfa,train_loader,test_loader,device,args.epochs,args.t0,args.lr,args.wd) print(f" Checkpoint acc: {ckpt['acc']:.4f}") # Create frozen random Vec for reference torch.manual_seed(args.seed+7777) vec_frozen_ref=VectorCreditNet(d_hidden=d,s_dim=10).to(device) vec_frozen_ref.eval() # Estimate norm ratio for scaled_DFA model_ref=ResidualMLP(input_dim,d,10,L).to(device) model_ref.load_state_dict(ckpt['model']);model_ref.eval() for p in model_ref.parameters():p.requires_grad_(False) beta=estimate_blend_norm_ratio(model_ref,ckpt['Bs'],vec_frozen_ref,train_loader,device,args.alpha) print(f" Estimated beta (norm ratio): {beta:.4f}") del model_ref # Define branches branches=[ ('continue_DFA','dfa',False,None), ('blend_random_frozen','blend_frozen',False,'frozen'), ('blend_random_trainable','blend_trainable',True,'trainable'), ('blend_shuffled_trainable','blend_shuffled',True,'shuffled'), ('blend_gaussian_noise','blend_gaussian',False,'gaussian'), ('scaled_DFA_norm_match','scaled_dfa',False,None), ] all_results={} for bname,btype,needs_trainable_vec,vec_mode in branches: print(f"\n{'='*60}\n{bname}\n{'='*60}") model_b=ResidualMLP(input_dim,d,10,L).to(device) model_b.load_state_dict(ckpt['model']) # Prepare Vec if btype in ['blend_frozen','blend_gaussian']: torch.manual_seed(args.seed+7777) # same random Vec as reference vec_b=VectorCreditNet(d_hidden=d,s_dim=10).to(device) vec_b.eval() for p in vec_b.parameters():p.requires_grad_(False) elif needs_trainable_vec: torch.manual_seed(args.seed+7777) # same init vec_b=VectorCreditNet(d_hidden=d,s_dim=10).to(device) else: vec_b=None log=run_branch(model_b,vec_b,ckpt['Bs'],train_loader,test_loader,device, args.t0,args.epochs,btype,args.alpha,args.lr,args.lr_fb,args.wd,args.M, beta_scale=beta,branch_name=bname) all_results[bname]=log print(f" {bname} final: {log['test_acc'][-1]:.4f}") # Summary dfa_final=all_results['continue_DFA']['test_acc'][-1] dfa_acc20=all_results['continue_DFA']['test_acc'][20-args.t0-1] if len(all_results['continue_DFA']['test_acc'])>20-args.t0-1 else dfa_final print(f"\n{'='*90}") print("SUMMARY") print(f"{'='*90}") print(f"{'Branch':<30} {'acc@20':>7} {'final':>7} {'diff':>7} {'mG_5:15':>8} {'mr_5:15':>8} {'aeff':>7}") print("-"*78) for bname,log in all_results.items(): accs=log['test_acc'] acc20=accs[20-args.t0-1] if len(accs)>20-args.t0-1 else accs[-1] final=accs[-1] diff=final-dfa_final gammas_early=[g for e,g in log['gamma'] if args.t07.4f} {final:>7.4f} {diff:>+7.4f} {mg:>8.4f} {mr:>8.4f} {mae:>7.3f}") # Save save_data={'beta':beta} for bname,log in all_results.items(): save_data[bname]={'test_acc':log['test_acc'],'train_loss':log['train_loss'], 'gamma':log['gamma'],'rho':log['rho'],'alpha_eff':log['alpha_eff']} out_path=os.path.join(args.output_dir,f'dissection_t{args.t0}_s{args.seed}.json') with open(out_path,'w') as f:json.dump(save_data,f,indent=2,default=float) print(f"\nSaved to {out_path}") # Judgment print(f"\n{'='*60}\nJUDGMENT\n{'='*60}") rf=all_results.get('blend_random_frozen',{}).get('test_acc',[-1])[-1] rt=all_results.get('blend_random_trainable',{}).get('test_acc',[-1])[-1] sh=all_results.get('blend_shuffled_trainable',{}).get('test_acc',[-1])[-1] gn=all_results.get('blend_gaussian_noise',{}).get('test_acc',[-1])[-1] sd=all_results.get('scaled_DFA_norm_match',{}).get('test_acc',[-1])[-1] print(f" DFA={dfa_final:.4f}, rf={rf:.4f}, rt={rt:.4f}, sh={sh:.4f}, gn={gn:.4f}, sd={sd:.4f}") if abs(rf-rt)<0.005 and abs(rf-sh)<0.005 and abs(rf-gn)<0.005: if abs(rf-sd)<0.005: print(" -> ALL SIMILAR including scaled_DFA: gain is primarily norm/step-size effect") else: print(" -> rf≈rt≈sh≈gn but ≠ scaled_DFA: gain is diversification, not just norm") elif rt>rf+0.005: print(" -> random_trainable > random_frozen: online learning contributes") else: print(" -> Mixed signal, see table for details") def main(): parser=argparse.ArgumentParser(description='Phase 10A.5: Blend Mechanism Dissection') parser.add_argument('--num_blocks',type=int,default=4) parser.add_argument('--d_hidden',type=int,default=256) parser.add_argument('--batch_size',type=int,default=128) parser.add_argument('--epochs',type=int,default=100) parser.add_argument('--t0',type=int,default=5) parser.add_argument('--alpha',type=float,default=0.75) parser.add_argument('--lr',type=float,default=1e-3) parser.add_argument('--lr_fb',type=float,default=1e-3) parser.add_argument('--wd',type=float,default=0.01) parser.add_argument('--M',type=int,default=4) parser.add_argument('--seed',type=int,default=42) parser.add_argument('--gpu',type=int,default=2) parser.add_argument('--output_dir',type=str,default='results/blend_dissection') args=parser.parse_args() run_experiment(args) if __name__=='__main__': main()