From ef4aed70130e2212b4ed1cb7212e2ea6c7c7adb2 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Thu, 26 Mar 2026 08:37:39 -0500 Subject: =?UTF-8?q?Add=20Phase=2010A:=20no=20prefit=20threshold=20?= =?UTF-8?q?=E2=80=94=20even=20random=20Vec=20blend=20beats=20DFA=20by=20+1?= =?UTF-8?q?.3%?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit E_prefit=0 (random Vec) + blend(0.75): 32.4% vs DFA 31.1% (+1.3%) E_prefit=15: 32.3% (+1.2%) E_prefit=60: 32.5% (+1.4%) Frozen Gamma/rho near zero at all prefit levels. The Phase 9A success was NOT from Vec learning useful credit — it was from the blend mechanism itself providing regularization/diversification over pure DFA. Co-Authored-By: Claude Opus 4.6 (1M context) --- experiments/prefit_threshold_curve.py | 360 ++++++++++++++++++++++++++++++++++ 1 file changed, 360 insertions(+) create mode 100644 experiments/prefit_threshold_curve.py (limited to 'experiments/prefit_threshold_curve.py') diff --git a/experiments/prefit_threshold_curve.py b/experiments/prefit_threshold_curve.py new file mode 100644 index 0000000..95ec713 --- /dev/null +++ b/experiments/prefit_threshold_curve.py @@ -0,0 +1,360 @@ +""" +Phase 10A: Prefit Threshold Curve. + +Quantify: how much offline prefit does Vec need before blend handoff starts helping? + +Sweep E_prefit in {0, 5, 15, 30, 60, 120} on a fixed DFA checkpoint (t0=5). +For each, measure frozen credit quality, then branch into continue_DFA vs blend handoff. +""" +import os +import sys +import json +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision +import torchvision.transforms as transforms +import copy + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.residual_mlp import ResidualMLP +from models.value_net import SinusoidalTimeEmbed +from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation, nudging_test + + +class VectorCreditNet(nn.Module): + def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3): + super().__init__() + self.ln = nn.LayerNorm(d_hidden) + self.time_embed = SinusoidalTimeEmbed(time_embed_dim) + input_dim = d_hidden + time_embed_dim + s_dim + layers = [] + for i in range(num_layers): + in_d = input_dim if i == 0 else hidden_dim + layers.append(nn.Linear(in_d, hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, d_hidden)) + self.net = nn.Sequential(*layers) + + def forward(self, h, t, s): + return self.net(torch.cat([self.ln(h), self.time_embed(t), s], dim=-1)) + + +def get_cifar10(batch_size=128): + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), + transforms.ToTensor(), transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))]) + transform_test = transforms.Compose([ + transforms.ToTensor(), transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))]) + trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) + testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) + return (DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True), + DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)) + + +def evaluate(model, test_loader, device): + model.eval(); c, t = 0, 0 + with torch.no_grad(): + for x, y in test_loader: + x = x.view(x.size(0),-1).to(device); y = y.to(device) + c += (model(x).argmax(1)==y).sum().item(); t += x.size(0) + return c/t + + +def train_dfa_with_checkpoint(model, train_loader, test_loader, device, total_epochs, t0, lr, wd): + """Train DFA, save checkpoint at t0, continue to total_epochs.""" + d = model.d_hidden; L = model.num_blocks + Bs = [torch.randn(d,10,device=device)/np.sqrt(10) for _ in range(L)] + block_opts = [optim.AdamW(b.parameters(),lr=lr,weight_decay=wd) for b in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(),lr=lr,weight_decay=wd) + head_opt = optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()),lr=lr,weight_decay=wd) + scheds = [optim.lr_scheduler.CosineAnnealingLR(o,T_max=total_epochs) for o in block_opts]+\ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt,T_max=total_epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt,T_max=total_epochs)] + ckpt_state = None + for epoch in range(1, total_epochs+1): + model.train(); tl,c,t=0,0,0 + for x,y in train_loader: + x=x.view(x.size(0),-1).to(device);y=y.to(device);b=x.size(0) + with torch.no_grad(): + lo,hi=model(x,return_hidden=True);lv=F.cross_entropy(lo,y) + eT=lo.softmax(-1);eT[torch.arange(b),y]-=1 + hL=hi[-1].detach() + lo2=F.cross_entropy(model.out_head(model.out_ln(hL)),y) + head_opt.zero_grad();lo2.backward();head_opt.step() + for l in range(L): + a=(eT@Bs[l].T).detach();rm=(a**2).mean(-1,keepdim=True).sqrt()+1e-6 + f=model.blocks[l](hi[l].detach());ll=(f*(a/rm)).sum(-1).mean() + block_opts[l].zero_grad();ll.backward() + torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0);block_opts[l].step() + a0=(eT@Bs[0].T).detach();r0=(a0**2).mean(-1,keepdim=True).sqrt()+1e-6 + el=(model.embed(x)*(a0/r0)).sum(-1).mean() + embed_opt.zero_grad();el.backward();embed_opt.step() + tl+=lv.item()*b;c+=(lo.argmax(1)==y).sum().item();t+=b + for s in scheds:s.step() + if epoch == t0: + acc = evaluate(model, test_loader, device) + ckpt_state = {'model': copy.deepcopy(model.state_dict()), 'Bs': [B.clone() for B in Bs], 'acc': acc} + print(f" [DFA] Checkpoint at epoch {t0}: acc={acc:.4f}") + if epoch % 20 == 0: + acc = evaluate(model, test_loader, device) + print(f" [DFA] Epoch {epoch}: acc={acc:.4f}") + final_acc = evaluate(model, test_loader, device) + return Bs, ckpt_state, final_acc + + +def offline_fit_vec(model, train_loader, device, epochs, lr_fb=1e-3, M=4): + d = model.d_hidden; L = model.num_blocks; eps = 1e-3 + vec_net = VectorCreditNet(d_hidden=d, s_dim=10).to(device) + vec_opt = optim.Adam(vec_net.parameters(), lr=lr_fb) + model.eval() + for p in model.parameters(): p.requires_grad_(False) + for ep in range(1, epochs+1): + vec_net.train() + for x,y in train_loader: + x=x.view(x.size(0),-1).to(device);y=y.to(device);batch=x.size(0) + with torch.no_grad(): + lo,hi=model(x,return_hidden=True) + eT=lo.softmax(-1);eT[torch.arange(batch),y]-=1;s=eT.detach() + hL=hi[-1].detach();t_L=torch.ones(batch,device=device) + a_term=vec_net(hL,t_L,s) + hL_req=hL.clone().requires_grad_(True) + ce=F.cross_entropy(model.out_head(model.out_ln(hL_req)),y,reduction='sum') + dL=torch.autograd.grad(ce,hL_req,create_graph=False)[0].detach() + loss_term=((a_term-dL)**2).sum(-1).mean() + l=np.random.randint(0,L);h_l=hi[l].detach();t_l=torch.full((batch,),l/L,device=device) + a_l=vec_net(h_l,t_l,s);loss_proj=torch.tensor(0.0,device=device) + for _ in range(M): + v=torch.randn_like(h_l);v=v/(v.norm(-1,keepdim=True)+1e-8) + with torch.no_grad(): + lp=F.cross_entropy(model.forward_from_layer(h_l+eps*v,l),y,reduction='none') + lm=F.cross_entropy(model.forward_from_layer(h_l-eps*v,l),y,reduction='none') + gj=(lp-lm)/(2*eps) + loss_proj=loss_proj+(((a_l*v).sum(-1)-gj.detach())**2).mean() + loss_proj/=M + vloss=loss_term+loss_proj + vec_opt.zero_grad();vloss.backward() + torch.nn.utils.clip_grad_norm_(vec_net.parameters(),1.0);vec_opt.step() + for p in model.parameters(): p.requires_grad_(True) + return vec_net + + +def eval_frozen_credit_quality(model, vec_net, test_loader, device): + """Evaluate Vec credit quality on frozen model.""" + model.eval(); vec_net.eval(); L = model.num_blocks + for x,y in test_loader: + x=x.view(x.size(0),-1).to(device);y=y.to(device);break + batch=x.size(0) + # BP grads + for p in model.parameters(): p.requires_grad_(True) + model.zero_grad() + lo,hbp=model(x,return_hidden=True) + for l in range(L+1): hbp[l].retain_grad() + F.cross_entropy(lo,y).backward() + bp={l:hbp[l].grad.detach().clone() for l in range(L+1)} + for p in model.parameters(): p.requires_grad_(False) + with torch.no_grad(): + lo2,hi=model(x,return_hidden=True) + eT=lo2.softmax(-1);eT[torch.arange(batch),y]-=1;s=eT.detach() + gammas,rhos,nudges=[],[],[] + for l in range(L): + h_l=hi[l].detach();t_l=torch.full((batch,),l/L,device=device) + a_l=vec_net(h_l,t_l,s).detach() + gammas.append(cosine_similarity_batch(a_l,bp[l])) + def make_fwd(sl): + def f(h): + with torch.no_grad(): + c=h + for i in range(sl,L):c=c+model.blocks[i](c) + return F.cross_entropy(model.out_head(model.out_ln(c)),y,reduction='none') + return f + rhos.append(perturbation_correlation(h_l,a_l,make_fwd(l),epsilon=1e-3,M=16)) + nudges.append(nudging_test(h_l,a_l,make_fwd(l),eta=0.003)) + return float(np.mean(gammas)),float(np.mean(rhos)),float(np.mean(nudges)) + + +def continue_from_checkpoint(model, vec_net, Bs, train_loader, test_loader, device, + t0, total_epochs, blend_alpha, lr, lr_fb, wd, M, branch_name): + """Continue training from checkpoint with blend credit.""" + d=model.d_hidden;L=model.num_blocks;eps=1e-3 + block_opts=[optim.AdamW(b.parameters(),lr=lr,weight_decay=wd) for b in model.blocks] + embed_opt=optim.AdamW(model.embed.parameters(),lr=lr,weight_decay=wd) + head_opt=optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()),lr=lr,weight_decay=wd) + vec_opt=optim.Adam(vec_net.parameters(),lr=lr_fb) if blend_alpha>0 else None + scheds=[optim.lr_scheduler.CosineAnnealingLR(o,T_max=total_epochs) for o in block_opts]+\ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt,T_max=total_epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt,T_max=total_epochs)] + for _ in range(t0): + for s in scheds:s.step() + log={'test_acc':[],'train_loss':[]} + for epoch in range(t0+1,total_epochs+1): + model.train() + if vec_net:vec_net.train() + tl,c,t=0,0,0 + for x,y in train_loader: + x=x.view(x.size(0),-1).to(device);y=y.to(device);batch=x.size(0) + with torch.no_grad(): + lo,hi=model(x,return_hidden=True);lv=F.cross_entropy(lo,y) + eT=lo.softmax(-1);eT[torch.arange(batch),y]-=1;s=eT.detach() + hL=hi[-1].detach() + if blend_alpha>0 and vec_opt: + t_L=torch.ones(batch,device=device) + a_term=vec_net(hL,t_L,s) + hL_req=hL.clone().requires_grad_(True) + ce=F.cross_entropy(model.out_head(model.out_ln(hL_req)),y,reduction='sum') + dL=torch.autograd.grad(ce,hL_req,create_graph=False)[0].detach() + loss_t=((a_term-dL)**2).sum(-1).mean() + lt=np.random.randint(0,L);h_l=hi[lt].detach();t_l=torch.full((batch,),lt/L,device=device) + a_l=vec_net(h_l,t_l,s);lp2=torch.tensor(0.0,device=device) + for _ in range(M): + v=torch.randn_like(h_l);v=v/(v.norm(-1,keepdim=True)+1e-8) + with torch.no_grad(): + lp=F.cross_entropy(model.forward_from_layer(h_l+eps*v,lt),y,reduction='none') + lm=F.cross_entropy(model.forward_from_layer(h_l-eps*v,lt),y,reduction='none') + gj=(lp-lm)/(2*eps) + lp2=lp2+(((a_l*v).sum(-1)-gj.detach())**2).mean() + lp2/=M + vl=loss_t+lp2;vec_opt.zero_grad();vl.backward() + torch.nn.utils.clip_grad_norm_(vec_net.parameters(),1.0);vec_opt.step() + with torch.no_grad(): + vc=[vec_net(hi[l].detach(),torch.full((batch,),l/L,device=device),s).detach() for l in range(L)] if blend_alpha>0 else [None]*L + dc=[(eT@Bs[l].T).detach() for l in range(L)] + credits=[] + for l in range(L): + if blend_alpha<=0:credits.append(dc[l]) + elif blend_alpha>=1:credits.append(vc[l]) + else: + rv=(vc[l]**2).mean(-1,keepdim=True).sqrt()+1e-6;rd=(dc[l]**2).mean(-1,keepdim=True).sqrt()+1e-6 + credits.append(blend_alpha*vc[l]/rv+(1-blend_alpha)*dc[l]/rd) + lo2=F.cross_entropy(model.out_head(model.out_ln(hL)),y) + head_opt.zero_grad();lo2.backward();head_opt.step() + for l in range(L): + a=credits[l];rm=(a**2).mean(-1,keepdim=True).sqrt()+1e-6 + f=model.blocks[l](hi[l].detach());ll=(f*(a/rm)).sum(-1).mean() + block_opts[l].zero_grad();ll.backward() + torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0);block_opts[l].step() + a0=credits[0];r0=(a0**2).mean(-1,keepdim=True).sqrt()+1e-6 + el=(model.embed(x)*(a0/r0)).sum(-1).mean() + embed_opt.zero_grad();el.backward();embed_opt.step() + tl+=lv.item()*batch;c+=(lo.argmax(1)==y).sum().item();t+=batch + for s in scheds:s.step() + ta=evaluate(model,test_loader,device);log['test_acc'].append(ta);log['train_loss'].append(tl/t) + if epoch%20==0 or epoch==t0+1 or epoch==total_epochs: + print(f" [{branch_name}] Ep {epoch}: acc={ta:.4f}") + return log + + +def run_experiment(args): + device=torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + os.makedirs(args.output_dir,exist_ok=True) + torch.manual_seed(args.seed);np.random.seed(args.seed);torch.cuda.manual_seed_all(args.seed) + train_loader,test_loader=get_cifar10(args.batch_size) + input_dim=32*32*3;L=args.num_blocks;d=args.d_hidden + + # Train DFA and get checkpoint + print(f"\n{'='*60}\nTraining DFA baseline with checkpoint at t0={args.t0}\n{'='*60}") + model_dfa=ResidualMLP(input_dim,d,10,L).to(device) + Bs,ckpt,dfa_final=train_dfa_with_checkpoint(model_dfa,train_loader,test_loader,device, + args.epochs,args.t0,args.lr,args.wd) + print(f" DFA final: {dfa_final:.4f}") + + # continue_DFA baseline (from checkpoint) + print(f"\n{'='*60}\ncontinue_DFA from t0={args.t0}\n{'='*60}") + model_base=ResidualMLP(input_dim,d,10,L).to(device) + model_base.load_state_dict(ckpt['model']) + log_dfa=continue_from_checkpoint(model_base,None,ckpt['Bs'],train_loader,test_loader,device, + args.t0,args.epochs,0.0,args.lr,args.lr_fb,args.wd,args.M,'continue_DFA') + dfa_cont_final=log_dfa['test_acc'][-1] + dfa_cont_acc20=log_dfa['test_acc'][20-args.t0-1] if len(log_dfa['test_acc'])>20-args.t0-1 else log_dfa['test_acc'][-1] + print(f" continue_DFA final: {dfa_cont_final:.4f}") + + # Sweep E_prefit + all_results=[] + for E in args.prefit_epochs: + print(f"\n{'='*60}\nE_prefit={E}\n{'='*60}") + model_frozen=ResidualMLP(input_dim,d,10,L).to(device) + model_frozen.load_state_dict(ckpt['model']) + model_frozen.eval() + for p in model_frozen.parameters():p.requires_grad_(False) + + # Offline fit Vec + torch.manual_seed(args.seed+E*100+4000) + if E>0: + print(f" Offline fitting Vec for {E} epochs...") + vec_net=offline_fit_vec(model_frozen,train_loader,device,epochs=E,lr_fb=args.lr_fb,M=args.M) + else: + vec_net=VectorCreditNet(d_hidden=d,s_dim=10).to(device) + print(f" E_prefit=0: random Vec") + + # Evaluate frozen credit quality + gamma_f,rho_f,nudge_f=eval_frozen_credit_quality(model_frozen,vec_net,test_loader,device) + print(f" Frozen quality: Gamma={gamma_f:.4f}, rho={rho_f:.4f}, nudge={nudge_f:.6f}") + + # Branch: blend handoff + for branch_name,alpha in [('blend_075',0.75)]: + print(f"\n --- {branch_name} (E_prefit={E}) ---") + model_branch=ResidualMLP(input_dim,d,10,L).to(device) + model_branch.load_state_dict(ckpt['model']) + vec_branch=copy.deepcopy(vec_net) + for p in model_branch.parameters():p.requires_grad_(True) + log=continue_from_checkpoint(model_branch,vec_branch,ckpt['Bs'],train_loader,test_loader,device, + args.t0,args.epochs,alpha,args.lr,args.lr_fb,args.wd,args.M, + f'E{E}_{branch_name}') + final=log['test_acc'][-1] + acc20=log['test_acc'][20-args.t0-1] if len(log['test_acc'])>20-args.t0-1 else log['test_acc'][-1] + diff_final=final-dfa_cont_final + diff_acc20=acc20-dfa_cont_acc20 + + r={'E_prefit':E,'branch':branch_name,'gamma_frozen':gamma_f,'rho_frozen':rho_f, + 'nudge_frozen':nudge_f,'final_acc':final,'acc_at_20':acc20, + 'diff_final':diff_final,'diff_acc20':diff_acc20, + 'test_acc':log['test_acc']} + all_results.append(r) + print(f" E={E} {branch_name}: final={final:.4f} (diff={diff_final:+.4f}), acc@20={acc20:.4f} (diff={diff_acc20:+.4f})") + + # Summary + print(f"\n{'='*80}") + print("SUMMARY") + print(f"{'='*80}") + print(f"continue_DFA: final={dfa_cont_final:.4f}, acc@20={dfa_cont_acc20:.4f}") + print(f"\n{'E_prefit':>8} {'Gamma_f':>8} {'rho_f':>8} {'final':>8} {'diff':>8} {'acc@20':>8} {'d@20':>8}") + print("-"*62) + for r in all_results: + print(f"{r['E_prefit']:>8} {r['gamma_frozen']:>8.4f} {r['rho_frozen']:>8.4f} " + f"{r['final_acc']:>8.4f} {r['diff_final']:>+8.4f} {r['acc_at_20']:>8.4f} {r['diff_acc20']:>+8.4f}") + + # Save + save_data={'dfa_cont_final':float(dfa_cont_final),'dfa_cont_acc20':float(dfa_cont_acc20), + 'results':[{k:v for k,v in r.items()} for r in all_results]} + out_path=os.path.join(args.output_dir,f'prefit_curve_t{args.t0}_s{args.seed}.json') + with open(out_path,'w') as f:json.dump(save_data,f,indent=2,default=float) + print(f"\nSaved to {out_path}") + + +def main(): + parser=argparse.ArgumentParser(description='Phase 10A: Prefit Threshold Curve') + parser.add_argument('--num_blocks',type=int,default=4) + parser.add_argument('--d_hidden',type=int,default=256) + parser.add_argument('--batch_size',type=int,default=128) + parser.add_argument('--epochs',type=int,default=100) + parser.add_argument('--t0',type=int,default=5) + parser.add_argument('--prefit_epochs',type=int,nargs='+',default=[0,15,60]) + parser.add_argument('--lr',type=float,default=1e-3) + parser.add_argument('--lr_fb',type=float,default=1e-3) + parser.add_argument('--wd',type=float,default=0.01) + parser.add_argument('--M',type=int,default=4) + parser.add_argument('--seed',type=int,default=42) + parser.add_argument('--gpu',type=int,default=2) + parser.add_argument('--output_dir',type=str,default='results/prefit_threshold') + args=parser.parse_args() + run_experiment(args) + + +if __name__=='__main__': + main() -- cgit v1.2.3