""" Phase 6.5A: Same-batch infinitesimal descent test. Strict protocol: - Fixed train minibatch B - Compute credits on B - Do local update with B - Evaluate loss change on THE SAME B (same-batch) - Sweep eta over multiple orders of magnitude - Test raw credit (no normalization) and normalized credit separately - No gradient clamping - Separate update scopes: last-block-only, last-2, all-blocks """ import os import sys import json import argparse import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import copy sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.residual_mlp import ResidualMLP from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema from metrics.credit_metrics import cosine_similarity_batch # Reuse VectorCreditNet and estimator trainers from experiments.snapshot_exploitability import ( train_scalar_cb_on_snapshot, train_vector_on_snapshot, VectorCreditNet ) def get_cifar10(batch_size=128): import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) return train_loader, test_loader def get_credits_on_batch(model, x, y, device, credit_source, estimator=None, dfa_Bs=None): """Compute per-layer credits. Returns credits dict, hiddens list, conditioning s.""" L = model.num_blocks batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 s = e_T.detach() credits = {} if credit_source == 'dfa': for l in range(L): credits[l] = (s @ dfa_Bs[l].T).detach() elif credit_source == 'scalar_cb': estimator.eval() for l in range(L): h_l = hiddens[l].detach().requires_grad_(True) t_l = torch.full((batch,), l / L, device=device) V = estimator(h_l, t_l, s) credits[l] = torch.autograd.grad(V.sum(), h_l, create_graph=False)[0].detach() elif credit_source == 'vec': estimator.eval() for l in range(L): h_l = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) credits[l] = estimator(h_l, t_l, s).detach() elif credit_source == 'oracle_bp': for p in model.parameters(): p.requires_grad_(True) model.zero_grad() logits_bp, hiddens_bp = model(x, return_hidden=True) for l in range(L + 1): hiddens_bp[l].retain_grad() loss_bp = F.cross_entropy(logits_bp, y) loss_bp.backward() for l in range(L): credits[l] = hiddens_bp[l].grad.detach().clone() for p in model.parameters(): p.requires_grad_(False) return credits, hiddens, s def do_local_update_clean(model, x, y, credits, device, eta, update_layers, normalize_credit=False, update_head=True): """ Clean local update: no gradient clamping, explicit raw/norm control. Args: model: in-place modified x, y: the batch credits: dict {l: (batch, d)} raw credit vectors eta: step size update_layers: list of block indices to update normalize_credit: if True, normalize credit by RMS before use update_head: if True, also update output head with exact CE gradient """ L = model.num_blocks # Recompute hiddens with current params (important after any param change) with torch.no_grad(): _, hiddens = model(x, return_hidden=True) if update_head: hL = hiddens[-1].detach() logits_out = model.out_head(model.out_ln(hL)) loss_out = F.cross_entropy(logits_out, y) head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters()) grads_head = torch.autograd.grad(loss_out, head_params) with torch.no_grad(): for p, g in zip(head_params, grads_head): p.sub_(eta * g) for l in update_layers: h_l = hiddens[l].detach() a = credits[l] if normalize_credit: rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a_used = a / rms else: a_used = a f_l = model.blocks[l](h_l) local_loss = (f_l * a_used.detach()).sum(dim=-1).mean() block_grads = torch.autograd.grad(local_loss, model.blocks[l].parameters()) with torch.no_grad(): for p, g in zip(model.blocks[l].parameters(), block_grads): p.sub_(eta * g) # NO clamping def eval_loss_on_batch(model, x, y): """Evaluate CE loss on a specific batch.""" model.eval() with torch.no_grad(): logits = model(x) loss = F.cross_entropy(logits, y).item() acc = (logits.argmax(1) == y).float().mean().item() return loss, acc # ============================================================================= # Main # ============================================================================= def run_experiment(args): device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") os.makedirs(args.output_dir, exist_ok=True) torch.manual_seed(args.seed) np.random.seed(args.seed) torch.cuda.manual_seed_all(args.seed) train_loader, test_loader = get_cifar10(args.batch_size) input_dim = 32 * 32 * 3 L = args.num_blocks d = args.d_hidden # Load BP snapshot model_bp = ResidualMLP(input_dim, d, 10, L).to(device) bp_ckpt = f'results/frozen_cifar/bp_ref_L{L}_d{d}_s{args.seed}.pt' model_bp.load_state_dict(torch.load(bp_ckpt, map_location=device)) model_bp.eval() for p in model_bp.parameters(): p.requires_grad_(False) print(f"Loaded BP snapshot from {bp_ckpt}") # Get a FIXED train batch (same-batch protocol) train_iter = iter(train_loader) x_batch, y_batch = next(train_iter) x_batch = x_batch.view(x_batch.size(0), -1).to(device) y_batch = y_batch.to(device) print(f"Fixed train batch: {x_batch.shape[0]} samples") # Get a separate held-out batch for comparison x_held, y_held = next(train_iter) x_held = x_held.view(x_held.size(0), -1).to(device) y_held = y_held.to(device) # Baseline losses loss_before_same, acc_before_same = eval_loss_on_batch(model_bp, x_batch, y_batch) loss_before_held, acc_before_held = eval_loss_on_batch(model_bp, x_held, y_held) print(f"Before: same_batch_loss={loss_before_same:.6f}, held_out_loss={loss_before_held:.6f}") # ========================================================= # Prepare credit sources # ========================================================= credit_configs = {} # DFA dfa_Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)] credit_configs['dfa'] = ('dfa', None, dfa_Bs) # Scalar CB (train on frozen snapshot) if 'scalar_cb' in args.methods: print("\nTraining ScalarCB on snapshot...") torch.manual_seed(args.seed + 2000) cb = train_scalar_cb_on_snapshot(model_bp, train_loader, device, epochs=args.estimator_epochs, lr_fb=args.lr_fb) credit_configs['scalar_cb'] = ('scalar_cb', cb, None) # Vec M4 (train on frozen snapshot) if 'vec_eT_M4' in args.methods: print("\nTraining Vec_M4 on snapshot...") torch.manual_seed(args.seed + 4000) vec4 = train_vector_on_snapshot(model_bp, train_loader, device, epochs=args.estimator_epochs, lr_fb=args.lr_fb, M=4) credit_configs['vec_eT_M4'] = ('vec', vec4, None) # Oracle BP credit_configs['oracle_bp'] = ('oracle_bp', None, None) # ========================================================= # Compute credits on the fixed batch # ========================================================= print("\nComputing credits on fixed batch...") all_credits = {} for name, (src, est, Bs) in credit_configs.items(): if name not in args.methods: continue credits, hiddens, s = get_credits_on_batch(model_bp, x_batch, y_batch, device, src, estimator=est, dfa_Bs=Bs) all_credits[name] = credits # Report credit magnitudes mean_rms = np.mean([credits[l].pow(2).mean().sqrt().item() for l in range(L)]) print(f" {name}: mean_credit_RMS={mean_rms:.6f}") # ========================================================= # Line search # ========================================================= etas = args.etas update_ranges = {} if 'last1' in args.update_ranges: update_ranges['last1'] = [L - 1] if 'last2' in args.update_ranges: update_ranges['last2'] = [L - 2, L - 1] if 'all' in args.update_ranges: update_ranges['all'] = list(range(L)) norm_modes = args.norm_modes # ['raw'] or ['raw', 'norm'] results = {} for ur_name, layers in update_ranges.items(): for norm_mode in norm_modes: normalize = (norm_mode == 'norm') print(f"\n{'='*60}") print(f"Update range: {ur_name}, credit: {norm_mode}") print(f"{'='*60}") print(f"{'Method':<15} {'eta':>10} {'dL_same':>12} {'dL_held':>12} {'dAcc_same':>10}") print("-" * 62) for name in args.methods: if name not in all_credits: continue credits = all_credits[name] for eta in etas: # Deep copy snapshot model_test = copy.deepcopy(model_bp) for p in model_test.parameters(): p.requires_grad_(True) # Do update do_local_update_clean(model_test, x_batch, y_batch, credits, device, eta=eta, update_layers=layers, normalize_credit=normalize, update_head=(ur_name != 'last1')) # skip head for last1 only for p in model_test.parameters(): p.requires_grad_(False) # Evaluate on same batch loss_same, acc_same = eval_loss_on_batch(model_test, x_batch, y_batch) loss_held, acc_held = eval_loss_on_batch(model_test, x_held, y_held) dl_same = loss_same - loss_before_same dl_held = loss_held - loss_before_held da_same = acc_same - acc_before_same key = f"{name}_{ur_name}_{norm_mode}_eta{eta}" results[key] = { 'method': name, 'update_range': ur_name, 'norm_mode': norm_mode, 'eta': eta, 'loss_before_same': loss_before_same, 'loss_after_same': loss_same, 'delta_loss_same': dl_same, 'delta_loss_held': dl_held, 'delta_acc_same': da_same, } print(f"{name:<15} {eta:>10.1e} {dl_same:>+12.6f} {dl_held:>+12.6f} {da_same:>+10.4f}") # ========================================================= # Summary: best eta per method # ========================================================= print(f"\n{'='*60}") print("BEST ETA PER METHOD (minimum same-batch DeltaLoss)") print(f"{'='*60}") for ur_name in update_ranges: for norm_mode in norm_modes: print(f"\n {ur_name}, {norm_mode}:") for name in args.methods: relevant = {k: v for k, v in results.items() if v['method'] == name and v['update_range'] == ur_name and v['norm_mode'] == norm_mode} if not relevant: continue best_key = min(relevant, key=lambda k: relevant[k]['delta_loss_same']) best = relevant[best_key] print(f" {name:<15} best_eta={best['eta']:.1e}, " f"dL_same={best['delta_loss_same']:+.6f}, " f"dL_held={best['delta_loss_held']:+.6f}") # Save out_path = os.path.join(args.output_dir, f'linesearch_L{L}_d{d}_s{args.seed}.json') with open(out_path, 'w') as f: json.dump(results, f, indent=2, default=float) print(f"\nSaved to {out_path}") # ========================================================= # Key diagnostic: does Oracle BP descend at small eta? # ========================================================= print(f"\n{'='*60}") print("KEY DIAGNOSTIC") print(f"{'='*60}") for ur_name in update_ranges: oracle_results = {k: v for k, v in results.items() if v['method'] == 'oracle_bp' and v['update_range'] == ur_name and v['norm_mode'] == 'raw'} if not oracle_results: continue best = min(oracle_results.values(), key=lambda v: v['delta_loss_same']) worst = max(oracle_results.values(), key=lambda v: v['delta_loss_same']) print(f"\n Oracle BP ({ur_name}, raw):") print(f" Best: eta={best['eta']:.1e}, dL_same={best['delta_loss_same']:+.6f}") print(f" Worst: eta={worst['eta']:.1e}, dL_same={worst['delta_loss_same']:+.6f}") if best['delta_loss_same'] < -1e-6: print(f" -> Oracle BP CAN descend on same-batch at small eta. Protocol is OK.") else: print(f" -> WARNING: Oracle BP cannot descend! Check implementation.") def main(): parser = argparse.ArgumentParser(description='Phase 6.5A: Same-batch Line Search') parser.add_argument('--num_blocks', type=int, default=4) parser.add_argument('--d_hidden', type=int, default=256) parser.add_argument('--batch_size', type=int, default=128) parser.add_argument('--estimator_epochs', type=int, default=100) parser.add_argument('--lr_fb', type=float, default=1e-3) parser.add_argument('--methods', type=str, nargs='+', default=['oracle_bp', 'vec_eT_M4']) parser.add_argument('--etas', type=float, nargs='+', default=[1e-5, 3e-5, 1e-4, 3e-4, 1e-3]) parser.add_argument('--update_ranges', type=str, nargs='+', default=['last1']) parser.add_argument('--norm_modes', type=str, nargs='+', default=['raw']) parser.add_argument('--seed', type=int, default=42) parser.add_argument('--gpu', type=int, default=3) parser.add_argument('--output_dir', type=str, default='results/exploit_linesearch') args = parser.parse_args() run_experiment(args) if __name__ == '__main__': main()