diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-03-31 10:36:02 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-03-31 10:36:02 -0500 |
| commit | 1359b7e7a96ab57be0bb24ebdf842a793ce01223 (patch) | |
| tree | ba13d16ed1bf4604ffb76e12ceb1fdb5958900bb /experiments | |
| parent | 8b21fb32bf0997e3f4266c1c22414e49f1fdcfcc (diff) | |
Add naive state prediction baseline for A1 and A2
A1: 240 rows (3 alpha × 2 depth × 4 methods × 10 seeds)
A2: 30 rows (3 methods × 10 seeds)
naive_StateErr = ||h_{L//2} - h_L|| / ||h_L||
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
| -rw-r--r-- | experiments/compute_naive_state_err.py | 390 |
1 files changed, 390 insertions, 0 deletions
diff --git a/experiments/compute_naive_state_err.py b/experiments/compute_naive_state_err.py new file mode 100644 index 0000000..64c8790 --- /dev/null +++ b/experiments/compute_naive_state_err.py @@ -0,0 +1,390 @@ +""" +Compute naive state prediction baseline: ||h_l - h_L|| / ||h_L|| for all methods. +This is the identity-prediction error — how much does h change from layer l to L. +Computed at layer L//2 on test/eval data after training. + +Retrains models with same seeds as confirmatory experiments, then computes metric. +""" +import os, sys, json, csv, argparse, numpy as np, torch, torch.nn as nn, torch.nn.functional as F +import torch.optim as optim, copy +from torch.utils.data import DataLoader +import torchvision, torchvision.transforms as transforms + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from models.residual_mlp import ResidualMLP +from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema +from models.state_bridge import StateBridgeNet + + +def compute_naive_state_err(model, dataloader, device, eval_layer=None): + """Compute ||h_l - h_L|| / ||h_L|| averaged over data.""" + model.eval() + L = model.num_blocks + if eval_layer is None: + eval_layer = L // 2 + total_err, n = 0.0, 0 + with torch.no_grad(): + for x, y in dataloader: + if hasattr(model, 'embed'): + x = x.view(x.size(0), -1).to(device) + else: + x = x.to(device) + _, hiddens = model(x, return_hidden=True) + h_l = hiddens[eval_layer] + h_L = hiddens[-1] + norm_L = h_L.norm(dim=-1, keepdim=True).clamp(min=1.0) + err = ((h_l - h_L) / norm_L).pow(2).sum(-1).mean() + total_err += err.item() * x.size(0) + n += x.size(0) + return total_err / n + + +# ============ Synthetic ============ +class TeacherNet(nn.Module): + def __init__(self, d, C, L, alpha=1.0, seed=0): + super().__init__() + self.alpha = alpha + rng = torch.Generator().manual_seed(seed) + self.Ws = nn.ParameterList() + for _ in range(L): + W = torch.randn(d, d, generator=rng) * 0.3 / (d**0.5) + U, S, Vh = torch.linalg.svd(W, full_matrices=False) + W = U @ torch.diag(S.clamp(max=0.3)) @ Vh + self.Ws.append(nn.Parameter(W, requires_grad=False)) + self.U = nn.Parameter(torch.randn(C, d, generator=rng) / (d**0.5), requires_grad=False) + def phi(self, z): return (1-self.alpha)*z + self.alpha*torch.tanh(z) + def forward(self, x): + h = x + for W in self.Ws: h = h + self.phi(h @ W.T) + return h @ self.U.T + +class StudentBlock(nn.Module): + def __init__(self, d, alpha=1.0): + super().__init__() + self.ln = nn.LayerNorm(d); self.w = nn.Linear(d, d, bias=False) + nn.init.normal_(self.w.weight, std=0.01); self.alpha = alpha + def phi(self, z): return (1-self.alpha)*z + self.alpha*torch.tanh(z) + def forward(self, h): return self.w(self.phi(self.ln(h))) + +class StudentNet(nn.Module): + def __init__(self, d, C, L, alpha=1.0): + super().__init__() + self.blocks = nn.ModuleList([StudentBlock(d, alpha) for _ in range(L)]) + self.out_head = nn.Linear(d, C); self.d_hidden = d; self.num_blocks = L + def forward(self, x, return_hidden=False): + h = x; hiddens = [h] if return_hidden else None + for b in self.blocks: + h = h + b(h) + if return_hidden: hiddens.append(h) + logits = self.out_head(h) + return (logits, hiddens) if return_hidden else logits + def forward_from_layer(self, h, sl): + for i in range(sl, self.num_blocks): h = h + self.blocks[i](h) + return self.out_head(h) + + +def train_synth_method(method, model, teacher, device, d, C, L, epochs=80, steps=50, bs=256, lr=1e-3, lr_fb=1e-3): + """Train synthetic model with given method, return trained model.""" + if method == 'bp': + opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01) + for ep in range(epochs): + model.train() + for _ in range(steps): + x = torch.randn(bs, d, device=device) + with torch.no_grad(): y = teacher(x).argmax(-1) + loss = F.cross_entropy(model(x), y); opt.zero_grad(); loss.backward(); opt.step() + elif method == 'dfa': + Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)] + bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=0.01) for b in model.blocks] + hop = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=0.01) + for ep in range(epochs): + model.train() + for _ in range(steps): + x = torch.randn(bs, d, device=device) + with torch.no_grad(): + y = teacher(x).argmax(-1); lo, hi = model(x, return_hidden=True) + eT = lo.softmax(-1); eT[torch.arange(bs), y] -= 1 + lo2 = F.cross_entropy(model.out_head(hi[-1].detach()), y); hop.zero_grad(); lo2.backward(); hop.step() + for l in range(L): + a = (eT @ Bs[l].T).detach(); rm = (a**2).mean(-1, keepdim=True).sqrt()+1e-6 + ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean() + bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0); bops[l].step() + elif method == 'state_bridge': + sp = StateBridgeNet(d_hidden=d, s_dim=C).to(device) + Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)] + bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=0.01) for b in model.blocks] + hop = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=0.01) + sop = optim.Adam(sp.parameters(), lr=lr_fb) + warmup = max(1, epochs//5) + for ep in range(epochs): + model.train(); sp.train() + for _ in range(steps): + x = torch.randn(bs, d, device=device) + with torch.no_grad(): + y = teacher(x).argmax(-1); lo, hi = model(x, return_hidden=True) + eT = lo.softmax(-1); eT[torch.arange(bs), y] -= 1; s = eT.detach() + hL = hi[-1].detach() + sl = 0.0 + for l in range(L): + t_l = torch.full((bs,), l/L, device=device) + pred = sp(hi[l].detach(), t_l, s) + tn = hL.norm(-1, keepdim=True).clamp(min=1.0) + sl += (((pred-hL)/tn)**2).sum(-1).mean() + sl /= L; sop.zero_grad(); sl.backward(); sop.step() + credits = [] + for l in range(L): + hl = hi[l].detach().requires_grad_(True); t_l = torch.full((bs,), l/L, device=device) + pred = sp(hl, t_l, s); pl = F.cross_entropy(model.out_head(pred), y, reduction='sum') + credits.append(torch.autograd.grad(pl, hl, create_graph=False)[0].detach()) + lo2 = F.cross_entropy(model.out_head(hL), y); hop.zero_grad(); lo2.backward(); hop.step() + for l in range(L): + a = credits[l]; rm = (a**2).mean(-1, keepdim=True).sqrt()+1e-6 + ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean() + bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0); bops[l].step() + elif method == 'credit_bridge': + vn = ValueNet(d_hidden=d, s_dim=C).to(device); ve = create_ema_model(vn) + Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)] + bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=0.01) for b in model.blocks] + hop = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=0.01) + vop = optim.Adam(vn.parameters(), lr=lr_fb) + warmup = max(1, epochs//5) + for ep in range(epochs): + model.train(); vn.train() + blend = 0.0 if ep < warmup else min(1.0, (ep-warmup)/max(1, warmup)) + for _ in range(steps): + x = torch.randn(bs, d, device=device) + with torch.no_grad(): + y = teacher(x).argmax(-1); lo, hi = model(x, return_hidden=True) + eT = lo.softmax(-1); eT[torch.arange(bs), y] -= 1; s = eT.detach() + tl_val = F.cross_entropy(lo, y, reduction='none').detach() + hL = hi[-1].detach(); t_L = torch.ones(bs, device=device) + lt = ((vn(hL, t_L, s)-tl_val)**2).mean() + hLr = hL.clone().requires_grad_(True); VL = vn(hLr, t_L, s) + gV = torch.autograd.grad(VL.sum(), hLr, create_graph=True)[0] + hLr2 = hL.clone().requires_grad_(True); ce = F.cross_entropy(model.out_head(hLr2), y, reduction='sum') + aLe = torch.autograd.grad(ce, hLr2, create_graph=False)[0].detach() + ltg = ((gV-aLe)**2).sum(-1).mean() + lb = 0.0 + for l in range(L): + hl = hi[l].detach(); tl = torch.full((bs,), l/L, device=device); tn = torch.full((bs,), (l+1)/L, device=device) + Vl = vn(hl, tl, s) + with torch.no_grad(): + hn = hi[l+1].detach(); lts = [] + for k in range(4): + ns = 0.05*torch.randn_like(hn); lts.append(-ve(hn+ns, tn, s)/0.1) + Vt = -0.1*(torch.logsumexp(torch.stack(lts, -1), -1)-np.log(4)) + lb += ((Vl-Vt.detach())**2).mean() + lb /= L; vl = lt+lb+1.0*ltg + vop.zero_grad(); vl.backward(); torch.nn.utils.clip_grad_norm_(vn.parameters(), 1.0); vop.step() + update_ema(vn, ve, 0.995) + cbc = [] + for l in range(L): + hl = hi[l].detach().requires_grad_(True); tl = torch.full((bs,), l/L, device=device) + Vl = vn(hl, tl, s); cbc.append(torch.autograd.grad(Vl.sum(), hl, create_graph=False)[0].detach()) + dfac = [(eT@Bs[l].T).detach() for l in range(L)] + credits = [] + for l in range(L): + if blend >= 1: credits.append(cbc[l]) + elif blend <= 0: credits.append(dfac[l]) + else: + cr = (cbc[l]**2).mean(-1,keepdim=True).sqrt()+1e-6; dr = (dfac[l]**2).mean(-1,keepdim=True).sqrt()+1e-6 + credits.append(blend*cbc[l]/cr+(1-blend)*dfac[l]/dr) + lo2 = F.cross_entropy(model.out_head(hL), y); hop.zero_grad(); lo2.backward(); hop.step() + for l in range(L): + a = credits[l]; rm = (a**2).mean(-1, keepdim=True).sqrt()+1e-6 + ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean() + bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0); bops[l].step() + return model + + +def run_A1_naive(args): + """Compute naive state err for synthetic ladder.""" + device = torch.device(f'cuda:{args.gpu}') + seeds = [42,123,456,789,1024,2048,3000,4000,5000,6000] + alphas = [0.0, 0.5, 1.0]; depths = [4, 8]; d = 128; C = 10 + methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge'] + rows = [] + for alpha in alphas: + for L in depths: + for seed in seeds: + teacher = TeacherNet(d, C, L, alpha=alpha, seed=seed*1000).to(device) + # Eval data + eval_x = torch.randn(512, d, device=device) + with torch.no_grad(): eval_y = teacher(eval_x).argmax(-1) + eval_data = [(eval_x, eval_y)] + for method in methods: + torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) + model = StudentNet(d, C, L, alpha=alpha).to(device) + model = train_synth_method(method, model, teacher, device, d, C, L) + nse = compute_naive_state_err(model, eval_data, device, eval_layer=L//2) + rows.append({'alpha': alpha, 'depth': L, 'method': method, 'seed': seed, 'naive_StateErr': nse}) + print(f" A1 alpha={alpha} L={L} {method} s={seed}: naive_StateErr={nse:.6f}", flush=True) + # Save CSV + out = os.path.join(args.output_dir, 'A1_naive_state_err.csv') + with open(out, 'w', newline='') as f: + w = csv.DictWriter(f, fieldnames=['alpha','depth','method','seed','naive_StateErr']) + w.writeheader(); w.writerows(rows) + print(f"Saved {len(rows)} rows to {out}", flush=True) + + +def get_cifar10(bs=128): + tt = transforms.Compose([transforms.RandomCrop(32,4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))]) + tv = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))]) + return (DataLoader(torchvision.datasets.CIFAR10('./data',True,download=True,transform=tt),bs,True,num_workers=4,pin_memory=True), + DataLoader(torchvision.datasets.CIFAR10('./data',False,download=True,transform=tv),bs,False,num_workers=4,pin_memory=True)) + + +def train_cifar_method(method, model, train_loader, test_loader, device, L, d, epochs=100, lr=1e-3, lr_fb=1e-3, wd=0.01): + """Train CIFAR model, return trained model.""" + C = 10 + if method == 'dfa': + Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)] + bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] + eop = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) + hop = optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()), lr=lr, weight_decay=wd) + schs = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in bops]+[optim.lr_scheduler.CosineAnnealingLR(eop, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(hop, T_max=epochs)] + for ep in range(1, epochs+1): + model.train() + for x, y in train_loader: + x = x.view(x.size(0),-1).to(device); y = y.to(device); b = x.size(0) + with torch.no_grad(): lo, hi = model(x, return_hidden=True); eT = lo.softmax(-1); eT[torch.arange(b),y] -= 1 + hL = hi[-1].detach() + F.cross_entropy(model.out_head(model.out_ln(hL)), y).backward(); hop.step(); hop.zero_grad() + for l in range(L): + a = (eT@Bs[l].T).detach(); rm = (a**2).mean(-1,keepdim=True).sqrt()+1e-6 + ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean() + bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0); bops[l].step() + a0 = (eT@Bs[0].T).detach(); r0 = (a0**2).mean(-1,keepdim=True).sqrt()+1e-6 + el = (model.embed(x.view(x.size(0),-1))*(a0/r0)).sum(-1).mean(); eop.zero_grad(); el.backward(); eop.step() + for s in schs: s.step() + elif method == 'state_bridge': + sp = StateBridgeNet(d_hidden=d, s_dim=C).to(device) + Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)] + bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] + eop = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) + hop = optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()), lr=lr, weight_decay=wd) + sop = optim.Adam(sp.parameters(), lr=lr_fb) + schs = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in bops]+[optim.lr_scheduler.CosineAnnealingLR(eop, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(hop, T_max=epochs)] + for ep in range(1, epochs+1): + model.train(); sp.train() + for x, y in train_loader: + x = x.view(x.size(0),-1).to(device); y = y.to(device); b = x.size(0) + with torch.no_grad(): lo, hi = model(x, return_hidden=True); eT = lo.softmax(-1); eT[torch.arange(b),y] -= 1; s = eT.detach() + hL = hi[-1].detach() + sl = 0.0 + for l in range(L): + tl = torch.full((b,), l/L, device=device) + pred = sp(hi[l].detach(), tl, s); tn = hL.norm(-1,keepdim=True).clamp(min=1.0) + sl += (((pred-hL)/tn)**2).sum(-1).mean() + sl /= L; sop.zero_grad(); sl.backward(); sop.step() + credits = [] + for l in range(L): + hl = hi[l].detach().requires_grad_(True); tl = torch.full((b,), l/L, device=device) + pred = sp(hl, tl, s); pl = F.cross_entropy(model.out_head(model.out_ln(pred)), y, reduction='sum') + credits.append(torch.autograd.grad(pl, hl, create_graph=False)[0].detach()) + lo2 = F.cross_entropy(model.out_head(model.out_ln(hL)), y); hop.zero_grad(); lo2.backward(); hop.step() + for l in range(L): + a = credits[l]; rm = (a**2).mean(-1,keepdim=True).sqrt()+1e-6 + ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean() + bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0); bops[l].step() + a0 = credits[0]; r0 = (a0**2).mean(-1,keepdim=True).sqrt()+1e-6 + el = (model.embed(x.view(x.size(0),-1))*(a0/r0)).sum(-1).mean(); eop.zero_grad(); el.backward(); eop.step() + for s in schs: s.step() + elif method == 'credit_bridge': + vn = ValueNet(d_hidden=d, s_dim=C).to(device); ve = create_ema_model(vn) + Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)] + bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] + eop = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) + hop = optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()), lr=lr, weight_decay=wd) + vop = optim.Adam(vn.parameters(), lr=lr_fb) + schs = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in bops]+[optim.lr_scheduler.CosineAnnealingLR(eop, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(hop, T_max=epochs)] + warmup = max(1, epochs//5) + for ep in range(1, epochs+1): + model.train(); vn.train() + blend = 0.0 if ep <= warmup else min(1.0, (ep-warmup)/max(1, warmup)) + for x, y in train_loader: + x = x.view(x.size(0),-1).to(device); y = y.to(device); b = x.size(0) + with torch.no_grad(): lo, hi = model(x, return_hidden=True); eT = lo.softmax(-1); eT[torch.arange(b),y] -= 1; s = eT.detach(); tlv = F.cross_entropy(lo, y, reduction='none').detach() + hL = hi[-1].detach(); t_L = torch.ones(b, device=device) + lt = ((vn(hL, t_L, s)-tlv)**2).mean() + hLr = hL.clone().requires_grad_(True); VL = vn(hLr, t_L, s) + gV = torch.autograd.grad(VL.sum(), hLr, create_graph=True)[0] + hLr2 = hL.clone().requires_grad_(True); ce = F.cross_entropy(model.out_head(model.out_ln(hLr2)), y, reduction='sum') + aLe = torch.autograd.grad(ce, hLr2, create_graph=False)[0].detach() + ltg = ((gV-aLe)**2).sum(-1).mean() + lb = 0.0 + for l in range(L): + hl = hi[l].detach(); tl = torch.full((b,), l/L, device=device); tn = torch.full((b,), (l+1)/L, device=device) + Vl = vn(hl, tl, s) + with torch.no_grad(): + hn = hi[l+1].detach(); lts = [] + for k in range(4): lts.append(-ve(hn+0.05*torch.randn_like(hn), tn, s)/0.1) + Vt = -0.1*(torch.logsumexp(torch.stack(lts,-1),-1)-np.log(4)) + lb += ((Vl-Vt.detach())**2).mean() + lb /= L; vl = lt+lb+1.0*ltg + vop.zero_grad(); vl.backward(); torch.nn.utils.clip_grad_norm_(vn.parameters(),1.0); vop.step() + update_ema(vn, ve, 0.995) + cbc = [] + for l in range(L): + hl = hi[l].detach().requires_grad_(True); tl = torch.full((b,), l/L, device=device) + Vl = vn(hl, tl, s); cbc.append(torch.autograd.grad(Vl.sum(), hl, create_graph=False)[0].detach()) + dfac = [(eT@Bs[l].T).detach() for l in range(L)] + credits = [] + for l in range(L): + if blend >= 1: credits.append(cbc[l]) + elif blend <= 0: credits.append(dfac[l]) + else: + cr = (cbc[l]**2).mean(-1,keepdim=True).sqrt()+1e-6; dr = (dfac[l]**2).mean(-1,keepdim=True).sqrt()+1e-6 + credits.append(blend*cbc[l]/cr+(1-blend)*dfac[l]/dr) + lo2 = F.cross_entropy(model.out_head(model.out_ln(hL)), y); hop.zero_grad(); lo2.backward(); hop.step() + for l in range(L): + a = credits[l]; rm = (a**2).mean(-1,keepdim=True).sqrt()+1e-6 + ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean() + bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0); bops[l].step() + a0 = credits[0]; r0 = (a0**2).mean(-1,keepdim=True).sqrt()+1e-6 + el = (model.embed(x.view(x.size(0),-1))*(a0/r0)).sum(-1).mean(); eop.zero_grad(); el.backward(); eop.step() + for s in schs: s.step() + return model + + +def run_A2_naive(args): + """Compute naive state err for CIFAR methods.""" + device = torch.device(f'cuda:{args.gpu}') + seeds = [42,123,456,789,1024,2048,3000,4000,5000,6000] + L, d = 4, 256; methods = ['dfa', 'state_bridge', 'credit_bridge'] + train_loader, test_loader = get_cifar10() + rows = [] + for seed in seeds: + for method in methods: + torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) + model = ResidualMLP(3072, d, 10, L).to(device) + print(f" A2 {method} s={seed}: training...", flush=True) + model = train_cifar_method(method, model, train_loader, test_loader, device, L, d) + nse = compute_naive_state_err(model, test_loader, device, eval_layer=L//2) + rows.append({'method': method, 'seed': seed, 'naive_StateErr': nse}) + print(f" A2 {method} s={seed}: naive_StateErr={nse:.6f}", flush=True) + out = os.path.join(args.output_dir, 'A2_naive_state_err.csv') + with open(out, 'w', newline='') as f: + w = csv.DictWriter(f, fieldnames=['method','seed','naive_StateErr']) + w.writeheader(); w.writerows(rows) + print(f"Saved {len(rows)} rows to {out}", flush=True) + + +def main(): + p = argparse.ArgumentParser() + p.add_argument('--experiment', type=str, default='A1', choices=['A1','A2','both']) + p.add_argument('--gpu', type=int, default=3) + p.add_argument('--output_dir', type=str, default='results/confirmatory') + args = p.parse_args() + os.makedirs(args.output_dir, exist_ok=True) + if args.experiment in ['A1', 'both']: + print("=== A1: Synthetic naive state err ===", flush=True) + run_A1_naive(args) + if args.experiment in ['A2', 'both']: + print("=== A2: CIFAR naive state err ===", flush=True) + run_A2_naive(args) + print("Done.", flush=True) + + +if __name__ == '__main__': + main() |
