summaryrefslogtreecommitdiff
path: root/experiments/exploitability_samebatch_linesearch.py
diff options
context:
space:
mode:
Diffstat (limited to 'experiments/exploitability_samebatch_linesearch.py')
-rw-r--r--experiments/exploitability_samebatch_linesearch.py392
1 files changed, 392 insertions, 0 deletions
diff --git a/experiments/exploitability_samebatch_linesearch.py b/experiments/exploitability_samebatch_linesearch.py
new file mode 100644
index 0000000..c6c9ba8
--- /dev/null
+++ b/experiments/exploitability_samebatch_linesearch.py
@@ -0,0 +1,392 @@
+"""
+Phase 6.5A: Same-batch infinitesimal descent test.
+
+Strict protocol:
+- Fixed train minibatch B
+- Compute credits on B
+- Do local update with B
+- Evaluate loss change on THE SAME B (same-batch)
+- Sweep eta over multiple orders of magnitude
+- Test raw credit (no normalization) and normalized credit separately
+- No gradient clamping
+- Separate update scopes: last-block-only, last-2, all-blocks
+"""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+import copy
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.residual_mlp import ResidualMLP
+from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema
+from metrics.credit_metrics import cosine_similarity_batch
+
+# Reuse VectorCreditNet and estimator trainers
+from experiments.snapshot_exploitability import (
+ train_scalar_cb_on_snapshot, train_vector_on_snapshot, VectorCreditNet
+)
+
+
+def get_cifar10(batch_size=128):
+ import torchvision
+ import torchvision.transforms as transforms
+ from torch.utils.data import DataLoader
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
+ testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
+ train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
+ test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
+ return train_loader, test_loader
+
+
+def get_credits_on_batch(model, x, y, device, credit_source, estimator=None, dfa_Bs=None):
+ """Compute per-layer credits. Returns credits dict, hiddens list, conditioning s."""
+ L = model.num_blocks
+ batch = x.size(0)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+
+ credits = {}
+
+ if credit_source == 'dfa':
+ for l in range(L):
+ credits[l] = (s @ dfa_Bs[l].T).detach()
+
+ elif credit_source == 'scalar_cb':
+ estimator.eval()
+ for l in range(L):
+ h_l = hiddens[l].detach().requires_grad_(True)
+ t_l = torch.full((batch,), l / L, device=device)
+ V = estimator(h_l, t_l, s)
+ credits[l] = torch.autograd.grad(V.sum(), h_l, create_graph=False)[0].detach()
+
+ elif credit_source == 'vec':
+ estimator.eval()
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ credits[l] = estimator(h_l, t_l, s).detach()
+
+ elif credit_source == 'oracle_bp':
+ for p in model.parameters():
+ p.requires_grad_(True)
+ model.zero_grad()
+ logits_bp, hiddens_bp = model(x, return_hidden=True)
+ for l in range(L + 1):
+ hiddens_bp[l].retain_grad()
+ loss_bp = F.cross_entropy(logits_bp, y)
+ loss_bp.backward()
+ for l in range(L):
+ credits[l] = hiddens_bp[l].grad.detach().clone()
+ for p in model.parameters():
+ p.requires_grad_(False)
+
+ return credits, hiddens, s
+
+
+def do_local_update_clean(model, x, y, credits, device, eta,
+ update_layers, normalize_credit=False,
+ update_head=True):
+ """
+ Clean local update: no gradient clamping, explicit raw/norm control.
+
+ Args:
+ model: in-place modified
+ x, y: the batch
+ credits: dict {l: (batch, d)} raw credit vectors
+ eta: step size
+ update_layers: list of block indices to update
+ normalize_credit: if True, normalize credit by RMS before use
+ update_head: if True, also update output head with exact CE gradient
+ """
+ L = model.num_blocks
+
+ # Recompute hiddens with current params (important after any param change)
+ with torch.no_grad():
+ _, hiddens = model(x, return_hidden=True)
+
+ if update_head:
+ hL = hiddens[-1].detach()
+ logits_out = model.out_head(model.out_ln(hL))
+ loss_out = F.cross_entropy(logits_out, y)
+ head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters())
+ grads_head = torch.autograd.grad(loss_out, head_params)
+ with torch.no_grad():
+ for p, g in zip(head_params, grads_head):
+ p.sub_(eta * g)
+
+ for l in update_layers:
+ h_l = hiddens[l].detach()
+ a = credits[l]
+
+ if normalize_credit:
+ rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_used = a / rms
+ else:
+ a_used = a
+
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a_used.detach()).sum(dim=-1).mean()
+ block_grads = torch.autograd.grad(local_loss, model.blocks[l].parameters())
+ with torch.no_grad():
+ for p, g in zip(model.blocks[l].parameters(), block_grads):
+ p.sub_(eta * g) # NO clamping
+
+
+def eval_loss_on_batch(model, x, y):
+ """Evaluate CE loss on a specific batch."""
+ model.eval()
+ with torch.no_grad():
+ logits = model(x)
+ loss = F.cross_entropy(logits, y).item()
+ acc = (logits.argmax(1) == y).float().mean().item()
+ return loss, acc
+
+
+# =============================================================================
+# Main
+# =============================================================================
+def run_experiment(args):
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ print(f"Using device: {device}")
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ torch.manual_seed(args.seed)
+ np.random.seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+
+ train_loader, test_loader = get_cifar10(args.batch_size)
+ input_dim = 32 * 32 * 3
+ L = args.num_blocks
+ d = args.d_hidden
+
+ # Load BP snapshot
+ model_bp = ResidualMLP(input_dim, d, 10, L).to(device)
+ bp_ckpt = f'results/frozen_cifar/bp_ref_L{L}_d{d}_s{args.seed}.pt'
+ model_bp.load_state_dict(torch.load(bp_ckpt, map_location=device))
+ model_bp.eval()
+ for p in model_bp.parameters():
+ p.requires_grad_(False)
+ print(f"Loaded BP snapshot from {bp_ckpt}")
+
+ # Get a FIXED train batch (same-batch protocol)
+ train_iter = iter(train_loader)
+ x_batch, y_batch = next(train_iter)
+ x_batch = x_batch.view(x_batch.size(0), -1).to(device)
+ y_batch = y_batch.to(device)
+ print(f"Fixed train batch: {x_batch.shape[0]} samples")
+
+ # Get a separate held-out batch for comparison
+ x_held, y_held = next(train_iter)
+ x_held = x_held.view(x_held.size(0), -1).to(device)
+ y_held = y_held.to(device)
+
+ # Baseline losses
+ loss_before_same, acc_before_same = eval_loss_on_batch(model_bp, x_batch, y_batch)
+ loss_before_held, acc_before_held = eval_loss_on_batch(model_bp, x_held, y_held)
+ print(f"Before: same_batch_loss={loss_before_same:.6f}, held_out_loss={loss_before_held:.6f}")
+
+ # =========================================================
+ # Prepare credit sources
+ # =========================================================
+ credit_configs = {}
+
+ # DFA
+ dfa_Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)]
+ credit_configs['dfa'] = ('dfa', None, dfa_Bs)
+
+ # Scalar CB (train on frozen snapshot)
+ if 'scalar_cb' in args.methods:
+ print("\nTraining ScalarCB on snapshot...")
+ torch.manual_seed(args.seed + 2000)
+ cb = train_scalar_cb_on_snapshot(model_bp, train_loader, device,
+ epochs=args.estimator_epochs, lr_fb=args.lr_fb)
+ credit_configs['scalar_cb'] = ('scalar_cb', cb, None)
+
+ # Vec M4 (train on frozen snapshot)
+ if 'vec_eT_M4' in args.methods:
+ print("\nTraining Vec_M4 on snapshot...")
+ torch.manual_seed(args.seed + 4000)
+ vec4 = train_vector_on_snapshot(model_bp, train_loader, device,
+ epochs=args.estimator_epochs, lr_fb=args.lr_fb, M=4)
+ credit_configs['vec_eT_M4'] = ('vec', vec4, None)
+
+ # Oracle BP
+ credit_configs['oracle_bp'] = ('oracle_bp', None, None)
+
+ # =========================================================
+ # Compute credits on the fixed batch
+ # =========================================================
+ print("\nComputing credits on fixed batch...")
+ all_credits = {}
+ for name, (src, est, Bs) in credit_configs.items():
+ if name not in args.methods:
+ continue
+ credits, hiddens, s = get_credits_on_batch(model_bp, x_batch, y_batch, device,
+ src, estimator=est, dfa_Bs=Bs)
+ all_credits[name] = credits
+ # Report credit magnitudes
+ mean_rms = np.mean([credits[l].pow(2).mean().sqrt().item() for l in range(L)])
+ print(f" {name}: mean_credit_RMS={mean_rms:.6f}")
+
+ # =========================================================
+ # Line search
+ # =========================================================
+ etas = args.etas
+ update_ranges = {}
+ if 'last1' in args.update_ranges:
+ update_ranges['last1'] = [L - 1]
+ if 'last2' in args.update_ranges:
+ update_ranges['last2'] = [L - 2, L - 1]
+ if 'all' in args.update_ranges:
+ update_ranges['all'] = list(range(L))
+
+ norm_modes = args.norm_modes # ['raw'] or ['raw', 'norm']
+
+ results = {}
+
+ for ur_name, layers in update_ranges.items():
+ for norm_mode in norm_modes:
+ normalize = (norm_mode == 'norm')
+ print(f"\n{'='*60}")
+ print(f"Update range: {ur_name}, credit: {norm_mode}")
+ print(f"{'='*60}")
+ print(f"{'Method':<15} {'eta':>10} {'dL_same':>12} {'dL_held':>12} {'dAcc_same':>10}")
+ print("-" * 62)
+
+ for name in args.methods:
+ if name not in all_credits:
+ continue
+ credits = all_credits[name]
+
+ for eta in etas:
+ # Deep copy snapshot
+ model_test = copy.deepcopy(model_bp)
+ for p in model_test.parameters():
+ p.requires_grad_(True)
+
+ # Do update
+ do_local_update_clean(model_test, x_batch, y_batch, credits, device,
+ eta=eta, update_layers=layers,
+ normalize_credit=normalize,
+ update_head=(ur_name != 'last1')) # skip head for last1 only
+
+ for p in model_test.parameters():
+ p.requires_grad_(False)
+
+ # Evaluate on same batch
+ loss_same, acc_same = eval_loss_on_batch(model_test, x_batch, y_batch)
+ loss_held, acc_held = eval_loss_on_batch(model_test, x_held, y_held)
+
+ dl_same = loss_same - loss_before_same
+ dl_held = loss_held - loss_before_held
+ da_same = acc_same - acc_before_same
+
+ key = f"{name}_{ur_name}_{norm_mode}_eta{eta}"
+ results[key] = {
+ 'method': name, 'update_range': ur_name,
+ 'norm_mode': norm_mode, 'eta': eta,
+ 'loss_before_same': loss_before_same,
+ 'loss_after_same': loss_same,
+ 'delta_loss_same': dl_same,
+ 'delta_loss_held': dl_held,
+ 'delta_acc_same': da_same,
+ }
+
+ print(f"{name:<15} {eta:>10.1e} {dl_same:>+12.6f} {dl_held:>+12.6f} {da_same:>+10.4f}")
+
+ # =========================================================
+ # Summary: best eta per method
+ # =========================================================
+ print(f"\n{'='*60}")
+ print("BEST ETA PER METHOD (minimum same-batch DeltaLoss)")
+ print(f"{'='*60}")
+
+ for ur_name in update_ranges:
+ for norm_mode in norm_modes:
+ print(f"\n {ur_name}, {norm_mode}:")
+ for name in args.methods:
+ relevant = {k: v for k, v in results.items()
+ if v['method'] == name and v['update_range'] == ur_name
+ and v['norm_mode'] == norm_mode}
+ if not relevant:
+ continue
+ best_key = min(relevant, key=lambda k: relevant[k]['delta_loss_same'])
+ best = relevant[best_key]
+ print(f" {name:<15} best_eta={best['eta']:.1e}, "
+ f"dL_same={best['delta_loss_same']:+.6f}, "
+ f"dL_held={best['delta_loss_held']:+.6f}")
+
+ # Save
+ out_path = os.path.join(args.output_dir, f'linesearch_L{L}_d{d}_s{args.seed}.json')
+ with open(out_path, 'w') as f:
+ json.dump(results, f, indent=2, default=float)
+ print(f"\nSaved to {out_path}")
+
+ # =========================================================
+ # Key diagnostic: does Oracle BP descend at small eta?
+ # =========================================================
+ print(f"\n{'='*60}")
+ print("KEY DIAGNOSTIC")
+ print(f"{'='*60}")
+
+ for ur_name in update_ranges:
+ oracle_results = {k: v for k, v in results.items()
+ if v['method'] == 'oracle_bp' and v['update_range'] == ur_name
+ and v['norm_mode'] == 'raw'}
+ if not oracle_results:
+ continue
+ best = min(oracle_results.values(), key=lambda v: v['delta_loss_same'])
+ worst = max(oracle_results.values(), key=lambda v: v['delta_loss_same'])
+ print(f"\n Oracle BP ({ur_name}, raw):")
+ print(f" Best: eta={best['eta']:.1e}, dL_same={best['delta_loss_same']:+.6f}")
+ print(f" Worst: eta={worst['eta']:.1e}, dL_same={worst['delta_loss_same']:+.6f}")
+
+ if best['delta_loss_same'] < -1e-6:
+ print(f" -> Oracle BP CAN descend on same-batch at small eta. Protocol is OK.")
+ else:
+ print(f" -> WARNING: Oracle BP cannot descend! Check implementation.")
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Phase 6.5A: Same-batch Line Search')
+ parser.add_argument('--num_blocks', type=int, default=4)
+ parser.add_argument('--d_hidden', type=int, default=256)
+ parser.add_argument('--batch_size', type=int, default=128)
+ parser.add_argument('--estimator_epochs', type=int, default=100)
+ parser.add_argument('--lr_fb', type=float, default=1e-3)
+ parser.add_argument('--methods', type=str, nargs='+',
+ default=['oracle_bp', 'vec_eT_M4'])
+ parser.add_argument('--etas', type=float, nargs='+',
+ default=[1e-5, 3e-5, 1e-4, 3e-4, 1e-3])
+ parser.add_argument('--update_ranges', type=str, nargs='+', default=['last1'])
+ parser.add_argument('--norm_modes', type=str, nargs='+', default=['raw'])
+ parser.add_argument('--seed', type=int, default=42)
+ parser.add_argument('--gpu', type=int, default=3)
+ parser.add_argument('--output_dir', type=str, default='results/exploit_linesearch')
+ args = parser.parse_args()
+ run_experiment(args)
+
+
+if __name__ == '__main__':
+ main()