summaryrefslogtreecommitdiff
path: root/experiments/prefit_threshold_curve.py
diff options
context:
space:
mode:
Diffstat (limited to 'experiments/prefit_threshold_curve.py')
-rw-r--r--experiments/prefit_threshold_curve.py360
1 files changed, 360 insertions, 0 deletions
diff --git a/experiments/prefit_threshold_curve.py b/experiments/prefit_threshold_curve.py
new file mode 100644
index 0000000..95ec713
--- /dev/null
+++ b/experiments/prefit_threshold_curve.py
@@ -0,0 +1,360 @@
+"""
+Phase 10A: Prefit Threshold Curve.
+
+Quantify: how much offline prefit does Vec need before blend handoff starts helping?
+
+Sweep E_prefit in {0, 5, 15, 30, 60, 120} on a fixed DFA checkpoint (t0=5).
+For each, measure frozen credit quality, then branch into continue_DFA vs blend handoff.
+"""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+import copy
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.residual_mlp import ResidualMLP
+from models.value_net import SinusoidalTimeEmbed
+from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation, nudging_test
+
+
+class VectorCreditNet(nn.Module):
+ def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.time_embed = SinusoidalTimeEmbed(time_embed_dim)
+ input_dim = d_hidden + time_embed_dim + s_dim
+ layers = []
+ for i in range(num_layers):
+ in_d = input_dim if i == 0 else hidden_dim
+ layers.append(nn.Linear(in_d, hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, d_hidden))
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, h, t, s):
+ return self.net(torch.cat([self.ln(h), self.time_embed(t), s], dim=-1))
+
+
+def get_cifar10(batch_size=128):
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(), transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))])
+ transform_test = transforms.Compose([
+ transforms.ToTensor(), transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))])
+ trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
+ testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
+ return (DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True),
+ DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True))
+
+
+def evaluate(model, test_loader, device):
+ model.eval(); c, t = 0, 0
+ with torch.no_grad():
+ for x, y in test_loader:
+ x = x.view(x.size(0),-1).to(device); y = y.to(device)
+ c += (model(x).argmax(1)==y).sum().item(); t += x.size(0)
+ return c/t
+
+
+def train_dfa_with_checkpoint(model, train_loader, test_loader, device, total_epochs, t0, lr, wd):
+ """Train DFA, save checkpoint at t0, continue to total_epochs."""
+ d = model.d_hidden; L = model.num_blocks
+ Bs = [torch.randn(d,10,device=device)/np.sqrt(10) for _ in range(L)]
+ block_opts = [optim.AdamW(b.parameters(),lr=lr,weight_decay=wd) for b in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(),lr=lr,weight_decay=wd)
+ head_opt = optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()),lr=lr,weight_decay=wd)
+ scheds = [optim.lr_scheduler.CosineAnnealingLR(o,T_max=total_epochs) for o in block_opts]+\
+ [optim.lr_scheduler.CosineAnnealingLR(embed_opt,T_max=total_epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt,T_max=total_epochs)]
+ ckpt_state = None
+ for epoch in range(1, total_epochs+1):
+ model.train(); tl,c,t=0,0,0
+ for x,y in train_loader:
+ x=x.view(x.size(0),-1).to(device);y=y.to(device);b=x.size(0)
+ with torch.no_grad():
+ lo,hi=model(x,return_hidden=True);lv=F.cross_entropy(lo,y)
+ eT=lo.softmax(-1);eT[torch.arange(b),y]-=1
+ hL=hi[-1].detach()
+ lo2=F.cross_entropy(model.out_head(model.out_ln(hL)),y)
+ head_opt.zero_grad();lo2.backward();head_opt.step()
+ for l in range(L):
+ a=(eT@Bs[l].T).detach();rm=(a**2).mean(-1,keepdim=True).sqrt()+1e-6
+ f=model.blocks[l](hi[l].detach());ll=(f*(a/rm)).sum(-1).mean()
+ block_opts[l].zero_grad();ll.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0);block_opts[l].step()
+ a0=(eT@Bs[0].T).detach();r0=(a0**2).mean(-1,keepdim=True).sqrt()+1e-6
+ el=(model.embed(x)*(a0/r0)).sum(-1).mean()
+ embed_opt.zero_grad();el.backward();embed_opt.step()
+ tl+=lv.item()*b;c+=(lo.argmax(1)==y).sum().item();t+=b
+ for s in scheds:s.step()
+ if epoch == t0:
+ acc = evaluate(model, test_loader, device)
+ ckpt_state = {'model': copy.deepcopy(model.state_dict()), 'Bs': [B.clone() for B in Bs], 'acc': acc}
+ print(f" [DFA] Checkpoint at epoch {t0}: acc={acc:.4f}")
+ if epoch % 20 == 0:
+ acc = evaluate(model, test_loader, device)
+ print(f" [DFA] Epoch {epoch}: acc={acc:.4f}")
+ final_acc = evaluate(model, test_loader, device)
+ return Bs, ckpt_state, final_acc
+
+
+def offline_fit_vec(model, train_loader, device, epochs, lr_fb=1e-3, M=4):
+ d = model.d_hidden; L = model.num_blocks; eps = 1e-3
+ vec_net = VectorCreditNet(d_hidden=d, s_dim=10).to(device)
+ vec_opt = optim.Adam(vec_net.parameters(), lr=lr_fb)
+ model.eval()
+ for p in model.parameters(): p.requires_grad_(False)
+ for ep in range(1, epochs+1):
+ vec_net.train()
+ for x,y in train_loader:
+ x=x.view(x.size(0),-1).to(device);y=y.to(device);batch=x.size(0)
+ with torch.no_grad():
+ lo,hi=model(x,return_hidden=True)
+ eT=lo.softmax(-1);eT[torch.arange(batch),y]-=1;s=eT.detach()
+ hL=hi[-1].detach();t_L=torch.ones(batch,device=device)
+ a_term=vec_net(hL,t_L,s)
+ hL_req=hL.clone().requires_grad_(True)
+ ce=F.cross_entropy(model.out_head(model.out_ln(hL_req)),y,reduction='sum')
+ dL=torch.autograd.grad(ce,hL_req,create_graph=False)[0].detach()
+ loss_term=((a_term-dL)**2).sum(-1).mean()
+ l=np.random.randint(0,L);h_l=hi[l].detach();t_l=torch.full((batch,),l/L,device=device)
+ a_l=vec_net(h_l,t_l,s);loss_proj=torch.tensor(0.0,device=device)
+ for _ in range(M):
+ v=torch.randn_like(h_l);v=v/(v.norm(-1,keepdim=True)+1e-8)
+ with torch.no_grad():
+ lp=F.cross_entropy(model.forward_from_layer(h_l+eps*v,l),y,reduction='none')
+ lm=F.cross_entropy(model.forward_from_layer(h_l-eps*v,l),y,reduction='none')
+ gj=(lp-lm)/(2*eps)
+ loss_proj=loss_proj+(((a_l*v).sum(-1)-gj.detach())**2).mean()
+ loss_proj/=M
+ vloss=loss_term+loss_proj
+ vec_opt.zero_grad();vloss.backward()
+ torch.nn.utils.clip_grad_norm_(vec_net.parameters(),1.0);vec_opt.step()
+ for p in model.parameters(): p.requires_grad_(True)
+ return vec_net
+
+
+def eval_frozen_credit_quality(model, vec_net, test_loader, device):
+ """Evaluate Vec credit quality on frozen model."""
+ model.eval(); vec_net.eval(); L = model.num_blocks
+ for x,y in test_loader:
+ x=x.view(x.size(0),-1).to(device);y=y.to(device);break
+ batch=x.size(0)
+ # BP grads
+ for p in model.parameters(): p.requires_grad_(True)
+ model.zero_grad()
+ lo,hbp=model(x,return_hidden=True)
+ for l in range(L+1): hbp[l].retain_grad()
+ F.cross_entropy(lo,y).backward()
+ bp={l:hbp[l].grad.detach().clone() for l in range(L+1)}
+ for p in model.parameters(): p.requires_grad_(False)
+ with torch.no_grad():
+ lo2,hi=model(x,return_hidden=True)
+ eT=lo2.softmax(-1);eT[torch.arange(batch),y]-=1;s=eT.detach()
+ gammas,rhos,nudges=[],[],[]
+ for l in range(L):
+ h_l=hi[l].detach();t_l=torch.full((batch,),l/L,device=device)
+ a_l=vec_net(h_l,t_l,s).detach()
+ gammas.append(cosine_similarity_batch(a_l,bp[l]))
+ def make_fwd(sl):
+ def f(h):
+ with torch.no_grad():
+ c=h
+ for i in range(sl,L):c=c+model.blocks[i](c)
+ return F.cross_entropy(model.out_head(model.out_ln(c)),y,reduction='none')
+ return f
+ rhos.append(perturbation_correlation(h_l,a_l,make_fwd(l),epsilon=1e-3,M=16))
+ nudges.append(nudging_test(h_l,a_l,make_fwd(l),eta=0.003))
+ return float(np.mean(gammas)),float(np.mean(rhos)),float(np.mean(nudges))
+
+
+def continue_from_checkpoint(model, vec_net, Bs, train_loader, test_loader, device,
+ t0, total_epochs, blend_alpha, lr, lr_fb, wd, M, branch_name):
+ """Continue training from checkpoint with blend credit."""
+ d=model.d_hidden;L=model.num_blocks;eps=1e-3
+ block_opts=[optim.AdamW(b.parameters(),lr=lr,weight_decay=wd) for b in model.blocks]
+ embed_opt=optim.AdamW(model.embed.parameters(),lr=lr,weight_decay=wd)
+ head_opt=optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()),lr=lr,weight_decay=wd)
+ vec_opt=optim.Adam(vec_net.parameters(),lr=lr_fb) if blend_alpha>0 else None
+ scheds=[optim.lr_scheduler.CosineAnnealingLR(o,T_max=total_epochs) for o in block_opts]+\
+ [optim.lr_scheduler.CosineAnnealingLR(embed_opt,T_max=total_epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt,T_max=total_epochs)]
+ for _ in range(t0):
+ for s in scheds:s.step()
+ log={'test_acc':[],'train_loss':[]}
+ for epoch in range(t0+1,total_epochs+1):
+ model.train()
+ if vec_net:vec_net.train()
+ tl,c,t=0,0,0
+ for x,y in train_loader:
+ x=x.view(x.size(0),-1).to(device);y=y.to(device);batch=x.size(0)
+ with torch.no_grad():
+ lo,hi=model(x,return_hidden=True);lv=F.cross_entropy(lo,y)
+ eT=lo.softmax(-1);eT[torch.arange(batch),y]-=1;s=eT.detach()
+ hL=hi[-1].detach()
+ if blend_alpha>0 and vec_opt:
+ t_L=torch.ones(batch,device=device)
+ a_term=vec_net(hL,t_L,s)
+ hL_req=hL.clone().requires_grad_(True)
+ ce=F.cross_entropy(model.out_head(model.out_ln(hL_req)),y,reduction='sum')
+ dL=torch.autograd.grad(ce,hL_req,create_graph=False)[0].detach()
+ loss_t=((a_term-dL)**2).sum(-1).mean()
+ lt=np.random.randint(0,L);h_l=hi[lt].detach();t_l=torch.full((batch,),lt/L,device=device)
+ a_l=vec_net(h_l,t_l,s);lp2=torch.tensor(0.0,device=device)
+ for _ in range(M):
+ v=torch.randn_like(h_l);v=v/(v.norm(-1,keepdim=True)+1e-8)
+ with torch.no_grad():
+ lp=F.cross_entropy(model.forward_from_layer(h_l+eps*v,lt),y,reduction='none')
+ lm=F.cross_entropy(model.forward_from_layer(h_l-eps*v,lt),y,reduction='none')
+ gj=(lp-lm)/(2*eps)
+ lp2=lp2+(((a_l*v).sum(-1)-gj.detach())**2).mean()
+ lp2/=M
+ vl=loss_t+lp2;vec_opt.zero_grad();vl.backward()
+ torch.nn.utils.clip_grad_norm_(vec_net.parameters(),1.0);vec_opt.step()
+ with torch.no_grad():
+ vc=[vec_net(hi[l].detach(),torch.full((batch,),l/L,device=device),s).detach() for l in range(L)] if blend_alpha>0 else [None]*L
+ dc=[(eT@Bs[l].T).detach() for l in range(L)]
+ credits=[]
+ for l in range(L):
+ if blend_alpha<=0:credits.append(dc[l])
+ elif blend_alpha>=1:credits.append(vc[l])
+ else:
+ rv=(vc[l]**2).mean(-1,keepdim=True).sqrt()+1e-6;rd=(dc[l]**2).mean(-1,keepdim=True).sqrt()+1e-6
+ credits.append(blend_alpha*vc[l]/rv+(1-blend_alpha)*dc[l]/rd)
+ lo2=F.cross_entropy(model.out_head(model.out_ln(hL)),y)
+ head_opt.zero_grad();lo2.backward();head_opt.step()
+ for l in range(L):
+ a=credits[l];rm=(a**2).mean(-1,keepdim=True).sqrt()+1e-6
+ f=model.blocks[l](hi[l].detach());ll=(f*(a/rm)).sum(-1).mean()
+ block_opts[l].zero_grad();ll.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0);block_opts[l].step()
+ a0=credits[0];r0=(a0**2).mean(-1,keepdim=True).sqrt()+1e-6
+ el=(model.embed(x)*(a0/r0)).sum(-1).mean()
+ embed_opt.zero_grad();el.backward();embed_opt.step()
+ tl+=lv.item()*batch;c+=(lo.argmax(1)==y).sum().item();t+=batch
+ for s in scheds:s.step()
+ ta=evaluate(model,test_loader,device);log['test_acc'].append(ta);log['train_loss'].append(tl/t)
+ if epoch%20==0 or epoch==t0+1 or epoch==total_epochs:
+ print(f" [{branch_name}] Ep {epoch}: acc={ta:.4f}")
+ return log
+
+
+def run_experiment(args):
+ device=torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ print(f"Using device: {device}")
+ os.makedirs(args.output_dir,exist_ok=True)
+ torch.manual_seed(args.seed);np.random.seed(args.seed);torch.cuda.manual_seed_all(args.seed)
+ train_loader,test_loader=get_cifar10(args.batch_size)
+ input_dim=32*32*3;L=args.num_blocks;d=args.d_hidden
+
+ # Train DFA and get checkpoint
+ print(f"\n{'='*60}\nTraining DFA baseline with checkpoint at t0={args.t0}\n{'='*60}")
+ model_dfa=ResidualMLP(input_dim,d,10,L).to(device)
+ Bs,ckpt,dfa_final=train_dfa_with_checkpoint(model_dfa,train_loader,test_loader,device,
+ args.epochs,args.t0,args.lr,args.wd)
+ print(f" DFA final: {dfa_final:.4f}")
+
+ # continue_DFA baseline (from checkpoint)
+ print(f"\n{'='*60}\ncontinue_DFA from t0={args.t0}\n{'='*60}")
+ model_base=ResidualMLP(input_dim,d,10,L).to(device)
+ model_base.load_state_dict(ckpt['model'])
+ log_dfa=continue_from_checkpoint(model_base,None,ckpt['Bs'],train_loader,test_loader,device,
+ args.t0,args.epochs,0.0,args.lr,args.lr_fb,args.wd,args.M,'continue_DFA')
+ dfa_cont_final=log_dfa['test_acc'][-1]
+ dfa_cont_acc20=log_dfa['test_acc'][20-args.t0-1] if len(log_dfa['test_acc'])>20-args.t0-1 else log_dfa['test_acc'][-1]
+ print(f" continue_DFA final: {dfa_cont_final:.4f}")
+
+ # Sweep E_prefit
+ all_results=[]
+ for E in args.prefit_epochs:
+ print(f"\n{'='*60}\nE_prefit={E}\n{'='*60}")
+ model_frozen=ResidualMLP(input_dim,d,10,L).to(device)
+ model_frozen.load_state_dict(ckpt['model'])
+ model_frozen.eval()
+ for p in model_frozen.parameters():p.requires_grad_(False)
+
+ # Offline fit Vec
+ torch.manual_seed(args.seed+E*100+4000)
+ if E>0:
+ print(f" Offline fitting Vec for {E} epochs...")
+ vec_net=offline_fit_vec(model_frozen,train_loader,device,epochs=E,lr_fb=args.lr_fb,M=args.M)
+ else:
+ vec_net=VectorCreditNet(d_hidden=d,s_dim=10).to(device)
+ print(f" E_prefit=0: random Vec")
+
+ # Evaluate frozen credit quality
+ gamma_f,rho_f,nudge_f=eval_frozen_credit_quality(model_frozen,vec_net,test_loader,device)
+ print(f" Frozen quality: Gamma={gamma_f:.4f}, rho={rho_f:.4f}, nudge={nudge_f:.6f}")
+
+ # Branch: blend handoff
+ for branch_name,alpha in [('blend_075',0.75)]:
+ print(f"\n --- {branch_name} (E_prefit={E}) ---")
+ model_branch=ResidualMLP(input_dim,d,10,L).to(device)
+ model_branch.load_state_dict(ckpt['model'])
+ vec_branch=copy.deepcopy(vec_net)
+ for p in model_branch.parameters():p.requires_grad_(True)
+ log=continue_from_checkpoint(model_branch,vec_branch,ckpt['Bs'],train_loader,test_loader,device,
+ args.t0,args.epochs,alpha,args.lr,args.lr_fb,args.wd,args.M,
+ f'E{E}_{branch_name}')
+ final=log['test_acc'][-1]
+ acc20=log['test_acc'][20-args.t0-1] if len(log['test_acc'])>20-args.t0-1 else log['test_acc'][-1]
+ diff_final=final-dfa_cont_final
+ diff_acc20=acc20-dfa_cont_acc20
+
+ r={'E_prefit':E,'branch':branch_name,'gamma_frozen':gamma_f,'rho_frozen':rho_f,
+ 'nudge_frozen':nudge_f,'final_acc':final,'acc_at_20':acc20,
+ 'diff_final':diff_final,'diff_acc20':diff_acc20,
+ 'test_acc':log['test_acc']}
+ all_results.append(r)
+ print(f" E={E} {branch_name}: final={final:.4f} (diff={diff_final:+.4f}), acc@20={acc20:.4f} (diff={diff_acc20:+.4f})")
+
+ # Summary
+ print(f"\n{'='*80}")
+ print("SUMMARY")
+ print(f"{'='*80}")
+ print(f"continue_DFA: final={dfa_cont_final:.4f}, acc@20={dfa_cont_acc20:.4f}")
+ print(f"\n{'E_prefit':>8} {'Gamma_f':>8} {'rho_f':>8} {'final':>8} {'diff':>8} {'acc@20':>8} {'d@20':>8}")
+ print("-"*62)
+ for r in all_results:
+ print(f"{r['E_prefit']:>8} {r['gamma_frozen']:>8.4f} {r['rho_frozen']:>8.4f} "
+ f"{r['final_acc']:>8.4f} {r['diff_final']:>+8.4f} {r['acc_at_20']:>8.4f} {r['diff_acc20']:>+8.4f}")
+
+ # Save
+ save_data={'dfa_cont_final':float(dfa_cont_final),'dfa_cont_acc20':float(dfa_cont_acc20),
+ 'results':[{k:v for k,v in r.items()} for r in all_results]}
+ out_path=os.path.join(args.output_dir,f'prefit_curve_t{args.t0}_s{args.seed}.json')
+ with open(out_path,'w') as f:json.dump(save_data,f,indent=2,default=float)
+ print(f"\nSaved to {out_path}")
+
+
+def main():
+ parser=argparse.ArgumentParser(description='Phase 10A: Prefit Threshold Curve')
+ parser.add_argument('--num_blocks',type=int,default=4)
+ parser.add_argument('--d_hidden',type=int,default=256)
+ parser.add_argument('--batch_size',type=int,default=128)
+ parser.add_argument('--epochs',type=int,default=100)
+ parser.add_argument('--t0',type=int,default=5)
+ parser.add_argument('--prefit_epochs',type=int,nargs='+',default=[0,15,60])
+ parser.add_argument('--lr',type=float,default=1e-3)
+ parser.add_argument('--lr_fb',type=float,default=1e-3)
+ parser.add_argument('--wd',type=float,default=0.01)
+ parser.add_argument('--M',type=int,default=4)
+ parser.add_argument('--seed',type=int,default=42)
+ parser.add_argument('--gpu',type=int,default=2)
+ parser.add_argument('--output_dir',type=str,default='results/prefit_threshold')
+ args=parser.parse_args()
+ run_experiment(args)
+
+
+if __name__=='__main__':
+ main()