diff options
| -rw-r--r-- | experiments/a2_bp_supplement.py | 107 | ||||
| -rw-r--r-- | results/confirmatory/A2_bp_supplement.csv | 11 | ||||
| -rw-r--r-- | results/confirmatory/A2_cifar_state_vs_credit.csv | 10 | ||||
| -rw-r--r-- | results/confirmatory/A2_naive_state_err.csv | 10 |
4 files changed, 138 insertions, 0 deletions
diff --git a/experiments/a2_bp_supplement.py b/experiments/a2_bp_supplement.py new file mode 100644 index 0000000..c594f9e --- /dev/null +++ b/experiments/a2_bp_supplement.py @@ -0,0 +1,107 @@ +"""Supplement A2 with BP baseline: 10 seeds, CIFAR-10, L=4, d=256.""" +import os, sys, json, csv, argparse, numpy as np, torch, torch.nn as nn, torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision, torchvision.transforms as transforms +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from models.residual_mlp import ResidualMLP +from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation + +def get_cifar10(bs=128): + tt=transforms.Compose([transforms.RandomCrop(32,4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))]) + tv=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))]) + return (DataLoader(torchvision.datasets.CIFAR10('./data',True,download=True,transform=tt),bs,True,num_workers=4,pin_memory=True), + DataLoader(torchvision.datasets.CIFAR10('./data',False,download=True,transform=tv),bs,False,num_workers=4,pin_memory=True)) + +def evaluate(m,tl,dev): + m.eval();c,t=0,0 + with torch.no_grad(): + for x,y in tl:x=x.view(x.size(0),-1).to(dev);y=y.to(dev);c+=(m(x).argmax(1)==y).sum().item();t+=x.size(0) + return c/t + +def compute_naive_state_err(model, test_loader, device, eval_layer=None): + model.eval(); L=model.num_blocks + if eval_layer is None: eval_layer=L//2 + te,n=0.0,0 + with torch.no_grad(): + for x,y in test_loader: + x=x.view(x.size(0),-1).to(device);_,hi=model(x,return_hidden=True) + hL=hi[-1];hl=hi[eval_layer];norm=hL.norm(-1,keepdim=True).clamp(min=1.0) + te+=((hl-hL)/norm).pow(2).sum(-1).mean().item()*x.size(0);n+=x.size(0) + return te/n + +def compute_bp_diagnostics(model, test_loader, device): + model.eval(); L=model.num_blocks + for x,y in test_loader: + x=x.view(x.size(0),-1).to(device);y=y.to(device);break + batch=x.size(0) + lo,hbp=model(x,return_hidden=True) + for l in range(L+1):hbp[l].retain_grad() + F.cross_entropy(lo,y).backward() + bp={l:hbp[l].grad.detach().clone() for l in range(L+1)} + # BP credit = BP grad, so Gamma=1.0 by definition, rho≈1.0 + # But let's compute rho properly + with torch.no_grad():_,hi=model(x,return_hidden=True) + gammas,rhos=[],[] + for l in range(L): + gammas.append(1.0) # BP cosine with itself + h_l=hi[l].detach();a_l=bp[l] + def make_fwd(sl): + def f(h): + with torch.no_grad(): + c=h + for i in range(sl,L):c=c+model.blocks[i](c) + return F.cross_entropy(model.out_head(model.out_ln(c)),y,reduction='none') + return f + rhos.append(perturbation_correlation(h_l,a_l,make_fwd(l),epsilon=1e-3,M=16)) + return float(np.mean(gammas)),float(np.mean(rhos)) + +def main(): + p=argparse.ArgumentParser() + p.add_argument('--gpu',type=int,default=3) + p.add_argument('--output_dir',type=str,default='results/confirmatory') + a=p.parse_args() + device=torch.device(f'cuda:{a.gpu}') + os.makedirs(a.output_dir,exist_ok=True) + seeds=[42,123,456,789,1024,2048,3000,4000,5000,6000] + L,d=4,256;trl,tel=get_cifar10() + rows_main=[];rows_naive=[] + for seed in seeds: + torch.manual_seed(seed);np.random.seed(seed);torch.cuda.manual_seed_all(seed) + model=ResidualMLP(3072,d,10,L).to(device) + opt=optim.AdamW(model.parameters(),lr=1e-3,weight_decay=0.01) + sch=optim.lr_scheduler.CosineAnnealingLR(opt,T_max=100) + print(f" BP seed={seed}: training...",flush=True) + for ep in range(1,101): + model.train() + for x,y in trl: + x=x.view(x.size(0),-1).to(device);y=y.to(device) + loss=F.cross_entropy(model(x),y);opt.zero_grad();loss.backward();opt.step() + sch.step() + if ep%20==0:print(f" Ep {ep}: acc={evaluate(model,tel,device):.4f}",flush=True) + acc=evaluate(model,tel,device) + gamma,rho=compute_bp_diagnostics(model,tel,device) + nse=compute_naive_state_err(model,tel,device) + rows_main.append({'method':'bp','seed':seed,'StateErr':float('nan'),'Gamma':gamma,'rho':rho,'acc':acc}) + rows_naive.append({'method':'bp','seed':seed,'naive_StateErr':nse}) + print(f" BP seed={seed}: acc={acc:.4f}, Gamma={gamma:.4f}, rho={rho:.4f}, naive_StateErr={nse:.6f}",flush=True) + # Append to A2 CSV + a2_path=os.path.join(a.output_dir,'A2_cifar_state_vs_credit.csv') + with open(a2_path,'a',newline='') as f: + w=csv.DictWriter(f,fieldnames=['method','seed','StateErr','Gamma','rho','acc']);w.writerows(rows_main) + print(f"Appended {len(rows_main)} BP rows to {a2_path}",flush=True) + # Append to naive CSV + naive_path=os.path.join(a.output_dir,'A2_naive_state_err.csv') + with open(naive_path,'a',newline='') as f: + w=csv.DictWriter(f,fieldnames=['method','seed','naive_StateErr']);w.writerows(rows_naive) + print(f"Appended {len(rows_naive)} BP rows to {naive_path}",flush=True) + # Also save standalone + bp_path=os.path.join(a.output_dir,'A2_bp_supplement.csv') + with open(bp_path,'w',newline='') as f: + w=csv.DictWriter(f,fieldnames=['method','seed','StateErr','Gamma','rho','acc','naive_StateErr']) + w.writeheader() + for m,n in zip(rows_main,rows_naive): + w.writerow({**m,'naive_StateErr':n['naive_StateErr']}) + print(f"Saved standalone to {bp_path}",flush=True) + +if __name__=='__main__':main() diff --git a/results/confirmatory/A2_bp_supplement.csv b/results/confirmatory/A2_bp_supplement.csv new file mode 100644 index 0000000..51fde14 --- /dev/null +++ b/results/confirmatory/A2_bp_supplement.csv @@ -0,0 +1,11 @@ +method,seed,StateErr,Gamma,rho,acc,naive_StateErr
+bp,42,nan,1.0,0.9973812252283096,0.6125,12270.6220625
+bp,123,nan,1.0,0.9977057427167892,0.6084,11526.486759375
+bp,456,nan,1.0,0.9973872005939484,0.6149,11941.3817234375
+bp,789,nan,1.0,0.9976505190134048,0.6165,12075.4278453125
+bp,1024,nan,1.0,0.9978837668895721,0.6132,11724.9296015625
+bp,2048,nan,1.0,0.997720941901207,0.6096,11844.94355
+bp,3000,nan,1.0,0.998002827167511,0.6139,10728.462175
+bp,4000,nan,1.0,0.9976081401109695,0.6153,12229.952775
+bp,5000,nan,1.0,0.9972489476203918,0.6188,10800.38535
+bp,6000,nan,1.0,0.9975202232599258,0.6131,11821.658378125
diff --git a/results/confirmatory/A2_cifar_state_vs_credit.csv b/results/confirmatory/A2_cifar_state_vs_credit.csv index 68156b5..8a57702 100644 --- a/results/confirmatory/A2_cifar_state_vs_credit.csv +++ b/results/confirmatory/A2_cifar_state_vs_credit.csv @@ -29,3 +29,13 @@ credit_bridge_eT,5000,nan,0.18241790123283863,0.007469391683116555,0.3002 dfa,6000,nan,0.09259073645807803,0.0021172836422920227,0.282
state_bridge,6000,0.00011229384421696886,0.2289915308356285,0.14905795291997492,0.2964
credit_bridge_eT,6000,nan,0.20601818710565567,-0.0038647495675832033,0.2882
+bp,42,nan,1.0,0.9973812252283096,0.6125
+bp,123,nan,1.0,0.9977057427167892,0.6084
+bp,456,nan,1.0,0.9973872005939484,0.6149
+bp,789,nan,1.0,0.9976505190134048,0.6165
+bp,1024,nan,1.0,0.9978837668895721,0.6132
+bp,2048,nan,1.0,0.997720941901207,0.6096
+bp,3000,nan,1.0,0.998002827167511,0.6139
+bp,4000,nan,1.0,0.9976081401109695,0.6153
+bp,5000,nan,1.0,0.9972489476203918,0.6188
+bp,6000,nan,1.0,0.9975202232599258,0.6131
diff --git a/results/confirmatory/A2_naive_state_err.csv b/results/confirmatory/A2_naive_state_err.csv index fcb5f01..905d1c2 100644 --- a/results/confirmatory/A2_naive_state_err.csv +++ b/results/confirmatory/A2_naive_state_err.csv @@ -29,3 +29,13 @@ credit_bridge,5000,0.6871149513244629 dfa,6000,0.8224315113067627
state_bridge,6000,0.14781150970458984
credit_bridge,6000,0.5449538031578064
+bp,42,12270.6220625
+bp,123,11526.486759375
+bp,456,11941.3817234375
+bp,789,12075.4278453125
+bp,1024,11724.9296015625
+bp,2048,11844.94355
+bp,3000,10728.462175
+bp,4000,12229.952775
+bp,5000,10800.38535
+bp,6000,11821.658378125
|
