""" Phase 10A: Prefit Threshold Curve. Quantify: how much offline prefit does Vec need before blend handoff starts helping? Sweep E_prefit in {0, 5, 15, 30, 60, 120} on a fixed DFA checkpoint (t0=5). For each, measure frozen credit quality, then branch into continue_DFA vs blend handoff. """ import os import sys import json import argparse import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader import torchvision import torchvision.transforms as transforms import copy sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.residual_mlp import ResidualMLP from models.value_net import SinusoidalTimeEmbed from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation, nudging_test class VectorCreditNet(nn.Module): def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3): super().__init__() self.ln = nn.LayerNorm(d_hidden) self.time_embed = SinusoidalTimeEmbed(time_embed_dim) input_dim = d_hidden + time_embed_dim + s_dim layers = [] for i in range(num_layers): in_d = input_dim if i == 0 else hidden_dim layers.append(nn.Linear(in_d, hidden_dim)) layers.append(nn.GELU()) layers.append(nn.Linear(hidden_dim, d_hidden)) self.net = nn.Sequential(*layers) def forward(self, h, t, s): return self.net(torch.cat([self.ln(h), self.time_embed(t), s], dim=-1)) def get_cifar10(batch_size=128): transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) return (DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True), DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)) def evaluate(model, test_loader, device): model.eval(); c, t = 0, 0 with torch.no_grad(): for x, y in test_loader: x = x.view(x.size(0),-1).to(device); y = y.to(device) c += (model(x).argmax(1)==y).sum().item(); t += x.size(0) return c/t def train_dfa_with_checkpoint(model, train_loader, test_loader, device, total_epochs, t0, lr, wd): """Train DFA, save checkpoint at t0, continue to total_epochs.""" d = model.d_hidden; L = model.num_blocks Bs = [torch.randn(d,10,device=device)/np.sqrt(10) for _ in range(L)] block_opts = [optim.AdamW(b.parameters(),lr=lr,weight_decay=wd) for b in model.blocks] embed_opt = optim.AdamW(model.embed.parameters(),lr=lr,weight_decay=wd) head_opt = optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()),lr=lr,weight_decay=wd) scheds = [optim.lr_scheduler.CosineAnnealingLR(o,T_max=total_epochs) for o in block_opts]+\ [optim.lr_scheduler.CosineAnnealingLR(embed_opt,T_max=total_epochs), optim.lr_scheduler.CosineAnnealingLR(head_opt,T_max=total_epochs)] ckpt_state = None for epoch in range(1, total_epochs+1): model.train(); tl,c,t=0,0,0 for x,y in train_loader: x=x.view(x.size(0),-1).to(device);y=y.to(device);b=x.size(0) with torch.no_grad(): lo,hi=model(x,return_hidden=True);lv=F.cross_entropy(lo,y) eT=lo.softmax(-1);eT[torch.arange(b),y]-=1 hL=hi[-1].detach() lo2=F.cross_entropy(model.out_head(model.out_ln(hL)),y) head_opt.zero_grad();lo2.backward();head_opt.step() for l in range(L): a=(eT@Bs[l].T).detach();rm=(a**2).mean(-1,keepdim=True).sqrt()+1e-6 f=model.blocks[l](hi[l].detach());ll=(f*(a/rm)).sum(-1).mean() block_opts[l].zero_grad();ll.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0);block_opts[l].step() a0=(eT@Bs[0].T).detach();r0=(a0**2).mean(-1,keepdim=True).sqrt()+1e-6 el=(model.embed(x)*(a0/r0)).sum(-1).mean() embed_opt.zero_grad();el.backward();embed_opt.step() tl+=lv.item()*b;c+=(lo.argmax(1)==y).sum().item();t+=b for s in scheds:s.step() if epoch == t0: acc = evaluate(model, test_loader, device) ckpt_state = {'model': copy.deepcopy(model.state_dict()), 'Bs': [B.clone() for B in Bs], 'acc': acc} print(f" [DFA] Checkpoint at epoch {t0}: acc={acc:.4f}") if epoch % 20 == 0: acc = evaluate(model, test_loader, device) print(f" [DFA] Epoch {epoch}: acc={acc:.4f}") final_acc = evaluate(model, test_loader, device) return Bs, ckpt_state, final_acc def offline_fit_vec(model, train_loader, device, epochs, lr_fb=1e-3, M=4): d = model.d_hidden; L = model.num_blocks; eps = 1e-3 vec_net = VectorCreditNet(d_hidden=d, s_dim=10).to(device) vec_opt = optim.Adam(vec_net.parameters(), lr=lr_fb) model.eval() for p in model.parameters(): p.requires_grad_(False) for ep in range(1, epochs+1): vec_net.train() for x,y in train_loader: x=x.view(x.size(0),-1).to(device);y=y.to(device);batch=x.size(0) with torch.no_grad(): lo,hi=model(x,return_hidden=True) eT=lo.softmax(-1);eT[torch.arange(batch),y]-=1;s=eT.detach() hL=hi[-1].detach();t_L=torch.ones(batch,device=device) a_term=vec_net(hL,t_L,s) hL_req=hL.clone().requires_grad_(True) ce=F.cross_entropy(model.out_head(model.out_ln(hL_req)),y,reduction='sum') dL=torch.autograd.grad(ce,hL_req,create_graph=False)[0].detach() loss_term=((a_term-dL)**2).sum(-1).mean() l=np.random.randint(0,L);h_l=hi[l].detach();t_l=torch.full((batch,),l/L,device=device) a_l=vec_net(h_l,t_l,s);loss_proj=torch.tensor(0.0,device=device) for _ in range(M): v=torch.randn_like(h_l);v=v/(v.norm(-1,keepdim=True)+1e-8) with torch.no_grad(): lp=F.cross_entropy(model.forward_from_layer(h_l+eps*v,l),y,reduction='none') lm=F.cross_entropy(model.forward_from_layer(h_l-eps*v,l),y,reduction='none') gj=(lp-lm)/(2*eps) loss_proj=loss_proj+(((a_l*v).sum(-1)-gj.detach())**2).mean() loss_proj/=M vloss=loss_term+loss_proj vec_opt.zero_grad();vloss.backward() torch.nn.utils.clip_grad_norm_(vec_net.parameters(),1.0);vec_opt.step() for p in model.parameters(): p.requires_grad_(True) return vec_net def eval_frozen_credit_quality(model, vec_net, test_loader, device): """Evaluate Vec credit quality on frozen model.""" model.eval(); vec_net.eval(); L = model.num_blocks for x,y in test_loader: x=x.view(x.size(0),-1).to(device);y=y.to(device);break batch=x.size(0) # BP grads for p in model.parameters(): p.requires_grad_(True) model.zero_grad() lo,hbp=model(x,return_hidden=True) for l in range(L+1): hbp[l].retain_grad() F.cross_entropy(lo,y).backward() bp={l:hbp[l].grad.detach().clone() for l in range(L+1)} for p in model.parameters(): p.requires_grad_(False) with torch.no_grad(): lo2,hi=model(x,return_hidden=True) eT=lo2.softmax(-1);eT[torch.arange(batch),y]-=1;s=eT.detach() gammas,rhos,nudges=[],[],[] for l in range(L): h_l=hi[l].detach();t_l=torch.full((batch,),l/L,device=device) a_l=vec_net(h_l,t_l,s).detach() gammas.append(cosine_similarity_batch(a_l,bp[l])) def make_fwd(sl): def f(h): with torch.no_grad(): c=h for i in range(sl,L):c=c+model.blocks[i](c) return F.cross_entropy(model.out_head(model.out_ln(c)),y,reduction='none') return f rhos.append(perturbation_correlation(h_l,a_l,make_fwd(l),epsilon=1e-3,M=16)) nudges.append(nudging_test(h_l,a_l,make_fwd(l),eta=0.003)) return float(np.mean(gammas)),float(np.mean(rhos)),float(np.mean(nudges)) def continue_from_checkpoint(model, vec_net, Bs, train_loader, test_loader, device, t0, total_epochs, blend_alpha, lr, lr_fb, wd, M, branch_name): """Continue training from checkpoint with blend credit.""" d=model.d_hidden;L=model.num_blocks;eps=1e-3 block_opts=[optim.AdamW(b.parameters(),lr=lr,weight_decay=wd) for b in model.blocks] embed_opt=optim.AdamW(model.embed.parameters(),lr=lr,weight_decay=wd) head_opt=optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()),lr=lr,weight_decay=wd) vec_opt=optim.Adam(vec_net.parameters(),lr=lr_fb) if blend_alpha>0 else None scheds=[optim.lr_scheduler.CosineAnnealingLR(o,T_max=total_epochs) for o in block_opts]+\ [optim.lr_scheduler.CosineAnnealingLR(embed_opt,T_max=total_epochs), optim.lr_scheduler.CosineAnnealingLR(head_opt,T_max=total_epochs)] for _ in range(t0): for s in scheds:s.step() log={'test_acc':[],'train_loss':[]} for epoch in range(t0+1,total_epochs+1): model.train() if vec_net:vec_net.train() tl,c,t=0,0,0 for x,y in train_loader: x=x.view(x.size(0),-1).to(device);y=y.to(device);batch=x.size(0) with torch.no_grad(): lo,hi=model(x,return_hidden=True);lv=F.cross_entropy(lo,y) eT=lo.softmax(-1);eT[torch.arange(batch),y]-=1;s=eT.detach() hL=hi[-1].detach() if blend_alpha>0 and vec_opt: t_L=torch.ones(batch,device=device) a_term=vec_net(hL,t_L,s) hL_req=hL.clone().requires_grad_(True) ce=F.cross_entropy(model.out_head(model.out_ln(hL_req)),y,reduction='sum') dL=torch.autograd.grad(ce,hL_req,create_graph=False)[0].detach() loss_t=((a_term-dL)**2).sum(-1).mean() lt=np.random.randint(0,L);h_l=hi[lt].detach();t_l=torch.full((batch,),lt/L,device=device) a_l=vec_net(h_l,t_l,s);lp2=torch.tensor(0.0,device=device) for _ in range(M): v=torch.randn_like(h_l);v=v/(v.norm(-1,keepdim=True)+1e-8) with torch.no_grad(): lp=F.cross_entropy(model.forward_from_layer(h_l+eps*v,lt),y,reduction='none') lm=F.cross_entropy(model.forward_from_layer(h_l-eps*v,lt),y,reduction='none') gj=(lp-lm)/(2*eps) lp2=lp2+(((a_l*v).sum(-1)-gj.detach())**2).mean() lp2/=M vl=loss_t+lp2;vec_opt.zero_grad();vl.backward() torch.nn.utils.clip_grad_norm_(vec_net.parameters(),1.0);vec_opt.step() with torch.no_grad(): vc=[vec_net(hi[l].detach(),torch.full((batch,),l/L,device=device),s).detach() for l in range(L)] if blend_alpha>0 else [None]*L dc=[(eT@Bs[l].T).detach() for l in range(L)] credits=[] for l in range(L): if blend_alpha<=0:credits.append(dc[l]) elif blend_alpha>=1:credits.append(vc[l]) else: rv=(vc[l]**2).mean(-1,keepdim=True).sqrt()+1e-6;rd=(dc[l]**2).mean(-1,keepdim=True).sqrt()+1e-6 credits.append(blend_alpha*vc[l]/rv+(1-blend_alpha)*dc[l]/rd) lo2=F.cross_entropy(model.out_head(model.out_ln(hL)),y) head_opt.zero_grad();lo2.backward();head_opt.step() for l in range(L): a=credits[l];rm=(a**2).mean(-1,keepdim=True).sqrt()+1e-6 f=model.blocks[l](hi[l].detach());ll=(f*(a/rm)).sum(-1).mean() block_opts[l].zero_grad();ll.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0);block_opts[l].step() a0=credits[0];r0=(a0**2).mean(-1,keepdim=True).sqrt()+1e-6 el=(model.embed(x)*(a0/r0)).sum(-1).mean() embed_opt.zero_grad();el.backward();embed_opt.step() tl+=lv.item()*batch;c+=(lo.argmax(1)==y).sum().item();t+=batch for s in scheds:s.step() ta=evaluate(model,test_loader,device);log['test_acc'].append(ta);log['train_loss'].append(tl/t) if epoch%20==0 or epoch==t0+1 or epoch==total_epochs: print(f" [{branch_name}] Ep {epoch}: acc={ta:.4f}") return log def run_experiment(args): device=torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") os.makedirs(args.output_dir,exist_ok=True) torch.manual_seed(args.seed);np.random.seed(args.seed);torch.cuda.manual_seed_all(args.seed) train_loader,test_loader=get_cifar10(args.batch_size) input_dim=32*32*3;L=args.num_blocks;d=args.d_hidden # Train DFA and get checkpoint print(f"\n{'='*60}\nTraining DFA baseline with checkpoint at t0={args.t0}\n{'='*60}") model_dfa=ResidualMLP(input_dim,d,10,L).to(device) Bs,ckpt,dfa_final=train_dfa_with_checkpoint(model_dfa,train_loader,test_loader,device, args.epochs,args.t0,args.lr,args.wd) print(f" DFA final: {dfa_final:.4f}") # continue_DFA baseline (from checkpoint) print(f"\n{'='*60}\ncontinue_DFA from t0={args.t0}\n{'='*60}") model_base=ResidualMLP(input_dim,d,10,L).to(device) model_base.load_state_dict(ckpt['model']) log_dfa=continue_from_checkpoint(model_base,None,ckpt['Bs'],train_loader,test_loader,device, args.t0,args.epochs,0.0,args.lr,args.lr_fb,args.wd,args.M,'continue_DFA') dfa_cont_final=log_dfa['test_acc'][-1] dfa_cont_acc20=log_dfa['test_acc'][20-args.t0-1] if len(log_dfa['test_acc'])>20-args.t0-1 else log_dfa['test_acc'][-1] print(f" continue_DFA final: {dfa_cont_final:.4f}") # Sweep E_prefit all_results=[] for E in args.prefit_epochs: print(f"\n{'='*60}\nE_prefit={E}\n{'='*60}") model_frozen=ResidualMLP(input_dim,d,10,L).to(device) model_frozen.load_state_dict(ckpt['model']) model_frozen.eval() for p in model_frozen.parameters():p.requires_grad_(False) # Offline fit Vec torch.manual_seed(args.seed+E*100+4000) if E>0: print(f" Offline fitting Vec for {E} epochs...") vec_net=offline_fit_vec(model_frozen,train_loader,device,epochs=E,lr_fb=args.lr_fb,M=args.M) else: vec_net=VectorCreditNet(d_hidden=d,s_dim=10).to(device) print(f" E_prefit=0: random Vec") # Evaluate frozen credit quality gamma_f,rho_f,nudge_f=eval_frozen_credit_quality(model_frozen,vec_net,test_loader,device) print(f" Frozen quality: Gamma={gamma_f:.4f}, rho={rho_f:.4f}, nudge={nudge_f:.6f}") # Branch: blend handoff for branch_name,alpha in [('blend_075',0.75)]: print(f"\n --- {branch_name} (E_prefit={E}) ---") model_branch=ResidualMLP(input_dim,d,10,L).to(device) model_branch.load_state_dict(ckpt['model']) vec_branch=copy.deepcopy(vec_net) for p in model_branch.parameters():p.requires_grad_(True) log=continue_from_checkpoint(model_branch,vec_branch,ckpt['Bs'],train_loader,test_loader,device, args.t0,args.epochs,alpha,args.lr,args.lr_fb,args.wd,args.M, f'E{E}_{branch_name}') final=log['test_acc'][-1] acc20=log['test_acc'][20-args.t0-1] if len(log['test_acc'])>20-args.t0-1 else log['test_acc'][-1] diff_final=final-dfa_cont_final diff_acc20=acc20-dfa_cont_acc20 r={'E_prefit':E,'branch':branch_name,'gamma_frozen':gamma_f,'rho_frozen':rho_f, 'nudge_frozen':nudge_f,'final_acc':final,'acc_at_20':acc20, 'diff_final':diff_final,'diff_acc20':diff_acc20, 'test_acc':log['test_acc']} all_results.append(r) print(f" E={E} {branch_name}: final={final:.4f} (diff={diff_final:+.4f}), acc@20={acc20:.4f} (diff={diff_acc20:+.4f})") # Summary print(f"\n{'='*80}") print("SUMMARY") print(f"{'='*80}") print(f"continue_DFA: final={dfa_cont_final:.4f}, acc@20={dfa_cont_acc20:.4f}") print(f"\n{'E_prefit':>8} {'Gamma_f':>8} {'rho_f':>8} {'final':>8} {'diff':>8} {'acc@20':>8} {'d@20':>8}") print("-"*62) for r in all_results: print(f"{r['E_prefit']:>8} {r['gamma_frozen']:>8.4f} {r['rho_frozen']:>8.4f} " f"{r['final_acc']:>8.4f} {r['diff_final']:>+8.4f} {r['acc_at_20']:>8.4f} {r['diff_acc20']:>+8.4f}") # Save save_data={'dfa_cont_final':float(dfa_cont_final),'dfa_cont_acc20':float(dfa_cont_acc20), 'results':[{k:v for k,v in r.items()} for r in all_results]} out_path=os.path.join(args.output_dir,f'prefit_curve_t{args.t0}_s{args.seed}.json') with open(out_path,'w') as f:json.dump(save_data,f,indent=2,default=float) print(f"\nSaved to {out_path}") def main(): parser=argparse.ArgumentParser(description='Phase 10A: Prefit Threshold Curve') parser.add_argument('--num_blocks',type=int,default=4) parser.add_argument('--d_hidden',type=int,default=256) parser.add_argument('--batch_size',type=int,default=128) parser.add_argument('--epochs',type=int,default=100) parser.add_argument('--t0',type=int,default=5) parser.add_argument('--prefit_epochs',type=int,nargs='+',default=[0,15,60]) parser.add_argument('--lr',type=float,default=1e-3) parser.add_argument('--lr_fb',type=float,default=1e-3) parser.add_argument('--wd',type=float,default=0.01) parser.add_argument('--M',type=int,default=4) parser.add_argument('--seed',type=int,default=42) parser.add_argument('--gpu',type=int,default=2) parser.add_argument('--output_dir',type=str,default='results/prefit_threshold') args=parser.parse_args() run_experiment(args) if __name__=='__main__': main()