""" Compute naive state prediction baseline: ||h_l - h_L|| / ||h_L|| for all methods. This is the identity-prediction error — how much does h change from layer l to L. Computed at layer L//2 on test/eval data after training. Retrains models with same seeds as confirmatory experiments, then computes metric. """ import os, sys, json, csv, argparse, numpy as np, torch, torch.nn as nn, torch.nn.functional as F import torch.optim as optim, copy from torch.utils.data import DataLoader import torchvision, torchvision.transforms as transforms sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.residual_mlp import ResidualMLP from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema from models.state_bridge import StateBridgeNet def compute_naive_state_err(model, dataloader, device, eval_layer=None): """Compute ||h_l - h_L||_2 / ||h_L||_2 averaged over data (L2 norm ratio, scalar).""" model.eval() L = model.num_blocks if eval_layer is None: eval_layer = L // 2 total_err, n = 0.0, 0 with torch.no_grad(): for x, y in dataloader: if hasattr(model, 'embed'): x = x.view(x.size(0), -1).to(device) else: x = x.to(device) _, hiddens = model(x, return_hidden=True) h_l = hiddens[eval_layer] h_L = hiddens[-1] diff_norm = (h_l - h_L).norm(dim=-1) # (batch,) hL_norm = h_L.norm(dim=-1).clamp(min=1e-8) # (batch,) ratio = (diff_norm / hL_norm).mean() # scalar total_err += ratio.item() * x.size(0) n += x.size(0) return total_err / n # ============ Synthetic ============ class TeacherNet(nn.Module): def __init__(self, d, C, L, alpha=1.0, seed=0): super().__init__() self.alpha = alpha rng = torch.Generator().manual_seed(seed) self.Ws = nn.ParameterList() for _ in range(L): W = torch.randn(d, d, generator=rng) * 0.3 / (d**0.5) U, S, Vh = torch.linalg.svd(W, full_matrices=False) W = U @ torch.diag(S.clamp(max=0.3)) @ Vh self.Ws.append(nn.Parameter(W, requires_grad=False)) self.U = nn.Parameter(torch.randn(C, d, generator=rng) / (d**0.5), requires_grad=False) def phi(self, z): return (1-self.alpha)*z + self.alpha*torch.tanh(z) def forward(self, x): h = x for W in self.Ws: h = h + self.phi(h @ W.T) return h @ self.U.T class StudentBlock(nn.Module): def __init__(self, d, alpha=1.0): super().__init__() self.ln = nn.LayerNorm(d); self.w = nn.Linear(d, d, bias=False) nn.init.normal_(self.w.weight, std=0.01); self.alpha = alpha def phi(self, z): return (1-self.alpha)*z + self.alpha*torch.tanh(z) def forward(self, h): return self.w(self.phi(self.ln(h))) class StudentNet(nn.Module): def __init__(self, d, C, L, alpha=1.0): super().__init__() self.blocks = nn.ModuleList([StudentBlock(d, alpha) for _ in range(L)]) self.out_head = nn.Linear(d, C); self.d_hidden = d; self.num_blocks = L def forward(self, x, return_hidden=False): h = x; hiddens = [h] if return_hidden else None for b in self.blocks: h = h + b(h) if return_hidden: hiddens.append(h) logits = self.out_head(h) return (logits, hiddens) if return_hidden else logits def forward_from_layer(self, h, sl): for i in range(sl, self.num_blocks): h = h + self.blocks[i](h) return self.out_head(h) def train_synth_method(method, model, teacher, device, d, C, L, epochs=80, steps=50, bs=256, lr=1e-3, lr_fb=1e-3): """Train synthetic model with given method, return trained model.""" if method == 'bp': opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01) for ep in range(epochs): model.train() for _ in range(steps): x = torch.randn(bs, d, device=device) with torch.no_grad(): y = teacher(x).argmax(-1) loss = F.cross_entropy(model(x), y); opt.zero_grad(); loss.backward(); opt.step() elif method == 'dfa': Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)] bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=0.01) for b in model.blocks] hop = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=0.01) for ep in range(epochs): model.train() for _ in range(steps): x = torch.randn(bs, d, device=device) with torch.no_grad(): y = teacher(x).argmax(-1); lo, hi = model(x, return_hidden=True) eT = lo.softmax(-1); eT[torch.arange(bs), y] -= 1 lo2 = F.cross_entropy(model.out_head(hi[-1].detach()), y); hop.zero_grad(); lo2.backward(); hop.step() for l in range(L): a = (eT @ Bs[l].T).detach(); rm = (a**2).mean(-1, keepdim=True).sqrt()+1e-6 ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean() bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0); bops[l].step() elif method == 'state_bridge': sp = StateBridgeNet(d_hidden=d, s_dim=C).to(device) Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)] bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=0.01) for b in model.blocks] hop = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=0.01) sop = optim.Adam(sp.parameters(), lr=lr_fb) warmup = max(1, epochs//5) for ep in range(epochs): model.train(); sp.train() for _ in range(steps): x = torch.randn(bs, d, device=device) with torch.no_grad(): y = teacher(x).argmax(-1); lo, hi = model(x, return_hidden=True) eT = lo.softmax(-1); eT[torch.arange(bs), y] -= 1; s = eT.detach() hL = hi[-1].detach() sl = 0.0 for l in range(L): t_l = torch.full((bs,), l/L, device=device) pred = sp(hi[l].detach(), t_l, s) tn = hL.norm(-1, keepdim=True).clamp(min=1.0) sl += (((pred-hL)/tn)**2).sum(-1).mean() sl /= L; sop.zero_grad(); sl.backward(); sop.step() credits = [] for l in range(L): hl = hi[l].detach().requires_grad_(True); t_l = torch.full((bs,), l/L, device=device) pred = sp(hl, t_l, s); pl = F.cross_entropy(model.out_head(pred), y, reduction='sum') credits.append(torch.autograd.grad(pl, hl, create_graph=False)[0].detach()) lo2 = F.cross_entropy(model.out_head(hL), y); hop.zero_grad(); lo2.backward(); hop.step() for l in range(L): a = credits[l]; rm = (a**2).mean(-1, keepdim=True).sqrt()+1e-6 ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean() bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0); bops[l].step() elif method == 'credit_bridge': vn = ValueNet(d_hidden=d, s_dim=C).to(device); ve = create_ema_model(vn) Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)] bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=0.01) for b in model.blocks] hop = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=0.01) vop = optim.Adam(vn.parameters(), lr=lr_fb) warmup = max(1, epochs//5) for ep in range(epochs): model.train(); vn.train() blend = 0.0 if ep < warmup else min(1.0, (ep-warmup)/max(1, warmup)) for _ in range(steps): x = torch.randn(bs, d, device=device) with torch.no_grad(): y = teacher(x).argmax(-1); lo, hi = model(x, return_hidden=True) eT = lo.softmax(-1); eT[torch.arange(bs), y] -= 1; s = eT.detach() tl_val = F.cross_entropy(lo, y, reduction='none').detach() hL = hi[-1].detach(); t_L = torch.ones(bs, device=device) lt = ((vn(hL, t_L, s)-tl_val)**2).mean() hLr = hL.clone().requires_grad_(True); VL = vn(hLr, t_L, s) gV = torch.autograd.grad(VL.sum(), hLr, create_graph=True)[0] hLr2 = hL.clone().requires_grad_(True); ce = F.cross_entropy(model.out_head(hLr2), y, reduction='sum') aLe = torch.autograd.grad(ce, hLr2, create_graph=False)[0].detach() ltg = ((gV-aLe)**2).sum(-1).mean() lb = 0.0 for l in range(L): hl = hi[l].detach(); tl = torch.full((bs,), l/L, device=device); tn = torch.full((bs,), (l+1)/L, device=device) Vl = vn(hl, tl, s) with torch.no_grad(): hn = hi[l+1].detach(); lts = [] for k in range(4): ns = 0.05*torch.randn_like(hn); lts.append(-ve(hn+ns, tn, s)/0.1) Vt = -0.1*(torch.logsumexp(torch.stack(lts, -1), -1)-np.log(4)) lb += ((Vl-Vt.detach())**2).mean() lb /= L; vl = lt+lb+1.0*ltg vop.zero_grad(); vl.backward(); torch.nn.utils.clip_grad_norm_(vn.parameters(), 1.0); vop.step() update_ema(vn, ve, 0.995) cbc = [] for l in range(L): hl = hi[l].detach().requires_grad_(True); tl = torch.full((bs,), l/L, device=device) Vl = vn(hl, tl, s); cbc.append(torch.autograd.grad(Vl.sum(), hl, create_graph=False)[0].detach()) dfac = [(eT@Bs[l].T).detach() for l in range(L)] credits = [] for l in range(L): if blend >= 1: credits.append(cbc[l]) elif blend <= 0: credits.append(dfac[l]) else: cr = (cbc[l]**2).mean(-1,keepdim=True).sqrt()+1e-6; dr = (dfac[l]**2).mean(-1,keepdim=True).sqrt()+1e-6 credits.append(blend*cbc[l]/cr+(1-blend)*dfac[l]/dr) lo2 = F.cross_entropy(model.out_head(hL), y); hop.zero_grad(); lo2.backward(); hop.step() for l in range(L): a = credits[l]; rm = (a**2).mean(-1, keepdim=True).sqrt()+1e-6 ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean() bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0); bops[l].step() return model def run_A1_naive(args): """Compute naive state err for synthetic ladder.""" device = torch.device(f'cuda:{args.gpu}') seeds = [42,123,456,789,1024,2048,3000,4000,5000,6000] alphas = [0.0, 0.5, 1.0]; depths = [4, 8]; d = 128; C = 10 methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge'] rows = [] for alpha in alphas: for L in depths: for seed in seeds: teacher = TeacherNet(d, C, L, alpha=alpha, seed=seed*1000).to(device) # Eval data eval_x = torch.randn(512, d, device=device) with torch.no_grad(): eval_y = teacher(eval_x).argmax(-1) eval_data = [(eval_x, eval_y)] for method in methods: ckpt_dir = os.path.join(args.output_dir, 'checkpoints_A1') os.makedirs(ckpt_dir, exist_ok=True) ckpt_path = os.path.join(ckpt_dir, f'a{alpha}_L{L}_{method}_s{seed}.pt') torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) model = StudentNet(d, C, L, alpha=alpha).to(device) if os.path.exists(ckpt_path): model.load_state_dict(torch.load(ckpt_path, map_location=device)) print(f" A1 alpha={alpha} L={L} {method} s={seed}: loaded checkpoint", flush=True) else: model = train_synth_method(method, model, teacher, device, d, C, L) torch.save(model.state_dict(), ckpt_path) nse = compute_naive_state_err(model, eval_data, device, eval_layer=L//2) rows.append({'alpha': alpha, 'depth': L, 'method': method, 'seed': seed, 'naive_StateErr': nse}) print(f" A1 alpha={alpha} L={L} {method} s={seed}: naive_StateErr={nse:.6f}", flush=True) # Save CSV out = os.path.join(args.output_dir, 'A1_naive_state_err.csv') with open(out, 'w', newline='') as f: w = csv.DictWriter(f, fieldnames=['alpha','depth','method','seed','naive_StateErr']) w.writeheader(); w.writerows(rows) print(f"Saved {len(rows)} rows to {out}", flush=True) def get_cifar10(bs=128): tt = transforms.Compose([transforms.RandomCrop(32,4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))]) tv = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))]) return (DataLoader(torchvision.datasets.CIFAR10('./data',True,download=True,transform=tt),bs,True,num_workers=4,pin_memory=True), DataLoader(torchvision.datasets.CIFAR10('./data',False,download=True,transform=tv),bs,False,num_workers=4,pin_memory=True)) def train_cifar_method(method, model, train_loader, test_loader, device, L, d, epochs=100, lr=1e-3, lr_fb=1e-3, wd=0.01): """Train CIFAR model, return trained model.""" C = 10 if method == 'bp': opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) for ep in range(1, epochs+1): model.train() for x, y in train_loader: x = x.view(x.size(0),-1).to(device); y = y.to(device) loss = F.cross_entropy(model(x), y); opt.zero_grad(); loss.backward(); opt.step() sch.step() return model elif method == 'dfa': Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)] bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] eop = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) hop = optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()), lr=lr, weight_decay=wd) schs = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in bops]+[optim.lr_scheduler.CosineAnnealingLR(eop, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(hop, T_max=epochs)] for ep in range(1, epochs+1): model.train() for x, y in train_loader: x = x.view(x.size(0),-1).to(device); y = y.to(device); b = x.size(0) with torch.no_grad(): lo, hi = model(x, return_hidden=True); eT = lo.softmax(-1); eT[torch.arange(b),y] -= 1 hL = hi[-1].detach() F.cross_entropy(model.out_head(model.out_ln(hL)), y).backward(); hop.step(); hop.zero_grad() for l in range(L): a = (eT@Bs[l].T).detach(); rm = (a**2).mean(-1,keepdim=True).sqrt()+1e-6 ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean() bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0); bops[l].step() a0 = (eT@Bs[0].T).detach(); r0 = (a0**2).mean(-1,keepdim=True).sqrt()+1e-6 el = (model.embed(x.view(x.size(0),-1))*(a0/r0)).sum(-1).mean(); eop.zero_grad(); el.backward(); eop.step() for s in schs: s.step() elif method == 'state_bridge': sp = StateBridgeNet(d_hidden=d, s_dim=C).to(device) Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)] bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] eop = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) hop = optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()), lr=lr, weight_decay=wd) sop = optim.Adam(sp.parameters(), lr=lr_fb) schs = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in bops]+[optim.lr_scheduler.CosineAnnealingLR(eop, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(hop, T_max=epochs)] for ep in range(1, epochs+1): model.train(); sp.train() for x, y in train_loader: x = x.view(x.size(0),-1).to(device); y = y.to(device); b = x.size(0) with torch.no_grad(): lo, hi = model(x, return_hidden=True); eT = lo.softmax(-1); eT[torch.arange(b),y] -= 1; s = eT.detach() hL = hi[-1].detach() sl = 0.0 for l in range(L): tl = torch.full((b,), l/L, device=device) pred = sp(hi[l].detach(), tl, s); tn = hL.norm(-1,keepdim=True).clamp(min=1.0) sl += (((pred-hL)/tn)**2).sum(-1).mean() sl /= L; sop.zero_grad(); sl.backward(); sop.step() credits = [] for l in range(L): hl = hi[l].detach().requires_grad_(True); tl = torch.full((b,), l/L, device=device) pred = sp(hl, tl, s); pl = F.cross_entropy(model.out_head(model.out_ln(pred)), y, reduction='sum') credits.append(torch.autograd.grad(pl, hl, create_graph=False)[0].detach()) lo2 = F.cross_entropy(model.out_head(model.out_ln(hL)), y); hop.zero_grad(); lo2.backward(); hop.step() for l in range(L): a = credits[l]; rm = (a**2).mean(-1,keepdim=True).sqrt()+1e-6 ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean() bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0); bops[l].step() a0 = credits[0]; r0 = (a0**2).mean(-1,keepdim=True).sqrt()+1e-6 el = (model.embed(x.view(x.size(0),-1))*(a0/r0)).sum(-1).mean(); eop.zero_grad(); el.backward(); eop.step() for s in schs: s.step() elif method == 'credit_bridge': vn = ValueNet(d_hidden=d, s_dim=C).to(device); ve = create_ema_model(vn) Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)] bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] eop = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) hop = optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()), lr=lr, weight_decay=wd) vop = optim.Adam(vn.parameters(), lr=lr_fb) schs = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in bops]+[optim.lr_scheduler.CosineAnnealingLR(eop, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(hop, T_max=epochs)] warmup = max(1, epochs//5) for ep in range(1, epochs+1): model.train(); vn.train() blend = 0.0 if ep <= warmup else min(1.0, (ep-warmup)/max(1, warmup)) for x, y in train_loader: x = x.view(x.size(0),-1).to(device); y = y.to(device); b = x.size(0) with torch.no_grad(): lo, hi = model(x, return_hidden=True); eT = lo.softmax(-1); eT[torch.arange(b),y] -= 1; s = eT.detach(); tlv = F.cross_entropy(lo, y, reduction='none').detach() hL = hi[-1].detach(); t_L = torch.ones(b, device=device) lt = ((vn(hL, t_L, s)-tlv)**2).mean() hLr = hL.clone().requires_grad_(True); VL = vn(hLr, t_L, s) gV = torch.autograd.grad(VL.sum(), hLr, create_graph=True)[0] hLr2 = hL.clone().requires_grad_(True); ce = F.cross_entropy(model.out_head(model.out_ln(hLr2)), y, reduction='sum') aLe = torch.autograd.grad(ce, hLr2, create_graph=False)[0].detach() ltg = ((gV-aLe)**2).sum(-1).mean() lb = 0.0 for l in range(L): hl = hi[l].detach(); tl = torch.full((b,), l/L, device=device); tn = torch.full((b,), (l+1)/L, device=device) Vl = vn(hl, tl, s) with torch.no_grad(): hn = hi[l+1].detach(); lts = [] for k in range(4): lts.append(-ve(hn+0.05*torch.randn_like(hn), tn, s)/0.1) Vt = -0.1*(torch.logsumexp(torch.stack(lts,-1),-1)-np.log(4)) lb += ((Vl-Vt.detach())**2).mean() lb /= L; vl = lt+lb+1.0*ltg vop.zero_grad(); vl.backward(); torch.nn.utils.clip_grad_norm_(vn.parameters(),1.0); vop.step() update_ema(vn, ve, 0.995) cbc = [] for l in range(L): hl = hi[l].detach().requires_grad_(True); tl = torch.full((b,), l/L, device=device) Vl = vn(hl, tl, s); cbc.append(torch.autograd.grad(Vl.sum(), hl, create_graph=False)[0].detach()) dfac = [(eT@Bs[l].T).detach() for l in range(L)] credits = [] for l in range(L): if blend >= 1: credits.append(cbc[l]) elif blend <= 0: credits.append(dfac[l]) else: cr = (cbc[l]**2).mean(-1,keepdim=True).sqrt()+1e-6; dr = (dfac[l]**2).mean(-1,keepdim=True).sqrt()+1e-6 credits.append(blend*cbc[l]/cr+(1-blend)*dfac[l]/dr) lo2 = F.cross_entropy(model.out_head(model.out_ln(hL)), y); hop.zero_grad(); lo2.backward(); hop.step() for l in range(L): a = credits[l]; rm = (a**2).mean(-1,keepdim=True).sqrt()+1e-6 ll = (model.blocks[l](hi[l].detach())*(a/rm)).sum(-1).mean() bops[l].zero_grad(); ll.backward(); torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0); bops[l].step() a0 = credits[0]; r0 = (a0**2).mean(-1,keepdim=True).sqrt()+1e-6 el = (model.embed(x.view(x.size(0),-1))*(a0/r0)).sum(-1).mean(); eop.zero_grad(); el.backward(); eop.step() for s in schs: s.step() return model def run_A2_naive(args): """Compute naive state err for CIFAR methods.""" device = torch.device(f'cuda:{args.gpu}') seeds = [42,123,456,789,1024,2048,3000,4000,5000,6000] L, d = 4, 256; methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge'] train_loader, test_loader = get_cifar10() rows = [] for seed in seeds: for method in methods: ckpt_dir = os.path.join(args.output_dir, 'checkpoints_A2') os.makedirs(ckpt_dir, exist_ok=True) ckpt_path = os.path.join(ckpt_dir, f'{method}_s{seed}.pt') torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) model = ResidualMLP(3072, d, 10, L).to(device) if os.path.exists(ckpt_path): model.load_state_dict(torch.load(ckpt_path, map_location=device)) print(f" A2 {method} s={seed}: loaded checkpoint", flush=True) else: print(f" A2 {method} s={seed}: training...", flush=True) model = train_cifar_method(method, model, train_loader, test_loader, device, L, d) torch.save(model.state_dict(), ckpt_path) nse = compute_naive_state_err(model, test_loader, device, eval_layer=L//2) rows.append({'method': method, 'seed': seed, 'naive_StateErr': nse}) print(f" A2 {method} s={seed}: naive_StateErr={nse:.6f}", flush=True) out = os.path.join(args.output_dir, 'A2_naive_state_err.csv') with open(out, 'w', newline='') as f: w = csv.DictWriter(f, fieldnames=['method','seed','naive_StateErr']) w.writeheader(); w.writerows(rows) print(f"Saved {len(rows)} rows to {out}", flush=True) def main(): p = argparse.ArgumentParser() p.add_argument('--experiment', type=str, default='A1', choices=['A1','A2','both']) p.add_argument('--gpu', type=int, default=3) p.add_argument('--output_dir', type=str, default='results/confirmatory') args = p.parse_args() os.makedirs(args.output_dir, exist_ok=True) if args.experiment in ['A1', 'both']: print("=== A1: Synthetic naive state err ===", flush=True) run_A1_naive(args) if args.experiment in ['A2', 'both']: print("=== A2: CIFAR naive state err ===", flush=True) run_A2_naive(args) print("Done.", flush=True) if __name__ == '__main__': main()