summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-03-26 16:27:53 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-03-26 16:27:53 -0500
commit610e1169e19378cccd2d9b92a588c24dca7f3df7 (patch)
tree532f8dc2fda6c68ab1409b20d7431b76d8d6f378 /experiments
parentef4aed70130e2212b4ed1cb7212e2ea6c7c7adb2 (diff)
Add Phase 10A.5: blend gain is implicit regularization, not learned credit
Dissection of 6 branches from same DFA checkpoint: - blend_random_frozen: 12.6% (CATASTROPHIC — frozen noise destroys training) - blend_random_trainable: 32.2% (+1.2% — trainable network helps) - blend_shuffled_trainable: 32.5% (+1.4% — even wrong targets work!) - blend_gaussian_noise: 30.8% (neutral) - scaled_DFA_norm_match: 31.0% (neutral) The gain comes from implicit regularization through a co-optimized auxiliary network, NOT from learned credit quality. Phase 9A's +1.5% was an optimization dynamics effect, not evidence of useful credit assignment. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
-rw-r--r--experiments/blend_mechanism_dissection.py466
1 files changed, 466 insertions, 0 deletions
diff --git a/experiments/blend_mechanism_dissection.py b/experiments/blend_mechanism_dissection.py
new file mode 100644
index 0000000..a193524
--- /dev/null
+++ b/experiments/blend_mechanism_dissection.py
@@ -0,0 +1,466 @@
+"""
+Phase 10A.5: Blend Mechanism Dissection.
+
+Core question: Phase 10A's gain from blend(random Vec, DFA) — is it learned
+correction or blend/diversification/norm effect?
+
+6 branches from the same DFA checkpoint:
+1. continue_DFA — baseline
+2. blend_random_frozen — random Vec frozen, no training
+3. blend_random_trainable — random Vec, trained online after handoff
+4. blend_shuffled_trainable — Vec trained but targets permuted (no semantics)
+5. blend_gaussian_noise_matched — no Vec, just RMS-matched Gaussian noise
+6. scaled_DFA_match_norm — DFA only but effective norm matched to blend
+"""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+import copy
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.residual_mlp import ResidualMLP
+from models.value_net import SinusoidalTimeEmbed
+from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation
+
+
+class VectorCreditNet(nn.Module):
+ def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.time_embed = SinusoidalTimeEmbed(time_embed_dim)
+ input_dim = d_hidden + time_embed_dim + s_dim
+ layers = []
+ for i in range(num_layers):
+ in_d = input_dim if i == 0 else hidden_dim
+ layers.append(nn.Linear(in_d, hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, d_hidden))
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, h, t, s):
+ return self.net(torch.cat([self.ln(h), self.time_embed(t), s], dim=-1))
+
+
+def get_cifar10(batch_size=128):
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32,padding=4),transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))])
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))])
+ trainset = torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform_train)
+ testset = torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform_test)
+ return (DataLoader(trainset,batch_size=batch_size,shuffle=True,num_workers=4,pin_memory=True),
+ DataLoader(testset,batch_size=batch_size,shuffle=False,num_workers=4,pin_memory=True))
+
+
+def evaluate(model, test_loader, device):
+ model.eval();c,t=0,0
+ with torch.no_grad():
+ for x,y in test_loader:
+ x=x.view(x.size(0),-1).to(device);y=y.to(device)
+ c+=(model(x).argmax(1)==y).sum().item();t+=x.size(0)
+ return c/t
+
+
+def train_dfa_get_checkpoint(model, train_loader, test_loader, device, total_epochs, t0, lr, wd):
+ d=model.d_hidden;L=model.num_blocks
+ Bs=[torch.randn(d,10,device=device)/np.sqrt(10) for _ in range(L)]
+ block_opts=[optim.AdamW(b.parameters(),lr=lr,weight_decay=wd) for b in model.blocks]
+ embed_opt=optim.AdamW(model.embed.parameters(),lr=lr,weight_decay=wd)
+ head_opt=optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()),lr=lr,weight_decay=wd)
+ scheds=[optim.lr_scheduler.CosineAnnealingLR(o,T_max=total_epochs) for o in block_opts]+\
+ [optim.lr_scheduler.CosineAnnealingLR(embed_opt,T_max=total_epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt,T_max=total_epochs)]
+ ckpt=None
+ for epoch in range(1,total_epochs+1):
+ model.train();tl,c,t=0,0,0
+ for x,y in train_loader:
+ x=x.view(x.size(0),-1).to(device);y=y.to(device);b=x.size(0)
+ with torch.no_grad():
+ lo,hi=model(x,return_hidden=True);lv=F.cross_entropy(lo,y)
+ eT=lo.softmax(-1);eT[torch.arange(b),y]-=1
+ hL=hi[-1].detach()
+ lo2=F.cross_entropy(model.out_head(model.out_ln(hL)),y)
+ head_opt.zero_grad();lo2.backward();head_opt.step()
+ for l in range(L):
+ a=(eT@Bs[l].T).detach();rm=(a**2).mean(-1,keepdim=True).sqrt()+1e-6
+ f=model.blocks[l](hi[l].detach());ll=(f*(a/rm)).sum(-1).mean()
+ block_opts[l].zero_grad();ll.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0);block_opts[l].step()
+ a0=(eT@Bs[0].T).detach();r0=(a0**2).mean(-1,keepdim=True).sqrt()+1e-6
+ el=(model.embed(x)*(a0/r0)).sum(-1).mean()
+ embed_opt.zero_grad();el.backward();embed_opt.step()
+ tl+=lv.item()*b;c+=(lo.argmax(1)==y).sum().item();t+=b
+ for s in scheds:s.step()
+ if epoch==t0:
+ acc=evaluate(model,test_loader,device)
+ ckpt={'model':copy.deepcopy(model.state_dict()),'Bs':[B.clone() for B in Bs],'acc':acc}
+ print(f" [DFA] Checkpoint at epoch {t0}: acc={acc:.4f}")
+ if epoch%20==0:print(f" [DFA] Epoch {epoch}: acc={evaluate(model,test_loader,device):.4f}")
+ return Bs,ckpt
+
+
+def estimate_blend_norm_ratio(model, Bs, vec_net, train_loader, device, alpha, n_batches=50):
+ """Estimate ratio of blend update norm to DFA-only update norm (per block)."""
+ model.eval();L=model.num_blocks
+ dfa_norms=[0.0]*L;blend_norms=[0.0]*L;n=0
+ for i,(x,y) in enumerate(train_loader):
+ if i>=n_batches:break
+ x=x.view(x.size(0),-1).to(device);y=y.to(device);batch=x.size(0)
+ with torch.no_grad():
+ lo,hi=model(x,return_hidden=True)
+ eT=lo.softmax(-1);eT[torch.arange(batch),y]-=1;s=eT.detach()
+ for l in range(L):
+ h_l=hi[l].detach();t_l=torch.full((batch,),l/L,device=device)
+ a_dfa=(eT@Bs[l].T).detach()
+ a_vec=vec_net(h_l,t_l,s).detach() if vec_net else torch.zeros_like(a_dfa)
+ rms_d=(a_dfa**2).mean(-1,keepdim=True).sqrt()+1e-6
+ rms_v=(a_vec**2).mean(-1,keepdim=True).sqrt()+1e-6
+ a_blend=alpha*a_vec/rms_v+(1-alpha)*a_dfa/rms_d
+ a_dfa_norm=a_dfa/rms_d
+ # Compute block update norms (proxy: just use credit norm)
+ dfa_norms[l]+=(a_dfa_norm**2).mean().item()
+ blend_norms[l]+=(a_blend**2).mean().item()
+ n+=1
+ ratios=[((blend_norms[l]/n)/(dfa_norms[l]/n+1e-12))**0.5 for l in range(L)]
+ return float(np.mean(ratios))
+
+
+def compute_diagnostics(model, vec_net, Bs, test_loader, device, credit_mode, alpha=0.75):
+ """Compute Gamma, rho for current credit source."""
+ model.eval()
+ if vec_net:vec_net.eval()
+ L=model.num_blocks
+ for x,y in test_loader:
+ x=x.view(x.size(0),-1).to(device);y=y.to(device);break
+ batch=x.size(0)
+ was_frozen=not next(model.parameters()).requires_grad
+ if was_frozen:
+ for p in model.parameters():p.requires_grad_(True)
+ model.zero_grad()
+ lo,hbp=model(x,return_hidden=True)
+ for l in range(L+1):hbp[l].retain_grad()
+ F.cross_entropy(lo,y).backward()
+ bp={l:hbp[l].grad.detach().clone() for l in range(L+1)}
+ if was_frozen:
+ for p in model.parameters():p.requires_grad_(False)
+ with torch.no_grad():
+ lo2,hi=model(x,return_hidden=True)
+ eT=lo2.softmax(-1);eT[torch.arange(batch),y]-=1;s=eT.detach()
+ gammas,rhos=[],[]
+ for l in range(L):
+ h_l=hi[l].detach();t_l=torch.full((batch,),l/L,device=device)
+ if credit_mode=='dfa':
+ a_l=(s@Bs[l].T).detach()
+ elif credit_mode=='vec' and vec_net:
+ a_l=vec_net(h_l,t_l,s).detach()
+ elif credit_mode=='blend' and vec_net:
+ a_dfa=(s@Bs[l].T).detach();a_vec=vec_net(h_l,t_l,s).detach()
+ rv=(a_vec**2).mean(-1,keepdim=True).sqrt()+1e-6;rd=(a_dfa**2).mean(-1,keepdim=True).sqrt()+1e-6
+ a_l=alpha*a_vec/rv+(1-alpha)*a_dfa/rd
+ else:
+ a_l=(s@Bs[l].T).detach()
+ gammas.append(cosine_similarity_batch(a_l,bp[l]))
+ def make_fwd(sl):
+ def f(h):
+ with torch.no_grad():
+ c=h
+ for i in range(sl,L):c=c+model.blocks[i](c)
+ return F.cross_entropy(model.out_head(model.out_ln(c)),y,reduction='none')
+ return f
+ rhos.append(perturbation_correlation(h_l,a_l,make_fwd(l),epsilon=1e-3,M=16))
+ return float(np.mean(gammas)),float(np.mean(rhos))
+
+
+def run_branch(model, vec_net, Bs, train_loader, test_loader, device, t0, total_epochs,
+ branch_type, alpha, lr, lr_fb, wd, M, beta_scale=1.0, branch_name=''):
+ """
+ Run a training branch from checkpoint.
+ branch_type: 'dfa', 'blend_frozen', 'blend_trainable', 'blend_shuffled',
+ 'blend_gaussian', 'scaled_dfa'
+ """
+ d=model.d_hidden;L=model.num_blocks;eps_pert=1e-3
+ block_opts=[optim.AdamW(b.parameters(),lr=lr,weight_decay=wd) for b in model.blocks]
+ embed_opt=optim.AdamW(model.embed.parameters(),lr=lr,weight_decay=wd)
+ head_opt=optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()),lr=lr,weight_decay=wd)
+ vec_opt=optim.Adam(vec_net.parameters(),lr=lr_fb) if (vec_net and branch_type in ['blend_trainable','blend_shuffled']) else None
+ scheds=[optim.lr_scheduler.CosineAnnealingLR(o,T_max=total_epochs) for o in block_opts]+\
+ [optim.lr_scheduler.CosineAnnealingLR(embed_opt,T_max=total_epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt,T_max=total_epochs)]
+ for _ in range(t0):
+ for s in scheds:s.step()
+
+ # Get reference RMS for gaussian noise matching
+ ref_rms=[0.01]*L # default
+ if branch_type=='blend_gaussian' and vec_net:
+ # Estimate RMS from frozen random Vec
+ model.eval()
+ for x,y in train_loader:
+ x=x.view(x.size(0),-1).to(device);y=y.to(device);batch=x.size(0)
+ with torch.no_grad():
+ lo,hi=model(x,return_hidden=True)
+ eT=lo.softmax(-1);eT[torch.arange(batch),y]-=1;s=eT.detach()
+ for l in range(L):
+ h_l=hi[l].detach();t_l=torch.full((batch,),l/L,device=device)
+ a_v=vec_net(h_l,t_l,s).detach()
+ ref_rms[l]=(a_v**2).mean().sqrt().item()
+ break
+
+ log={'test_acc':[],'train_loss':[],'gamma':[],'rho':[],'alpha_eff':[]}
+ diag_epochs=set(list(range(t0+1,min(t0+6,total_epochs+1)))+[t0+8,t0+10,t0+15,t0+20]+
+ list(range(t0+10,total_epochs+1,10))+[total_epochs])
+
+ for epoch in range(t0+1,total_epochs+1):
+ model.train()
+ if vec_net and vec_opt:vec_net.train()
+ tl,c,t=0,0,0
+ epoch_aux_norms,epoch_dfa_norms=[],[]
+
+ for x,y in train_loader:
+ x=x.view(x.size(0),-1).to(device);y=y.to(device);batch=x.size(0)
+ with torch.no_grad():
+ lo,hi=model(x,return_hidden=True);lv=F.cross_entropy(lo,y)
+ eT=lo.softmax(-1);eT[torch.arange(batch),y]-=1;s=eT.detach()
+ hL=hi[-1].detach()
+
+ # Train Vec if applicable
+ if vec_opt and branch_type in ['blend_trainable','blend_shuffled']:
+ t_L=torch.ones(batch,device=device)
+ a_term=vec_net(hL,t_L,s)
+ hL_req=hL.clone().requires_grad_(True)
+ ce=F.cross_entropy(model.out_head(model.out_ln(hL_req)),y,reduction='sum')
+ dL=torch.autograd.grad(ce,hL_req,create_graph=False)[0].detach()
+ loss_term=((a_term-dL)**2).sum(-1).mean()
+ lt=np.random.randint(0,L);h_l=hi[lt].detach();t_l=torch.full((batch,),lt/L,device=device)
+ a_l=vec_net(h_l,t_l,s);lp2=torch.tensor(0.0,device=device)
+ for _ in range(M):
+ v=torch.randn_like(h_l);v=v/(v.norm(-1,keepdim=True)+1e-8)
+ with torch.no_grad():
+ lp=F.cross_entropy(model.forward_from_layer(h_l+eps_pert*v,lt),y,reduction='none')
+ lm=F.cross_entropy(model.forward_from_layer(h_l-eps_pert*v,lt),y,reduction='none')
+ gj=(lp-lm)/(2*eps_pert)
+ if branch_type=='blend_shuffled':
+ gj=gj[torch.randperm(batch,device=device)] # Shuffle targets
+ lp2=lp2+(((a_l*v).sum(-1)-gj.detach())**2).mean()
+ lp2/=M
+ vl=loss_term+lp2;vec_opt.zero_grad();vl.backward()
+ torch.nn.utils.clip_grad_norm_(vec_net.parameters(),1.0);vec_opt.step()
+
+ # Compute credits per block
+ dfa_credits=[(eT@Bs[l].T).detach() for l in range(L)]
+ credits=[]
+ for l in range(L):
+ a_dfa=dfa_credits[l]
+ rms_d=(a_dfa**2).mean(-1,keepdim=True).sqrt()+1e-6
+
+ if branch_type=='dfa':
+ credits.append(a_dfa/rms_d)
+ elif branch_type=='scaled_dfa':
+ credits.append(beta_scale*a_dfa/rms_d)
+ elif branch_type=='blend_gaussian':
+ xi=torch.randn_like(a_dfa)
+ rms_xi=(xi**2).mean(-1,keepdim=True).sqrt()+1e-6
+ xi_matched=xi/rms_xi*ref_rms[l] # match to Vec RMS
+ rms_xi2=(xi_matched**2).mean(-1,keepdim=True).sqrt()+1e-6
+ credits.append(alpha*xi_matched/rms_xi2+(1-alpha)*a_dfa/rms_d)
+ else:
+ # blend with Vec (frozen or trainable or shuffled)
+ h_l=hi[l].detach();t_l=torch.full((batch,),l/L,device=device)
+ with torch.no_grad():
+ a_vec=vec_net(h_l,t_l,s).detach()
+ rms_v=(a_vec**2).mean(-1,keepdim=True).sqrt()+1e-6
+ credits.append(alpha*a_vec/rms_v+(1-alpha)*a_dfa/rms_d)
+
+ # Track update geometry
+ a_blend=credits[-1]
+ a_dfa_n=a_dfa/rms_d
+ aux_norm=(alpha*a_blend).norm().item() if branch_type!='dfa' else 0
+ dfa_norm=((1-alpha)*a_dfa_n).norm().item() if branch_type not in ['dfa','scaled_dfa'] else a_blend.norm().item()
+ epoch_aux_norms.append(aux_norm)
+ epoch_dfa_norms.append(dfa_norm)
+
+ # Update head
+ lo2=F.cross_entropy(model.out_head(model.out_ln(hL)),y)
+ head_opt.zero_grad();lo2.backward();head_opt.step()
+
+ # Update blocks
+ for l in range(L):
+ a=credits[l];rm=(a**2).mean(-1,keepdim=True).sqrt()+1e-6
+ f=model.blocks[l](hi[l].detach());ll=(f*(a/rm)).sum(-1).mean()
+ block_opts[l].zero_grad();ll.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0);block_opts[l].step()
+
+ a0=credits[0];r0=(a0**2).mean(-1,keepdim=True).sqrt()+1e-6
+ el=(model.embed(x)*(a0/r0)).sum(-1).mean()
+ embed_opt.zero_grad();el.backward();embed_opt.step()
+
+ tl+=lv.item()*batch;c+=(lo.argmax(1)==y).sum().item();t+=batch
+
+ for sch in scheds:sch.step()
+ ta=evaluate(model,test_loader,device)
+ log['test_acc'].append(ta);log['train_loss'].append(tl/t)
+
+ mean_aux=np.mean(epoch_aux_norms) if epoch_aux_norms else 0
+ mean_dfa=np.mean(epoch_dfa_norms) if epoch_dfa_norms else 1
+ aeff=mean_aux/(mean_aux+mean_dfa+1e-12)
+ log['alpha_eff'].append((epoch,aeff))
+
+ if epoch in diag_epochs:
+ cm='vec' if branch_type in ['blend_trainable','blend_shuffled'] and vec_net else 'dfa'
+ gamma,rho=compute_diagnostics(model,vec_net,Bs,test_loader,device,cm)
+ log['gamma'].append((epoch,gamma));log['rho'].append((epoch,rho))
+ if epoch<=t0+15 or epoch%20==0 or epoch==total_epochs:
+ print(f" [{branch_name}] Ep {epoch}: acc={ta:.4f}, G={gamma:.4f}, r={rho:.4f}, aeff={aeff:.3f}")
+ elif epoch%20==0 or epoch==total_epochs:
+ print(f" [{branch_name}] Ep {epoch}: acc={ta:.4f}")
+
+ return log
+
+
+def run_experiment(args):
+ device=torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ print(f"Using device: {device}")
+ os.makedirs(args.output_dir,exist_ok=True)
+ torch.manual_seed(args.seed);np.random.seed(args.seed);torch.cuda.manual_seed_all(args.seed)
+ train_loader,test_loader=get_cifar10(args.batch_size)
+ input_dim=32*32*3;L=args.num_blocks;d=args.d_hidden
+
+ # Train DFA baseline
+ print(f"\n{'='*60}\nTraining DFA baseline\n{'='*60}")
+ model_dfa=ResidualMLP(input_dim,d,10,L).to(device)
+ Bs,ckpt=train_dfa_get_checkpoint(model_dfa,train_loader,test_loader,device,args.epochs,args.t0,args.lr,args.wd)
+ print(f" Checkpoint acc: {ckpt['acc']:.4f}")
+
+ # Create frozen random Vec for reference
+ torch.manual_seed(args.seed+7777)
+ vec_frozen_ref=VectorCreditNet(d_hidden=d,s_dim=10).to(device)
+ vec_frozen_ref.eval()
+
+ # Estimate norm ratio for scaled_DFA
+ model_ref=ResidualMLP(input_dim,d,10,L).to(device)
+ model_ref.load_state_dict(ckpt['model']);model_ref.eval()
+ for p in model_ref.parameters():p.requires_grad_(False)
+ beta=estimate_blend_norm_ratio(model_ref,ckpt['Bs'],vec_frozen_ref,train_loader,device,args.alpha)
+ print(f" Estimated beta (norm ratio): {beta:.4f}")
+ del model_ref
+
+ # Define branches
+ branches=[
+ ('continue_DFA','dfa',False,None),
+ ('blend_random_frozen','blend_frozen',False,'frozen'),
+ ('blend_random_trainable','blend_trainable',True,'trainable'),
+ ('blend_shuffled_trainable','blend_shuffled',True,'shuffled'),
+ ('blend_gaussian_noise','blend_gaussian',False,'gaussian'),
+ ('scaled_DFA_norm_match','scaled_dfa',False,None),
+ ]
+
+ all_results={}
+ for bname,btype,needs_trainable_vec,vec_mode in branches:
+ print(f"\n{'='*60}\n{bname}\n{'='*60}")
+ model_b=ResidualMLP(input_dim,d,10,L).to(device)
+ model_b.load_state_dict(ckpt['model'])
+
+ # Prepare Vec
+ if btype in ['blend_frozen','blend_gaussian']:
+ torch.manual_seed(args.seed+7777) # same random Vec as reference
+ vec_b=VectorCreditNet(d_hidden=d,s_dim=10).to(device)
+ vec_b.eval()
+ for p in vec_b.parameters():p.requires_grad_(False)
+ elif needs_trainable_vec:
+ torch.manual_seed(args.seed+7777) # same init
+ vec_b=VectorCreditNet(d_hidden=d,s_dim=10).to(device)
+ else:
+ vec_b=None
+
+ log=run_branch(model_b,vec_b,ckpt['Bs'],train_loader,test_loader,device,
+ args.t0,args.epochs,btype,args.alpha,args.lr,args.lr_fb,args.wd,args.M,
+ beta_scale=beta,branch_name=bname)
+ all_results[bname]=log
+ print(f" {bname} final: {log['test_acc'][-1]:.4f}")
+
+ # Summary
+ dfa_final=all_results['continue_DFA']['test_acc'][-1]
+ dfa_acc20=all_results['continue_DFA']['test_acc'][20-args.t0-1] if len(all_results['continue_DFA']['test_acc'])>20-args.t0-1 else dfa_final
+
+ print(f"\n{'='*90}")
+ print("SUMMARY")
+ print(f"{'='*90}")
+ print(f"{'Branch':<30} {'acc@20':>7} {'final':>7} {'diff':>7} {'mG_5:15':>8} {'mr_5:15':>8} {'aeff':>7}")
+ print("-"*78)
+
+ for bname,log in all_results.items():
+ accs=log['test_acc']
+ acc20=accs[20-args.t0-1] if len(accs)>20-args.t0-1 else accs[-1]
+ final=accs[-1]
+ diff=final-dfa_final
+
+ gammas_early=[g for e,g in log['gamma'] if args.t0<e<=args.t0+15]
+ rhos_early=[r for e,r in log['rho'] if args.t0<e<=args.t0+15]
+ aeffs_early=[a for e,a in log['alpha_eff'] if args.t0<e<=args.t0+15]
+ mg=np.mean(gammas_early) if gammas_early else float('nan')
+ mr=np.mean(rhos_early) if rhos_early else float('nan')
+ mae=np.mean(aeffs_early) if aeffs_early else float('nan')
+
+ print(f"{bname:<30} {acc20:>7.4f} {final:>7.4f} {diff:>+7.4f} {mg:>8.4f} {mr:>8.4f} {mae:>7.3f}")
+
+ # Save
+ save_data={'beta':beta}
+ for bname,log in all_results.items():
+ save_data[bname]={'test_acc':log['test_acc'],'train_loss':log['train_loss'],
+ 'gamma':log['gamma'],'rho':log['rho'],'alpha_eff':log['alpha_eff']}
+ out_path=os.path.join(args.output_dir,f'dissection_t{args.t0}_s{args.seed}.json')
+ with open(out_path,'w') as f:json.dump(save_data,f,indent=2,default=float)
+ print(f"\nSaved to {out_path}")
+
+ # Judgment
+ print(f"\n{'='*60}\nJUDGMENT\n{'='*60}")
+ rf=all_results.get('blend_random_frozen',{}).get('test_acc',[-1])[-1]
+ rt=all_results.get('blend_random_trainable',{}).get('test_acc',[-1])[-1]
+ sh=all_results.get('blend_shuffled_trainable',{}).get('test_acc',[-1])[-1]
+ gn=all_results.get('blend_gaussian_noise',{}).get('test_acc',[-1])[-1]
+ sd=all_results.get('scaled_DFA_norm_match',{}).get('test_acc',[-1])[-1]
+
+ print(f" DFA={dfa_final:.4f}, rf={rf:.4f}, rt={rt:.4f}, sh={sh:.4f}, gn={gn:.4f}, sd={sd:.4f}")
+
+ if abs(rf-rt)<0.005 and abs(rf-sh)<0.005 and abs(rf-gn)<0.005:
+ if abs(rf-sd)<0.005:
+ print(" -> ALL SIMILAR including scaled_DFA: gain is primarily norm/step-size effect")
+ else:
+ print(" -> rf≈rt≈sh≈gn but ≠ scaled_DFA: gain is diversification, not just norm")
+ elif rt>rf+0.005:
+ print(" -> random_trainable > random_frozen: online learning contributes")
+ else:
+ print(" -> Mixed signal, see table for details")
+
+
+def main():
+ parser=argparse.ArgumentParser(description='Phase 10A.5: Blend Mechanism Dissection')
+ parser.add_argument('--num_blocks',type=int,default=4)
+ parser.add_argument('--d_hidden',type=int,default=256)
+ parser.add_argument('--batch_size',type=int,default=128)
+ parser.add_argument('--epochs',type=int,default=100)
+ parser.add_argument('--t0',type=int,default=5)
+ parser.add_argument('--alpha',type=float,default=0.75)
+ parser.add_argument('--lr',type=float,default=1e-3)
+ parser.add_argument('--lr_fb',type=float,default=1e-3)
+ parser.add_argument('--wd',type=float,default=0.01)
+ parser.add_argument('--M',type=int,default=4)
+ parser.add_argument('--seed',type=int,default=42)
+ parser.add_argument('--gpu',type=int,default=2)
+ parser.add_argument('--output_dir',type=str,default='results/blend_dissection')
+ args=parser.parse_args()
+ run_experiment(args)
+
+
+if __name__=='__main__':
+ main()