diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-03-31 10:36:02 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-03-31 10:36:02 -0500 |
| commit | 1359b7e7a96ab57be0bb24ebdf842a793ce01223 (patch) | |
| tree | ba13d16ed1bf4604ffb76e12ceb1fdb5958900bb | |
| parent | 8b21fb32bf0997e3f4266c1c22414e49f1fdcfcc (diff) | |
Add naive state prediction baseline for A1 and A2
A1: 240 rows (3 alpha × 2 depth × 4 methods × 10 seeds)
A2: 30 rows (3 methods × 10 seeds)
naive_StateErr = ||h_{L//2} - h_L|| / ||h_L||
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
| -rw-r--r-- | experiments/compute_naive_state_err.py | 390 | ||||
| -rw-r--r-- | results/confirmatory/A1_naive_state_err.csv | 241 | ||||
| -rw-r--r-- | results/confirmatory/A2_naive_state_err.csv | 31 |
3 files changed, 662 insertions, 0 deletions
diff --git a/experiments/compute_naive_state_err.py b/experiments/compute_naive_state_err.py new file mode 100644 index 0000000..64c8790 --- /dev/null +++ b/experiments/compute_naive_state_err.py @@ -0,0 +1,390 @@ +""" +Compute naive state prediction baseline: ||h_l - h_L|| / ||h_L|| for all methods. +This is the identity-prediction error — how much does h change from layer l to L. +Computed at layer L//2 on test/eval data after training. + +Retrains models with same seeds as confirmatory experiments, then computes metric. +""" +import os, sys, json, csv, argparse, numpy as np, torch, torch.nn as nn, torch.nn.functional as F +import torch.optim as optim, copy +from torch.utils.data import DataLoader +import torchvision, torchvision.transforms as transforms + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from models.residual_mlp import ResidualMLP +from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema +from models.state_bridge import StateBridgeNet + + +def compute_naive_state_err(model, dataloader, device, eval_layer=None): + """Compute ||h_l - h_L|| / ||h_L|| averaged over data.""" + model.eval() + L = model.num_blocks + if eval_layer is None: + eval_layer = L // 2 + total_err, n = 0.0, 0 + with torch.no_grad(): + for x, y in dataloader: + if hasattr(model, 'embed'): + x = x.view(x.size(0), -1).to(device) + else: + x = x.to(device) + _, hiddens = model(x, return_hidden=True) + h_l = hiddens[eval_layer] + h_L = hiddens[-1] + norm_L = h_L.norm(dim=-1, keepdim=True).clamp(min=1.0) + err = ((h_l - h_L) / norm_L).pow(2).sum(-1).mean() + total_err += err.item() * x.size(0) + n += x.size(0) + return total_err / n + + +# ============ Synthetic ============ +class TeacherNet(nn.Module): + def __init__(self, d, C, L, alpha=1.0, seed=0): + super().__init__() + self.alpha = alpha + rng = torch.Generator().manual_seed(seed) + self.Ws = nn.ParameterList() + for _ in range(L): + W = torch.randn(d, d, generator=rng) * 0.3 / (d**0.5) + U, S, Vh = torch.linalg.svd(W, full_matrices=False) + W = U @ torch.diag(S.clamp(max=0.3)) @ Vh + self.Ws.append(nn.Parameter(W, requires_grad=False)) + self.U = nn.Parameter(torch.randn(C, d, generator=rng) / (d**0.5), requires_grad=False) + def phi(self, z): return (1-self.alpha)*z + self.alpha*torch.tanh(z) + def forward(self, x): + h = x + for W in self.Ws: h = h + self.phi(h @ W.T) + return h @ self.U.T + +class StudentBlock(nn.Module): + def __init__(self, d, alpha=1.0): + super().__init__() + self.ln = nn.LayerNorm(d); self.w = nn.Linear(d, d, bias=False) + nn.init.normal_(self.w.weight, std=0.01); self.alpha = alpha + def phi(self, z): return (1-self.alpha)*z + self.alpha*torch.tanh(z) + def forward(self, h): return self.w(self.phi(self.ln(h))) + +class StudentNet(nn.Module): + def __init__(self, d, C, L, alpha=1.0): + super().__init__() + self.blocks = nn.ModuleList([StudentBlock(d, alpha) for _ in range(L)]) + self.out_head = nn.Linear(d, C); self.d_hidden = d; self.num_blocks = L + def forward(self, x, return_hidden=False): + h = x; hiddens = [h] if return_hidden else None + for b in self.blocks: + h = h + b(h) + if return_hidden: hiddens.append(h) + logits = self.out_head(h) + return (logits, hiddens) if return_hidden else logits + def forward_from_layer(self, h, sl): + for i in range(sl, self.num_blocks): h = h + self.blocks[i](h) + return self.out_head(h) + + +def train_synth_method(method, model, teacher, device, d, C, L, epochs=80, steps=50, bs=256, lr=1e-3, lr_fb=1e-3): + """Train synthetic model with given method, return trained model.""" + if method == 'bp': + opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01) + for ep in range(epochs): + model.train() + for _ in range(steps): + x = torch.randn(bs, d, device=device) + with torch.no_grad(): y = teacher(x).argmax(-1) + loss = F.cross_entropy(model(x), y); opt.zero_grad(); loss.backward(); opt.step() + elif method == 'dfa': + Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)] + bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=0.01) for b in model.blocks] + hop = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=0.01) + for ep in range(epochs): + model.train() + for _ in range(steps): + x = torch.randn(bs, d, device=device) + with torch.no_grad(): + y = teacher(x).argmax(-1); lo, hi = model(x, return_hidden=True) + eT = lo.softmax(-1); eT[torch.arange(bs), y] -= 1 + lo2 = F.cross_entropy(model.out_head(hi[-1].detach()), y); hop.zero_grad(); lo2.backward(); hop.step() + for l in range(L): + a = (eT @ Bs[l].T).detach(); rm = (a**2).mean(-1, keepdim=True).sqrt()+1e-6 + ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean() + bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0); bops[l].step() + elif method == 'state_bridge': + sp = StateBridgeNet(d_hidden=d, s_dim=C).to(device) + Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)] + bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=0.01) for b in model.blocks] + hop = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=0.01) + sop = optim.Adam(sp.parameters(), lr=lr_fb) + warmup = max(1, epochs//5) + for ep in range(epochs): + model.train(); sp.train() + for _ in range(steps): + x = torch.randn(bs, d, device=device) + with torch.no_grad(): + y = teacher(x).argmax(-1); lo, hi = model(x, return_hidden=True) + eT = lo.softmax(-1); eT[torch.arange(bs), y] -= 1; s = eT.detach() + hL = hi[-1].detach() + sl = 0.0 + for l in range(L): + t_l = torch.full((bs,), l/L, device=device) + pred = sp(hi[l].detach(), t_l, s) + tn = hL.norm(-1, keepdim=True).clamp(min=1.0) + sl += (((pred-hL)/tn)**2).sum(-1).mean() + sl /= L; sop.zero_grad(); sl.backward(); sop.step() + credits = [] + for l in range(L): + hl = hi[l].detach().requires_grad_(True); t_l = torch.full((bs,), l/L, device=device) + pred = sp(hl, t_l, s); pl = F.cross_entropy(model.out_head(pred), y, reduction='sum') + credits.append(torch.autograd.grad(pl, hl, create_graph=False)[0].detach()) + lo2 = F.cross_entropy(model.out_head(hL), y); hop.zero_grad(); lo2.backward(); hop.step() + for l in range(L): + a = credits[l]; rm = (a**2).mean(-1, keepdim=True).sqrt()+1e-6 + ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean() + bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0); bops[l].step() + elif method == 'credit_bridge': + vn = ValueNet(d_hidden=d, s_dim=C).to(device); ve = create_ema_model(vn) + Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)] + bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=0.01) for b in model.blocks] + hop = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=0.01) + vop = optim.Adam(vn.parameters(), lr=lr_fb) + warmup = max(1, epochs//5) + for ep in range(epochs): + model.train(); vn.train() + blend = 0.0 if ep < warmup else min(1.0, (ep-warmup)/max(1, warmup)) + for _ in range(steps): + x = torch.randn(bs, d, device=device) + with torch.no_grad(): + y = teacher(x).argmax(-1); lo, hi = model(x, return_hidden=True) + eT = lo.softmax(-1); eT[torch.arange(bs), y] -= 1; s = eT.detach() + tl_val = F.cross_entropy(lo, y, reduction='none').detach() + hL = hi[-1].detach(); t_L = torch.ones(bs, device=device) + lt = ((vn(hL, t_L, s)-tl_val)**2).mean() + hLr = hL.clone().requires_grad_(True); VL = vn(hLr, t_L, s) + gV = torch.autograd.grad(VL.sum(), hLr, create_graph=True)[0] + hLr2 = hL.clone().requires_grad_(True); ce = F.cross_entropy(model.out_head(hLr2), y, reduction='sum') + aLe = torch.autograd.grad(ce, hLr2, create_graph=False)[0].detach() + ltg = ((gV-aLe)**2).sum(-1).mean() + lb = 0.0 + for l in range(L): + hl = hi[l].detach(); tl = torch.full((bs,), l/L, device=device); tn = torch.full((bs,), (l+1)/L, device=device) + Vl = vn(hl, tl, s) + with torch.no_grad(): + hn = hi[l+1].detach(); lts = [] + for k in range(4): + ns = 0.05*torch.randn_like(hn); lts.append(-ve(hn+ns, tn, s)/0.1) + Vt = -0.1*(torch.logsumexp(torch.stack(lts, -1), -1)-np.log(4)) + lb += ((Vl-Vt.detach())**2).mean() + lb /= L; vl = lt+lb+1.0*ltg + vop.zero_grad(); vl.backward(); torch.nn.utils.clip_grad_norm_(vn.parameters(), 1.0); vop.step() + update_ema(vn, ve, 0.995) + cbc = [] + for l in range(L): + hl = hi[l].detach().requires_grad_(True); tl = torch.full((bs,), l/L, device=device) + Vl = vn(hl, tl, s); cbc.append(torch.autograd.grad(Vl.sum(), hl, create_graph=False)[0].detach()) + dfac = [(eT@Bs[l].T).detach() for l in range(L)] + credits = [] + for l in range(L): + if blend >= 1: credits.append(cbc[l]) + elif blend <= 0: credits.append(dfac[l]) + else: + cr = (cbc[l]**2).mean(-1,keepdim=True).sqrt()+1e-6; dr = (dfac[l]**2).mean(-1,keepdim=True).sqrt()+1e-6 + credits.append(blend*cbc[l]/cr+(1-blend)*dfac[l]/dr) + lo2 = F.cross_entropy(model.out_head(hL), y); hop.zero_grad(); lo2.backward(); hop.step() + for l in range(L): + a = credits[l]; rm = (a**2).mean(-1, keepdim=True).sqrt()+1e-6 + ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean() + bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0); bops[l].step() + return model + + +def run_A1_naive(args): + """Compute naive state err for synthetic ladder.""" + device = torch.device(f'cuda:{args.gpu}') + seeds = [42,123,456,789,1024,2048,3000,4000,5000,6000] + alphas = [0.0, 0.5, 1.0]; depths = [4, 8]; d = 128; C = 10 + methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge'] + rows = [] + for alpha in alphas: + for L in depths: + for seed in seeds: + teacher = TeacherNet(d, C, L, alpha=alpha, seed=seed*1000).to(device) + # Eval data + eval_x = torch.randn(512, d, device=device) + with torch.no_grad(): eval_y = teacher(eval_x).argmax(-1) + eval_data = [(eval_x, eval_y)] + for method in methods: + torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) + model = StudentNet(d, C, L, alpha=alpha).to(device) + model = train_synth_method(method, model, teacher, device, d, C, L) + nse = compute_naive_state_err(model, eval_data, device, eval_layer=L//2) + rows.append({'alpha': alpha, 'depth': L, 'method': method, 'seed': seed, 'naive_StateErr': nse}) + print(f" A1 alpha={alpha} L={L} {method} s={seed}: naive_StateErr={nse:.6f}", flush=True) + # Save CSV + out = os.path.join(args.output_dir, 'A1_naive_state_err.csv') + with open(out, 'w', newline='') as f: + w = csv.DictWriter(f, fieldnames=['alpha','depth','method','seed','naive_StateErr']) + w.writeheader(); w.writerows(rows) + print(f"Saved {len(rows)} rows to {out}", flush=True) + + +def get_cifar10(bs=128): + tt = transforms.Compose([transforms.RandomCrop(32,4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))]) + tv = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))]) + return (DataLoader(torchvision.datasets.CIFAR10('./data',True,download=True,transform=tt),bs,True,num_workers=4,pin_memory=True), + DataLoader(torchvision.datasets.CIFAR10('./data',False,download=True,transform=tv),bs,False,num_workers=4,pin_memory=True)) + + +def train_cifar_method(method, model, train_loader, test_loader, device, L, d, epochs=100, lr=1e-3, lr_fb=1e-3, wd=0.01): + """Train CIFAR model, return trained model.""" + C = 10 + if method == 'dfa': + Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)] + bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] + eop = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) + hop = optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()), lr=lr, weight_decay=wd) + schs = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in bops]+[optim.lr_scheduler.CosineAnnealingLR(eop, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(hop, T_max=epochs)] + for ep in range(1, epochs+1): + model.train() + for x, y in train_loader: + x = x.view(x.size(0),-1).to(device); y = y.to(device); b = x.size(0) + with torch.no_grad(): lo, hi = model(x, return_hidden=True); eT = lo.softmax(-1); eT[torch.arange(b),y] -= 1 + hL = hi[-1].detach() + F.cross_entropy(model.out_head(model.out_ln(hL)), y).backward(); hop.step(); hop.zero_grad() + for l in range(L): + a = (eT@Bs[l].T).detach(); rm = (a**2).mean(-1,keepdim=True).sqrt()+1e-6 + ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean() + bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0); bops[l].step() + a0 = (eT@Bs[0].T).detach(); r0 = (a0**2).mean(-1,keepdim=True).sqrt()+1e-6 + el = (model.embed(x.view(x.size(0),-1))*(a0/r0)).sum(-1).mean(); eop.zero_grad(); el.backward(); eop.step() + for s in schs: s.step() + elif method == 'state_bridge': + sp = StateBridgeNet(d_hidden=d, s_dim=C).to(device) + Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)] + bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] + eop = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) + hop = optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()), lr=lr, weight_decay=wd) + sop = optim.Adam(sp.parameters(), lr=lr_fb) + schs = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in bops]+[optim.lr_scheduler.CosineAnnealingLR(eop, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(hop, T_max=epochs)] + for ep in range(1, epochs+1): + model.train(); sp.train() + for x, y in train_loader: + x = x.view(x.size(0),-1).to(device); y = y.to(device); b = x.size(0) + with torch.no_grad(): lo, hi = model(x, return_hidden=True); eT = lo.softmax(-1); eT[torch.arange(b),y] -= 1; s = eT.detach() + hL = hi[-1].detach() + sl = 0.0 + for l in range(L): + tl = torch.full((b,), l/L, device=device) + pred = sp(hi[l].detach(), tl, s); tn = hL.norm(-1,keepdim=True).clamp(min=1.0) + sl += (((pred-hL)/tn)**2).sum(-1).mean() + sl /= L; sop.zero_grad(); sl.backward(); sop.step() + credits = [] + for l in range(L): + hl = hi[l].detach().requires_grad_(True); tl = torch.full((b,), l/L, device=device) + pred = sp(hl, tl, s); pl = F.cross_entropy(model.out_head(model.out_ln(pred)), y, reduction='sum') + credits.append(torch.autograd.grad(pl, hl, create_graph=False)[0].detach()) + lo2 = F.cross_entropy(model.out_head(model.out_ln(hL)), y); hop.zero_grad(); lo2.backward(); hop.step() + for l in range(L): + a = credits[l]; rm = (a**2).mean(-1,keepdim=True).sqrt()+1e-6 + ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean() + bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0); bops[l].step() + a0 = credits[0]; r0 = (a0**2).mean(-1,keepdim=True).sqrt()+1e-6 + el = (model.embed(x.view(x.size(0),-1))*(a0/r0)).sum(-1).mean(); eop.zero_grad(); el.backward(); eop.step() + for s in schs: s.step() + elif method == 'credit_bridge': + vn = ValueNet(d_hidden=d, s_dim=C).to(device); ve = create_ema_model(vn) + Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)] + bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] + eop = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) + hop = optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()), lr=lr, weight_decay=wd) + vop = optim.Adam(vn.parameters(), lr=lr_fb) + schs = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in bops]+[optim.lr_scheduler.CosineAnnealingLR(eop, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(hop, T_max=epochs)] + warmup = max(1, epochs//5) + for ep in range(1, epochs+1): + model.train(); vn.train() + blend = 0.0 if ep <= warmup else min(1.0, (ep-warmup)/max(1, warmup)) + for x, y in train_loader: + x = x.view(x.size(0),-1).to(device); y = y.to(device); b = x.size(0) + with torch.no_grad(): lo, hi = model(x, return_hidden=True); eT = lo.softmax(-1); eT[torch.arange(b),y] -= 1; s = eT.detach(); tlv = F.cross_entropy(lo, y, reduction='none').detach() + hL = hi[-1].detach(); t_L = torch.ones(b, device=device) + lt = ((vn(hL, t_L, s)-tlv)**2).mean() + hLr = hL.clone().requires_grad_(True); VL = vn(hLr, t_L, s) + gV = torch.autograd.grad(VL.sum(), hLr, create_graph=True)[0] + hLr2 = hL.clone().requires_grad_(True); ce = F.cross_entropy(model.out_head(model.out_ln(hLr2)), y, reduction='sum') + aLe = torch.autograd.grad(ce, hLr2, create_graph=False)[0].detach() + ltg = ((gV-aLe)**2).sum(-1).mean() + lb = 0.0 + for l in range(L): + hl = hi[l].detach(); tl = torch.full((b,), l/L, device=device); tn = torch.full((b,), (l+1)/L, device=device) + Vl = vn(hl, tl, s) + with torch.no_grad(): + hn = hi[l+1].detach(); lts = [] + for k in range(4): lts.append(-ve(hn+0.05*torch.randn_like(hn), tn, s)/0.1) + Vt = -0.1*(torch.logsumexp(torch.stack(lts,-1),-1)-np.log(4)) + lb += ((Vl-Vt.detach())**2).mean() + lb /= L; vl = lt+lb+1.0*ltg + vop.zero_grad(); vl.backward(); torch.nn.utils.clip_grad_norm_(vn.parameters(),1.0); vop.step() + update_ema(vn, ve, 0.995) + cbc = [] + for l in range(L): + hl = hi[l].detach().requires_grad_(True); tl = torch.full((b,), l/L, device=device) + Vl = vn(hl, tl, s); cbc.append(torch.autograd.grad(Vl.sum(), hl, create_graph=False)[0].detach()) + dfac = [(eT@Bs[l].T).detach() for l in range(L)] + credits = [] + for l in range(L): + if blend >= 1: credits.append(cbc[l]) + elif blend <= 0: credits.append(dfac[l]) + else: + cr = (cbc[l]**2).mean(-1,keepdim=True).sqrt()+1e-6; dr = (dfac[l]**2).mean(-1,keepdim=True).sqrt()+1e-6 + credits.append(blend*cbc[l]/cr+(1-blend)*dfac[l]/dr) + lo2 = F.cross_entropy(model.out_head(model.out_ln(hL)), y); hop.zero_grad(); lo2.backward(); hop.step() + for l in range(L): + a = credits[l]; rm = (a**2).mean(-1,keepdim=True).sqrt()+1e-6 + ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean() + bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0); bops[l].step() + a0 = credits[0]; r0 = (a0**2).mean(-1,keepdim=True).sqrt()+1e-6 + el = (model.embed(x.view(x.size(0),-1))*(a0/r0)).sum(-1).mean(); eop.zero_grad(); el.backward(); eop.step() + for s in schs: s.step() + return model + + +def run_A2_naive(args): + """Compute naive state err for CIFAR methods.""" + device = torch.device(f'cuda:{args.gpu}') + seeds = [42,123,456,789,1024,2048,3000,4000,5000,6000] + L, d = 4, 256; methods = ['dfa', 'state_bridge', 'credit_bridge'] + train_loader, test_loader = get_cifar10() + rows = [] + for seed in seeds: + for method in methods: + torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) + model = ResidualMLP(3072, d, 10, L).to(device) + print(f" A2 {method} s={seed}: training...", flush=True) + model = train_cifar_method(method, model, train_loader, test_loader, device, L, d) + nse = compute_naive_state_err(model, test_loader, device, eval_layer=L//2) + rows.append({'method': method, 'seed': seed, 'naive_StateErr': nse}) + print(f" A2 {method} s={seed}: naive_StateErr={nse:.6f}", flush=True) + out = os.path.join(args.output_dir, 'A2_naive_state_err.csv') + with open(out, 'w', newline='') as f: + w = csv.DictWriter(f, fieldnames=['method','seed','naive_StateErr']) + w.writeheader(); w.writerows(rows) + print(f"Saved {len(rows)} rows to {out}", flush=True) + + +def main(): + p = argparse.ArgumentParser() + p.add_argument('--experiment', type=str, default='A1', choices=['A1','A2','both']) + p.add_argument('--gpu', type=int, default=3) + p.add_argument('--output_dir', type=str, default='results/confirmatory') + args = p.parse_args() + os.makedirs(args.output_dir, exist_ok=True) + if args.experiment in ['A1', 'both']: + print("=== A1: Synthetic naive state err ===", flush=True) + run_A1_naive(args) + if args.experiment in ['A2', 'both']: + print("=== A2: CIFAR naive state err ===", flush=True) + run_A2_naive(args) + print("Done.", flush=True) + + +if __name__ == '__main__': + main() diff --git a/results/confirmatory/A1_naive_state_err.csv b/results/confirmatory/A1_naive_state_err.csv new file mode 100644 index 0000000..6ddcfe7 --- /dev/null +++ b/results/confirmatory/A1_naive_state_err.csv @@ -0,0 +1,241 @@ +alpha,depth,method,seed,naive_StateErr
+0.0,4,bp,42,0.26068803668022156
+0.0,4,dfa,42,0.7258815169334412
+0.0,4,state_bridge,42,0.7580535411834717
+0.0,4,credit_bridge,42,0.23254971206188202
+0.0,4,bp,123,0.24781781435012817
+0.0,4,dfa,123,0.6604026556015015
+0.0,4,state_bridge,123,0.5317853689193726
+0.0,4,credit_bridge,123,0.17576289176940918
+0.0,4,bp,456,0.2476104348897934
+0.0,4,dfa,456,0.6199116110801697
+0.0,4,state_bridge,456,0.863245964050293
+0.0,4,credit_bridge,456,0.12035267055034637
+0.0,4,bp,789,0.2656460106372833
+0.0,4,dfa,789,0.7435265779495239
+0.0,4,state_bridge,789,0.6364812850952148
+0.0,4,credit_bridge,789,0.2562270164489746
+0.0,4,bp,1024,0.2501564025878906
+0.0,4,dfa,1024,0.6919815540313721
+0.0,4,state_bridge,1024,0.6960811018943787
+0.0,4,credit_bridge,1024,0.2409060299396515
+0.0,4,bp,2048,0.24213725328445435
+0.0,4,dfa,2048,0.8119536638259888
+0.0,4,state_bridge,2048,1.0384111404418945
+0.0,4,credit_bridge,2048,0.29382896423339844
+0.0,4,bp,3000,0.25653713941574097
+0.0,4,dfa,3000,0.7727931141853333
+0.0,4,state_bridge,3000,0.7795010805130005
+0.0,4,credit_bridge,3000,0.1640912890434265
+0.0,4,bp,4000,0.2790989279747009
+0.0,4,dfa,4000,0.6488389372825623
+0.0,4,state_bridge,4000,0.8933002948760986
+0.0,4,credit_bridge,4000,0.129751056432724
+0.0,4,bp,5000,0.23482249677181244
+0.0,4,dfa,5000,0.5603621006011963
+0.0,4,state_bridge,5000,0.7118808031082153
+0.0,4,credit_bridge,5000,0.1876041740179062
+0.0,4,bp,6000,0.2366245985031128
+0.0,4,dfa,6000,0.7504462003707886
+0.0,4,state_bridge,6000,0.8727204203605652
+0.0,4,credit_bridge,6000,0.18112479150295258
+0.0,8,bp,42,0.07397101819515228
+0.0,8,dfa,42,0.5621968507766724
+0.0,8,state_bridge,42,0.816449761390686
+0.0,8,credit_bridge,42,0.15379583835601807
+0.0,8,bp,123,0.07488589733839035
+0.0,8,dfa,123,0.6059672832489014
+0.0,8,state_bridge,123,0.5263348817825317
+0.0,8,credit_bridge,123,0.12262149155139923
+0.0,8,bp,456,0.07783643156290054
+0.0,8,dfa,456,0.7705618143081665
+0.0,8,state_bridge,456,0.4180244505405426
+0.0,8,credit_bridge,456,0.19311878085136414
+0.0,8,bp,789,0.09013043344020844
+0.0,8,dfa,789,0.5577319860458374
+0.0,8,state_bridge,789,0.8297959566116333
+0.0,8,credit_bridge,789,0.05089893937110901
+0.0,8,bp,1024,0.0830039381980896
+0.0,8,dfa,1024,0.5080844163894653
+0.0,8,state_bridge,1024,0.39521247148513794
+0.0,8,credit_bridge,1024,0.06527997553348541
+0.0,8,bp,2048,0.07513846457004547
+0.0,8,dfa,2048,0.6275949478149414
+0.0,8,state_bridge,2048,0.6864550113677979
+0.0,8,credit_bridge,2048,0.08002562075853348
+0.0,8,bp,3000,0.10901156067848206
+0.0,8,dfa,3000,0.6083968281745911
+0.0,8,state_bridge,3000,0.8938897848129272
+0.0,8,credit_bridge,3000,0.38711249828338623
+0.0,8,bp,4000,0.09766732156276703
+0.0,8,dfa,4000,0.5503160357475281
+0.0,8,state_bridge,4000,0.39751216769218445
+0.0,8,credit_bridge,4000,0.13849811255931854
+0.0,8,bp,5000,0.07875370979309082
+0.0,8,dfa,5000,0.5541619658470154
+0.0,8,state_bridge,5000,0.6255663633346558
+0.0,8,credit_bridge,5000,0.07148516178131104
+0.0,8,bp,6000,0.07798311114311218
+0.0,8,dfa,6000,0.5998473167419434
+0.0,8,state_bridge,6000,0.5577906370162964
+0.0,8,credit_bridge,6000,0.18784484267234802
+0.5,4,bp,42,0.2670668959617615
+0.5,4,dfa,42,0.78990238904953
+0.5,4,state_bridge,42,0.7565754055976868
+0.5,4,credit_bridge,42,0.28575778007507324
+0.5,4,bp,123,0.26157334446907043
+0.5,4,dfa,123,0.8240824937820435
+0.5,4,state_bridge,123,0.5587974786758423
+0.5,4,credit_bridge,123,0.18487781286239624
+0.5,4,bp,456,0.26119792461395264
+0.5,4,dfa,456,0.5580360889434814
+0.5,4,state_bridge,456,0.6804043650627136
+0.5,4,credit_bridge,456,0.10438685864210129
+0.5,4,bp,789,0.2688371241092682
+0.5,4,dfa,789,0.7619737386703491
+0.5,4,state_bridge,789,0.5724117755889893
+0.5,4,credit_bridge,789,0.2902269661426544
+0.5,4,bp,1024,0.2626464068889618
+0.5,4,dfa,1024,0.7350623607635498
+0.5,4,state_bridge,1024,0.7774407863616943
+0.5,4,credit_bridge,1024,0.2907182276248932
+0.5,4,bp,2048,0.25887274742126465
+0.5,4,dfa,2048,0.8292539119720459
+0.5,4,state_bridge,2048,0.6785949468612671
+0.5,4,credit_bridge,2048,0.23325428366661072
+0.5,4,bp,3000,0.2686220407485962
+0.5,4,dfa,3000,0.7964614033699036
+0.5,4,state_bridge,3000,0.6747612953186035
+0.5,4,credit_bridge,3000,0.15331298112869263
+0.5,4,bp,4000,0.27715128660202026
+0.5,4,dfa,4000,0.834259033203125
+0.5,4,state_bridge,4000,0.6629055738449097
+0.5,4,credit_bridge,4000,0.09292227774858475
+0.5,4,bp,5000,0.25511646270751953
+0.5,4,dfa,5000,0.8486669063568115
+0.5,4,state_bridge,5000,0.6432816386222839
+0.5,4,credit_bridge,5000,0.08898796141147614
+0.5,4,bp,6000,0.25425827503204346
+0.5,4,dfa,6000,0.7771180868148804
+0.5,4,state_bridge,6000,0.8868570327758789
+0.5,4,credit_bridge,6000,0.28548482060432434
+0.5,8,bp,42,0.08691950142383575
+0.5,8,dfa,42,0.5566835403442383
+0.5,8,state_bridge,42,0.5521173477172852
+0.5,8,credit_bridge,42,0.026307538151741028
+0.5,8,bp,123,0.08442661166191101
+0.5,8,dfa,123,0.6884247064590454
+0.5,8,state_bridge,123,1.1720242500305176
+0.5,8,credit_bridge,123,0.09352543950080872
+0.5,8,bp,456,0.0845673531293869
+0.5,8,dfa,456,0.7805156707763672
+0.5,8,state_bridge,456,0.564720630645752
+0.5,8,credit_bridge,456,0.04183648154139519
+0.5,8,bp,789,0.09304642677307129
+0.5,8,dfa,789,0.5289106965065002
+0.5,8,state_bridge,789,1.1913974285125732
+0.5,8,credit_bridge,789,0.0811903178691864
+0.5,8,bp,1024,0.08988181501626968
+0.5,8,dfa,1024,0.5268670320510864
+0.5,8,state_bridge,1024,0.38997870683670044
+0.5,8,credit_bridge,1024,0.0551433339715004
+0.5,8,bp,2048,0.08537864685058594
+0.5,8,dfa,2048,0.6378879547119141
+0.5,8,state_bridge,2048,0.8929705619812012
+0.5,8,credit_bridge,2048,0.05500475689768791
+0.5,8,bp,3000,0.10150086879730225
+0.5,8,dfa,3000,0.6010029315948486
+0.5,8,state_bridge,3000,0.6921526193618774
+0.5,8,credit_bridge,3000,0.0750318169593811
+0.5,8,bp,4000,0.09839055687189102
+0.5,8,dfa,4000,0.5477902889251709
+0.5,8,state_bridge,4000,0.4617120325565338
+0.5,8,credit_bridge,4000,0.0474717952311039
+0.5,8,bp,5000,0.09180224686861038
+0.5,8,dfa,5000,0.5527122616767883
+0.5,8,state_bridge,5000,0.5788565874099731
+0.5,8,credit_bridge,5000,0.08813595771789551
+0.5,8,bp,6000,0.08880583941936493
+0.5,8,dfa,6000,0.5930142402648926
+0.5,8,state_bridge,6000,0.5428022742271423
+0.5,8,credit_bridge,6000,0.1234222799539566
+1.0,4,bp,42,0.37005650997161865
+1.0,4,dfa,42,0.869762659072876
+1.0,4,state_bridge,42,0.87589430809021
+1.0,4,credit_bridge,42,0.2053530365228653
+1.0,4,bp,123,0.36214226484298706
+1.0,4,dfa,123,0.8254338502883911
+1.0,4,state_bridge,123,0.8833000659942627
+1.0,4,credit_bridge,123,0.21273687481880188
+1.0,4,bp,456,0.3626979887485504
+1.0,4,dfa,456,0.7873569130897522
+1.0,4,state_bridge,456,0.8605778217315674
+1.0,4,credit_bridge,456,0.21030157804489136
+1.0,4,bp,789,0.37471723556518555
+1.0,4,dfa,789,0.7885147929191589
+1.0,4,state_bridge,789,0.8968838453292847
+1.0,4,credit_bridge,789,0.18812689185142517
+1.0,4,bp,1024,0.36623838543891907
+1.0,4,dfa,1024,0.8550044894218445
+1.0,4,state_bridge,1024,0.8468506932258606
+1.0,4,credit_bridge,1024,0.32974082231521606
+1.0,4,bp,2048,0.3584083914756775
+1.0,4,dfa,2048,0.823596715927124
+1.0,4,state_bridge,2048,0.8956995010375977
+1.0,4,credit_bridge,2048,0.1712440550327301
+1.0,4,bp,3000,0.37105458974838257
+1.0,4,dfa,3000,0.8486852645874023
+1.0,4,state_bridge,3000,0.9403642416000366
+1.0,4,credit_bridge,3000,0.15761351585388184
+1.0,4,bp,4000,0.384659081697464
+1.0,4,dfa,4000,0.8266894817352295
+1.0,4,state_bridge,4000,0.8737555742263794
+1.0,4,credit_bridge,4000,0.0490226186811924
+1.0,4,bp,5000,0.3479670286178589
+1.0,4,dfa,5000,0.7952747344970703
+1.0,4,state_bridge,5000,0.8428233861923218
+1.0,4,credit_bridge,5000,0.2047322541475296
+1.0,4,bp,6000,0.35523170232772827
+1.0,4,dfa,6000,0.850082516670227
+1.0,4,state_bridge,6000,0.8412976264953613
+1.0,4,credit_bridge,6000,0.18619702756404877
+1.0,8,bp,42,0.20174872875213623
+1.0,8,dfa,42,0.5564329028129578
+1.0,8,state_bridge,42,0.6702669262886047
+1.0,8,credit_bridge,42,0.11915956437587738
+1.0,8,bp,123,0.20575112104415894
+1.0,8,dfa,123,0.7708055973052979
+1.0,8,state_bridge,123,0.6753015518188477
+1.0,8,credit_bridge,123,0.12676304578781128
+1.0,8,bp,456,0.20126384496688843
+1.0,8,dfa,456,0.789068341255188
+1.0,8,state_bridge,456,0.5630311965942383
+1.0,8,credit_bridge,456,0.0718986764550209
+1.0,8,bp,789,0.21950820088386536
+1.0,8,dfa,789,0.801855206489563
+1.0,8,state_bridge,789,0.5838653445243835
+1.0,8,credit_bridge,789,0.1493334174156189
+1.0,8,bp,1024,0.2049655318260193
+1.0,8,dfa,1024,0.7054224014282227
+1.0,8,state_bridge,1024,0.564822793006897
+1.0,8,credit_bridge,1024,0.0481550358235836
+1.0,8,bp,2048,0.2031216025352478
+1.0,8,dfa,2048,0.6300236582756042
+1.0,8,state_bridge,2048,0.584318995475769
+1.0,8,credit_bridge,2048,0.22037747502326965
+1.0,8,bp,3000,0.22935035824775696
+1.0,8,dfa,3000,0.712699294090271
+1.0,8,state_bridge,3000,0.7996810674667358
+1.0,8,credit_bridge,3000,0.09546110779047012
+1.0,8,bp,4000,0.22556602954864502
+1.0,8,dfa,4000,0.5526505708694458
+1.0,8,state_bridge,4000,0.5007961392402649
+1.0,8,credit_bridge,4000,0.04477586969733238
+1.0,8,bp,5000,0.2003089338541031
+1.0,8,dfa,5000,0.777186393737793
+1.0,8,state_bridge,5000,0.3940466642379761
+1.0,8,credit_bridge,5000,0.1061391606926918
+1.0,8,bp,6000,0.20401637256145477
+1.0,8,dfa,6000,0.5910547971725464
+1.0,8,state_bridge,6000,0.6176798343658447
+1.0,8,credit_bridge,6000,0.10662685334682465
diff --git a/results/confirmatory/A2_naive_state_err.csv b/results/confirmatory/A2_naive_state_err.csv new file mode 100644 index 0000000..fcb5f01 --- /dev/null +++ b/results/confirmatory/A2_naive_state_err.csv @@ -0,0 +1,31 @@ +method,seed,naive_StateErr
+dfa,42,0.8608599968910218
+state_bridge,42,1.1073074493408204
+credit_bridge,42,0.18549979393482208
+dfa,123,0.6620678688049316
+state_bridge,123,1.157683981513977
+credit_bridge,123,0.5449653492927551
+dfa,456,0.862576175403595
+state_bridge,456,0.5174413678169251
+credit_bridge,456,0.02991231045126915
+dfa,789,0.5742178151130676
+state_bridge,789,0.32427600870132445
+credit_bridge,789,0.34830747079849245
+dfa,1024,0.7317259416580201
+state_bridge,1024,1.1036756999969481
+credit_bridge,1024,0.5166968782424927
+dfa,2048,0.29954994792938233
+state_bridge,2048,0.4312982501029968
+credit_bridge,2048,0.8699032402038575
+dfa,3000,0.7138990812301635
+state_bridge,3000,0.08414710862636567
+credit_bridge,3000,0.07805714750289917
+dfa,4000,0.6306101009368896
+state_bridge,4000,0.9965199737548828
+credit_bridge,4000,0.4791879728317261
+dfa,5000,0.9650134473800659
+state_bridge,5000,0.9529169537544251
+credit_bridge,5000,0.6871149513244629
+dfa,6000,0.8224315113067627
+state_bridge,6000,0.14781150970458984
+credit_bridge,6000,0.5449538031578064
|
