summaryrefslogtreecommitdiff
path: root/experiments/a2_bp_supplement.py
blob: c594f9ea2dade4841e79fffe393e980f83e4293c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""Supplement A2 with BP baseline: 10 seeds, CIFAR-10, L=4, d=256."""
import os, sys, json, csv, argparse, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision, torchvision.transforms as transforms
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.residual_mlp import ResidualMLP
from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation

def get_cifar10(bs=128):
    tt=transforms.Compose([transforms.RandomCrop(32,4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))])
    tv=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))])
    return (DataLoader(torchvision.datasets.CIFAR10('./data',True,download=True,transform=tt),bs,True,num_workers=4,pin_memory=True),
            DataLoader(torchvision.datasets.CIFAR10('./data',False,download=True,transform=tv),bs,False,num_workers=4,pin_memory=True))

def evaluate(m,tl,dev):
    m.eval();c,t=0,0
    with torch.no_grad():
        for x,y in tl:x=x.view(x.size(0),-1).to(dev);y=y.to(dev);c+=(m(x).argmax(1)==y).sum().item();t+=x.size(0)
    return c/t

def compute_naive_state_err(model, test_loader, device, eval_layer=None):
    model.eval(); L=model.num_blocks
    if eval_layer is None: eval_layer=L//2
    te,n=0.0,0
    with torch.no_grad():
        for x,y in test_loader:
            x=x.view(x.size(0),-1).to(device);_,hi=model(x,return_hidden=True)
            hL=hi[-1];hl=hi[eval_layer];norm=hL.norm(-1,keepdim=True).clamp(min=1.0)
            te+=((hl-hL)/norm).pow(2).sum(-1).mean().item()*x.size(0);n+=x.size(0)
    return te/n

def compute_bp_diagnostics(model, test_loader, device):
    model.eval(); L=model.num_blocks
    for x,y in test_loader:
        x=x.view(x.size(0),-1).to(device);y=y.to(device);break
    batch=x.size(0)
    lo,hbp=model(x,return_hidden=True)
    for l in range(L+1):hbp[l].retain_grad()
    F.cross_entropy(lo,y).backward()
    bp={l:hbp[l].grad.detach().clone() for l in range(L+1)}
    # BP credit = BP grad, so Gamma=1.0 by definition, rho≈1.0
    # But let's compute rho properly
    with torch.no_grad():_,hi=model(x,return_hidden=True)
    gammas,rhos=[],[]
    for l in range(L):
        gammas.append(1.0)  # BP cosine with itself
        h_l=hi[l].detach();a_l=bp[l]
        def make_fwd(sl):
            def f(h):
                with torch.no_grad():
                    c=h
                    for i in range(sl,L):c=c+model.blocks[i](c)
                    return F.cross_entropy(model.out_head(model.out_ln(c)),y,reduction='none')
            return f
        rhos.append(perturbation_correlation(h_l,a_l,make_fwd(l),epsilon=1e-3,M=16))
    return float(np.mean(gammas)),float(np.mean(rhos))

def main():
    p=argparse.ArgumentParser()
    p.add_argument('--gpu',type=int,default=3)
    p.add_argument('--output_dir',type=str,default='results/confirmatory')
    a=p.parse_args()
    device=torch.device(f'cuda:{a.gpu}')
    os.makedirs(a.output_dir,exist_ok=True)
    seeds=[42,123,456,789,1024,2048,3000,4000,5000,6000]
    L,d=4,256;trl,tel=get_cifar10()
    rows_main=[];rows_naive=[]
    for seed in seeds:
        torch.manual_seed(seed);np.random.seed(seed);torch.cuda.manual_seed_all(seed)
        model=ResidualMLP(3072,d,10,L).to(device)
        opt=optim.AdamW(model.parameters(),lr=1e-3,weight_decay=0.01)
        sch=optim.lr_scheduler.CosineAnnealingLR(opt,T_max=100)
        print(f"  BP seed={seed}: training...",flush=True)
        for ep in range(1,101):
            model.train()
            for x,y in trl:
                x=x.view(x.size(0),-1).to(device);y=y.to(device)
                loss=F.cross_entropy(model(x),y);opt.zero_grad();loss.backward();opt.step()
            sch.step()
            if ep%20==0:print(f"    Ep {ep}: acc={evaluate(model,tel,device):.4f}",flush=True)
        acc=evaluate(model,tel,device)
        gamma,rho=compute_bp_diagnostics(model,tel,device)
        nse=compute_naive_state_err(model,tel,device)
        rows_main.append({'method':'bp','seed':seed,'StateErr':float('nan'),'Gamma':gamma,'rho':rho,'acc':acc})
        rows_naive.append({'method':'bp','seed':seed,'naive_StateErr':nse})
        print(f"  BP seed={seed}: acc={acc:.4f}, Gamma={gamma:.4f}, rho={rho:.4f}, naive_StateErr={nse:.6f}",flush=True)
    # Append to A2 CSV
    a2_path=os.path.join(a.output_dir,'A2_cifar_state_vs_credit.csv')
    with open(a2_path,'a',newline='') as f:
        w=csv.DictWriter(f,fieldnames=['method','seed','StateErr','Gamma','rho','acc']);w.writerows(rows_main)
    print(f"Appended {len(rows_main)} BP rows to {a2_path}",flush=True)
    # Append to naive CSV
    naive_path=os.path.join(a.output_dir,'A2_naive_state_err.csv')
    with open(naive_path,'a',newline='') as f:
        w=csv.DictWriter(f,fieldnames=['method','seed','naive_StateErr']);w.writerows(rows_naive)
    print(f"Appended {len(rows_naive)} BP rows to {naive_path}",flush=True)
    # Also save standalone
    bp_path=os.path.join(a.output_dir,'A2_bp_supplement.csv')
    with open(bp_path,'w',newline='') as f:
        w=csv.DictWriter(f,fieldnames=['method','seed','StateErr','Gamma','rho','acc','naive_StateErr'])
        w.writeheader()
        for m,n in zip(rows_main,rows_naive):
            w.writerow({**m,'naive_StateErr':n['naive_StateErr']})
    print(f"Saved standalone to {bp_path}",flush=True)

if __name__=='__main__':main()