""" Phase 6A: Snapshot Exploitability Test. Core question: on a fixed network snapshot, does better credit lead to better real loss decrease via the current local surrogate update? Tests: A1. Single global step DeltaLoss A2. k-step short rollout (k=1,5,20) A3. Per-layer ablation (last-1, last-2, all blocks) Credit sources: DFA, ScalarCB_eT, Vec_eT_M4, Oracle BP gradient Snapshot sources: BP-trained, DFA-warmup """ import os import sys import json import argparse import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader import torchvision import torchvision.transforms as transforms import copy sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.residual_mlp import ResidualMLP from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema from metrics.credit_metrics import ( cosine_similarity_batch, perturbation_correlation, nudging_test ) class VectorCreditNet(nn.Module): def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3): super().__init__() self.ln = nn.LayerNorm(d_hidden) self.time_embed = SinusoidalTimeEmbed(time_embed_dim) input_dim = d_hidden + time_embed_dim + s_dim layers = [] for i in range(num_layers): in_d = input_dim if i == 0 else hidden_dim layers.append(nn.Linear(in_d, hidden_dim)) layers.append(nn.GELU()) layers.append(nn.Linear(hidden_dim, d_hidden)) self.net = nn.Sequential(*layers) def forward(self, h, t, s): h_normed = self.ln(h) t_emb = self.time_embed(t) inp = torch.cat([h_normed, t_emb, s], dim=-1) return self.net(inp) def get_cifar10(batch_size=128): transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) return train_loader, test_loader def evaluate_loss_acc(model, data_iter, device, n_batches=5): """Evaluate loss and accuracy on a few batches.""" model.eval() total_loss, correct, total = 0, 0, 0 with torch.no_grad(): for i, (x, y) in enumerate(data_iter): if i >= n_batches: break x = x.view(x.size(0), -1).to(device) y = y.to(device) logits = model(x) loss = F.cross_entropy(logits, y, reduction='sum') total_loss += loss.item() correct += (logits.argmax(1) == y).sum().item() total += x.size(0) return total_loss / total, correct / total # ============================================================================= # Train estimators on frozen snapshot # ============================================================================= def train_scalar_cb_on_snapshot(model, train_loader, device, epochs=100, lr_fb=1e-3): d = model.d_hidden L = model.num_blocks value_net = ValueNet(d_hidden=d, s_dim=10, time_embed_dim=32, hidden_dim=256, num_layers=3).to(device) value_net_ema = create_ema_model(value_net) value_opt = optim.Adam(value_net.parameters(), lr=lr_fb) model.eval() for epoch in range(1, epochs + 1): value_net.train() for x, y in train_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 s = e_T.detach() true_loss = F.cross_entropy(logits, y, reduction='none').detach() hL = hiddens[-1].detach() t_L = torch.ones(batch, device=device) V_term = value_net(hL, t_L, s) loss_term = ((V_term - true_loss) ** 2).mean() hL_req = hL.clone().requires_grad_(True) V_at_L = value_net(hL_req, t_L, s) grad_V = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0] hL_req2 = hL.clone().requires_grad_(True) logits_tgt = model.out_head(model.out_ln(hL_req2)) ce = F.cross_entropy(logits_tgt, y, reduction='sum') a_L_exact = torch.autograd.grad(ce, hL_req2, create_graph=False)[0].detach() loss_tgrad = ((grad_V - a_L_exact) ** 2).sum(dim=-1).mean() loss_bridge = 0.0 for l in range(L): h_l = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) t_next = torch.full((batch,), (l + 1) / L, device=device) V_l = value_net(h_l, t_l, s) with torch.no_grad(): h_next = hiddens[l + 1].detach() log_terms = [] for k in range(4): noise = 0.05 * torch.randn_like(h_next) V_next = value_net_ema(h_next + noise, t_next, s) log_terms.append(-V_next / 0.1) log_stack = torch.stack(log_terms, dim=-1) V_target = -0.1 * (torch.logsumexp(log_stack, dim=-1) - np.log(4)) loss_bridge += ((V_l - V_target.detach()) ** 2).mean() loss_bridge /= L vloss = loss_term + loss_bridge + 1.0 * loss_tgrad value_opt.zero_grad() vloss.backward() torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0) value_opt.step() update_ema(value_net, value_net_ema, 0.995) if epoch % 20 == 0 or epoch == 1: print(f" [CB] Ep {epoch}") return value_net def train_vector_on_snapshot(model, train_loader, device, epochs=100, lr_fb=1e-3, M=4): d = model.d_hidden L = model.num_blocks vector_net = VectorCreditNet(d_hidden=d, s_dim=10, time_embed_dim=32, hidden_dim=256, num_layers=3).to(device) vec_opt = optim.Adam(vector_net.parameters(), lr=lr_fb) eps = 1e-3 model.eval() for epoch in range(1, epochs + 1): vector_net.train() for x, y in train_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 s = e_T.detach() hL = hiddens[-1].detach() # Terminal matching t_L = torch.ones(batch, device=device) a_term = vector_net(hL, t_L, s) hL_req = hL.clone().requires_grad_(True) logits_tgt = model.out_head(model.out_ln(hL_req)) ce = F.cross_entropy(logits_tgt, y, reduction='sum') delta_L = torch.autograd.grad(ce, hL_req, create_graph=False)[0].detach() loss_term = ((a_term - delta_L) ** 2).sum(dim=-1).mean() # Perturbation target — subsample 1 layer l = np.random.randint(0, L) h_l = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) a_l = vector_net(h_l, t_l, s) loss_proj = torch.tensor(0.0, device=device) for _ in range(M): v = torch.randn_like(h_l) v = v / (v.norm(dim=-1, keepdim=True) + 1e-8) with torch.no_grad(): lp = F.cross_entropy(model.forward_from_layer(h_l + eps * v, l), y, reduction='none') lm = F.cross_entropy(model.forward_from_layer(h_l - eps * v, l), y, reduction='none') g_j = (lp - lm) / (2 * eps) loss_proj = loss_proj + (((a_l * v).sum(-1) - g_j.detach()) ** 2).mean() loss_proj = loss_proj / M vloss = loss_term + loss_proj vec_opt.zero_grad() vloss.backward() torch.nn.utils.clip_grad_norm_(vector_net.parameters(), 1.0) vec_opt.step() if epoch % 20 == 0 or epoch == 1: print(f" [Vec_M{M}] Ep {epoch}") return vector_net # ============================================================================= # Credit computation # ============================================================================= def get_credits(model, x, y, device, credit_source, estimator=None, dfa_Bs=None): """Compute per-layer credits for a single batch. Returns dict {l: (batch, d)}.""" L = model.num_blocks batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 s = e_T.detach() credits = {} if credit_source == 'dfa': for l in range(L): credits[l] = (s @ dfa_Bs[l].T).detach() elif credit_source == 'scalar_cb': estimator.eval() for l in range(L): h_l = hiddens[l].detach().requires_grad_(True) t_l = torch.full((batch,), l / L, device=device) V = estimator(h_l, t_l, s) a = torch.autograd.grad(V.sum(), h_l, create_graph=False)[0] credits[l] = a.detach() elif credit_source == 'vec': estimator.eval() for l in range(L): h_l = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) credits[l] = estimator(h_l, t_l, s).detach() elif credit_source == 'oracle_bp': # Compute true BP gradients — evaluation only, never used for training for p in model.parameters(): p.requires_grad_(True) model.zero_grad() logits_bp, hiddens_bp = model(x, return_hidden=True) for l in range(L + 1): hiddens_bp[l].retain_grad() loss_bp = F.cross_entropy(logits_bp, y) loss_bp.backward() for l in range(L): credits[l] = hiddens_bp[l].grad.detach().clone() for p in model.parameters(): p.requires_grad_(False) return credits, hiddens, s def compute_credit_quality(model, credits, hiddens, x, y, device): """Compute mean Gamma, rho, nudge for given credits.""" L = model.num_blocks batch = x.size(0) # BP gradients for Gamma for p in model.parameters(): p.requires_grad_(True) model.zero_grad() logits_bp, hiddens_bp = model(x, return_hidden=True) for l in range(L + 1): hiddens_bp[l].retain_grad() F.cross_entropy(logits_bp, y).backward() bp_grads = {l: hiddens_bp[l].grad.detach().clone() for l in range(L + 1)} for p in model.parameters(): p.requires_grad_(False) gammas, rhos, nudges = [], [], [] for l in range(L): h_l = hiddens[l].detach() a_l = credits[l] gammas.append(cosine_similarity_batch(a_l, bp_grads[l])) def make_fwd(sl): def f(h): with torch.no_grad(): c = h for i in range(sl, L): c = c + model.blocks[i](c) return F.cross_entropy(model.out_head(model.out_ln(c)), y, reduction='none') return f fwd = make_fwd(l) rhos.append(perturbation_correlation(h_l, a_l, fwd, epsilon=1e-3, M=16)) nudges.append(nudging_test(h_l, a_l, fwd, eta=0.003)) return float(np.mean(gammas)), float(np.mean(rhos)), float(np.mean(nudges)) # ============================================================================= # Local update step (block-local inner-product surrogate) # ============================================================================= def do_local_update_step(model, x, y, credits, device, lr=1e-3, update_layers=None): """ Perform one local surrogate update step on model blocks + output head. update_layers: list of layer indices to update (None = all) Returns new model state (modifies in-place). """ L = model.num_blocks if update_layers is None: update_layers = list(range(L)) model.train() with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) # Update output head (always) hL = hiddens[-1].detach() logits_out = model.out_head(model.out_ln(hL)) loss_out = F.cross_entropy(logits_out, y) head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters()) grads_head = torch.autograd.grad(loss_out, head_params) with torch.no_grad(): for p, g in zip(head_params, grads_head): p.sub_(lr * g) # Update selected blocks for l in update_layers: h_l = hiddens[l].detach() a = credits[l] rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a_norm = a / rms f_l = model.blocks[l](h_l) local_loss = (f_l * a_norm).sum(dim=-1).mean() block_grads = torch.autograd.grad(local_loss, model.blocks[l].parameters()) with torch.no_grad(): for p, g in zip(model.blocks[l].parameters(), block_grads): p.sub_(lr * g.clamp(-1, 1)) # implicit grad clip # Update embedding if 0 in update_layers: a_0 = credits[0] rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 h0 = model.embed(x) embed_loss = (h0 * (a_0 / rms_0)).sum(dim=-1).mean() embed_grads = torch.autograd.grad(embed_loss, model.embed.parameters()) with torch.no_grad(): for p, g in zip(model.embed.parameters(), embed_grads): p.sub_(lr * g.clamp(-1, 1)) # ============================================================================= # Main experiment # ============================================================================= def run_experiment(args): device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") os.makedirs(args.output_dir, exist_ok=True) torch.manual_seed(args.seed) np.random.seed(args.seed) torch.cuda.manual_seed_all(args.seed) train_loader, test_loader = get_cifar10(batch_size=args.batch_size) input_dim = 32 * 32 * 3 L = args.num_blocks d = args.d_hidden # ========================================================= # Step 1: Load or train BP snapshot # ========================================================= print(f"\n{'='*60}") print(f"Loading BP snapshot (L={L}, d={d})") print(f"{'='*60}") model_bp = ResidualMLP(input_dim, d, 10, L).to(device) bp_ckpt = f'results/frozen_cifar/bp_ref_L{L}_d{d}_s{args.seed}.pt' if os.path.exists(bp_ckpt): model_bp.load_state_dict(torch.load(bp_ckpt, map_location=device)) print(f" Loaded from {bp_ckpt}") else: print(f" Training BP reference...") optimizer = optim.AdamW(model_bp.parameters(), lr=1e-3, weight_decay=0.01) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100) for epoch in range(1, 101): model_bp.train() for x, y in train_loader: x = x.view(x.size(0), -1).to(device); y = y.to(device) loss = F.cross_entropy(model_bp(x), y) optimizer.zero_grad(); loss.backward(); optimizer.step() scheduler.step() if epoch % 20 == 0: model_bp.eval() c, t = 0, 0 with torch.no_grad(): for xv, yv in test_loader: xv = xv.view(xv.size(0), -1).to(device); yv = yv.to(device) c += (model_bp(xv).argmax(1) == yv).sum().item(); t += xv.size(0) print(f" Ep {epoch}: test_acc={c/t:.4f}") os.makedirs(os.path.dirname(bp_ckpt), exist_ok=True) torch.save(model_bp.state_dict(), bp_ckpt) model_bp.eval() for p in model_bp.parameters(): p.requires_grad_(False) bp_loss, bp_acc = evaluate_loss_acc(model_bp, test_loader, device, n_batches=20) print(f" BP snapshot: loss={bp_loss:.4f}, acc={bp_acc:.4f}") # ========================================================= # Step 2: Train estimators on BP snapshot # ========================================================= print(f"\n{'='*60}") print(f"Training estimators on BP snapshot ({args.estimator_epochs} epochs)") print(f"{'='*60}") # DFA dfa_Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)] # Scalar CB print(" --- ScalarCB_eT ---") torch.manual_seed(args.seed + 2000) cb = train_scalar_cb_on_snapshot(model_bp, train_loader, device, epochs=args.estimator_epochs, lr_fb=args.lr_fb) # Vector field M4 print(" --- Vec_eT_M4 ---") torch.manual_seed(args.seed + 4000) vec4 = train_vector_on_snapshot(model_bp, train_loader, device, epochs=args.estimator_epochs, lr_fb=args.lr_fb, M=4) credit_sources = { 'dfa': ('dfa', None, dfa_Bs), 'scalar_cb': ('scalar_cb', cb, None), 'vec_eT_M4': ('vec', vec4, None), 'oracle_bp': ('oracle_bp', None, None), } # ========================================================= # Step 3: Compute credit quality # ========================================================= print(f"\n{'='*60}") print(f"Computing credit quality on BP snapshot") print(f"{'='*60}") # Get a fixed batch for credit quality for x_eval, y_eval in test_loader: x_eval = x_eval.view(x_eval.size(0), -1).to(device) y_eval = y_eval.to(device) break quality = {} for name, (src, est, Bs) in credit_sources.items(): credits, hiddens, s = get_credits(model_bp, x_eval, y_eval, device, src, estimator=est, dfa_Bs=Bs) gamma, rho, nudge = compute_credit_quality(model_bp, credits, hiddens, x_eval, y_eval, device) quality[name] = {'gamma': gamma, 'rho': rho, 'nudge': nudge} print(f" {name}: Gamma={gamma:.4f}, rho={rho:.4f}, nudge={nudge:.6f}") # ========================================================= # Step 4: Exploitability tests # ========================================================= print(f"\n{'='*60}") print(f"Test A1: Single-step DeltaLoss") print(f"{'='*60}") # Get train batch for update train_iter = iter(train_loader) x_train, y_train = next(train_iter) x_train = x_train.view(x_train.size(0), -1).to(device) y_train = y_train.to(device) # Get held-out eval batches (different from train batch) eval_batches = [] for i, (xv, yv) in enumerate(test_loader): if i >= 10: break eval_batches.append((xv.view(xv.size(0), -1).to(device), yv.to(device))) def eval_on_batches(model): model.eval() total_loss, correct, total = 0, 0, 0 with torch.no_grad(): for xv, yv in eval_batches: logits = model(xv) total_loss += F.cross_entropy(logits, yv, reduction='sum').item() correct += (logits.argmax(1) == yv).sum().item() total += xv.size(0) return total_loss / total, correct / total results = {} lr_update = args.lr_update for name, (src, est, Bs) in credit_sources.items(): # Reset to snapshot model_test = copy.deepcopy(model_bp) model_test.eval() for p in model_test.parameters(): p.requires_grad_(False) loss_before, acc_before = eval_on_batches(model_test) # Compute credits on the train batch credits, hiddens, s = get_credits(model_test, x_train, y_train, device, src, estimator=est, dfa_Bs=Bs) # Re-enable grad for update for p in model_test.parameters(): p.requires_grad_(True) # One-step update do_local_update_step(model_test, x_train, y_train, credits, device, lr=lr_update) for p in model_test.parameters(): p.requires_grad_(False) loss_after, acc_after = eval_on_batches(model_test) delta_loss = loss_after - loss_before delta_acc = acc_after - acc_before results[name] = { 'gamma': quality[name]['gamma'], 'rho': quality[name]['rho'], 'nudge': quality[name]['nudge'], 'loss_before': loss_before, 'loss_after_1step': loss_after, 'delta_loss_1step': delta_loss, 'delta_acc_1step': delta_acc, } print(f" {name}: DeltaLoss={delta_loss:+.6f}, DeltaAcc={delta_acc:+.4f} " f"(before={loss_before:.4f}, after={loss_after:.4f})") # ========================================================= # Test A2: k-step rollout # ========================================================= print(f"\n{'='*60}") print(f"Test A2: k-step rollout (k=1,5,20)") print(f"{'='*60}") for name, (src, est, Bs) in credit_sources.items(): rollout = {} for k in [1, 5, 20]: model_test = copy.deepcopy(model_bp) for p in model_test.parameters(): p.requires_grad_(True) for step in range(k): # Fresh batch each step try: x_step, y_step = next(train_iter) except StopIteration: train_iter = iter(train_loader) x_step, y_step = next(train_iter) x_step = x_step.view(x_step.size(0), -1).to(device) y_step = y_step.to(device) # Recompute credits on current model state for p in model_test.parameters(): p.requires_grad_(False) credits_step, _, _ = get_credits(model_test, x_step, y_step, device, src, estimator=est, dfa_Bs=Bs) for p in model_test.parameters(): p.requires_grad_(True) do_local_update_step(model_test, x_step, y_step, credits_step, device, lr=lr_update) for p in model_test.parameters(): p.requires_grad_(False) loss_k, acc_k = eval_on_batches(model_test) rollout[k] = {'loss': loss_k, 'acc': acc_k, 'delta_loss': loss_k - results[name]['loss_before'], 'delta_acc': acc_k - results[name]['delta_acc_1step']} results[name]['rollout'] = rollout print(f" {name}: k=1 dL={rollout[1]['delta_loss']:+.4f}, " f"k=5 dL={rollout[5]['delta_loss']:+.4f}, " f"k=20 dL={rollout[20]['delta_loss']:+.4f}") # ========================================================= # Test A3: Per-layer ablation # ========================================================= print(f"\n{'='*60}") print(f"Test A3: Per-layer ablation (5-step)") print(f"{'='*60}") layer_configs = { 'last_1': [L - 1], 'last_2': [L - 2, L - 1], 'all': list(range(L)), } layer_results = {} for name, (src, est, Bs) in credit_sources.items(): layer_results[name] = {} for lname, layers in layer_configs.items(): model_test = copy.deepcopy(model_bp) for p in model_test.parameters(): p.requires_grad_(True) train_iter2 = iter(train_loader) for step in range(5): try: x_step, y_step = next(train_iter2) except StopIteration: train_iter2 = iter(train_loader) x_step, y_step = next(train_iter2) x_step = x_step.view(x_step.size(0), -1).to(device) y_step = y_step.to(device) for p in model_test.parameters(): p.requires_grad_(False) credits_step, _, _ = get_credits(model_test, x_step, y_step, device, src, estimator=est, dfa_Bs=Bs) for p in model_test.parameters(): p.requires_grad_(True) do_local_update_step(model_test, x_step, y_step, credits_step, device, lr=lr_update, update_layers=layers) for p in model_test.parameters(): p.requires_grad_(False) loss_k, acc_k = eval_on_batches(model_test) dl = loss_k - results[name]['loss_before'] layer_results[name][lname] = {'delta_loss': dl, 'loss': loss_k} print(f" {name}: last_1={layer_results[name]['last_1']['delta_loss']:+.4f}, " f"last_2={layer_results[name]['last_2']['delta_loss']:+.4f}, " f"all={layer_results[name]['all']['delta_loss']:+.4f}") # ========================================================= # Summary # ========================================================= print(f"\n{'='*80}") print("SUMMARY TABLE") print(f"{'='*80}") print(f"{'Method':<15} {'Gamma':>7} {'rho':>7} {'nudge':>10} {'dL_1':>8} {'dL_5':>8} {'dL_20':>8}") print("-" * 68) for name in ['dfa', 'scalar_cb', 'vec_eT_M4', 'oracle_bp']: r = results[name] dL1 = r['delta_loss_1step'] dL5 = r['rollout'][5]['delta_loss'] dL20 = r['rollout'][20]['delta_loss'] print(f"{name:<15} {r['gamma']:>7.4f} {r['rho']:>7.4f} {r['nudge']:>10.6f} " f"{dL1:>+8.4f} {dL5:>+8.4f} {dL20:>+8.4f}") print(f"\nPer-layer ablation (5-step DeltaLoss):") print(f"{'Method':<15} {'last_1':>8} {'last_2':>8} {'all':>8}") print("-" * 42) for name in ['dfa', 'scalar_cb', 'vec_eT_M4', 'oracle_bp']: lr = layer_results[name] print(f"{name:<15} {lr['last_1']['delta_loss']:>+8.4f} " f"{lr['last_2']['delta_loss']:>+8.4f} {lr['all']['delta_loss']:>+8.4f}") # Save save_data = { 'config': {'L': L, 'd': d, 'seed': args.seed, 'lr_update': lr_update, 'estimator_epochs': args.estimator_epochs}, 'credit_quality': quality, 'exploitability': {n: {k: v for k, v in r.items() if k != 'rollout'} for n, r in results.items()}, 'rollout': {n: r['rollout'] for n, r in results.items()}, 'layer_ablation': layer_results, } out_path = os.path.join(args.output_dir, f'snapshot_L{L}_d{d}_s{args.seed}.json') with open(out_path, 'w') as f: json.dump(save_data, f, indent=2, default=float) print(f"\nSaved to {out_path}") # Judgment print(f"\n{'='*60}") print("JUDGMENT") print(f"{'='*60}") vec_dl = results['vec_eT_M4']['rollout'][5]['delta_loss'] cb_dl = results['scalar_cb']['rollout'][5]['delta_loss'] bp_dl = results['oracle_bp']['rollout'][5]['delta_loss'] dfa_dl = results['dfa']['rollout'][5]['delta_loss'] print(f"5-step DeltaLoss: DFA={dfa_dl:+.4f}, CB={cb_dl:+.4f}, Vec={vec_dl:+.4f}, BP={bp_dl:+.4f}") if vec_dl < cb_dl and vec_dl < dfa_dl: print("EXPLOITABLE: Vec credit produces better loss decrease than ScalarCB and DFA.") print(" -> Online failure is likely tracking/co-adaptation (Case A).") elif bp_dl < dfa_dl and vec_dl >= cb_dl: print("NOT EXPLOITABLE: Better credit (vec) doesn't translate to better loss decrease.") print(" -> Bottleneck is in local update rule (Case B).") else: print("AMBIGUOUS: Need more investigation.") def main(): parser = argparse.ArgumentParser(description='Phase 6A: Snapshot Exploitability') parser.add_argument('--num_blocks', type=int, default=4) parser.add_argument('--d_hidden', type=int, default=256) parser.add_argument('--batch_size', type=int, default=128) parser.add_argument('--estimator_epochs', type=int, default=100) parser.add_argument('--lr_fb', type=float, default=1e-3) parser.add_argument('--lr_update', type=float, default=1e-3) parser.add_argument('--seed', type=int, default=42) parser.add_argument('--gpu', type=int, default=3) parser.add_argument('--output_dir', type=str, default='results/snapshot_exploit') args = parser.parse_args() run_experiment(args) if __name__ == '__main__': main()