summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-03-31 10:36:02 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-03-31 10:36:02 -0500
commit1359b7e7a96ab57be0bb24ebdf842a793ce01223 (patch)
treeba13d16ed1bf4604ffb76e12ceb1fdb5958900bb /experiments
parent8b21fb32bf0997e3f4266c1c22414e49f1fdcfcc (diff)
Add naive state prediction baseline for A1 and A2
A1: 240 rows (3 alpha × 2 depth × 4 methods × 10 seeds) A2: 30 rows (3 methods × 10 seeds) naive_StateErr = ||h_{L//2} - h_L|| / ||h_L|| Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
-rw-r--r--experiments/compute_naive_state_err.py390
1 files changed, 390 insertions, 0 deletions
diff --git a/experiments/compute_naive_state_err.py b/experiments/compute_naive_state_err.py
new file mode 100644
index 0000000..64c8790
--- /dev/null
+++ b/experiments/compute_naive_state_err.py
@@ -0,0 +1,390 @@
+"""
+Compute naive state prediction baseline: ||h_l - h_L|| / ||h_L|| for all methods.
+This is the identity-prediction error — how much does h change from layer l to L.
+Computed at layer L//2 on test/eval data after training.
+
+Retrains models with same seeds as confirmatory experiments, then computes metric.
+"""
+import os, sys, json, csv, argparse, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
+import torch.optim as optim, copy
+from torch.utils.data import DataLoader
+import torchvision, torchvision.transforms as transforms
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from models.residual_mlp import ResidualMLP
+from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema
+from models.state_bridge import StateBridgeNet
+
+
+def compute_naive_state_err(model, dataloader, device, eval_layer=None):
+ """Compute ||h_l - h_L|| / ||h_L|| averaged over data."""
+ model.eval()
+ L = model.num_blocks
+ if eval_layer is None:
+ eval_layer = L // 2
+ total_err, n = 0.0, 0
+ with torch.no_grad():
+ for x, y in dataloader:
+ if hasattr(model, 'embed'):
+ x = x.view(x.size(0), -1).to(device)
+ else:
+ x = x.to(device)
+ _, hiddens = model(x, return_hidden=True)
+ h_l = hiddens[eval_layer]
+ h_L = hiddens[-1]
+ norm_L = h_L.norm(dim=-1, keepdim=True).clamp(min=1.0)
+ err = ((h_l - h_L) / norm_L).pow(2).sum(-1).mean()
+ total_err += err.item() * x.size(0)
+ n += x.size(0)
+ return total_err / n
+
+
+# ============ Synthetic ============
+class TeacherNet(nn.Module):
+ def __init__(self, d, C, L, alpha=1.0, seed=0):
+ super().__init__()
+ self.alpha = alpha
+ rng = torch.Generator().manual_seed(seed)
+ self.Ws = nn.ParameterList()
+ for _ in range(L):
+ W = torch.randn(d, d, generator=rng) * 0.3 / (d**0.5)
+ U, S, Vh = torch.linalg.svd(W, full_matrices=False)
+ W = U @ torch.diag(S.clamp(max=0.3)) @ Vh
+ self.Ws.append(nn.Parameter(W, requires_grad=False))
+ self.U = nn.Parameter(torch.randn(C, d, generator=rng) / (d**0.5), requires_grad=False)
+ def phi(self, z): return (1-self.alpha)*z + self.alpha*torch.tanh(z)
+ def forward(self, x):
+ h = x
+ for W in self.Ws: h = h + self.phi(h @ W.T)
+ return h @ self.U.T
+
+class StudentBlock(nn.Module):
+ def __init__(self, d, alpha=1.0):
+ super().__init__()
+ self.ln = nn.LayerNorm(d); self.w = nn.Linear(d, d, bias=False)
+ nn.init.normal_(self.w.weight, std=0.01); self.alpha = alpha
+ def phi(self, z): return (1-self.alpha)*z + self.alpha*torch.tanh(z)
+ def forward(self, h): return self.w(self.phi(self.ln(h)))
+
+class StudentNet(nn.Module):
+ def __init__(self, d, C, L, alpha=1.0):
+ super().__init__()
+ self.blocks = nn.ModuleList([StudentBlock(d, alpha) for _ in range(L)])
+ self.out_head = nn.Linear(d, C); self.d_hidden = d; self.num_blocks = L
+ def forward(self, x, return_hidden=False):
+ h = x; hiddens = [h] if return_hidden else None
+ for b in self.blocks:
+ h = h + b(h)
+ if return_hidden: hiddens.append(h)
+ logits = self.out_head(h)
+ return (logits, hiddens) if return_hidden else logits
+ def forward_from_layer(self, h, sl):
+ for i in range(sl, self.num_blocks): h = h + self.blocks[i](h)
+ return self.out_head(h)
+
+
+def train_synth_method(method, model, teacher, device, d, C, L, epochs=80, steps=50, bs=256, lr=1e-3, lr_fb=1e-3):
+ """Train synthetic model with given method, return trained model."""
+ if method == 'bp':
+ opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
+ for ep in range(epochs):
+ model.train()
+ for _ in range(steps):
+ x = torch.randn(bs, d, device=device)
+ with torch.no_grad(): y = teacher(x).argmax(-1)
+ loss = F.cross_entropy(model(x), y); opt.zero_grad(); loss.backward(); opt.step()
+ elif method == 'dfa':
+ Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)]
+ bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=0.01) for b in model.blocks]
+ hop = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=0.01)
+ for ep in range(epochs):
+ model.train()
+ for _ in range(steps):
+ x = torch.randn(bs, d, device=device)
+ with torch.no_grad():
+ y = teacher(x).argmax(-1); lo, hi = model(x, return_hidden=True)
+ eT = lo.softmax(-1); eT[torch.arange(bs), y] -= 1
+ lo2 = F.cross_entropy(model.out_head(hi[-1].detach()), y); hop.zero_grad(); lo2.backward(); hop.step()
+ for l in range(L):
+ a = (eT @ Bs[l].T).detach(); rm = (a**2).mean(-1, keepdim=True).sqrt()+1e-6
+ ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean()
+ bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0); bops[l].step()
+ elif method == 'state_bridge':
+ sp = StateBridgeNet(d_hidden=d, s_dim=C).to(device)
+ Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)]
+ bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=0.01) for b in model.blocks]
+ hop = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=0.01)
+ sop = optim.Adam(sp.parameters(), lr=lr_fb)
+ warmup = max(1, epochs//5)
+ for ep in range(epochs):
+ model.train(); sp.train()
+ for _ in range(steps):
+ x = torch.randn(bs, d, device=device)
+ with torch.no_grad():
+ y = teacher(x).argmax(-1); lo, hi = model(x, return_hidden=True)
+ eT = lo.softmax(-1); eT[torch.arange(bs), y] -= 1; s = eT.detach()
+ hL = hi[-1].detach()
+ sl = 0.0
+ for l in range(L):
+ t_l = torch.full((bs,), l/L, device=device)
+ pred = sp(hi[l].detach(), t_l, s)
+ tn = hL.norm(-1, keepdim=True).clamp(min=1.0)
+ sl += (((pred-hL)/tn)**2).sum(-1).mean()
+ sl /= L; sop.zero_grad(); sl.backward(); sop.step()
+ credits = []
+ for l in range(L):
+ hl = hi[l].detach().requires_grad_(True); t_l = torch.full((bs,), l/L, device=device)
+ pred = sp(hl, t_l, s); pl = F.cross_entropy(model.out_head(pred), y, reduction='sum')
+ credits.append(torch.autograd.grad(pl, hl, create_graph=False)[0].detach())
+ lo2 = F.cross_entropy(model.out_head(hL), y); hop.zero_grad(); lo2.backward(); hop.step()
+ for l in range(L):
+ a = credits[l]; rm = (a**2).mean(-1, keepdim=True).sqrt()+1e-6
+ ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean()
+ bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0); bops[l].step()
+ elif method == 'credit_bridge':
+ vn = ValueNet(d_hidden=d, s_dim=C).to(device); ve = create_ema_model(vn)
+ Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)]
+ bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=0.01) for b in model.blocks]
+ hop = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=0.01)
+ vop = optim.Adam(vn.parameters(), lr=lr_fb)
+ warmup = max(1, epochs//5)
+ for ep in range(epochs):
+ model.train(); vn.train()
+ blend = 0.0 if ep < warmup else min(1.0, (ep-warmup)/max(1, warmup))
+ for _ in range(steps):
+ x = torch.randn(bs, d, device=device)
+ with torch.no_grad():
+ y = teacher(x).argmax(-1); lo, hi = model(x, return_hidden=True)
+ eT = lo.softmax(-1); eT[torch.arange(bs), y] -= 1; s = eT.detach()
+ tl_val = F.cross_entropy(lo, y, reduction='none').detach()
+ hL = hi[-1].detach(); t_L = torch.ones(bs, device=device)
+ lt = ((vn(hL, t_L, s)-tl_val)**2).mean()
+ hLr = hL.clone().requires_grad_(True); VL = vn(hLr, t_L, s)
+ gV = torch.autograd.grad(VL.sum(), hLr, create_graph=True)[0]
+ hLr2 = hL.clone().requires_grad_(True); ce = F.cross_entropy(model.out_head(hLr2), y, reduction='sum')
+ aLe = torch.autograd.grad(ce, hLr2, create_graph=False)[0].detach()
+ ltg = ((gV-aLe)**2).sum(-1).mean()
+ lb = 0.0
+ for l in range(L):
+ hl = hi[l].detach(); tl = torch.full((bs,), l/L, device=device); tn = torch.full((bs,), (l+1)/L, device=device)
+ Vl = vn(hl, tl, s)
+ with torch.no_grad():
+ hn = hi[l+1].detach(); lts = []
+ for k in range(4):
+ ns = 0.05*torch.randn_like(hn); lts.append(-ve(hn+ns, tn, s)/0.1)
+ Vt = -0.1*(torch.logsumexp(torch.stack(lts, -1), -1)-np.log(4))
+ lb += ((Vl-Vt.detach())**2).mean()
+ lb /= L; vl = lt+lb+1.0*ltg
+ vop.zero_grad(); vl.backward(); torch.nn.utils.clip_grad_norm_(vn.parameters(), 1.0); vop.step()
+ update_ema(vn, ve, 0.995)
+ cbc = []
+ for l in range(L):
+ hl = hi[l].detach().requires_grad_(True); tl = torch.full((bs,), l/L, device=device)
+ Vl = vn(hl, tl, s); cbc.append(torch.autograd.grad(Vl.sum(), hl, create_graph=False)[0].detach())
+ dfac = [(eT@Bs[l].T).detach() for l in range(L)]
+ credits = []
+ for l in range(L):
+ if blend >= 1: credits.append(cbc[l])
+ elif blend <= 0: credits.append(dfac[l])
+ else:
+ cr = (cbc[l]**2).mean(-1,keepdim=True).sqrt()+1e-6; dr = (dfac[l]**2).mean(-1,keepdim=True).sqrt()+1e-6
+ credits.append(blend*cbc[l]/cr+(1-blend)*dfac[l]/dr)
+ lo2 = F.cross_entropy(model.out_head(hL), y); hop.zero_grad(); lo2.backward(); hop.step()
+ for l in range(L):
+ a = credits[l]; rm = (a**2).mean(-1, keepdim=True).sqrt()+1e-6
+ ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean()
+ bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0); bops[l].step()
+ return model
+
+
+def run_A1_naive(args):
+ """Compute naive state err for synthetic ladder."""
+ device = torch.device(f'cuda:{args.gpu}')
+ seeds = [42,123,456,789,1024,2048,3000,4000,5000,6000]
+ alphas = [0.0, 0.5, 1.0]; depths = [4, 8]; d = 128; C = 10
+ methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge']
+ rows = []
+ for alpha in alphas:
+ for L in depths:
+ for seed in seeds:
+ teacher = TeacherNet(d, C, L, alpha=alpha, seed=seed*1000).to(device)
+ # Eval data
+ eval_x = torch.randn(512, d, device=device)
+ with torch.no_grad(): eval_y = teacher(eval_x).argmax(-1)
+ eval_data = [(eval_x, eval_y)]
+ for method in methods:
+ torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
+ model = StudentNet(d, C, L, alpha=alpha).to(device)
+ model = train_synth_method(method, model, teacher, device, d, C, L)
+ nse = compute_naive_state_err(model, eval_data, device, eval_layer=L//2)
+ rows.append({'alpha': alpha, 'depth': L, 'method': method, 'seed': seed, 'naive_StateErr': nse})
+ print(f" A1 alpha={alpha} L={L} {method} s={seed}: naive_StateErr={nse:.6f}", flush=True)
+ # Save CSV
+ out = os.path.join(args.output_dir, 'A1_naive_state_err.csv')
+ with open(out, 'w', newline='') as f:
+ w = csv.DictWriter(f, fieldnames=['alpha','depth','method','seed','naive_StateErr'])
+ w.writeheader(); w.writerows(rows)
+ print(f"Saved {len(rows)} rows to {out}", flush=True)
+
+
+def get_cifar10(bs=128):
+ tt = transforms.Compose([transforms.RandomCrop(32,4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))])
+ tv = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))])
+ return (DataLoader(torchvision.datasets.CIFAR10('./data',True,download=True,transform=tt),bs,True,num_workers=4,pin_memory=True),
+ DataLoader(torchvision.datasets.CIFAR10('./data',False,download=True,transform=tv),bs,False,num_workers=4,pin_memory=True))
+
+
+def train_cifar_method(method, model, train_loader, test_loader, device, L, d, epochs=100, lr=1e-3, lr_fb=1e-3, wd=0.01):
+ """Train CIFAR model, return trained model."""
+ C = 10
+ if method == 'dfa':
+ Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)]
+ bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks]
+ eop = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd)
+ hop = optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()), lr=lr, weight_decay=wd)
+ schs = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in bops]+[optim.lr_scheduler.CosineAnnealingLR(eop, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(hop, T_max=epochs)]
+ for ep in range(1, epochs+1):
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0),-1).to(device); y = y.to(device); b = x.size(0)
+ with torch.no_grad(): lo, hi = model(x, return_hidden=True); eT = lo.softmax(-1); eT[torch.arange(b),y] -= 1
+ hL = hi[-1].detach()
+ F.cross_entropy(model.out_head(model.out_ln(hL)), y).backward(); hop.step(); hop.zero_grad()
+ for l in range(L):
+ a = (eT@Bs[l].T).detach(); rm = (a**2).mean(-1,keepdim=True).sqrt()+1e-6
+ ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean()
+ bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0); bops[l].step()
+ a0 = (eT@Bs[0].T).detach(); r0 = (a0**2).mean(-1,keepdim=True).sqrt()+1e-6
+ el = (model.embed(x.view(x.size(0),-1))*(a0/r0)).sum(-1).mean(); eop.zero_grad(); el.backward(); eop.step()
+ for s in schs: s.step()
+ elif method == 'state_bridge':
+ sp = StateBridgeNet(d_hidden=d, s_dim=C).to(device)
+ Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)]
+ bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks]
+ eop = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd)
+ hop = optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()), lr=lr, weight_decay=wd)
+ sop = optim.Adam(sp.parameters(), lr=lr_fb)
+ schs = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in bops]+[optim.lr_scheduler.CosineAnnealingLR(eop, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(hop, T_max=epochs)]
+ for ep in range(1, epochs+1):
+ model.train(); sp.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0),-1).to(device); y = y.to(device); b = x.size(0)
+ with torch.no_grad(): lo, hi = model(x, return_hidden=True); eT = lo.softmax(-1); eT[torch.arange(b),y] -= 1; s = eT.detach()
+ hL = hi[-1].detach()
+ sl = 0.0
+ for l in range(L):
+ tl = torch.full((b,), l/L, device=device)
+ pred = sp(hi[l].detach(), tl, s); tn = hL.norm(-1,keepdim=True).clamp(min=1.0)
+ sl += (((pred-hL)/tn)**2).sum(-1).mean()
+ sl /= L; sop.zero_grad(); sl.backward(); sop.step()
+ credits = []
+ for l in range(L):
+ hl = hi[l].detach().requires_grad_(True); tl = torch.full((b,), l/L, device=device)
+ pred = sp(hl, tl, s); pl = F.cross_entropy(model.out_head(model.out_ln(pred)), y, reduction='sum')
+ credits.append(torch.autograd.grad(pl, hl, create_graph=False)[0].detach())
+ lo2 = F.cross_entropy(model.out_head(model.out_ln(hL)), y); hop.zero_grad(); lo2.backward(); hop.step()
+ for l in range(L):
+ a = credits[l]; rm = (a**2).mean(-1,keepdim=True).sqrt()+1e-6
+ ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean()
+ bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0); bops[l].step()
+ a0 = credits[0]; r0 = (a0**2).mean(-1,keepdim=True).sqrt()+1e-6
+ el = (model.embed(x.view(x.size(0),-1))*(a0/r0)).sum(-1).mean(); eop.zero_grad(); el.backward(); eop.step()
+ for s in schs: s.step()
+ elif method == 'credit_bridge':
+ vn = ValueNet(d_hidden=d, s_dim=C).to(device); ve = create_ema_model(vn)
+ Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)]
+ bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks]
+ eop = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd)
+ hop = optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()), lr=lr, weight_decay=wd)
+ vop = optim.Adam(vn.parameters(), lr=lr_fb)
+ schs = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in bops]+[optim.lr_scheduler.CosineAnnealingLR(eop, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(hop, T_max=epochs)]
+ warmup = max(1, epochs//5)
+ for ep in range(1, epochs+1):
+ model.train(); vn.train()
+ blend = 0.0 if ep <= warmup else min(1.0, (ep-warmup)/max(1, warmup))
+ for x, y in train_loader:
+ x = x.view(x.size(0),-1).to(device); y = y.to(device); b = x.size(0)
+ with torch.no_grad(): lo, hi = model(x, return_hidden=True); eT = lo.softmax(-1); eT[torch.arange(b),y] -= 1; s = eT.detach(); tlv = F.cross_entropy(lo, y, reduction='none').detach()
+ hL = hi[-1].detach(); t_L = torch.ones(b, device=device)
+ lt = ((vn(hL, t_L, s)-tlv)**2).mean()
+ hLr = hL.clone().requires_grad_(True); VL = vn(hLr, t_L, s)
+ gV = torch.autograd.grad(VL.sum(), hLr, create_graph=True)[0]
+ hLr2 = hL.clone().requires_grad_(True); ce = F.cross_entropy(model.out_head(model.out_ln(hLr2)), y, reduction='sum')
+ aLe = torch.autograd.grad(ce, hLr2, create_graph=False)[0].detach()
+ ltg = ((gV-aLe)**2).sum(-1).mean()
+ lb = 0.0
+ for l in range(L):
+ hl = hi[l].detach(); tl = torch.full((b,), l/L, device=device); tn = torch.full((b,), (l+1)/L, device=device)
+ Vl = vn(hl, tl, s)
+ with torch.no_grad():
+ hn = hi[l+1].detach(); lts = []
+ for k in range(4): lts.append(-ve(hn+0.05*torch.randn_like(hn), tn, s)/0.1)
+ Vt = -0.1*(torch.logsumexp(torch.stack(lts,-1),-1)-np.log(4))
+ lb += ((Vl-Vt.detach())**2).mean()
+ lb /= L; vl = lt+lb+1.0*ltg
+ vop.zero_grad(); vl.backward(); torch.nn.utils.clip_grad_norm_(vn.parameters(),1.0); vop.step()
+ update_ema(vn, ve, 0.995)
+ cbc = []
+ for l in range(L):
+ hl = hi[l].detach().requires_grad_(True); tl = torch.full((b,), l/L, device=device)
+ Vl = vn(hl, tl, s); cbc.append(torch.autograd.grad(Vl.sum(), hl, create_graph=False)[0].detach())
+ dfac = [(eT@Bs[l].T).detach() for l in range(L)]
+ credits = []
+ for l in range(L):
+ if blend >= 1: credits.append(cbc[l])
+ elif blend <= 0: credits.append(dfac[l])
+ else:
+ cr = (cbc[l]**2).mean(-1,keepdim=True).sqrt()+1e-6; dr = (dfac[l]**2).mean(-1,keepdim=True).sqrt()+1e-6
+ credits.append(blend*cbc[l]/cr+(1-blend)*dfac[l]/dr)
+ lo2 = F.cross_entropy(model.out_head(model.out_ln(hL)), y); hop.zero_grad(); lo2.backward(); hop.step()
+ for l in range(L):
+ a = credits[l]; rm = (a**2).mean(-1,keepdim=True).sqrt()+1e-6
+ ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean()
+ bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0); bops[l].step()
+ a0 = credits[0]; r0 = (a0**2).mean(-1,keepdim=True).sqrt()+1e-6
+ el = (model.embed(x.view(x.size(0),-1))*(a0/r0)).sum(-1).mean(); eop.zero_grad(); el.backward(); eop.step()
+ for s in schs: s.step()
+ return model
+
+
+def run_A2_naive(args):
+ """Compute naive state err for CIFAR methods."""
+ device = torch.device(f'cuda:{args.gpu}')
+ seeds = [42,123,456,789,1024,2048,3000,4000,5000,6000]
+ L, d = 4, 256; methods = ['dfa', 'state_bridge', 'credit_bridge']
+ train_loader, test_loader = get_cifar10()
+ rows = []
+ for seed in seeds:
+ for method in methods:
+ torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
+ model = ResidualMLP(3072, d, 10, L).to(device)
+ print(f" A2 {method} s={seed}: training...", flush=True)
+ model = train_cifar_method(method, model, train_loader, test_loader, device, L, d)
+ nse = compute_naive_state_err(model, test_loader, device, eval_layer=L//2)
+ rows.append({'method': method, 'seed': seed, 'naive_StateErr': nse})
+ print(f" A2 {method} s={seed}: naive_StateErr={nse:.6f}", flush=True)
+ out = os.path.join(args.output_dir, 'A2_naive_state_err.csv')
+ with open(out, 'w', newline='') as f:
+ w = csv.DictWriter(f, fieldnames=['method','seed','naive_StateErr'])
+ w.writeheader(); w.writerows(rows)
+ print(f"Saved {len(rows)} rows to {out}", flush=True)
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--experiment', type=str, default='A1', choices=['A1','A2','both'])
+ p.add_argument('--gpu', type=int, default=3)
+ p.add_argument('--output_dir', type=str, default='results/confirmatory')
+ args = p.parse_args()
+ os.makedirs(args.output_dir, exist_ok=True)
+ if args.experiment in ['A1', 'both']:
+ print("=== A1: Synthetic naive state err ===", flush=True)
+ run_A1_naive(args)
+ if args.experiment in ['A2', 'both']:
+ print("=== A2: CIFAR naive state err ===", flush=True)
+ run_A2_naive(args)
+ print("Done.", flush=True)
+
+
+if __name__ == '__main__':
+ main()