""" Phase 5B: Frozen CIFAR Vector Credit Transfer. Test whether direct vector credit field can recover better credit than scalar CB on frozen BP-trained CIFAR representations. Methods compared: - DFA (random) - StateBridge_eT - ScalarCB_eT - ScalarCB_deltaL - VectorField_eT_M{4,8,16} - VectorField_deltaL_M{4,8,16} (if resources allow) """ import os import sys import json import argparse import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader import torchvision import torchvision.transforms as transforms import copy sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.residual_mlp import ResidualMLP from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema from models.state_bridge import StateBridgeNet from metrics.credit_metrics import ( cosine_similarity_batch, perturbation_correlation, nudging_test ) class VectorCreditNet(nn.Module): """Direct vector credit field: a_phi(h_l, t_l, s) -> R^d.""" def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3): super().__init__() self.ln = nn.LayerNorm(d_hidden) self.time_embed = SinusoidalTimeEmbed(time_embed_dim) input_dim = d_hidden + time_embed_dim + s_dim layers = [] for i in range(num_layers): in_d = input_dim if i == 0 else hidden_dim layers.append(nn.Linear(in_d, hidden_dim)) layers.append(nn.GELU()) layers.append(nn.Linear(hidden_dim, d_hidden)) self.net = nn.Sequential(*layers) def forward(self, h, t, s): h_normed = self.ln(h) t_emb = self.time_embed(t) inp = torch.cat([h_normed, t_emb, s], dim=-1) return self.net(inp) def get_cifar10(batch_size=128): transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) return train_loader, test_loader def evaluate(model, test_loader, device): model.eval() correct, total = 0, 0 with torch.no_grad(): for x, y in test_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) logits = model(x) correct += (logits.argmax(1) == y).sum().item() total += x.size(0) return correct / total def train_bp_reference(model, train_loader, test_loader, device, epochs=100, lr=1e-3, wd=0.01): optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) for epoch in range(1, epochs + 1): model.train() total_loss, correct, total = 0, 0, 0 for x, y in train_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) logits = model(x) loss = F.cross_entropy(logits, y) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() * x.size(0) correct += (logits.argmax(1) == y).sum().item() total += x.size(0) scheduler.step() if epoch % 20 == 0 or epoch == 1: test_acc = evaluate(model, test_loader, device) print(f" [BP ref] Ep {epoch}: loss={total_loss/total:.4f}, test={test_acc:.4f}") test_acc = evaluate(model, test_loader, device) print(f" [BP ref] Final: {test_acc:.4f}") return test_acc # ============================================================================= # Estimator training functions (all on frozen model) # ============================================================================= def train_state_bridge_frozen(model, train_loader, device, epochs, lr_fb): d = model.d_hidden L = model.num_blocks state_pred = StateBridgeNet(d_hidden=d, s_dim=10, time_embed_dim=32, hidden_dim=256, num_layers=3).to(device) state_opt = optim.Adam(state_pred.parameters(), lr=lr_fb) model.eval() for epoch in range(1, epochs + 1): state_pred.train() total_loss, n = 0, 0 for x, y in train_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 s = e_T.detach() hL_det = hiddens[-1].detach() state_loss = 0.0 for l in range(L): h_l_det = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) pred_hL = state_pred(h_l_det, t_l, s) target_norm = hL_det.norm(dim=-1, keepdim=True).clamp(min=1.0) state_loss += (((pred_hL - hL_det) / target_norm) ** 2).sum(dim=-1).mean() state_loss /= L state_opt.zero_grad() state_loss.backward() state_opt.step() total_loss += state_loss.item() * batch n += batch if epoch % 20 == 0 or epoch == 1: print(f" [SB] Ep {epoch}: loss={total_loss/n:.6f}") return state_pred def train_scalar_cb_frozen(model, train_loader, device, epochs, lr_fb, s_type='eT', lam=0.1, K=4, sigma_bridge=0.05, ema_momentum=0.995, term_grad_weight=1.0): d = model.d_hidden L = model.num_blocks s_dim = 10 if s_type == 'eT' else d value_net = ValueNet(d_hidden=d, s_dim=s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3).to(device) value_net_ema = create_ema_model(value_net) value_opt = optim.Adam(value_net.parameters(), lr=lr_fb) model.eval() for epoch in range(1, epochs + 1): value_net.train() total_vloss, n = 0, 0 for x, y in train_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 true_loss = F.cross_entropy(logits, y, reduction='none').detach() hL_det = hiddens[-1].detach() if s_type == 'eT': s = e_T.detach() else: hL_req = hL_det.clone().requires_grad_(True) logits_for_s = model.out_head(model.out_ln(hL_req)) ce_for_s = F.cross_entropy(logits_for_s, y, reduction='sum') s = torch.autograd.grad(ce_for_s, hL_req, create_graph=False)[0].detach() t_L = torch.ones(batch, device=device) V_terminal = value_net(hL_det, t_L, s) loss_term = ((V_terminal - true_loss) ** 2).mean() loss_tgrad = torch.tensor(0.0, device=device) if term_grad_weight > 0: hL_req = hL_det.clone().requires_grad_(True) V_at_L = value_net(hL_req, t_L, s) grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0] hL_req2 = hL_det.clone().requires_grad_(True) logits_tgt = model.out_head(model.out_ln(hL_req2)) ce_loss = F.cross_entropy(logits_tgt, y, reduction='sum') a_L_exact = torch.autograd.grad(ce_loss, hL_req2, create_graph=False)[0].detach() loss_tgrad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean() loss_bridge = 0.0 for l in range(L): h_l_det = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) t_next = torch.full((batch,), (l + 1) / L, device=device) V_l = value_net(h_l_det, t_l, s) with torch.no_grad(): h_next_det = hiddens[l + 1].detach() log_terms = [] for k in range(K): noise = sigma_bridge * torch.randn_like(h_next_det) V_next = value_net_ema(h_next_det + noise, t_next, s) log_terms.append(-V_next / lam) log_stack = torch.stack(log_terms, dim=-1) V_target = -lam * (torch.logsumexp(log_stack, dim=-1) - np.log(K)) loss_bridge += ((V_l - V_target.detach()) ** 2).mean() loss_bridge /= L vloss = loss_term + loss_bridge + term_grad_weight * loss_tgrad value_opt.zero_grad() vloss.backward() torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0) value_opt.step() update_ema(value_net, value_net_ema, ema_momentum) total_vloss += vloss.item() * batch n += batch if epoch % 20 == 0 or epoch == 1: print(f" [CB_{s_type}] Ep {epoch}: vloss={total_vloss/n:.6f}") return value_net def train_vector_field_frozen(model, train_loader, device, epochs, lr_fb, s_type='eT', M=4, eps=1e-3, beta=1.0, term_weight=1.0): """ Train vector credit field on frozen CIFAR features. Layer subsampling: each batch, randomly pick one layer for perturbation target. Terminal matching always uses layer L. """ d = model.d_hidden L = model.num_blocks s_dim = 10 if s_type == 'eT' else d vector_net = VectorCreditNet(d_hidden=d, s_dim=s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3).to(device) vec_opt = optim.Adam(vector_net.parameters(), lr=lr_fb) model.eval() for epoch in range(1, epochs + 1): vector_net.train() total_vloss, n = 0, 0 for x, y in train_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 hL_det = hiddens[-1].detach() # Compute s if s_type == 'eT': s = e_T.detach() else: hL_req = hL_det.clone().requires_grad_(True) logits_for_s = model.out_head(model.out_ln(hL_req)) ce_for_s = F.cross_entropy(logits_for_s, y, reduction='sum') s = torch.autograd.grad(ce_for_s, hL_req, create_graph=False)[0].detach() # Terminal matching loss_term = torch.tensor(0.0, device=device) if term_weight > 0: t_L = torch.ones(batch, device=device) a_terminal = vector_net(hL_det, t_L, s) hL_req = hL_det.clone().requires_grad_(True) logits_tgt = model.out_head(model.out_ln(hL_req)) ce = F.cross_entropy(logits_tgt, y, reduction='sum') delta_L = torch.autograd.grad(ce, hL_req, create_graph=False)[0].detach() loss_term = ((a_terminal - delta_L) ** 2).sum(dim=-1).mean() # Perturbation target — subsample 1 random layer per batch l = np.random.randint(0, L) h_l_det = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) a_l = vector_net(h_l_det, t_l, s) loss_proj = torch.tensor(0.0, device=device) for _ in range(M): v = torch.randn_like(h_l_det) v = v / (v.norm(dim=-1, keepdim=True) + 1e-8) with torch.no_grad(): # Use model.forward_from_layer for tail forward logits_plus = model.forward_from_layer(h_l_det + eps * v, l) loss_plus = F.cross_entropy(logits_plus, y, reduction='none') logits_minus = model.forward_from_layer(h_l_det - eps * v, l) loss_minus = F.cross_entropy(logits_minus, y, reduction='none') g_j = (loss_plus - loss_minus) / (2 * eps) pred_j = (a_l * v).sum(dim=-1) loss_proj = loss_proj + ((pred_j - g_j.detach()) ** 2).mean() loss_proj = loss_proj / M vloss = term_weight * loss_term + beta * loss_proj vec_opt.zero_grad() vloss.backward() torch.nn.utils.clip_grad_norm_(vector_net.parameters(), 1.0) vec_opt.step() total_vloss += vloss.item() * batch n += batch if epoch % 20 == 0 or epoch == 1: print(f" [vec_{s_type}_M{M}] Ep {epoch}: vloss={total_vloss/n:.6f}") return vector_net # ============================================================================= # Evaluation # ============================================================================= def evaluate_all(model, test_loader, device, estimators): """Evaluate credit quality for all estimators on frozen features.""" model.eval() d = model.d_hidden L = model.num_blocks # DFA baseline dfa_Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)] # Use multiple test batches for robust Gamma, single batch for rho/nudge (expensive) results = {} for name in list(estimators.keys()) + ['dfa']: results[name] = { 'bp_cosine': [[] for _ in range(L)], 'perturbation_rho': [0.0] * L, 'nudging_0.001': [0.0] * L, 'nudging_0.003': [0.0] * L, 'nudging_0.01': [0.0] * L, } n_batches = min(10, len(test_loader)) batch_idx = 0 for x, y in test_loader: if batch_idx >= n_batches: break batch_idx += 1 x = x.view(x.size(0), -1).to(device) y = y.to(device) batch = x.size(0) # BP gradients for p in model.parameters(): p.requires_grad_(True) model.zero_grad() logits_bp, hiddens_bp = model(x, return_hidden=True) for l in range(L + 1): hiddens_bp[l].retain_grad() loss_bp = F.cross_entropy(logits_bp, y) loss_bp.backward() bp_grads = {l: hiddens_bp[l].grad.detach().clone() for l in range(L + 1)} for p in model.parameters(): p.requires_grad_(False) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 s_eT = e_T.detach() hL_det = hiddens[-1].detach() hL_req = hL_det.clone().requires_grad_(True) logits_delta = model.out_head(model.out_ln(hL_req)) ce_delta = F.cross_entropy(logits_delta, y, reduction='sum') delta_L = torch.autograd.grad(ce_delta, hL_req, create_graph=False)[0].detach() for l in range(L): h_l = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) def make_fwd_fn(start_l): def fwd_fn(h): with torch.no_grad(): curr = h for i in range(start_l, L): curr = curr + model.blocks[i](curr) out = model.out_head(model.out_ln(curr)) return F.cross_entropy(out, y, reduction='none') return fwd_fn fwd_fn = make_fwd_fn(l) # DFA a_dfa = (s_eT @ dfa_Bs[l].T).detach() results['dfa']['bp_cosine'][l].append(cosine_similarity_batch(a_dfa, bp_grads[l])) if batch_idx == 1: results['dfa']['perturbation_rho'][l] = perturbation_correlation(h_l, a_dfa, fwd_fn, epsilon=1e-3, M=32) for eta in [0.001, 0.003, 0.01]: results['dfa'][f'nudging_{eta}'][l] = nudging_test(h_l, a_dfa, fwd_fn, eta=eta) # Other estimators for name, est in estimators.items(): if est['type'] == 'sb': net = est['net'] net.eval() h_l_req = h_l.clone().requires_grad_(True) pred_hL = net(h_l_req, t_l, s_eT) pred_logits = model.out_head(model.out_ln(pred_hL)) pred_loss = F.cross_entropy(pred_logits, y, reduction='sum') a_l = torch.autograd.grad(pred_loss, h_l_req, create_graph=False)[0].detach() elif est['type'] == 'cb': net = est['net'] net.eval() s = s_eT if est['s_type'] == 'eT' else delta_L h_l_req = h_l.clone().requires_grad_(True) V_l = net(h_l_req, t_l, s) a_l = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0].detach() elif est['type'] == 'vec': net = est['net'] net.eval() s = s_eT if est['s_type'] == 'eT' else delta_L a_l = net(h_l, t_l, s).detach() results[name]['bp_cosine'][l].append(cosine_similarity_batch(a_l, bp_grads[l])) if batch_idx == 1: results[name]['perturbation_rho'][l] = perturbation_correlation(h_l, a_l, fwd_fn, epsilon=1e-3, M=32) for eta in [0.001, 0.003, 0.01]: results[name][f'nudging_{eta}'][l] = nudging_test(h_l, a_l, fwd_fn, eta=eta) # Average bp_cosine for name in results: for l in range(L): vals = results[name]['bp_cosine'][l] results[name]['bp_cosine'][l] = float(np.mean(vals)) if vals else 0.0 return results # ============================================================================= # Main # ============================================================================= def run_experiment(args): device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") os.makedirs(args.output_dir, exist_ok=True) torch.manual_seed(args.seed) np.random.seed(args.seed) torch.cuda.manual_seed_all(args.seed) train_loader, test_loader = get_cifar10(batch_size=args.batch_size) input_dim = 32 * 32 * 3 # Step 1: Load/train BP reference bp_ckpt = os.path.join(args.output_dir, f'bp_ref_L{args.num_blocks}_d{args.d_hidden}_s{args.seed}.pt') model = ResidualMLP(input_dim, args.d_hidden, 10, args.num_blocks).to(device) # Try loading from frozen_cifar directory first alt_ckpt = f'results/frozen_cifar/bp_ref_L{args.num_blocks}_d{args.d_hidden}_s{args.seed}.pt' if os.path.exists(alt_ckpt) and not args.retrain_bp: print(f" Loading BP ref from {alt_ckpt}") model.load_state_dict(torch.load(alt_ckpt, map_location=device)) bp_acc = evaluate(model, test_loader, device) elif os.path.exists(bp_ckpt) and not args.retrain_bp: print(f" Loading BP ref from {bp_ckpt}") model.load_state_dict(torch.load(bp_ckpt, map_location=device)) bp_acc = evaluate(model, test_loader, device) else: bp_acc = train_bp_reference(model, train_loader, test_loader, device, epochs=args.bp_epochs, lr=args.lr, wd=args.wd) torch.save(model.state_dict(), bp_ckpt) print(f" BP ref acc: {bp_acc:.4f}") model.eval() for p in model.parameters(): p.requires_grad_(False) L = args.num_blocks d = args.d_hidden # Step 2: Train estimators print(f"\n{'='*60}") print(f"Training estimators (L={L}, d={d}, {args.estimator_epochs} epochs)") print(f"{'='*60}") estimators = {} # StateBridge_eT print("\n--- StateBridge_eT ---") torch.manual_seed(args.seed + 1000) sb = train_state_bridge_frozen(model, train_loader, device, args.estimator_epochs, args.lr_fb) estimators['sb_eT'] = {'type': 'sb', 'net': sb, 's_type': 'eT'} # ScalarCB_eT print("\n--- ScalarCB_eT ---") torch.manual_seed(args.seed + 2000) cb_eT = train_scalar_cb_frozen(model, train_loader, device, args.estimator_epochs, args.lr_fb, s_type='eT', term_grad_weight=args.term_grad_weight) estimators['cb_eT'] = {'type': 'cb', 'net': cb_eT, 's_type': 'eT'} # ScalarCB_deltaL print("\n--- ScalarCB_deltaL ---") torch.manual_seed(args.seed + 3000) cb_dL = train_scalar_cb_frozen(model, train_loader, device, args.estimator_epochs, args.lr_fb, s_type='deltaL', term_grad_weight=args.term_grad_weight) estimators['cb_deltaL'] = {'type': 'cb', 'net': cb_dL, 's_type': 'deltaL'} # Vector fields for M in args.M_values: for s_type in args.vec_s_types: tag = f'vec_{s_type}_M{M}' print(f"\n--- {tag} ---") torch.manual_seed(args.seed + 4000 + M * 100 + (0 if s_type == 'eT' else 1)) vnet = train_vector_field_frozen(model, train_loader, device, args.estimator_epochs, args.lr_fb, s_type=s_type, M=M, eps=args.pert_eps, beta=args.pert_beta, term_weight=args.term_weight_vec) estimators[tag] = {'type': 'vec', 'net': vnet, 's_type': s_type} # Step 3: Evaluate print(f"\n{'='*60}") print("Evaluating credit quality") print(f"{'='*60}") results = evaluate_all(model, test_loader, device, estimators) # Print summary all_methods = ['dfa', 'sb_eT', 'cb_eT', 'cb_deltaL'] + \ [f'vec_{st}_M{M}' for M in args.M_values for st in args.vec_s_types] labels = { 'dfa': 'DFA', 'sb_eT': 'StateBridge_eT', 'cb_eT': 'ScalarCB_eT', 'cb_deltaL': 'ScalarCB_deltaL', } for M in args.M_values: for st in args.vec_s_types: labels[f'vec_{st}_M{M}'] = f'Vec_{st}_M{M}' print(f"\n{'Method':<25} {'Gamma':>8} {'rho':>8} {'nudge':>10}") print("-" * 55) summary = {} for m in all_methods: if m not in results: continue r = results[m] mg = np.mean(r['bp_cosine']) mr = np.mean(r['perturbation_rho']) mn = np.mean(r['nudging_0.003']) summary[m] = {'mean_gamma': float(mg), 'mean_rho': float(mr), 'mean_nudge': float(mn)} print(f"{labels.get(m, m):<25} {mg:>8.4f} {mr:>8.4f} {mn:>10.6f}") # Per-layer detail print(f"\n--- Per-layer Gamma ---") for l in range(L): row = f" L{l}: " for m in all_methods: if m in results: row += f" {results[m]['bp_cosine'][l]:>8.4f}" print(row) print(f"\n--- Per-layer rho ---") for l in range(L): row = f" L{l}: " for m in all_methods: if m in results: row += f" {results[m]['perturbation_rho'][l]:>8.4f}" print(row) # Save save_data = { 'config': { 'num_blocks': L, 'd_hidden': d, 'seed': args.seed, 'bp_acc': float(bp_acc), 'estimator_epochs': args.estimator_epochs, }, 'summary': summary, 'per_layer': {m: { 'bp_cosine': results[m]['bp_cosine'], 'perturbation_rho': results[m]['perturbation_rho'], 'nudging_0.003': results[m]['nudging_0.003'], } for m in all_methods if m in results}, } out_path = os.path.join(args.output_dir, f'frozen_vec_L{L}_d{d}_s{args.seed}.json') with open(out_path, 'w') as f: json.dump(save_data, f, indent=2) print(f"\nResults saved to {out_path}") # Judgment cb_eT_gamma = summary.get('cb_eT', {}).get('mean_gamma', 0) cb_eT_rho = summary.get('cb_eT', {}).get('mean_rho', 0) best_vec_gamma = max(summary.get(m, {}).get('mean_gamma', 0) for m in summary if m.startswith('vec_')) best_vec_rho = max(summary.get(m, {}).get('mean_rho', 0) for m in summary if m.startswith('vec_')) best_vec_name = max((m for m in summary if m.startswith('vec_')), key=lambda m: summary[m]['mean_gamma'] + summary[m]['mean_rho'], default='none') print(f"\n{'='*60}") print("JUDGMENT") print(f"{'='*60}") print(f"ScalarCB_eT: Gamma={cb_eT_gamma:.4f}, rho={cb_eT_rho:.4f}") print(f"Best vector ({best_vec_name}): Gamma={best_vec_gamma:.4f}, rho={best_vec_rho:.4f}") dg = best_vec_gamma - cb_eT_gamma dr = best_vec_rho - cb_eT_rho print(f"Delta: Gamma={dg:+.4f}, rho={dr:+.4f}") if dg >= 0.05 and dr >= 0.05: print("TRANSFER SUCCESS: Vector field significantly outperforms scalar CB on frozen CIFAR.") elif dg > 0 and dr > 0: print("MARGINAL: Vector field slightly better, but deltas below 0.05 threshold.") else: print("TRANSFER FAILED: Vector field does not outperform scalar CB on frozen CIFAR.") return save_data def main(): parser = argparse.ArgumentParser(description='Phase 5B: Frozen CIFAR Vector Transfer') parser.add_argument('--num_blocks', type=int, default=4) parser.add_argument('--d_hidden', type=int, default=256) parser.add_argument('--batch_size', type=int, default=128) parser.add_argument('--bp_epochs', type=int, default=100) parser.add_argument('--estimator_epochs', type=int, default=100) parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--lr_fb', type=float, default=1e-3) parser.add_argument('--wd', type=float, default=0.01) parser.add_argument('--term_grad_weight', type=float, default=1.0) parser.add_argument('--term_weight_vec', type=float, default=1.0) parser.add_argument('--pert_eps', type=float, default=1e-3) parser.add_argument('--pert_beta', type=float, default=1.0) parser.add_argument('--M_values', type=int, nargs='+', default=[4, 8, 16]) parser.add_argument('--vec_s_types', type=str, nargs='+', default=['eT']) parser.add_argument('--seed', type=int, default=42) parser.add_argument('--gpu', type=int, default=2) parser.add_argument('--output_dir', type=str, default='results/frozen_cifar_vec') parser.add_argument('--retrain_bp', action='store_true') args = parser.parse_args() run_experiment(args) if __name__ == '__main__': main()