summaryrefslogtreecommitdiff
path: root/experiments/snapshot_time_sweep.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-03-25 10:23:19 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-03-25 10:23:19 -0500
commitef5bd494087a46ee80d8bc17796074efdae81ff4 (patch)
tree3104d9b8c0a07a38961aee54057125e45941db88 /experiments/snapshot_time_sweep.py
parent7e01fbc0ce871857c1e1879ed0d3559e8bfae7c7 (diff)
Add Phase 7A: snapshot time sweep shows early snapshots have positive held-out transfer
At epoch 5 (acc=49%), Vec_M4 5-step: dL_held=-0.005 (PUR=0.70) Oracle BP 5-step: dL_held=-0.009 (PUR=1.05) DFA 5-step: dL_held=+0.003 (always hurts held-out) By epoch 20, generalization window closes. Held-out failure is late-snapshot artifact. Better credit → lower update variance (Vec=0.8 vs DFA=40), not higher. Key implication: DFA warmup delays credit bridge past its useful window. Credit should be used from epoch 0, not after 20% warmup. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments/snapshot_time_sweep.py')
-rw-r--r--experiments/snapshot_time_sweep.py519
1 files changed, 519 insertions, 0 deletions
diff --git a/experiments/snapshot_time_sweep.py b/experiments/snapshot_time_sweep.py
new file mode 100644
index 0000000..fd87927
--- /dev/null
+++ b/experiments/snapshot_time_sweep.py
@@ -0,0 +1,519 @@
+"""
+Phase 7A: Snapshot-time sweep.
+
+Test whether "same-batch descent + held-out ascent" is a late-snapshot artifact
+or persists across training time.
+
+For each snapshot epoch, train estimators on frozen features, then measure:
+- DeltaL_same (same-batch 1-step and 5-step)
+- DeltaL_held (held-out 1-step and 5-step)
+- PUR = -DeltaL_held / (-DeltaL_same + 1e-12)
+- Cross-batch update cosine and variance
+"""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+import copy
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.residual_mlp import ResidualMLP
+from models.value_net import SinusoidalTimeEmbed
+
+
+class VectorCreditNet(nn.Module):
+ def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.time_embed = SinusoidalTimeEmbed(time_embed_dim)
+ input_dim = d_hidden + time_embed_dim + s_dim
+ layers = []
+ for i in range(num_layers):
+ in_d = input_dim if i == 0 else hidden_dim
+ layers.append(nn.Linear(in_d, hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, d_hidden))
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, h, t, s):
+ h_normed = self.ln(h)
+ t_emb = self.time_embed(t)
+ inp = torch.cat([h_normed, t_emb, s], dim=-1)
+ return self.net(inp)
+
+
+def get_cifar10(batch_size=128):
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
+ testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
+ train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
+ test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
+ return train_loader, test_loader
+
+
+def evaluate_acc(model, test_loader, device):
+ model.eval()
+ c, t = 0, 0
+ with torch.no_grad():
+ for x, y in test_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device)
+ c += (model(x).argmax(1) == y).sum().item(); t += x.size(0)
+ return c / t
+
+
+# =============================================================================
+# BP training with checkpoint saving
+# =============================================================================
+def train_bp_with_checkpoints(model, train_loader, test_loader, device,
+ epochs, save_epochs, ckpt_dir, lr=1e-3, wd=0.01):
+ """Train BP and save checkpoints at specified epochs."""
+ os.makedirs(ckpt_dir, exist_ok=True)
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
+
+ # Save epoch 0 (init)
+ if 0 in save_epochs:
+ torch.save(model.state_dict(), os.path.join(ckpt_dir, 'epoch_0.pt'))
+ acc = evaluate_acc(model, test_loader, device)
+ print(f" Saved epoch 0 (acc={acc:.4f})")
+
+ for epoch in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device)
+ loss = F.cross_entropy(model(x), y)
+ optimizer.zero_grad(); loss.backward(); optimizer.step()
+ scheduler.step()
+
+ if epoch in save_epochs:
+ torch.save(model.state_dict(), os.path.join(ckpt_dir, f'epoch_{epoch}.pt'))
+ acc = evaluate_acc(model, test_loader, device)
+ print(f" Saved epoch {epoch} (acc={acc:.4f})")
+
+
+# =============================================================================
+# Train vector field on frozen snapshot
+# =============================================================================
+def train_vec_on_snapshot(model, train_loader, device, epochs=60, lr_fb=1e-3, M=4):
+ d = model.d_hidden
+ L = model.num_blocks
+ vec_net = VectorCreditNet(d_hidden=d, s_dim=10, time_embed_dim=32,
+ hidden_dim=256, num_layers=3).to(device)
+ vec_opt = optim.Adam(vec_net.parameters(), lr=lr_fb)
+ eps = 1e-3
+ model.eval()
+ for ep in range(1, epochs + 1):
+ vec_net.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+ hL = hiddens[-1].detach()
+ # Terminal matching
+ t_L = torch.ones(batch, device=device)
+ a_term = vec_net(hL, t_L, s)
+ hL_req = hL.clone().requires_grad_(True)
+ logits_tgt = model.out_head(model.out_ln(hL_req))
+ ce = F.cross_entropy(logits_tgt, y, reduction='sum')
+ delta_L = torch.autograd.grad(ce, hL_req, create_graph=False)[0].detach()
+ loss_term = ((a_term - delta_L) ** 2).sum(-1).mean()
+ # Perturbation target (subsample 1 layer)
+ l = np.random.randint(0, L)
+ h_l = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ a_l = vec_net(h_l, t_l, s)
+ loss_proj = torch.tensor(0.0, device=device)
+ for _ in range(M):
+ v = torch.randn_like(h_l)
+ v = v / (v.norm(dim=-1, keepdim=True) + 1e-8)
+ with torch.no_grad():
+ lp = F.cross_entropy(model.forward_from_layer(h_l + eps * v, l), y, reduction='none')
+ lm = F.cross_entropy(model.forward_from_layer(h_l - eps * v, l), y, reduction='none')
+ g_j = (lp - lm) / (2 * eps)
+ loss_proj = loss_proj + (((a_l * v).sum(-1) - g_j.detach()) ** 2).mean()
+ loss_proj /= M
+ vloss = loss_term + loss_proj
+ vec_opt.zero_grad(); vloss.backward()
+ torch.nn.utils.clip_grad_norm_(vec_net.parameters(), 1.0)
+ vec_opt.step()
+ if ep % 20 == 0 or ep == 1:
+ print(f" [Vec] Ep {ep}")
+ return vec_net
+
+
+# =============================================================================
+# Credit computation
+# =============================================================================
+def get_credits(model, x, y, device, source, estimator=None, dfa_Bs=None):
+ L = model.num_blocks
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+ credits = {}
+ if source == 'dfa':
+ for l in range(L):
+ credits[l] = (s @ dfa_Bs[l].T).detach()
+ elif source == 'vec':
+ estimator.eval()
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ credits[l] = estimator(h_l, t_l, s).detach()
+ elif source == 'oracle_bp':
+ for p in model.parameters(): p.requires_grad_(True)
+ model.zero_grad()
+ logits_bp, hbp = model(x, return_hidden=True)
+ for l in range(L + 1): hbp[l].retain_grad()
+ F.cross_entropy(logits_bp, y).backward()
+ for l in range(L):
+ credits[l] = hbp[l].grad.detach().clone()
+ for p in model.parameters(): p.requires_grad_(False)
+ return credits, hiddens
+
+
+# =============================================================================
+# Local update and evaluation
+# =============================================================================
+def compute_update_vector(model, x, y, credits, device, eta, update_layers, normalize=False):
+ """Compute the parameter update direction (as a flat vector) without applying it."""
+ L = model.num_blocks
+ with torch.no_grad():
+ _, hiddens = model(x, return_hidden=True)
+
+ all_grads = []
+
+ # Head update
+ hL = hiddens[-1].detach()
+ logits_out = model.out_head(model.out_ln(hL))
+ loss_out = F.cross_entropy(logits_out, y)
+ head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters())
+ grads_head = torch.autograd.grad(loss_out, head_params)
+ for g in grads_head:
+ all_grads.append(g.detach().flatten())
+
+ # Block updates
+ for l in update_layers:
+ h_l = hiddens[l].detach()
+ a = credits[l]
+ if normalize:
+ rms = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ a = a / rms
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a.detach()).sum(-1).mean()
+ block_grads = torch.autograd.grad(local_loss, model.blocks[l].parameters())
+ for g in block_grads:
+ all_grads.append(g.detach().flatten())
+
+ return torch.cat(all_grads)
+
+
+def apply_update(model, x, y, credits, device, eta, update_layers, normalize=False):
+ """Apply one local surrogate update step. Returns model (modified in-place)."""
+ L = model.num_blocks
+ with torch.no_grad():
+ _, hiddens = model(x, return_hidden=True)
+
+ hL = hiddens[-1].detach()
+ logits_out = model.out_head(model.out_ln(hL))
+ loss_out = F.cross_entropy(logits_out, y)
+ head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters())
+ grads_head = torch.autograd.grad(loss_out, head_params)
+ with torch.no_grad():
+ for p, g in zip(head_params, grads_head):
+ p.sub_(eta * g)
+
+ for l in update_layers:
+ h_l = hiddens[l].detach()
+ a = credits[l]
+ if normalize:
+ rms = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ a = a / rms
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a.detach()).sum(-1).mean()
+ block_grads = torch.autograd.grad(local_loss, model.blocks[l].parameters())
+ with torch.no_grad():
+ for p, g in zip(model.blocks[l].parameters(), block_grads):
+ p.sub_(eta * g)
+
+
+def eval_loss(model, x, y):
+ model.eval()
+ with torch.no_grad():
+ return F.cross_entropy(model(x), y).item()
+
+
+# =============================================================================
+# Main
+# =============================================================================
+def run_experiment(args):
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ print(f"Using device: {device}")
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ torch.manual_seed(args.seed)
+ np.random.seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+
+ train_loader, test_loader = get_cifar10(args.batch_size)
+ input_dim = 32 * 32 * 3
+ L = args.num_blocks
+ d = args.d_hidden
+
+ # =========================================================
+ # Step 1: Train BP model with checkpoint saving
+ # =========================================================
+ ckpt_dir = os.path.join(args.output_dir, f'bp_ckpts_L{L}_d{d}_s{args.seed}')
+ save_epochs = args.snapshot_epochs
+
+ # Check if checkpoints already exist
+ all_exist = all(os.path.exists(os.path.join(ckpt_dir, f'epoch_{e}.pt')) for e in save_epochs)
+
+ if not all_exist or args.retrain:
+ print(f"\nTraining BP model with checkpoints at epochs {save_epochs}...")
+ model_train = ResidualMLP(input_dim, d, 10, L).to(device)
+ train_bp_with_checkpoints(model_train, train_loader, test_loader, device,
+ epochs=max(save_epochs), save_epochs=save_epochs,
+ ckpt_dir=ckpt_dir)
+ else:
+ print(f"\nAll checkpoints exist in {ckpt_dir}")
+
+ # =========================================================
+ # Step 2: For each snapshot, train estimators and test exploitability
+ # =========================================================
+
+ # Fixed batches for consistent evaluation
+ train_iter = iter(train_loader)
+ x_same, y_same = next(train_iter)
+ x_same = x_same.view(x_same.size(0), -1).to(device); y_same = y_same.to(device)
+ x_held, y_held = next(train_iter)
+ x_held = x_held.view(x_held.size(0), -1).to(device); y_held = y_held.to(device)
+
+ # Extra batches for cross-batch variance
+ extra_batches = []
+ for _ in range(8):
+ xb, yb = next(train_iter)
+ extra_batches.append((xb.view(xb.size(0), -1).to(device), yb.to(device)))
+
+ # DFA matrices (fixed across snapshots)
+ dfa_Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)]
+
+ update_layers = [L - 1] # last block only
+ all_results = []
+
+ for epoch in save_epochs:
+ print(f"\n{'='*60}")
+ print(f"Snapshot: epoch {epoch}")
+ print(f"{'='*60}")
+
+ # Load snapshot
+ model = ResidualMLP(input_dim, d, 10, L).to(device)
+ ckpt_path = os.path.join(ckpt_dir, f'epoch_{epoch}.pt')
+ model.load_state_dict(torch.load(ckpt_path, map_location=device))
+ model.eval()
+ for p in model.parameters(): p.requires_grad_(False)
+ acc = evaluate_acc(model, test_loader, device)
+ print(f" Accuracy: {acc:.4f}")
+
+ loss_same_before = eval_loss(model, x_same, y_same)
+ loss_held_before = eval_loss(model, x_held, y_held)
+ print(f" Loss: same={loss_same_before:.4f}, held={loss_held_before:.4f}")
+
+ # Train Vec on this snapshot
+ print(f" Training Vec_M4...")
+ torch.manual_seed(args.seed + epoch * 100 + 4000)
+ vec_net = train_vec_on_snapshot(model, train_loader, device,
+ epochs=args.estimator_epochs, lr_fb=args.lr_fb, M=4)
+
+ credit_sources = {
+ 'dfa': ('dfa', None, dfa_Bs),
+ 'vec_eT_M4': ('vec', vec_net, None),
+ 'oracle_bp': ('oracle_bp', None, None),
+ }
+
+ # Eta line search for each method
+ etas = args.etas
+
+ for name, (src, est, Bs) in credit_sources.items():
+ if name not in args.methods:
+ continue
+
+ # Compute credits on same batch
+ credits_same, _ = get_credits(model, x_same, y_same, device, src,
+ estimator=est, dfa_Bs=Bs)
+
+ best_eta = None
+ best_dl_same = float('inf')
+
+ for eta in etas:
+ # 1-step test
+ model_test = copy.deepcopy(model)
+ for p in model_test.parameters(): p.requires_grad_(True)
+ apply_update(model_test, x_same, y_same, credits_same, device,
+ eta=eta, update_layers=update_layers, normalize=False)
+ for p in model_test.parameters(): p.requires_grad_(False)
+
+ dl_same = eval_loss(model_test, x_same, y_same) - loss_same_before
+ dl_held = eval_loss(model_test, x_held, y_held) - loss_held_before
+
+ if dl_same < best_dl_same:
+ best_dl_same = dl_same
+ best_eta = eta
+ best_dl_held = dl_held
+
+ # 5-step rollout at best eta
+ model_5 = copy.deepcopy(model)
+ for p in model_5.parameters(): p.requires_grad_(True)
+ train_iter2 = iter(train_loader)
+ for step in range(5):
+ try: xs, ys = next(train_iter2)
+ except StopIteration: train_iter2 = iter(train_loader); xs, ys = next(train_iter2)
+ xs = xs.view(xs.size(0), -1).to(device); ys = ys.to(device)
+ for p in model_5.parameters(): p.requires_grad_(False)
+ creds_step, _ = get_credits(model_5, xs, ys, device, src, estimator=est, dfa_Bs=Bs)
+ for p in model_5.parameters(): p.requires_grad_(True)
+ apply_update(model_5, xs, ys, creds_step, device,
+ eta=best_eta, update_layers=update_layers, normalize=False)
+ for p in model_5.parameters(): p.requires_grad_(False)
+ dl_same_5 = eval_loss(model_5, x_same, y_same) - loss_same_before
+ dl_held_5 = eval_loss(model_5, x_held, y_held) - loss_held_before
+
+ # Cross-batch update variance
+ update_vecs = []
+ for xb, yb in extra_batches[:4]:
+ # get_credits may toggle requires_grad for oracle_bp
+ for p in model.parameters(): p.requires_grad_(False)
+ creds_b, _ = get_credits(model, xb, yb, device, src, estimator=est, dfa_Bs=Bs)
+ # compute_update_vector needs requires_grad=True
+ for p in model.parameters(): p.requires_grad_(True)
+ u = compute_update_vector(model, xb, yb, creds_b, device,
+ eta=best_eta, update_layers=update_layers, normalize=False)
+ update_vecs.append(u)
+ for p in model.parameters(): p.requires_grad_(False)
+
+ # Update cosine (mean pairwise cosine)
+ cosines = []
+ for i in range(len(update_vecs)):
+ for j in range(i + 1, len(update_vecs)):
+ cos = F.cosine_similarity(update_vecs[i].unsqueeze(0),
+ update_vecs[j].unsqueeze(0)).item()
+ cosines.append(cos)
+ update_cos = float(np.mean(cosines)) if cosines else 0.0
+
+ # Update variance
+ stacked = torch.stack(update_vecs)
+ mean_u = stacked.mean(0)
+ update_var = ((stacked - mean_u) ** 2).sum(-1).mean().item()
+
+ # PUR
+ pur_1 = -best_dl_held / (-best_dl_same + 1e-12) if best_dl_same < 0 else float('nan')
+ pur_5 = -dl_held_5 / (-dl_same_5 + 1e-12) if dl_same_5 < 0 else float('nan')
+
+ result = {
+ 'snapshot_epoch': epoch, 'method': name, 'snapshot_acc': float(acc),
+ 'best_eta': best_eta,
+ 'dl_same_1': best_dl_same, 'dl_held_1': best_dl_held, 'pur_1': pur_1,
+ 'dl_same_5': dl_same_5, 'dl_held_5': dl_held_5, 'pur_5': pur_5,
+ 'update_cos': update_cos, 'update_var': update_var,
+ }
+ all_results.append(result)
+
+ print(f" {name:>12}: eta={best_eta:.0e}, dL_same_1={best_dl_same:+.6f}, "
+ f"dL_held_1={best_dl_held:+.6f}, PUR_1={pur_1:.3f}, "
+ f"dL_same_5={dl_same_5:+.6f}, dL_held_5={dl_held_5:+.6f}, PUR_5={pur_5:.3f}, "
+ f"u_cos={update_cos:.3f}, u_var={update_var:.2e}")
+
+ # =========================================================
+ # Summary
+ # =========================================================
+ print(f"\n{'='*100}")
+ print("SUMMARY")
+ print(f"{'='*100}")
+ print(f"{'Epoch':>6} {'Acc':>6} {'Method':>12} {'eta':>8} {'dL_same_1':>10} {'dL_held_1':>10} "
+ f"{'PUR_1':>7} {'dL_same_5':>10} {'dL_held_5':>10} {'PUR_5':>7} {'u_cos':>6} {'u_var':>10}")
+ print("-" * 110)
+ for r in all_results:
+ print(f"{r['snapshot_epoch']:>6} {r['snapshot_acc']:>6.3f} {r['method']:>12} {r['best_eta']:>8.0e} "
+ f"{r['dl_same_1']:>+10.6f} {r['dl_held_1']:>+10.6f} {r['pur_1']:>7.3f} "
+ f"{r['dl_same_5']:>+10.6f} {r['dl_held_5']:>+10.6f} {r['pur_5']:>7.3f} "
+ f"{r['update_cos']:>6.3f} {r['update_var']:>10.2e}")
+
+ # Save
+ out_path = os.path.join(args.output_dir, f'time_sweep_L{L}_d{d}_s{args.seed}.json')
+ with open(out_path, 'w') as f:
+ json.dump(all_results, f, indent=2, default=float)
+ print(f"\nSaved to {out_path}")
+
+ # Judgment
+ print(f"\n{'='*60}")
+ print("JUDGMENT")
+ print(f"{'='*60}")
+
+ early_held_failures = 0
+ late_held_failures = 0
+ for r in all_results:
+ if r['method'] == 'vec_eT_M4':
+ if r['snapshot_epoch'] <= 20 and r['dl_held_1'] > 0:
+ early_held_failures += 1
+ if r['snapshot_epoch'] >= 50 and r['dl_held_1'] > 0:
+ late_held_failures += 1
+
+ early_epochs = [e for e in save_epochs if e <= 20]
+ late_epochs = [e for e in save_epochs if e >= 50]
+
+ if early_held_failures == 0 and late_held_failures > 0:
+ print("LATE-SNAPSHOT ARTIFACT: held-out failure only at late snapshots.")
+ print(" -> Early-training local updates with good credit DO generalize.")
+ elif early_held_failures > 0 and late_held_failures > 0:
+ print("ACROSS-TRAINING FAILURE: held-out degradation at both early and late snapshots.")
+ print(" -> Problem is NOT just late-snapshot overfitting.")
+ else:
+ print("NEED MORE DATA: check results table above.")
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Phase 7A: Snapshot Time Sweep')
+ parser.add_argument('--num_blocks', type=int, default=4)
+ parser.add_argument('--d_hidden', type=int, default=256)
+ parser.add_argument('--batch_size', type=int, default=128)
+ parser.add_argument('--snapshot_epochs', type=int, nargs='+', default=[5, 20, 100])
+ parser.add_argument('--estimator_epochs', type=int, default=60)
+ parser.add_argument('--lr_fb', type=float, default=1e-3)
+ parser.add_argument('--etas', type=float, nargs='+',
+ default=[1e-5, 3e-5, 1e-4, 3e-4, 1e-3, 3e-3, 1e-2])
+ parser.add_argument('--methods', type=str, nargs='+',
+ default=['dfa', 'vec_eT_M4', 'oracle_bp'])
+ parser.add_argument('--seed', type=int, default=42)
+ parser.add_argument('--gpu', type=int, default=3)
+ parser.add_argument('--output_dir', type=str, default='results/snapshot_time')
+ parser.add_argument('--retrain', action='store_true')
+ args = parser.parse_args()
+ run_experiment(args)
+
+
+if __name__ == '__main__':
+ main()