""" Phase 7A: Snapshot-time sweep. Test whether "same-batch descent + held-out ascent" is a late-snapshot artifact or persists across training time. For each snapshot epoch, train estimators on frozen features, then measure: - DeltaL_same (same-batch 1-step and 5-step) - DeltaL_held (held-out 1-step and 5-step) - PUR = -DeltaL_held / (-DeltaL_same + 1e-12) - Cross-batch update cosine and variance """ import os import sys import json import argparse import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader import torchvision import torchvision.transforms as transforms import copy sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.residual_mlp import ResidualMLP from models.value_net import SinusoidalTimeEmbed class VectorCreditNet(nn.Module): def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3): super().__init__() self.ln = nn.LayerNorm(d_hidden) self.time_embed = SinusoidalTimeEmbed(time_embed_dim) input_dim = d_hidden + time_embed_dim + s_dim layers = [] for i in range(num_layers): in_d = input_dim if i == 0 else hidden_dim layers.append(nn.Linear(in_d, hidden_dim)) layers.append(nn.GELU()) layers.append(nn.Linear(hidden_dim, d_hidden)) self.net = nn.Sequential(*layers) def forward(self, h, t, s): h_normed = self.ln(h) t_emb = self.time_embed(t) inp = torch.cat([h_normed, t_emb, s], dim=-1) return self.net(inp) def get_cifar10(batch_size=128): transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) return train_loader, test_loader def evaluate_acc(model, test_loader, device): model.eval() c, t = 0, 0 with torch.no_grad(): for x, y in test_loader: x = x.view(x.size(0), -1).to(device); y = y.to(device) c += (model(x).argmax(1) == y).sum().item(); t += x.size(0) return c / t # ============================================================================= # BP training with checkpoint saving # ============================================================================= def train_bp_with_checkpoints(model, train_loader, test_loader, device, epochs, save_epochs, ckpt_dir, lr=1e-3, wd=0.01): """Train BP and save checkpoints at specified epochs.""" os.makedirs(ckpt_dir, exist_ok=True) optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) # Save epoch 0 (init) if 0 in save_epochs: torch.save(model.state_dict(), os.path.join(ckpt_dir, 'epoch_0.pt')) acc = evaluate_acc(model, test_loader, device) print(f" Saved epoch 0 (acc={acc:.4f})") for epoch in range(1, epochs + 1): model.train() for x, y in train_loader: x = x.view(x.size(0), -1).to(device); y = y.to(device) loss = F.cross_entropy(model(x), y) optimizer.zero_grad(); loss.backward(); optimizer.step() scheduler.step() if epoch in save_epochs: torch.save(model.state_dict(), os.path.join(ckpt_dir, f'epoch_{epoch}.pt')) acc = evaluate_acc(model, test_loader, device) print(f" Saved epoch {epoch} (acc={acc:.4f})") # ============================================================================= # Train vector field on frozen snapshot # ============================================================================= def train_vec_on_snapshot(model, train_loader, device, epochs=60, lr_fb=1e-3, M=4): d = model.d_hidden L = model.num_blocks vec_net = VectorCreditNet(d_hidden=d, s_dim=10, time_embed_dim=32, hidden_dim=256, num_layers=3).to(device) vec_opt = optim.Adam(vec_net.parameters(), lr=lr_fb) eps = 1e-3 model.eval() for ep in range(1, epochs + 1): vec_net.train() for x, y in train_loader: x = x.view(x.size(0), -1).to(device); y = y.to(device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 s = e_T.detach() hL = hiddens[-1].detach() # Terminal matching t_L = torch.ones(batch, device=device) a_term = vec_net(hL, t_L, s) hL_req = hL.clone().requires_grad_(True) logits_tgt = model.out_head(model.out_ln(hL_req)) ce = F.cross_entropy(logits_tgt, y, reduction='sum') delta_L = torch.autograd.grad(ce, hL_req, create_graph=False)[0].detach() loss_term = ((a_term - delta_L) ** 2).sum(-1).mean() # Perturbation target (subsample 1 layer) l = np.random.randint(0, L) h_l = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) a_l = vec_net(h_l, t_l, s) loss_proj = torch.tensor(0.0, device=device) for _ in range(M): v = torch.randn_like(h_l) v = v / (v.norm(dim=-1, keepdim=True) + 1e-8) with torch.no_grad(): lp = F.cross_entropy(model.forward_from_layer(h_l + eps * v, l), y, reduction='none') lm = F.cross_entropy(model.forward_from_layer(h_l - eps * v, l), y, reduction='none') g_j = (lp - lm) / (2 * eps) loss_proj = loss_proj + (((a_l * v).sum(-1) - g_j.detach()) ** 2).mean() loss_proj /= M vloss = loss_term + loss_proj vec_opt.zero_grad(); vloss.backward() torch.nn.utils.clip_grad_norm_(vec_net.parameters(), 1.0) vec_opt.step() if ep % 20 == 0 or ep == 1: print(f" [Vec] Ep {ep}") return vec_net # ============================================================================= # Credit computation # ============================================================================= def get_credits(model, x, y, device, source, estimator=None, dfa_Bs=None): L = model.num_blocks batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 s = e_T.detach() credits = {} if source == 'dfa': for l in range(L): credits[l] = (s @ dfa_Bs[l].T).detach() elif source == 'vec': estimator.eval() for l in range(L): h_l = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) credits[l] = estimator(h_l, t_l, s).detach() elif source == 'oracle_bp': for p in model.parameters(): p.requires_grad_(True) model.zero_grad() logits_bp, hbp = model(x, return_hidden=True) for l in range(L + 1): hbp[l].retain_grad() F.cross_entropy(logits_bp, y).backward() for l in range(L): credits[l] = hbp[l].grad.detach().clone() for p in model.parameters(): p.requires_grad_(False) return credits, hiddens # ============================================================================= # Local update and evaluation # ============================================================================= def compute_update_vector(model, x, y, credits, device, eta, update_layers, normalize=False): """Compute the parameter update direction (as a flat vector) without applying it.""" L = model.num_blocks with torch.no_grad(): _, hiddens = model(x, return_hidden=True) all_grads = [] # Head update hL = hiddens[-1].detach() logits_out = model.out_head(model.out_ln(hL)) loss_out = F.cross_entropy(logits_out, y) head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters()) grads_head = torch.autograd.grad(loss_out, head_params) for g in grads_head: all_grads.append(g.detach().flatten()) # Block updates for l in update_layers: h_l = hiddens[l].detach() a = credits[l] if normalize: rms = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 a = a / rms f_l = model.blocks[l](h_l) local_loss = (f_l * a.detach()).sum(-1).mean() block_grads = torch.autograd.grad(local_loss, model.blocks[l].parameters()) for g in block_grads: all_grads.append(g.detach().flatten()) return torch.cat(all_grads) def apply_update(model, x, y, credits, device, eta, update_layers, normalize=False): """Apply one local surrogate update step. Returns model (modified in-place).""" L = model.num_blocks with torch.no_grad(): _, hiddens = model(x, return_hidden=True) hL = hiddens[-1].detach() logits_out = model.out_head(model.out_ln(hL)) loss_out = F.cross_entropy(logits_out, y) head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters()) grads_head = torch.autograd.grad(loss_out, head_params) with torch.no_grad(): for p, g in zip(head_params, grads_head): p.sub_(eta * g) for l in update_layers: h_l = hiddens[l].detach() a = credits[l] if normalize: rms = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 a = a / rms f_l = model.blocks[l](h_l) local_loss = (f_l * a.detach()).sum(-1).mean() block_grads = torch.autograd.grad(local_loss, model.blocks[l].parameters()) with torch.no_grad(): for p, g in zip(model.blocks[l].parameters(), block_grads): p.sub_(eta * g) def eval_loss(model, x, y): model.eval() with torch.no_grad(): return F.cross_entropy(model(x), y).item() # ============================================================================= # Main # ============================================================================= def run_experiment(args): device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") os.makedirs(args.output_dir, exist_ok=True) torch.manual_seed(args.seed) np.random.seed(args.seed) torch.cuda.manual_seed_all(args.seed) train_loader, test_loader = get_cifar10(args.batch_size) input_dim = 32 * 32 * 3 L = args.num_blocks d = args.d_hidden # ========================================================= # Step 1: Train BP model with checkpoint saving # ========================================================= ckpt_dir = os.path.join(args.output_dir, f'bp_ckpts_L{L}_d{d}_s{args.seed}') save_epochs = args.snapshot_epochs # Check if checkpoints already exist all_exist = all(os.path.exists(os.path.join(ckpt_dir, f'epoch_{e}.pt')) for e in save_epochs) if not all_exist or args.retrain: print(f"\nTraining BP model with checkpoints at epochs {save_epochs}...") model_train = ResidualMLP(input_dim, d, 10, L).to(device) train_bp_with_checkpoints(model_train, train_loader, test_loader, device, epochs=max(save_epochs), save_epochs=save_epochs, ckpt_dir=ckpt_dir) else: print(f"\nAll checkpoints exist in {ckpt_dir}") # ========================================================= # Step 2: For each snapshot, train estimators and test exploitability # ========================================================= # Fixed batches for consistent evaluation train_iter = iter(train_loader) x_same, y_same = next(train_iter) x_same = x_same.view(x_same.size(0), -1).to(device); y_same = y_same.to(device) x_held, y_held = next(train_iter) x_held = x_held.view(x_held.size(0), -1).to(device); y_held = y_held.to(device) # Extra batches for cross-batch variance extra_batches = [] for _ in range(8): xb, yb = next(train_iter) extra_batches.append((xb.view(xb.size(0), -1).to(device), yb.to(device))) # DFA matrices (fixed across snapshots) dfa_Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)] update_layers = [L - 1] # last block only all_results = [] for epoch in save_epochs: print(f"\n{'='*60}") print(f"Snapshot: epoch {epoch}") print(f"{'='*60}") # Load snapshot model = ResidualMLP(input_dim, d, 10, L).to(device) ckpt_path = os.path.join(ckpt_dir, f'epoch_{epoch}.pt') model.load_state_dict(torch.load(ckpt_path, map_location=device)) model.eval() for p in model.parameters(): p.requires_grad_(False) acc = evaluate_acc(model, test_loader, device) print(f" Accuracy: {acc:.4f}") loss_same_before = eval_loss(model, x_same, y_same) loss_held_before = eval_loss(model, x_held, y_held) print(f" Loss: same={loss_same_before:.4f}, held={loss_held_before:.4f}") # Train Vec on this snapshot print(f" Training Vec_M4...") torch.manual_seed(args.seed + epoch * 100 + 4000) vec_net = train_vec_on_snapshot(model, train_loader, device, epochs=args.estimator_epochs, lr_fb=args.lr_fb, M=4) credit_sources = { 'dfa': ('dfa', None, dfa_Bs), 'vec_eT_M4': ('vec', vec_net, None), 'oracle_bp': ('oracle_bp', None, None), } # Eta line search for each method etas = args.etas for name, (src, est, Bs) in credit_sources.items(): if name not in args.methods: continue # Compute credits on same batch credits_same, _ = get_credits(model, x_same, y_same, device, src, estimator=est, dfa_Bs=Bs) best_eta = None best_dl_same = float('inf') for eta in etas: # 1-step test model_test = copy.deepcopy(model) for p in model_test.parameters(): p.requires_grad_(True) apply_update(model_test, x_same, y_same, credits_same, device, eta=eta, update_layers=update_layers, normalize=False) for p in model_test.parameters(): p.requires_grad_(False) dl_same = eval_loss(model_test, x_same, y_same) - loss_same_before dl_held = eval_loss(model_test, x_held, y_held) - loss_held_before if dl_same < best_dl_same: best_dl_same = dl_same best_eta = eta best_dl_held = dl_held # 5-step rollout at best eta model_5 = copy.deepcopy(model) for p in model_5.parameters(): p.requires_grad_(True) train_iter2 = iter(train_loader) for step in range(5): try: xs, ys = next(train_iter2) except StopIteration: train_iter2 = iter(train_loader); xs, ys = next(train_iter2) xs = xs.view(xs.size(0), -1).to(device); ys = ys.to(device) for p in model_5.parameters(): p.requires_grad_(False) creds_step, _ = get_credits(model_5, xs, ys, device, src, estimator=est, dfa_Bs=Bs) for p in model_5.parameters(): p.requires_grad_(True) apply_update(model_5, xs, ys, creds_step, device, eta=best_eta, update_layers=update_layers, normalize=False) for p in model_5.parameters(): p.requires_grad_(False) dl_same_5 = eval_loss(model_5, x_same, y_same) - loss_same_before dl_held_5 = eval_loss(model_5, x_held, y_held) - loss_held_before # Cross-batch update variance update_vecs = [] for xb, yb in extra_batches[:4]: # get_credits may toggle requires_grad for oracle_bp for p in model.parameters(): p.requires_grad_(False) creds_b, _ = get_credits(model, xb, yb, device, src, estimator=est, dfa_Bs=Bs) # compute_update_vector needs requires_grad=True for p in model.parameters(): p.requires_grad_(True) u = compute_update_vector(model, xb, yb, creds_b, device, eta=best_eta, update_layers=update_layers, normalize=False) update_vecs.append(u) for p in model.parameters(): p.requires_grad_(False) # Update cosine (mean pairwise cosine) cosines = [] for i in range(len(update_vecs)): for j in range(i + 1, len(update_vecs)): cos = F.cosine_similarity(update_vecs[i].unsqueeze(0), update_vecs[j].unsqueeze(0)).item() cosines.append(cos) update_cos = float(np.mean(cosines)) if cosines else 0.0 # Update variance stacked = torch.stack(update_vecs) mean_u = stacked.mean(0) update_var = ((stacked - mean_u) ** 2).sum(-1).mean().item() # PUR pur_1 = -best_dl_held / (-best_dl_same + 1e-12) if best_dl_same < 0 else float('nan') pur_5 = -dl_held_5 / (-dl_same_5 + 1e-12) if dl_same_5 < 0 else float('nan') result = { 'snapshot_epoch': epoch, 'method': name, 'snapshot_acc': float(acc), 'best_eta': best_eta, 'dl_same_1': best_dl_same, 'dl_held_1': best_dl_held, 'pur_1': pur_1, 'dl_same_5': dl_same_5, 'dl_held_5': dl_held_5, 'pur_5': pur_5, 'update_cos': update_cos, 'update_var': update_var, } all_results.append(result) print(f" {name:>12}: eta={best_eta:.0e}, dL_same_1={best_dl_same:+.6f}, " f"dL_held_1={best_dl_held:+.6f}, PUR_1={pur_1:.3f}, " f"dL_same_5={dl_same_5:+.6f}, dL_held_5={dl_held_5:+.6f}, PUR_5={pur_5:.3f}, " f"u_cos={update_cos:.3f}, u_var={update_var:.2e}") # ========================================================= # Summary # ========================================================= print(f"\n{'='*100}") print("SUMMARY") print(f"{'='*100}") print(f"{'Epoch':>6} {'Acc':>6} {'Method':>12} {'eta':>8} {'dL_same_1':>10} {'dL_held_1':>10} " f"{'PUR_1':>7} {'dL_same_5':>10} {'dL_held_5':>10} {'PUR_5':>7} {'u_cos':>6} {'u_var':>10}") print("-" * 110) for r in all_results: print(f"{r['snapshot_epoch']:>6} {r['snapshot_acc']:>6.3f} {r['method']:>12} {r['best_eta']:>8.0e} " f"{r['dl_same_1']:>+10.6f} {r['dl_held_1']:>+10.6f} {r['pur_1']:>7.3f} " f"{r['dl_same_5']:>+10.6f} {r['dl_held_5']:>+10.6f} {r['pur_5']:>7.3f} " f"{r['update_cos']:>6.3f} {r['update_var']:>10.2e}") # Save out_path = os.path.join(args.output_dir, f'time_sweep_L{L}_d{d}_s{args.seed}.json') with open(out_path, 'w') as f: json.dump(all_results, f, indent=2, default=float) print(f"\nSaved to {out_path}") # Judgment print(f"\n{'='*60}") print("JUDGMENT") print(f"{'='*60}") early_held_failures = 0 late_held_failures = 0 for r in all_results: if r['method'] == 'vec_eT_M4': if r['snapshot_epoch'] <= 20 and r['dl_held_1'] > 0: early_held_failures += 1 if r['snapshot_epoch'] >= 50 and r['dl_held_1'] > 0: late_held_failures += 1 early_epochs = [e for e in save_epochs if e <= 20] late_epochs = [e for e in save_epochs if e >= 50] if early_held_failures == 0 and late_held_failures > 0: print("LATE-SNAPSHOT ARTIFACT: held-out failure only at late snapshots.") print(" -> Early-training local updates with good credit DO generalize.") elif early_held_failures > 0 and late_held_failures > 0: print("ACROSS-TRAINING FAILURE: held-out degradation at both early and late snapshots.") print(" -> Problem is NOT just late-snapshot overfitting.") else: print("NEED MORE DATA: check results table above.") def main(): parser = argparse.ArgumentParser(description='Phase 7A: Snapshot Time Sweep') parser.add_argument('--num_blocks', type=int, default=4) parser.add_argument('--d_hidden', type=int, default=256) parser.add_argument('--batch_size', type=int, default=128) parser.add_argument('--snapshot_epochs', type=int, nargs='+', default=[5, 20, 100]) parser.add_argument('--estimator_epochs', type=int, default=60) parser.add_argument('--lr_fb', type=float, default=1e-3) parser.add_argument('--etas', type=float, nargs='+', default=[1e-5, 3e-5, 1e-4, 3e-4, 1e-3, 3e-3, 1e-2]) parser.add_argument('--methods', type=str, nargs='+', default=['dfa', 'vec_eT_M4', 'oracle_bp']) parser.add_argument('--seed', type=int, default=42) parser.add_argument('--gpu', type=int, default=3) parser.add_argument('--output_dir', type=str, default='results/snapshot_time') parser.add_argument('--retrain', action='store_true') args = parser.parse_args() run_experiment(args) if __name__ == '__main__': main()