From 1359b7e7a96ab57be0bb24ebdf842a793ce01223 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Tue, 31 Mar 2026 10:36:02 -0500 Subject: Add naive state prediction baseline for A1 and A2 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A1: 240 rows (3 alpha × 2 depth × 4 methods × 10 seeds) A2: 30 rows (3 methods × 10 seeds) naive_StateErr = ||h_{L//2} - h_L|| / ||h_L|| Co-Authored-By: Claude Opus 4.6 (1M context) --- experiments/compute_naive_state_err.py | 390 ++++++++++++++++++++++++++++ results/confirmatory/A1_naive_state_err.csv | 241 +++++++++++++++++ results/confirmatory/A2_naive_state_err.csv | 31 +++ 3 files changed, 662 insertions(+) create mode 100644 experiments/compute_naive_state_err.py create mode 100644 results/confirmatory/A1_naive_state_err.csv create mode 100644 results/confirmatory/A2_naive_state_err.csv diff --git a/experiments/compute_naive_state_err.py b/experiments/compute_naive_state_err.py new file mode 100644 index 0000000..64c8790 --- /dev/null +++ b/experiments/compute_naive_state_err.py @@ -0,0 +1,390 @@ +""" +Compute naive state prediction baseline: ||h_l - h_L|| / ||h_L|| for all methods. +This is the identity-prediction error — how much does h change from layer l to L. +Computed at layer L//2 on test/eval data after training. + +Retrains models with same seeds as confirmatory experiments, then computes metric. +""" +import os, sys, json, csv, argparse, numpy as np, torch, torch.nn as nn, torch.nn.functional as F +import torch.optim as optim, copy +from torch.utils.data import DataLoader +import torchvision, torchvision.transforms as transforms + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from models.residual_mlp import ResidualMLP +from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema +from models.state_bridge import StateBridgeNet + + +def compute_naive_state_err(model, dataloader, device, eval_layer=None): + """Compute ||h_l - h_L|| / ||h_L|| averaged over data.""" + model.eval() + L = model.num_blocks + if eval_layer is None: + eval_layer = L // 2 + total_err, n = 0.0, 0 + with torch.no_grad(): + for x, y in dataloader: + if hasattr(model, 'embed'): + x = x.view(x.size(0), -1).to(device) + else: + x = x.to(device) + _, hiddens = model(x, return_hidden=True) + h_l = hiddens[eval_layer] + h_L = hiddens[-1] + norm_L = h_L.norm(dim=-1, keepdim=True).clamp(min=1.0) + err = ((h_l - h_L) / norm_L).pow(2).sum(-1).mean() + total_err += err.item() * x.size(0) + n += x.size(0) + return total_err / n + + +# ============ Synthetic ============ +class TeacherNet(nn.Module): + def __init__(self, d, C, L, alpha=1.0, seed=0): + super().__init__() + self.alpha = alpha + rng = torch.Generator().manual_seed(seed) + self.Ws = nn.ParameterList() + for _ in range(L): + W = torch.randn(d, d, generator=rng) * 0.3 / (d**0.5) + U, S, Vh = torch.linalg.svd(W, full_matrices=False) + W = U @ torch.diag(S.clamp(max=0.3)) @ Vh + self.Ws.append(nn.Parameter(W, requires_grad=False)) + self.U = nn.Parameter(torch.randn(C, d, generator=rng) / (d**0.5), requires_grad=False) + def phi(self, z): return (1-self.alpha)*z + self.alpha*torch.tanh(z) + def forward(self, x): + h = x + for W in self.Ws: h = h + self.phi(h @ W.T) + return h @ self.U.T + +class StudentBlock(nn.Module): + def __init__(self, d, alpha=1.0): + super().__init__() + self.ln = nn.LayerNorm(d); self.w = nn.Linear(d, d, bias=False) + nn.init.normal_(self.w.weight, std=0.01); self.alpha = alpha + def phi(self, z): return (1-self.alpha)*z + self.alpha*torch.tanh(z) + def forward(self, h): return self.w(self.phi(self.ln(h))) + +class StudentNet(nn.Module): + def __init__(self, d, C, L, alpha=1.0): + super().__init__() + self.blocks = nn.ModuleList([StudentBlock(d, alpha) for _ in range(L)]) + self.out_head = nn.Linear(d, C); self.d_hidden = d; self.num_blocks = L + def forward(self, x, return_hidden=False): + h = x; hiddens = [h] if return_hidden else None + for b in self.blocks: + h = h + b(h) + if return_hidden: hiddens.append(h) + logits = self.out_head(h) + return (logits, hiddens) if return_hidden else logits + def forward_from_layer(self, h, sl): + for i in range(sl, self.num_blocks): h = h + self.blocks[i](h) + return self.out_head(h) + + +def train_synth_method(method, model, teacher, device, d, C, L, epochs=80, steps=50, bs=256, lr=1e-3, lr_fb=1e-3): + """Train synthetic model with given method, return trained model.""" + if method == 'bp': + opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01) + for ep in range(epochs): + model.train() + for _ in range(steps): + x = torch.randn(bs, d, device=device) + with torch.no_grad(): y = teacher(x).argmax(-1) + loss = F.cross_entropy(model(x), y); opt.zero_grad(); loss.backward(); opt.step() + elif method == 'dfa': + Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)] + bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=0.01) for b in model.blocks] + hop = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=0.01) + for ep in range(epochs): + model.train() + for _ in range(steps): + x = torch.randn(bs, d, device=device) + with torch.no_grad(): + y = teacher(x).argmax(-1); lo, hi = model(x, return_hidden=True) + eT = lo.softmax(-1); eT[torch.arange(bs), y] -= 1 + lo2 = F.cross_entropy(model.out_head(hi[-1].detach()), y); hop.zero_grad(); lo2.backward(); hop.step() + for l in range(L): + a = (eT @ Bs[l].T).detach(); rm = (a**2).mean(-1, keepdim=True).sqrt()+1e-6 + ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean() + bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0); bops[l].step() + elif method == 'state_bridge': + sp = StateBridgeNet(d_hidden=d, s_dim=C).to(device) + Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)] + bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=0.01) for b in model.blocks] + hop = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=0.01) + sop = optim.Adam(sp.parameters(), lr=lr_fb) + warmup = max(1, epochs//5) + for ep in range(epochs): + model.train(); sp.train() + for _ in range(steps): + x = torch.randn(bs, d, device=device) + with torch.no_grad(): + y = teacher(x).argmax(-1); lo, hi = model(x, return_hidden=True) + eT = lo.softmax(-1); eT[torch.arange(bs), y] -= 1; s = eT.detach() + hL = hi[-1].detach() + sl = 0.0 + for l in range(L): + t_l = torch.full((bs,), l/L, device=device) + pred = sp(hi[l].detach(), t_l, s) + tn = hL.norm(-1, keepdim=True).clamp(min=1.0) + sl += (((pred-hL)/tn)**2).sum(-1).mean() + sl /= L; sop.zero_grad(); sl.backward(); sop.step() + credits = [] + for l in range(L): + hl = hi[l].detach().requires_grad_(True); t_l = torch.full((bs,), l/L, device=device) + pred = sp(hl, t_l, s); pl = F.cross_entropy(model.out_head(pred), y, reduction='sum') + credits.append(torch.autograd.grad(pl, hl, create_graph=False)[0].detach()) + lo2 = F.cross_entropy(model.out_head(hL), y); hop.zero_grad(); lo2.backward(); hop.step() + for l in range(L): + a = credits[l]; rm = (a**2).mean(-1, keepdim=True).sqrt()+1e-6 + ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean() + bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0); bops[l].step() + elif method == 'credit_bridge': + vn = ValueNet(d_hidden=d, s_dim=C).to(device); ve = create_ema_model(vn) + Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)] + bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=0.01) for b in model.blocks] + hop = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=0.01) + vop = optim.Adam(vn.parameters(), lr=lr_fb) + warmup = max(1, epochs//5) + for ep in range(epochs): + model.train(); vn.train() + blend = 0.0 if ep < warmup else min(1.0, (ep-warmup)/max(1, warmup)) + for _ in range(steps): + x = torch.randn(bs, d, device=device) + with torch.no_grad(): + y = teacher(x).argmax(-1); lo, hi = model(x, return_hidden=True) + eT = lo.softmax(-1); eT[torch.arange(bs), y] -= 1; s = eT.detach() + tl_val = F.cross_entropy(lo, y, reduction='none').detach() + hL = hi[-1].detach(); t_L = torch.ones(bs, device=device) + lt = ((vn(hL, t_L, s)-tl_val)**2).mean() + hLr = hL.clone().requires_grad_(True); VL = vn(hLr, t_L, s) + gV = torch.autograd.grad(VL.sum(), hLr, create_graph=True)[0] + hLr2 = hL.clone().requires_grad_(True); ce = F.cross_entropy(model.out_head(hLr2), y, reduction='sum') + aLe = torch.autograd.grad(ce, hLr2, create_graph=False)[0].detach() + ltg = ((gV-aLe)**2).sum(-1).mean() + lb = 0.0 + for l in range(L): + hl = hi[l].detach(); tl = torch.full((bs,), l/L, device=device); tn = torch.full((bs,), (l+1)/L, device=device) + Vl = vn(hl, tl, s) + with torch.no_grad(): + hn = hi[l+1].detach(); lts = [] + for k in range(4): + ns = 0.05*torch.randn_like(hn); lts.append(-ve(hn+ns, tn, s)/0.1) + Vt = -0.1*(torch.logsumexp(torch.stack(lts, -1), -1)-np.log(4)) + lb += ((Vl-Vt.detach())**2).mean() + lb /= L; vl = lt+lb+1.0*ltg + vop.zero_grad(); vl.backward(); torch.nn.utils.clip_grad_norm_(vn.parameters(), 1.0); vop.step() + update_ema(vn, ve, 0.995) + cbc = [] + for l in range(L): + hl = hi[l].detach().requires_grad_(True); tl = torch.full((bs,), l/L, device=device) + Vl = vn(hl, tl, s); cbc.append(torch.autograd.grad(Vl.sum(), hl, create_graph=False)[0].detach()) + dfac = [(eT@Bs[l].T).detach() for l in range(L)] + credits = [] + for l in range(L): + if blend >= 1: credits.append(cbc[l]) + elif blend <= 0: credits.append(dfac[l]) + else: + cr = (cbc[l]**2).mean(-1,keepdim=True).sqrt()+1e-6; dr = (dfac[l]**2).mean(-1,keepdim=True).sqrt()+1e-6 + credits.append(blend*cbc[l]/cr+(1-blend)*dfac[l]/dr) + lo2 = F.cross_entropy(model.out_head(hL), y); hop.zero_grad(); lo2.backward(); hop.step() + for l in range(L): + a = credits[l]; rm = (a**2).mean(-1, keepdim=True).sqrt()+1e-6 + ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean() + bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0); bops[l].step() + return model + + +def run_A1_naive(args): + """Compute naive state err for synthetic ladder.""" + device = torch.device(f'cuda:{args.gpu}') + seeds = [42,123,456,789,1024,2048,3000,4000,5000,6000] + alphas = [0.0, 0.5, 1.0]; depths = [4, 8]; d = 128; C = 10 + methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge'] + rows = [] + for alpha in alphas: + for L in depths: + for seed in seeds: + teacher = TeacherNet(d, C, L, alpha=alpha, seed=seed*1000).to(device) + # Eval data + eval_x = torch.randn(512, d, device=device) + with torch.no_grad(): eval_y = teacher(eval_x).argmax(-1) + eval_data = [(eval_x, eval_y)] + for method in methods: + torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) + model = StudentNet(d, C, L, alpha=alpha).to(device) + model = train_synth_method(method, model, teacher, device, d, C, L) + nse = compute_naive_state_err(model, eval_data, device, eval_layer=L//2) + rows.append({'alpha': alpha, 'depth': L, 'method': method, 'seed': seed, 'naive_StateErr': nse}) + print(f" A1 alpha={alpha} L={L} {method} s={seed}: naive_StateErr={nse:.6f}", flush=True) + # Save CSV + out = os.path.join(args.output_dir, 'A1_naive_state_err.csv') + with open(out, 'w', newline='') as f: + w = csv.DictWriter(f, fieldnames=['alpha','depth','method','seed','naive_StateErr']) + w.writeheader(); w.writerows(rows) + print(f"Saved {len(rows)} rows to {out}", flush=True) + + +def get_cifar10(bs=128): + tt = transforms.Compose([transforms.RandomCrop(32,4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))]) + tv = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))]) + return (DataLoader(torchvision.datasets.CIFAR10('./data',True,download=True,transform=tt),bs,True,num_workers=4,pin_memory=True), + DataLoader(torchvision.datasets.CIFAR10('./data',False,download=True,transform=tv),bs,False,num_workers=4,pin_memory=True)) + + +def train_cifar_method(method, model, train_loader, test_loader, device, L, d, epochs=100, lr=1e-3, lr_fb=1e-3, wd=0.01): + """Train CIFAR model, return trained model.""" + C = 10 + if method == 'dfa': + Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)] + bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] + eop = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) + hop = optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()), lr=lr, weight_decay=wd) + schs = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in bops]+[optim.lr_scheduler.CosineAnnealingLR(eop, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(hop, T_max=epochs)] + for ep in range(1, epochs+1): + model.train() + for x, y in train_loader: + x = x.view(x.size(0),-1).to(device); y = y.to(device); b = x.size(0) + with torch.no_grad(): lo, hi = model(x, return_hidden=True); eT = lo.softmax(-1); eT[torch.arange(b),y] -= 1 + hL = hi[-1].detach() + F.cross_entropy(model.out_head(model.out_ln(hL)), y).backward(); hop.step(); hop.zero_grad() + for l in range(L): + a = (eT@Bs[l].T).detach(); rm = (a**2).mean(-1,keepdim=True).sqrt()+1e-6 + ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean() + bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0); bops[l].step() + a0 = (eT@Bs[0].T).detach(); r0 = (a0**2).mean(-1,keepdim=True).sqrt()+1e-6 + el = (model.embed(x.view(x.size(0),-1))*(a0/r0)).sum(-1).mean(); eop.zero_grad(); el.backward(); eop.step() + for s in schs: s.step() + elif method == 'state_bridge': + sp = StateBridgeNet(d_hidden=d, s_dim=C).to(device) + Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)] + bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] + eop = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) + hop = optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()), lr=lr, weight_decay=wd) + sop = optim.Adam(sp.parameters(), lr=lr_fb) + schs = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in bops]+[optim.lr_scheduler.CosineAnnealingLR(eop, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(hop, T_max=epochs)] + for ep in range(1, epochs+1): + model.train(); sp.train() + for x, y in train_loader: + x = x.view(x.size(0),-1).to(device); y = y.to(device); b = x.size(0) + with torch.no_grad(): lo, hi = model(x, return_hidden=True); eT = lo.softmax(-1); eT[torch.arange(b),y] -= 1; s = eT.detach() + hL = hi[-1].detach() + sl = 0.0 + for l in range(L): + tl = torch.full((b,), l/L, device=device) + pred = sp(hi[l].detach(), tl, s); tn = hL.norm(-1,keepdim=True).clamp(min=1.0) + sl += (((pred-hL)/tn)**2).sum(-1).mean() + sl /= L; sop.zero_grad(); sl.backward(); sop.step() + credits = [] + for l in range(L): + hl = hi[l].detach().requires_grad_(True); tl = torch.full((b,), l/L, device=device) + pred = sp(hl, tl, s); pl = F.cross_entropy(model.out_head(model.out_ln(pred)), y, reduction='sum') + credits.append(torch.autograd.grad(pl, hl, create_graph=False)[0].detach()) + lo2 = F.cross_entropy(model.out_head(model.out_ln(hL)), y); hop.zero_grad(); lo2.backward(); hop.step() + for l in range(L): + a = credits[l]; rm = (a**2).mean(-1,keepdim=True).sqrt()+1e-6 + ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean() + bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0); bops[l].step() + a0 = credits[0]; r0 = (a0**2).mean(-1,keepdim=True).sqrt()+1e-6 + el = (model.embed(x.view(x.size(0),-1))*(a0/r0)).sum(-1).mean(); eop.zero_grad(); el.backward(); eop.step() + for s in schs: s.step() + elif method == 'credit_bridge': + vn = ValueNet(d_hidden=d, s_dim=C).to(device); ve = create_ema_model(vn) + Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)] + bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] + eop = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) + hop = optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()), lr=lr, weight_decay=wd) + vop = optim.Adam(vn.parameters(), lr=lr_fb) + schs = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in bops]+[optim.lr_scheduler.CosineAnnealingLR(eop, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(hop, T_max=epochs)] + warmup = max(1, epochs//5) + for ep in range(1, epochs+1): + model.train(); vn.train() + blend = 0.0 if ep <= warmup else min(1.0, (ep-warmup)/max(1, warmup)) + for x, y in train_loader: + x = x.view(x.size(0),-1).to(device); y = y.to(device); b = x.size(0) + with torch.no_grad(): lo, hi = model(x, return_hidden=True); eT = lo.softmax(-1); eT[torch.arange(b),y] -= 1; s = eT.detach(); tlv = F.cross_entropy(lo, y, reduction='none').detach() + hL = hi[-1].detach(); t_L = torch.ones(b, device=device) + lt = ((vn(hL, t_L, s)-tlv)**2).mean() + hLr = hL.clone().requires_grad_(True); VL = vn(hLr, t_L, s) + gV = torch.autograd.grad(VL.sum(), hLr, create_graph=True)[0] + hLr2 = hL.clone().requires_grad_(True); ce = F.cross_entropy(model.out_head(model.out_ln(hLr2)), y, reduction='sum') + aLe = torch.autograd.grad(ce, hLr2, create_graph=False)[0].detach() + ltg = ((gV-aLe)**2).sum(-1).mean() + lb = 0.0 + for l in range(L): + hl = hi[l].detach(); tl = torch.full((b,), l/L, device=device); tn = torch.full((b,), (l+1)/L, device=device) + Vl = vn(hl, tl, s) + with torch.no_grad(): + hn = hi[l+1].detach(); lts = [] + for k in range(4): lts.append(-ve(hn+0.05*torch.randn_like(hn), tn, s)/0.1) + Vt = -0.1*(torch.logsumexp(torch.stack(lts,-1),-1)-np.log(4)) + lb += ((Vl-Vt.detach())**2).mean() + lb /= L; vl = lt+lb+1.0*ltg + vop.zero_grad(); vl.backward(); torch.nn.utils.clip_grad_norm_(vn.parameters(),1.0); vop.step() + update_ema(vn, ve, 0.995) + cbc = [] + for l in range(L): + hl = hi[l].detach().requires_grad_(True); tl = torch.full((b,), l/L, device=device) + Vl = vn(hl, tl, s); cbc.append(torch.autograd.grad(Vl.sum(), hl, create_graph=False)[0].detach()) + dfac = [(eT@Bs[l].T).detach() for l in range(L)] + credits = [] + for l in range(L): + if blend >= 1: credits.append(cbc[l]) + elif blend <= 0: credits.append(dfac[l]) + else: + cr = (cbc[l]**2).mean(-1,keepdim=True).sqrt()+1e-6; dr = (dfac[l]**2).mean(-1,keepdim=True).sqrt()+1e-6 + credits.append(blend*cbc[l]/cr+(1-blend)*dfac[l]/dr) + lo2 = F.cross_entropy(model.out_head(model.out_ln(hL)), y); hop.zero_grad(); lo2.backward(); hop.step() + for l in range(L): + a = credits[l]; rm = (a**2).mean(-1,keepdim=True).sqrt()+1e-6 + ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean() + bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0); bops[l].step() + a0 = credits[0]; r0 = (a0**2).mean(-1,keepdim=True).sqrt()+1e-6 + el = (model.embed(x.view(x.size(0),-1))*(a0/r0)).sum(-1).mean(); eop.zero_grad(); el.backward(); eop.step() + for s in schs: s.step() + return model + + +def run_A2_naive(args): + """Compute naive state err for CIFAR methods.""" + device = torch.device(f'cuda:{args.gpu}') + seeds = [42,123,456,789,1024,2048,3000,4000,5000,6000] + L, d = 4, 256; methods = ['dfa', 'state_bridge', 'credit_bridge'] + train_loader, test_loader = get_cifar10() + rows = [] + for seed in seeds: + for method in methods: + torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) + model = ResidualMLP(3072, d, 10, L).to(device) + print(f" A2 {method} s={seed}: training...", flush=True) + model = train_cifar_method(method, model, train_loader, test_loader, device, L, d) + nse = compute_naive_state_err(model, test_loader, device, eval_layer=L//2) + rows.append({'method': method, 'seed': seed, 'naive_StateErr': nse}) + print(f" A2 {method} s={seed}: naive_StateErr={nse:.6f}", flush=True) + out = os.path.join(args.output_dir, 'A2_naive_state_err.csv') + with open(out, 'w', newline='') as f: + w = csv.DictWriter(f, fieldnames=['method','seed','naive_StateErr']) + w.writeheader(); w.writerows(rows) + print(f"Saved {len(rows)} rows to {out}", flush=True) + + +def main(): + p = argparse.ArgumentParser() + p.add_argument('--experiment', type=str, default='A1', choices=['A1','A2','both']) + p.add_argument('--gpu', type=int, default=3) + p.add_argument('--output_dir', type=str, default='results/confirmatory') + args = p.parse_args() + os.makedirs(args.output_dir, exist_ok=True) + if args.experiment in ['A1', 'both']: + print("=== A1: Synthetic naive state err ===", flush=True) + run_A1_naive(args) + if args.experiment in ['A2', 'both']: + print("=== A2: CIFAR naive state err ===", flush=True) + run_A2_naive(args) + print("Done.", flush=True) + + +if __name__ == '__main__': + main() diff --git a/results/confirmatory/A1_naive_state_err.csv b/results/confirmatory/A1_naive_state_err.csv new file mode 100644 index 0000000..6ddcfe7 --- /dev/null +++ b/results/confirmatory/A1_naive_state_err.csv @@ -0,0 +1,241 @@ +alpha,depth,method,seed,naive_StateErr +0.0,4,bp,42,0.26068803668022156 +0.0,4,dfa,42,0.7258815169334412 +0.0,4,state_bridge,42,0.7580535411834717 +0.0,4,credit_bridge,42,0.23254971206188202 +0.0,4,bp,123,0.24781781435012817 +0.0,4,dfa,123,0.6604026556015015 +0.0,4,state_bridge,123,0.5317853689193726 +0.0,4,credit_bridge,123,0.17576289176940918 +0.0,4,bp,456,0.2476104348897934 +0.0,4,dfa,456,0.6199116110801697 +0.0,4,state_bridge,456,0.863245964050293 +0.0,4,credit_bridge,456,0.12035267055034637 +0.0,4,bp,789,0.2656460106372833 +0.0,4,dfa,789,0.7435265779495239 +0.0,4,state_bridge,789,0.6364812850952148 +0.0,4,credit_bridge,789,0.2562270164489746 +0.0,4,bp,1024,0.2501564025878906 +0.0,4,dfa,1024,0.6919815540313721 +0.0,4,state_bridge,1024,0.6960811018943787 +0.0,4,credit_bridge,1024,0.2409060299396515 +0.0,4,bp,2048,0.24213725328445435 +0.0,4,dfa,2048,0.8119536638259888 +0.0,4,state_bridge,2048,1.0384111404418945 +0.0,4,credit_bridge,2048,0.29382896423339844 +0.0,4,bp,3000,0.25653713941574097 +0.0,4,dfa,3000,0.7727931141853333 +0.0,4,state_bridge,3000,0.7795010805130005 +0.0,4,credit_bridge,3000,0.1640912890434265 +0.0,4,bp,4000,0.2790989279747009 +0.0,4,dfa,4000,0.6488389372825623 +0.0,4,state_bridge,4000,0.8933002948760986 +0.0,4,credit_bridge,4000,0.129751056432724 +0.0,4,bp,5000,0.23482249677181244 +0.0,4,dfa,5000,0.5603621006011963 +0.0,4,state_bridge,5000,0.7118808031082153 +0.0,4,credit_bridge,5000,0.1876041740179062 +0.0,4,bp,6000,0.2366245985031128 +0.0,4,dfa,6000,0.7504462003707886 +0.0,4,state_bridge,6000,0.8727204203605652 +0.0,4,credit_bridge,6000,0.18112479150295258 +0.0,8,bp,42,0.07397101819515228 +0.0,8,dfa,42,0.5621968507766724 +0.0,8,state_bridge,42,0.816449761390686 +0.0,8,credit_bridge,42,0.15379583835601807 +0.0,8,bp,123,0.07488589733839035 +0.0,8,dfa,123,0.6059672832489014 +0.0,8,state_bridge,123,0.5263348817825317 +0.0,8,credit_bridge,123,0.12262149155139923 +0.0,8,bp,456,0.07783643156290054 +0.0,8,dfa,456,0.7705618143081665 +0.0,8,state_bridge,456,0.4180244505405426 +0.0,8,credit_bridge,456,0.19311878085136414 +0.0,8,bp,789,0.09013043344020844 +0.0,8,dfa,789,0.5577319860458374 +0.0,8,state_bridge,789,0.8297959566116333 +0.0,8,credit_bridge,789,0.05089893937110901 +0.0,8,bp,1024,0.0830039381980896 +0.0,8,dfa,1024,0.5080844163894653 +0.0,8,state_bridge,1024,0.39521247148513794 +0.0,8,credit_bridge,1024,0.06527997553348541 +0.0,8,bp,2048,0.07513846457004547 +0.0,8,dfa,2048,0.6275949478149414 +0.0,8,state_bridge,2048,0.6864550113677979 +0.0,8,credit_bridge,2048,0.08002562075853348 +0.0,8,bp,3000,0.10901156067848206 +0.0,8,dfa,3000,0.6083968281745911 +0.0,8,state_bridge,3000,0.8938897848129272 +0.0,8,credit_bridge,3000,0.38711249828338623 +0.0,8,bp,4000,0.09766732156276703 +0.0,8,dfa,4000,0.5503160357475281 +0.0,8,state_bridge,4000,0.39751216769218445 +0.0,8,credit_bridge,4000,0.13849811255931854 +0.0,8,bp,5000,0.07875370979309082 +0.0,8,dfa,5000,0.5541619658470154 +0.0,8,state_bridge,5000,0.6255663633346558 +0.0,8,credit_bridge,5000,0.07148516178131104 +0.0,8,bp,6000,0.07798311114311218 +0.0,8,dfa,6000,0.5998473167419434 +0.0,8,state_bridge,6000,0.5577906370162964 +0.0,8,credit_bridge,6000,0.18784484267234802 +0.5,4,bp,42,0.2670668959617615 +0.5,4,dfa,42,0.78990238904953 +0.5,4,state_bridge,42,0.7565754055976868 +0.5,4,credit_bridge,42,0.28575778007507324 +0.5,4,bp,123,0.26157334446907043 +0.5,4,dfa,123,0.8240824937820435 +0.5,4,state_bridge,123,0.5587974786758423 +0.5,4,credit_bridge,123,0.18487781286239624 +0.5,4,bp,456,0.26119792461395264 +0.5,4,dfa,456,0.5580360889434814 +0.5,4,state_bridge,456,0.6804043650627136 +0.5,4,credit_bridge,456,0.10438685864210129 +0.5,4,bp,789,0.2688371241092682 +0.5,4,dfa,789,0.7619737386703491 +0.5,4,state_bridge,789,0.5724117755889893 +0.5,4,credit_bridge,789,0.2902269661426544 +0.5,4,bp,1024,0.2626464068889618 +0.5,4,dfa,1024,0.7350623607635498 +0.5,4,state_bridge,1024,0.7774407863616943 +0.5,4,credit_bridge,1024,0.2907182276248932 +0.5,4,bp,2048,0.25887274742126465 +0.5,4,dfa,2048,0.8292539119720459 +0.5,4,state_bridge,2048,0.6785949468612671 +0.5,4,credit_bridge,2048,0.23325428366661072 +0.5,4,bp,3000,0.2686220407485962 +0.5,4,dfa,3000,0.7964614033699036 +0.5,4,state_bridge,3000,0.6747612953186035 +0.5,4,credit_bridge,3000,0.15331298112869263 +0.5,4,bp,4000,0.27715128660202026 +0.5,4,dfa,4000,0.834259033203125 +0.5,4,state_bridge,4000,0.6629055738449097 +0.5,4,credit_bridge,4000,0.09292227774858475 +0.5,4,bp,5000,0.25511646270751953 +0.5,4,dfa,5000,0.8486669063568115 +0.5,4,state_bridge,5000,0.6432816386222839 +0.5,4,credit_bridge,5000,0.08898796141147614 +0.5,4,bp,6000,0.25425827503204346 +0.5,4,dfa,6000,0.7771180868148804 +0.5,4,state_bridge,6000,0.8868570327758789 +0.5,4,credit_bridge,6000,0.28548482060432434 +0.5,8,bp,42,0.08691950142383575 +0.5,8,dfa,42,0.5566835403442383 +0.5,8,state_bridge,42,0.5521173477172852 +0.5,8,credit_bridge,42,0.026307538151741028 +0.5,8,bp,123,0.08442661166191101 +0.5,8,dfa,123,0.6884247064590454 +0.5,8,state_bridge,123,1.1720242500305176 +0.5,8,credit_bridge,123,0.09352543950080872 +0.5,8,bp,456,0.0845673531293869 +0.5,8,dfa,456,0.7805156707763672 +0.5,8,state_bridge,456,0.564720630645752 +0.5,8,credit_bridge,456,0.04183648154139519 +0.5,8,bp,789,0.09304642677307129 +0.5,8,dfa,789,0.5289106965065002 +0.5,8,state_bridge,789,1.1913974285125732 +0.5,8,credit_bridge,789,0.0811903178691864 +0.5,8,bp,1024,0.08988181501626968 +0.5,8,dfa,1024,0.5268670320510864 +0.5,8,state_bridge,1024,0.38997870683670044 +0.5,8,credit_bridge,1024,0.0551433339715004 +0.5,8,bp,2048,0.08537864685058594 +0.5,8,dfa,2048,0.6378879547119141 +0.5,8,state_bridge,2048,0.8929705619812012 +0.5,8,credit_bridge,2048,0.05500475689768791 +0.5,8,bp,3000,0.10150086879730225 +0.5,8,dfa,3000,0.6010029315948486 +0.5,8,state_bridge,3000,0.6921526193618774 +0.5,8,credit_bridge,3000,0.0750318169593811 +0.5,8,bp,4000,0.09839055687189102 +0.5,8,dfa,4000,0.5477902889251709 +0.5,8,state_bridge,4000,0.4617120325565338 +0.5,8,credit_bridge,4000,0.0474717952311039 +0.5,8,bp,5000,0.09180224686861038 +0.5,8,dfa,5000,0.5527122616767883 +0.5,8,state_bridge,5000,0.5788565874099731 +0.5,8,credit_bridge,5000,0.08813595771789551 +0.5,8,bp,6000,0.08880583941936493 +0.5,8,dfa,6000,0.5930142402648926 +0.5,8,state_bridge,6000,0.5428022742271423 +0.5,8,credit_bridge,6000,0.1234222799539566 +1.0,4,bp,42,0.37005650997161865 +1.0,4,dfa,42,0.869762659072876 +1.0,4,state_bridge,42,0.87589430809021 +1.0,4,credit_bridge,42,0.2053530365228653 +1.0,4,bp,123,0.36214226484298706 +1.0,4,dfa,123,0.8254338502883911 +1.0,4,state_bridge,123,0.8833000659942627 +1.0,4,credit_bridge,123,0.21273687481880188 +1.0,4,bp,456,0.3626979887485504 +1.0,4,dfa,456,0.7873569130897522 +1.0,4,state_bridge,456,0.8605778217315674 +1.0,4,credit_bridge,456,0.21030157804489136 +1.0,4,bp,789,0.37471723556518555 +1.0,4,dfa,789,0.7885147929191589 +1.0,4,state_bridge,789,0.8968838453292847 +1.0,4,credit_bridge,789,0.18812689185142517 +1.0,4,bp,1024,0.36623838543891907 +1.0,4,dfa,1024,0.8550044894218445 +1.0,4,state_bridge,1024,0.8468506932258606 +1.0,4,credit_bridge,1024,0.32974082231521606 +1.0,4,bp,2048,0.3584083914756775 +1.0,4,dfa,2048,0.823596715927124 +1.0,4,state_bridge,2048,0.8956995010375977 +1.0,4,credit_bridge,2048,0.1712440550327301 +1.0,4,bp,3000,0.37105458974838257 +1.0,4,dfa,3000,0.8486852645874023 +1.0,4,state_bridge,3000,0.9403642416000366 +1.0,4,credit_bridge,3000,0.15761351585388184 +1.0,4,bp,4000,0.384659081697464 +1.0,4,dfa,4000,0.8266894817352295 +1.0,4,state_bridge,4000,0.8737555742263794 +1.0,4,credit_bridge,4000,0.0490226186811924 +1.0,4,bp,5000,0.3479670286178589 +1.0,4,dfa,5000,0.7952747344970703 +1.0,4,state_bridge,5000,0.8428233861923218 +1.0,4,credit_bridge,5000,0.2047322541475296 +1.0,4,bp,6000,0.35523170232772827 +1.0,4,dfa,6000,0.850082516670227 +1.0,4,state_bridge,6000,0.8412976264953613 +1.0,4,credit_bridge,6000,0.18619702756404877 +1.0,8,bp,42,0.20174872875213623 +1.0,8,dfa,42,0.5564329028129578 +1.0,8,state_bridge,42,0.6702669262886047 +1.0,8,credit_bridge,42,0.11915956437587738 +1.0,8,bp,123,0.20575112104415894 +1.0,8,dfa,123,0.7708055973052979 +1.0,8,state_bridge,123,0.6753015518188477 +1.0,8,credit_bridge,123,0.12676304578781128 +1.0,8,bp,456,0.20126384496688843 +1.0,8,dfa,456,0.789068341255188 +1.0,8,state_bridge,456,0.5630311965942383 +1.0,8,credit_bridge,456,0.0718986764550209 +1.0,8,bp,789,0.21950820088386536 +1.0,8,dfa,789,0.801855206489563 +1.0,8,state_bridge,789,0.5838653445243835 +1.0,8,credit_bridge,789,0.1493334174156189 +1.0,8,bp,1024,0.2049655318260193 +1.0,8,dfa,1024,0.7054224014282227 +1.0,8,state_bridge,1024,0.564822793006897 +1.0,8,credit_bridge,1024,0.0481550358235836 +1.0,8,bp,2048,0.2031216025352478 +1.0,8,dfa,2048,0.6300236582756042 +1.0,8,state_bridge,2048,0.584318995475769 +1.0,8,credit_bridge,2048,0.22037747502326965 +1.0,8,bp,3000,0.22935035824775696 +1.0,8,dfa,3000,0.712699294090271 +1.0,8,state_bridge,3000,0.7996810674667358 +1.0,8,credit_bridge,3000,0.09546110779047012 +1.0,8,bp,4000,0.22556602954864502 +1.0,8,dfa,4000,0.5526505708694458 +1.0,8,state_bridge,4000,0.5007961392402649 +1.0,8,credit_bridge,4000,0.04477586969733238 +1.0,8,bp,5000,0.2003089338541031 +1.0,8,dfa,5000,0.777186393737793 +1.0,8,state_bridge,5000,0.3940466642379761 +1.0,8,credit_bridge,5000,0.1061391606926918 +1.0,8,bp,6000,0.20401637256145477 +1.0,8,dfa,6000,0.5910547971725464 +1.0,8,state_bridge,6000,0.6176798343658447 +1.0,8,credit_bridge,6000,0.10662685334682465 diff --git a/results/confirmatory/A2_naive_state_err.csv b/results/confirmatory/A2_naive_state_err.csv new file mode 100644 index 0000000..fcb5f01 --- /dev/null +++ b/results/confirmatory/A2_naive_state_err.csv @@ -0,0 +1,31 @@ +method,seed,naive_StateErr +dfa,42,0.8608599968910218 +state_bridge,42,1.1073074493408204 +credit_bridge,42,0.18549979393482208 +dfa,123,0.6620678688049316 +state_bridge,123,1.157683981513977 +credit_bridge,123,0.5449653492927551 +dfa,456,0.862576175403595 +state_bridge,456,0.5174413678169251 +credit_bridge,456,0.02991231045126915 +dfa,789,0.5742178151130676 +state_bridge,789,0.32427600870132445 +credit_bridge,789,0.34830747079849245 +dfa,1024,0.7317259416580201 +state_bridge,1024,1.1036756999969481 +credit_bridge,1024,0.5166968782424927 +dfa,2048,0.29954994792938233 +state_bridge,2048,0.4312982501029968 +credit_bridge,2048,0.8699032402038575 +dfa,3000,0.7138990812301635 +state_bridge,3000,0.08414710862636567 +credit_bridge,3000,0.07805714750289917 +dfa,4000,0.6306101009368896 +state_bridge,4000,0.9965199737548828 +credit_bridge,4000,0.4791879728317261 +dfa,5000,0.9650134473800659 +state_bridge,5000,0.9529169537544251 +credit_bridge,5000,0.6871149513244629 +dfa,6000,0.8224315113067627 +state_bridge,6000,0.14781150970458984 +credit_bridge,6000,0.5449538031578064 -- cgit v1.2.3