summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--experiments/compute_naive_state_err.py390
-rw-r--r--results/confirmatory/A1_naive_state_err.csv241
-rw-r--r--results/confirmatory/A2_naive_state_err.csv31
3 files changed, 662 insertions, 0 deletions
diff --git a/experiments/compute_naive_state_err.py b/experiments/compute_naive_state_err.py
new file mode 100644
index 0000000..64c8790
--- /dev/null
+++ b/experiments/compute_naive_state_err.py
@@ -0,0 +1,390 @@
+"""
+Compute naive state prediction baseline: ||h_l - h_L|| / ||h_L|| for all methods.
+This is the identity-prediction error — how much does h change from layer l to L.
+Computed at layer L//2 on test/eval data after training.
+
+Retrains models with same seeds as confirmatory experiments, then computes metric.
+"""
+import os, sys, json, csv, argparse, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
+import torch.optim as optim, copy
+from torch.utils.data import DataLoader
+import torchvision, torchvision.transforms as transforms
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from models.residual_mlp import ResidualMLP
+from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema
+from models.state_bridge import StateBridgeNet
+
+
+def compute_naive_state_err(model, dataloader, device, eval_layer=None):
+ """Compute ||h_l - h_L|| / ||h_L|| averaged over data."""
+ model.eval()
+ L = model.num_blocks
+ if eval_layer is None:
+ eval_layer = L // 2
+ total_err, n = 0.0, 0
+ with torch.no_grad():
+ for x, y in dataloader:
+ if hasattr(model, 'embed'):
+ x = x.view(x.size(0), -1).to(device)
+ else:
+ x = x.to(device)
+ _, hiddens = model(x, return_hidden=True)
+ h_l = hiddens[eval_layer]
+ h_L = hiddens[-1]
+ norm_L = h_L.norm(dim=-1, keepdim=True).clamp(min=1.0)
+ err = ((h_l - h_L) / norm_L).pow(2).sum(-1).mean()
+ total_err += err.item() * x.size(0)
+ n += x.size(0)
+ return total_err / n
+
+
+# ============ Synthetic ============
+class TeacherNet(nn.Module):
+ def __init__(self, d, C, L, alpha=1.0, seed=0):
+ super().__init__()
+ self.alpha = alpha
+ rng = torch.Generator().manual_seed(seed)
+ self.Ws = nn.ParameterList()
+ for _ in range(L):
+ W = torch.randn(d, d, generator=rng) * 0.3 / (d**0.5)
+ U, S, Vh = torch.linalg.svd(W, full_matrices=False)
+ W = U @ torch.diag(S.clamp(max=0.3)) @ Vh
+ self.Ws.append(nn.Parameter(W, requires_grad=False))
+ self.U = nn.Parameter(torch.randn(C, d, generator=rng) / (d**0.5), requires_grad=False)
+ def phi(self, z): return (1-self.alpha)*z + self.alpha*torch.tanh(z)
+ def forward(self, x):
+ h = x
+ for W in self.Ws: h = h + self.phi(h @ W.T)
+ return h @ self.U.T
+
+class StudentBlock(nn.Module):
+ def __init__(self, d, alpha=1.0):
+ super().__init__()
+ self.ln = nn.LayerNorm(d); self.w = nn.Linear(d, d, bias=False)
+ nn.init.normal_(self.w.weight, std=0.01); self.alpha = alpha
+ def phi(self, z): return (1-self.alpha)*z + self.alpha*torch.tanh(z)
+ def forward(self, h): return self.w(self.phi(self.ln(h)))
+
+class StudentNet(nn.Module):
+ def __init__(self, d, C, L, alpha=1.0):
+ super().__init__()
+ self.blocks = nn.ModuleList([StudentBlock(d, alpha) for _ in range(L)])
+ self.out_head = nn.Linear(d, C); self.d_hidden = d; self.num_blocks = L
+ def forward(self, x, return_hidden=False):
+ h = x; hiddens = [h] if return_hidden else None
+ for b in self.blocks:
+ h = h + b(h)
+ if return_hidden: hiddens.append(h)
+ logits = self.out_head(h)
+ return (logits, hiddens) if return_hidden else logits
+ def forward_from_layer(self, h, sl):
+ for i in range(sl, self.num_blocks): h = h + self.blocks[i](h)
+ return self.out_head(h)
+
+
+def train_synth_method(method, model, teacher, device, d, C, L, epochs=80, steps=50, bs=256, lr=1e-3, lr_fb=1e-3):
+ """Train synthetic model with given method, return trained model."""
+ if method == 'bp':
+ opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
+ for ep in range(epochs):
+ model.train()
+ for _ in range(steps):
+ x = torch.randn(bs, d, device=device)
+ with torch.no_grad(): y = teacher(x).argmax(-1)
+ loss = F.cross_entropy(model(x), y); opt.zero_grad(); loss.backward(); opt.step()
+ elif method == 'dfa':
+ Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)]
+ bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=0.01) for b in model.blocks]
+ hop = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=0.01)
+ for ep in range(epochs):
+ model.train()
+ for _ in range(steps):
+ x = torch.randn(bs, d, device=device)
+ with torch.no_grad():
+ y = teacher(x).argmax(-1); lo, hi = model(x, return_hidden=True)
+ eT = lo.softmax(-1); eT[torch.arange(bs), y] -= 1
+ lo2 = F.cross_entropy(model.out_head(hi[-1].detach()), y); hop.zero_grad(); lo2.backward(); hop.step()
+ for l in range(L):
+ a = (eT @ Bs[l].T).detach(); rm = (a**2).mean(-1, keepdim=True).sqrt()+1e-6
+ ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean()
+ bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0); bops[l].step()
+ elif method == 'state_bridge':
+ sp = StateBridgeNet(d_hidden=d, s_dim=C).to(device)
+ Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)]
+ bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=0.01) for b in model.blocks]
+ hop = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=0.01)
+ sop = optim.Adam(sp.parameters(), lr=lr_fb)
+ warmup = max(1, epochs//5)
+ for ep in range(epochs):
+ model.train(); sp.train()
+ for _ in range(steps):
+ x = torch.randn(bs, d, device=device)
+ with torch.no_grad():
+ y = teacher(x).argmax(-1); lo, hi = model(x, return_hidden=True)
+ eT = lo.softmax(-1); eT[torch.arange(bs), y] -= 1; s = eT.detach()
+ hL = hi[-1].detach()
+ sl = 0.0
+ for l in range(L):
+ t_l = torch.full((bs,), l/L, device=device)
+ pred = sp(hi[l].detach(), t_l, s)
+ tn = hL.norm(-1, keepdim=True).clamp(min=1.0)
+ sl += (((pred-hL)/tn)**2).sum(-1).mean()
+ sl /= L; sop.zero_grad(); sl.backward(); sop.step()
+ credits = []
+ for l in range(L):
+ hl = hi[l].detach().requires_grad_(True); t_l = torch.full((bs,), l/L, device=device)
+ pred = sp(hl, t_l, s); pl = F.cross_entropy(model.out_head(pred), y, reduction='sum')
+ credits.append(torch.autograd.grad(pl, hl, create_graph=False)[0].detach())
+ lo2 = F.cross_entropy(model.out_head(hL), y); hop.zero_grad(); lo2.backward(); hop.step()
+ for l in range(L):
+ a = credits[l]; rm = (a**2).mean(-1, keepdim=True).sqrt()+1e-6
+ ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean()
+ bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0); bops[l].step()
+ elif method == 'credit_bridge':
+ vn = ValueNet(d_hidden=d, s_dim=C).to(device); ve = create_ema_model(vn)
+ Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)]
+ bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=0.01) for b in model.blocks]
+ hop = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=0.01)
+ vop = optim.Adam(vn.parameters(), lr=lr_fb)
+ warmup = max(1, epochs//5)
+ for ep in range(epochs):
+ model.train(); vn.train()
+ blend = 0.0 if ep < warmup else min(1.0, (ep-warmup)/max(1, warmup))
+ for _ in range(steps):
+ x = torch.randn(bs, d, device=device)
+ with torch.no_grad():
+ y = teacher(x).argmax(-1); lo, hi = model(x, return_hidden=True)
+ eT = lo.softmax(-1); eT[torch.arange(bs), y] -= 1; s = eT.detach()
+ tl_val = F.cross_entropy(lo, y, reduction='none').detach()
+ hL = hi[-1].detach(); t_L = torch.ones(bs, device=device)
+ lt = ((vn(hL, t_L, s)-tl_val)**2).mean()
+ hLr = hL.clone().requires_grad_(True); VL = vn(hLr, t_L, s)
+ gV = torch.autograd.grad(VL.sum(), hLr, create_graph=True)[0]
+ hLr2 = hL.clone().requires_grad_(True); ce = F.cross_entropy(model.out_head(hLr2), y, reduction='sum')
+ aLe = torch.autograd.grad(ce, hLr2, create_graph=False)[0].detach()
+ ltg = ((gV-aLe)**2).sum(-1).mean()
+ lb = 0.0
+ for l in range(L):
+ hl = hi[l].detach(); tl = torch.full((bs,), l/L, device=device); tn = torch.full((bs,), (l+1)/L, device=device)
+ Vl = vn(hl, tl, s)
+ with torch.no_grad():
+ hn = hi[l+1].detach(); lts = []
+ for k in range(4):
+ ns = 0.05*torch.randn_like(hn); lts.append(-ve(hn+ns, tn, s)/0.1)
+ Vt = -0.1*(torch.logsumexp(torch.stack(lts, -1), -1)-np.log(4))
+ lb += ((Vl-Vt.detach())**2).mean()
+ lb /= L; vl = lt+lb+1.0*ltg
+ vop.zero_grad(); vl.backward(); torch.nn.utils.clip_grad_norm_(vn.parameters(), 1.0); vop.step()
+ update_ema(vn, ve, 0.995)
+ cbc = []
+ for l in range(L):
+ hl = hi[l].detach().requires_grad_(True); tl = torch.full((bs,), l/L, device=device)
+ Vl = vn(hl, tl, s); cbc.append(torch.autograd.grad(Vl.sum(), hl, create_graph=False)[0].detach())
+ dfac = [(eT@Bs[l].T).detach() for l in range(L)]
+ credits = []
+ for l in range(L):
+ if blend >= 1: credits.append(cbc[l])
+ elif blend <= 0: credits.append(dfac[l])
+ else:
+ cr = (cbc[l]**2).mean(-1,keepdim=True).sqrt()+1e-6; dr = (dfac[l]**2).mean(-1,keepdim=True).sqrt()+1e-6
+ credits.append(blend*cbc[l]/cr+(1-blend)*dfac[l]/dr)
+ lo2 = F.cross_entropy(model.out_head(hL), y); hop.zero_grad(); lo2.backward(); hop.step()
+ for l in range(L):
+ a = credits[l]; rm = (a**2).mean(-1, keepdim=True).sqrt()+1e-6
+ ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean()
+ bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0); bops[l].step()
+ return model
+
+
+def run_A1_naive(args):
+ """Compute naive state err for synthetic ladder."""
+ device = torch.device(f'cuda:{args.gpu}')
+ seeds = [42,123,456,789,1024,2048,3000,4000,5000,6000]
+ alphas = [0.0, 0.5, 1.0]; depths = [4, 8]; d = 128; C = 10
+ methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge']
+ rows = []
+ for alpha in alphas:
+ for L in depths:
+ for seed in seeds:
+ teacher = TeacherNet(d, C, L, alpha=alpha, seed=seed*1000).to(device)
+ # Eval data
+ eval_x = torch.randn(512, d, device=device)
+ with torch.no_grad(): eval_y = teacher(eval_x).argmax(-1)
+ eval_data = [(eval_x, eval_y)]
+ for method in methods:
+ torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
+ model = StudentNet(d, C, L, alpha=alpha).to(device)
+ model = train_synth_method(method, model, teacher, device, d, C, L)
+ nse = compute_naive_state_err(model, eval_data, device, eval_layer=L//2)
+ rows.append({'alpha': alpha, 'depth': L, 'method': method, 'seed': seed, 'naive_StateErr': nse})
+ print(f" A1 alpha={alpha} L={L} {method} s={seed}: naive_StateErr={nse:.6f}", flush=True)
+ # Save CSV
+ out = os.path.join(args.output_dir, 'A1_naive_state_err.csv')
+ with open(out, 'w', newline='') as f:
+ w = csv.DictWriter(f, fieldnames=['alpha','depth','method','seed','naive_StateErr'])
+ w.writeheader(); w.writerows(rows)
+ print(f"Saved {len(rows)} rows to {out}", flush=True)
+
+
+def get_cifar10(bs=128):
+ tt = transforms.Compose([transforms.RandomCrop(32,4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))])
+ tv = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))])
+ return (DataLoader(torchvision.datasets.CIFAR10('./data',True,download=True,transform=tt),bs,True,num_workers=4,pin_memory=True),
+ DataLoader(torchvision.datasets.CIFAR10('./data',False,download=True,transform=tv),bs,False,num_workers=4,pin_memory=True))
+
+
+def train_cifar_method(method, model, train_loader, test_loader, device, L, d, epochs=100, lr=1e-3, lr_fb=1e-3, wd=0.01):
+ """Train CIFAR model, return trained model."""
+ C = 10
+ if method == 'dfa':
+ Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)]
+ bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks]
+ eop = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd)
+ hop = optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()), lr=lr, weight_decay=wd)
+ schs = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in bops]+[optim.lr_scheduler.CosineAnnealingLR(eop, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(hop, T_max=epochs)]
+ for ep in range(1, epochs+1):
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0),-1).to(device); y = y.to(device); b = x.size(0)
+ with torch.no_grad(): lo, hi = model(x, return_hidden=True); eT = lo.softmax(-1); eT[torch.arange(b),y] -= 1
+ hL = hi[-1].detach()
+ F.cross_entropy(model.out_head(model.out_ln(hL)), y).backward(); hop.step(); hop.zero_grad()
+ for l in range(L):
+ a = (eT@Bs[l].T).detach(); rm = (a**2).mean(-1,keepdim=True).sqrt()+1e-6
+ ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean()
+ bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0); bops[l].step()
+ a0 = (eT@Bs[0].T).detach(); r0 = (a0**2).mean(-1,keepdim=True).sqrt()+1e-6
+ el = (model.embed(x.view(x.size(0),-1))*(a0/r0)).sum(-1).mean(); eop.zero_grad(); el.backward(); eop.step()
+ for s in schs: s.step()
+ elif method == 'state_bridge':
+ sp = StateBridgeNet(d_hidden=d, s_dim=C).to(device)
+ Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)]
+ bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks]
+ eop = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd)
+ hop = optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()), lr=lr, weight_decay=wd)
+ sop = optim.Adam(sp.parameters(), lr=lr_fb)
+ schs = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in bops]+[optim.lr_scheduler.CosineAnnealingLR(eop, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(hop, T_max=epochs)]
+ for ep in range(1, epochs+1):
+ model.train(); sp.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0),-1).to(device); y = y.to(device); b = x.size(0)
+ with torch.no_grad(): lo, hi = model(x, return_hidden=True); eT = lo.softmax(-1); eT[torch.arange(b),y] -= 1; s = eT.detach()
+ hL = hi[-1].detach()
+ sl = 0.0
+ for l in range(L):
+ tl = torch.full((b,), l/L, device=device)
+ pred = sp(hi[l].detach(), tl, s); tn = hL.norm(-1,keepdim=True).clamp(min=1.0)
+ sl += (((pred-hL)/tn)**2).sum(-1).mean()
+ sl /= L; sop.zero_grad(); sl.backward(); sop.step()
+ credits = []
+ for l in range(L):
+ hl = hi[l].detach().requires_grad_(True); tl = torch.full((b,), l/L, device=device)
+ pred = sp(hl, tl, s); pl = F.cross_entropy(model.out_head(model.out_ln(pred)), y, reduction='sum')
+ credits.append(torch.autograd.grad(pl, hl, create_graph=False)[0].detach())
+ lo2 = F.cross_entropy(model.out_head(model.out_ln(hL)), y); hop.zero_grad(); lo2.backward(); hop.step()
+ for l in range(L):
+ a = credits[l]; rm = (a**2).mean(-1,keepdim=True).sqrt()+1e-6
+ ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean()
+ bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0); bops[l].step()
+ a0 = credits[0]; r0 = (a0**2).mean(-1,keepdim=True).sqrt()+1e-6
+ el = (model.embed(x.view(x.size(0),-1))*(a0/r0)).sum(-1).mean(); eop.zero_grad(); el.backward(); eop.step()
+ for s in schs: s.step()
+ elif method == 'credit_bridge':
+ vn = ValueNet(d_hidden=d, s_dim=C).to(device); ve = create_ema_model(vn)
+ Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)]
+ bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks]
+ eop = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd)
+ hop = optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()), lr=lr, weight_decay=wd)
+ vop = optim.Adam(vn.parameters(), lr=lr_fb)
+ schs = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in bops]+[optim.lr_scheduler.CosineAnnealingLR(eop, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(hop, T_max=epochs)]
+ warmup = max(1, epochs//5)
+ for ep in range(1, epochs+1):
+ model.train(); vn.train()
+ blend = 0.0 if ep <= warmup else min(1.0, (ep-warmup)/max(1, warmup))
+ for x, y in train_loader:
+ x = x.view(x.size(0),-1).to(device); y = y.to(device); b = x.size(0)
+ with torch.no_grad(): lo, hi = model(x, return_hidden=True); eT = lo.softmax(-1); eT[torch.arange(b),y] -= 1; s = eT.detach(); tlv = F.cross_entropy(lo, y, reduction='none').detach()
+ hL = hi[-1].detach(); t_L = torch.ones(b, device=device)
+ lt = ((vn(hL, t_L, s)-tlv)**2).mean()
+ hLr = hL.clone().requires_grad_(True); VL = vn(hLr, t_L, s)
+ gV = torch.autograd.grad(VL.sum(), hLr, create_graph=True)[0]
+ hLr2 = hL.clone().requires_grad_(True); ce = F.cross_entropy(model.out_head(model.out_ln(hLr2)), y, reduction='sum')
+ aLe = torch.autograd.grad(ce, hLr2, create_graph=False)[0].detach()
+ ltg = ((gV-aLe)**2).sum(-1).mean()
+ lb = 0.0
+ for l in range(L):
+ hl = hi[l].detach(); tl = torch.full((b,), l/L, device=device); tn = torch.full((b,), (l+1)/L, device=device)
+ Vl = vn(hl, tl, s)
+ with torch.no_grad():
+ hn = hi[l+1].detach(); lts = []
+ for k in range(4): lts.append(-ve(hn+0.05*torch.randn_like(hn), tn, s)/0.1)
+ Vt = -0.1*(torch.logsumexp(torch.stack(lts,-1),-1)-np.log(4))
+ lb += ((Vl-Vt.detach())**2).mean()
+ lb /= L; vl = lt+lb+1.0*ltg
+ vop.zero_grad(); vl.backward(); torch.nn.utils.clip_grad_norm_(vn.parameters(),1.0); vop.step()
+ update_ema(vn, ve, 0.995)
+ cbc = []
+ for l in range(L):
+ hl = hi[l].detach().requires_grad_(True); tl = torch.full((b,), l/L, device=device)
+ Vl = vn(hl, tl, s); cbc.append(torch.autograd.grad(Vl.sum(), hl, create_graph=False)[0].detach())
+ dfac = [(eT@Bs[l].T).detach() for l in range(L)]
+ credits = []
+ for l in range(L):
+ if blend >= 1: credits.append(cbc[l])
+ elif blend <= 0: credits.append(dfac[l])
+ else:
+ cr = (cbc[l]**2).mean(-1,keepdim=True).sqrt()+1e-6; dr = (dfac[l]**2).mean(-1,keepdim=True).sqrt()+1e-6
+ credits.append(blend*cbc[l]/cr+(1-blend)*dfac[l]/dr)
+ lo2 = F.cross_entropy(model.out_head(model.out_ln(hL)), y); hop.zero_grad(); lo2.backward(); hop.step()
+ for l in range(L):
+ a = credits[l]; rm = (a**2).mean(-1,keepdim=True).sqrt()+1e-6
+ ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean()
+ bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0); bops[l].step()
+ a0 = credits[0]; r0 = (a0**2).mean(-1,keepdim=True).sqrt()+1e-6
+ el = (model.embed(x.view(x.size(0),-1))*(a0/r0)).sum(-1).mean(); eop.zero_grad(); el.backward(); eop.step()
+ for s in schs: s.step()
+ return model
+
+
+def run_A2_naive(args):
+ """Compute naive state err for CIFAR methods."""
+ device = torch.device(f'cuda:{args.gpu}')
+ seeds = [42,123,456,789,1024,2048,3000,4000,5000,6000]
+ L, d = 4, 256; methods = ['dfa', 'state_bridge', 'credit_bridge']
+ train_loader, test_loader = get_cifar10()
+ rows = []
+ for seed in seeds:
+ for method in methods:
+ torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
+ model = ResidualMLP(3072, d, 10, L).to(device)
+ print(f" A2 {method} s={seed}: training...", flush=True)
+ model = train_cifar_method(method, model, train_loader, test_loader, device, L, d)
+ nse = compute_naive_state_err(model, test_loader, device, eval_layer=L//2)
+ rows.append({'method': method, 'seed': seed, 'naive_StateErr': nse})
+ print(f" A2 {method} s={seed}: naive_StateErr={nse:.6f}", flush=True)
+ out = os.path.join(args.output_dir, 'A2_naive_state_err.csv')
+ with open(out, 'w', newline='') as f:
+ w = csv.DictWriter(f, fieldnames=['method','seed','naive_StateErr'])
+ w.writeheader(); w.writerows(rows)
+ print(f"Saved {len(rows)} rows to {out}", flush=True)
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--experiment', type=str, default='A1', choices=['A1','A2','both'])
+ p.add_argument('--gpu', type=int, default=3)
+ p.add_argument('--output_dir', type=str, default='results/confirmatory')
+ args = p.parse_args()
+ os.makedirs(args.output_dir, exist_ok=True)
+ if args.experiment in ['A1', 'both']:
+ print("=== A1: Synthetic naive state err ===", flush=True)
+ run_A1_naive(args)
+ if args.experiment in ['A2', 'both']:
+ print("=== A2: CIFAR naive state err ===", flush=True)
+ run_A2_naive(args)
+ print("Done.", flush=True)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/results/confirmatory/A1_naive_state_err.csv b/results/confirmatory/A1_naive_state_err.csv
new file mode 100644
index 0000000..6ddcfe7
--- /dev/null
+++ b/results/confirmatory/A1_naive_state_err.csv
@@ -0,0 +1,241 @@
+alpha,depth,method,seed,naive_StateErr
+0.0,4,bp,42,0.26068803668022156
+0.0,4,dfa,42,0.7258815169334412
+0.0,4,state_bridge,42,0.7580535411834717
+0.0,4,credit_bridge,42,0.23254971206188202
+0.0,4,bp,123,0.24781781435012817
+0.0,4,dfa,123,0.6604026556015015
+0.0,4,state_bridge,123,0.5317853689193726
+0.0,4,credit_bridge,123,0.17576289176940918
+0.0,4,bp,456,0.2476104348897934
+0.0,4,dfa,456,0.6199116110801697
+0.0,4,state_bridge,456,0.863245964050293
+0.0,4,credit_bridge,456,0.12035267055034637
+0.0,4,bp,789,0.2656460106372833
+0.0,4,dfa,789,0.7435265779495239
+0.0,4,state_bridge,789,0.6364812850952148
+0.0,4,credit_bridge,789,0.2562270164489746
+0.0,4,bp,1024,0.2501564025878906
+0.0,4,dfa,1024,0.6919815540313721
+0.0,4,state_bridge,1024,0.6960811018943787
+0.0,4,credit_bridge,1024,0.2409060299396515
+0.0,4,bp,2048,0.24213725328445435
+0.0,4,dfa,2048,0.8119536638259888
+0.0,4,state_bridge,2048,1.0384111404418945
+0.0,4,credit_bridge,2048,0.29382896423339844
+0.0,4,bp,3000,0.25653713941574097
+0.0,4,dfa,3000,0.7727931141853333
+0.0,4,state_bridge,3000,0.7795010805130005
+0.0,4,credit_bridge,3000,0.1640912890434265
+0.0,4,bp,4000,0.2790989279747009
+0.0,4,dfa,4000,0.6488389372825623
+0.0,4,state_bridge,4000,0.8933002948760986
+0.0,4,credit_bridge,4000,0.129751056432724
+0.0,4,bp,5000,0.23482249677181244
+0.0,4,dfa,5000,0.5603621006011963
+0.0,4,state_bridge,5000,0.7118808031082153
+0.0,4,credit_bridge,5000,0.1876041740179062
+0.0,4,bp,6000,0.2366245985031128
+0.0,4,dfa,6000,0.7504462003707886
+0.0,4,state_bridge,6000,0.8727204203605652
+0.0,4,credit_bridge,6000,0.18112479150295258
+0.0,8,bp,42,0.07397101819515228
+0.0,8,dfa,42,0.5621968507766724
+0.0,8,state_bridge,42,0.816449761390686
+0.0,8,credit_bridge,42,0.15379583835601807
+0.0,8,bp,123,0.07488589733839035
+0.0,8,dfa,123,0.6059672832489014
+0.0,8,state_bridge,123,0.5263348817825317
+0.0,8,credit_bridge,123,0.12262149155139923
+0.0,8,bp,456,0.07783643156290054
+0.0,8,dfa,456,0.7705618143081665
+0.0,8,state_bridge,456,0.4180244505405426
+0.0,8,credit_bridge,456,0.19311878085136414
+0.0,8,bp,789,0.09013043344020844
+0.0,8,dfa,789,0.5577319860458374
+0.0,8,state_bridge,789,0.8297959566116333
+0.0,8,credit_bridge,789,0.05089893937110901
+0.0,8,bp,1024,0.0830039381980896
+0.0,8,dfa,1024,0.5080844163894653
+0.0,8,state_bridge,1024,0.39521247148513794
+0.0,8,credit_bridge,1024,0.06527997553348541
+0.0,8,bp,2048,0.07513846457004547
+0.0,8,dfa,2048,0.6275949478149414
+0.0,8,state_bridge,2048,0.6864550113677979
+0.0,8,credit_bridge,2048,0.08002562075853348
+0.0,8,bp,3000,0.10901156067848206
+0.0,8,dfa,3000,0.6083968281745911
+0.0,8,state_bridge,3000,0.8938897848129272
+0.0,8,credit_bridge,3000,0.38711249828338623
+0.0,8,bp,4000,0.09766732156276703
+0.0,8,dfa,4000,0.5503160357475281
+0.0,8,state_bridge,4000,0.39751216769218445
+0.0,8,credit_bridge,4000,0.13849811255931854
+0.0,8,bp,5000,0.07875370979309082
+0.0,8,dfa,5000,0.5541619658470154
+0.0,8,state_bridge,5000,0.6255663633346558
+0.0,8,credit_bridge,5000,0.07148516178131104
+0.0,8,bp,6000,0.07798311114311218
+0.0,8,dfa,6000,0.5998473167419434
+0.0,8,state_bridge,6000,0.5577906370162964
+0.0,8,credit_bridge,6000,0.18784484267234802
+0.5,4,bp,42,0.2670668959617615
+0.5,4,dfa,42,0.78990238904953
+0.5,4,state_bridge,42,0.7565754055976868
+0.5,4,credit_bridge,42,0.28575778007507324
+0.5,4,bp,123,0.26157334446907043
+0.5,4,dfa,123,0.8240824937820435
+0.5,4,state_bridge,123,0.5587974786758423
+0.5,4,credit_bridge,123,0.18487781286239624
+0.5,4,bp,456,0.26119792461395264
+0.5,4,dfa,456,0.5580360889434814
+0.5,4,state_bridge,456,0.6804043650627136
+0.5,4,credit_bridge,456,0.10438685864210129
+0.5,4,bp,789,0.2688371241092682
+0.5,4,dfa,789,0.7619737386703491
+0.5,4,state_bridge,789,0.5724117755889893
+0.5,4,credit_bridge,789,0.2902269661426544
+0.5,4,bp,1024,0.2626464068889618
+0.5,4,dfa,1024,0.7350623607635498
+0.5,4,state_bridge,1024,0.7774407863616943
+0.5,4,credit_bridge,1024,0.2907182276248932
+0.5,4,bp,2048,0.25887274742126465
+0.5,4,dfa,2048,0.8292539119720459
+0.5,4,state_bridge,2048,0.6785949468612671
+0.5,4,credit_bridge,2048,0.23325428366661072
+0.5,4,bp,3000,0.2686220407485962
+0.5,4,dfa,3000,0.7964614033699036
+0.5,4,state_bridge,3000,0.6747612953186035
+0.5,4,credit_bridge,3000,0.15331298112869263
+0.5,4,bp,4000,0.27715128660202026
+0.5,4,dfa,4000,0.834259033203125
+0.5,4,state_bridge,4000,0.6629055738449097
+0.5,4,credit_bridge,4000,0.09292227774858475
+0.5,4,bp,5000,0.25511646270751953
+0.5,4,dfa,5000,0.8486669063568115
+0.5,4,state_bridge,5000,0.6432816386222839
+0.5,4,credit_bridge,5000,0.08898796141147614
+0.5,4,bp,6000,0.25425827503204346
+0.5,4,dfa,6000,0.7771180868148804
+0.5,4,state_bridge,6000,0.8868570327758789
+0.5,4,credit_bridge,6000,0.28548482060432434
+0.5,8,bp,42,0.08691950142383575
+0.5,8,dfa,42,0.5566835403442383
+0.5,8,state_bridge,42,0.5521173477172852
+0.5,8,credit_bridge,42,0.026307538151741028
+0.5,8,bp,123,0.08442661166191101
+0.5,8,dfa,123,0.6884247064590454
+0.5,8,state_bridge,123,1.1720242500305176
+0.5,8,credit_bridge,123,0.09352543950080872
+0.5,8,bp,456,0.0845673531293869
+0.5,8,dfa,456,0.7805156707763672
+0.5,8,state_bridge,456,0.564720630645752
+0.5,8,credit_bridge,456,0.04183648154139519
+0.5,8,bp,789,0.09304642677307129
+0.5,8,dfa,789,0.5289106965065002
+0.5,8,state_bridge,789,1.1913974285125732
+0.5,8,credit_bridge,789,0.0811903178691864
+0.5,8,bp,1024,0.08988181501626968
+0.5,8,dfa,1024,0.5268670320510864
+0.5,8,state_bridge,1024,0.38997870683670044
+0.5,8,credit_bridge,1024,0.0551433339715004
+0.5,8,bp,2048,0.08537864685058594
+0.5,8,dfa,2048,0.6378879547119141
+0.5,8,state_bridge,2048,0.8929705619812012
+0.5,8,credit_bridge,2048,0.05500475689768791
+0.5,8,bp,3000,0.10150086879730225
+0.5,8,dfa,3000,0.6010029315948486
+0.5,8,state_bridge,3000,0.6921526193618774
+0.5,8,credit_bridge,3000,0.0750318169593811
+0.5,8,bp,4000,0.09839055687189102
+0.5,8,dfa,4000,0.5477902889251709
+0.5,8,state_bridge,4000,0.4617120325565338
+0.5,8,credit_bridge,4000,0.0474717952311039
+0.5,8,bp,5000,0.09180224686861038
+0.5,8,dfa,5000,0.5527122616767883
+0.5,8,state_bridge,5000,0.5788565874099731
+0.5,8,credit_bridge,5000,0.08813595771789551
+0.5,8,bp,6000,0.08880583941936493
+0.5,8,dfa,6000,0.5930142402648926
+0.5,8,state_bridge,6000,0.5428022742271423
+0.5,8,credit_bridge,6000,0.1234222799539566
+1.0,4,bp,42,0.37005650997161865
+1.0,4,dfa,42,0.869762659072876
+1.0,4,state_bridge,42,0.87589430809021
+1.0,4,credit_bridge,42,0.2053530365228653
+1.0,4,bp,123,0.36214226484298706
+1.0,4,dfa,123,0.8254338502883911
+1.0,4,state_bridge,123,0.8833000659942627
+1.0,4,credit_bridge,123,0.21273687481880188
+1.0,4,bp,456,0.3626979887485504
+1.0,4,dfa,456,0.7873569130897522
+1.0,4,state_bridge,456,0.8605778217315674
+1.0,4,credit_bridge,456,0.21030157804489136
+1.0,4,bp,789,0.37471723556518555
+1.0,4,dfa,789,0.7885147929191589
+1.0,4,state_bridge,789,0.8968838453292847
+1.0,4,credit_bridge,789,0.18812689185142517
+1.0,4,bp,1024,0.36623838543891907
+1.0,4,dfa,1024,0.8550044894218445
+1.0,4,state_bridge,1024,0.8468506932258606
+1.0,4,credit_bridge,1024,0.32974082231521606
+1.0,4,bp,2048,0.3584083914756775
+1.0,4,dfa,2048,0.823596715927124
+1.0,4,state_bridge,2048,0.8956995010375977
+1.0,4,credit_bridge,2048,0.1712440550327301
+1.0,4,bp,3000,0.37105458974838257
+1.0,4,dfa,3000,0.8486852645874023
+1.0,4,state_bridge,3000,0.9403642416000366
+1.0,4,credit_bridge,3000,0.15761351585388184
+1.0,4,bp,4000,0.384659081697464
+1.0,4,dfa,4000,0.8266894817352295
+1.0,4,state_bridge,4000,0.8737555742263794
+1.0,4,credit_bridge,4000,0.0490226186811924
+1.0,4,bp,5000,0.3479670286178589
+1.0,4,dfa,5000,0.7952747344970703
+1.0,4,state_bridge,5000,0.8428233861923218
+1.0,4,credit_bridge,5000,0.2047322541475296
+1.0,4,bp,6000,0.35523170232772827
+1.0,4,dfa,6000,0.850082516670227
+1.0,4,state_bridge,6000,0.8412976264953613
+1.0,4,credit_bridge,6000,0.18619702756404877
+1.0,8,bp,42,0.20174872875213623
+1.0,8,dfa,42,0.5564329028129578
+1.0,8,state_bridge,42,0.6702669262886047
+1.0,8,credit_bridge,42,0.11915956437587738
+1.0,8,bp,123,0.20575112104415894
+1.0,8,dfa,123,0.7708055973052979
+1.0,8,state_bridge,123,0.6753015518188477
+1.0,8,credit_bridge,123,0.12676304578781128
+1.0,8,bp,456,0.20126384496688843
+1.0,8,dfa,456,0.789068341255188
+1.0,8,state_bridge,456,0.5630311965942383
+1.0,8,credit_bridge,456,0.0718986764550209
+1.0,8,bp,789,0.21950820088386536
+1.0,8,dfa,789,0.801855206489563
+1.0,8,state_bridge,789,0.5838653445243835
+1.0,8,credit_bridge,789,0.1493334174156189
+1.0,8,bp,1024,0.2049655318260193
+1.0,8,dfa,1024,0.7054224014282227
+1.0,8,state_bridge,1024,0.564822793006897
+1.0,8,credit_bridge,1024,0.0481550358235836
+1.0,8,bp,2048,0.2031216025352478
+1.0,8,dfa,2048,0.6300236582756042
+1.0,8,state_bridge,2048,0.584318995475769
+1.0,8,credit_bridge,2048,0.22037747502326965
+1.0,8,bp,3000,0.22935035824775696
+1.0,8,dfa,3000,0.712699294090271
+1.0,8,state_bridge,3000,0.7996810674667358
+1.0,8,credit_bridge,3000,0.09546110779047012
+1.0,8,bp,4000,0.22556602954864502
+1.0,8,dfa,4000,0.5526505708694458
+1.0,8,state_bridge,4000,0.5007961392402649
+1.0,8,credit_bridge,4000,0.04477586969733238
+1.0,8,bp,5000,0.2003089338541031
+1.0,8,dfa,5000,0.777186393737793
+1.0,8,state_bridge,5000,0.3940466642379761
+1.0,8,credit_bridge,5000,0.1061391606926918
+1.0,8,bp,6000,0.20401637256145477
+1.0,8,dfa,6000,0.5910547971725464
+1.0,8,state_bridge,6000,0.6176798343658447
+1.0,8,credit_bridge,6000,0.10662685334682465
diff --git a/results/confirmatory/A2_naive_state_err.csv b/results/confirmatory/A2_naive_state_err.csv
new file mode 100644
index 0000000..fcb5f01
--- /dev/null
+++ b/results/confirmatory/A2_naive_state_err.csv
@@ -0,0 +1,31 @@
+method,seed,naive_StateErr
+dfa,42,0.8608599968910218
+state_bridge,42,1.1073074493408204
+credit_bridge,42,0.18549979393482208
+dfa,123,0.6620678688049316
+state_bridge,123,1.157683981513977
+credit_bridge,123,0.5449653492927551
+dfa,456,0.862576175403595
+state_bridge,456,0.5174413678169251
+credit_bridge,456,0.02991231045126915
+dfa,789,0.5742178151130676
+state_bridge,789,0.32427600870132445
+credit_bridge,789,0.34830747079849245
+dfa,1024,0.7317259416580201
+state_bridge,1024,1.1036756999969481
+credit_bridge,1024,0.5166968782424927
+dfa,2048,0.29954994792938233
+state_bridge,2048,0.4312982501029968
+credit_bridge,2048,0.8699032402038575
+dfa,3000,0.7138990812301635
+state_bridge,3000,0.08414710862636567
+credit_bridge,3000,0.07805714750289917
+dfa,4000,0.6306101009368896
+state_bridge,4000,0.9965199737548828
+credit_bridge,4000,0.4791879728317261
+dfa,5000,0.9650134473800659
+state_bridge,5000,0.9529169537544251
+credit_bridge,5000,0.6871149513244629
+dfa,6000,0.8224315113067627
+state_bridge,6000,0.14781150970458984
+credit_bridge,6000,0.5449538031578064