From 7e01fbc0ce871857c1e1879ed0d3559e8bfae7c7 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Wed, 25 Mar 2026 08:22:04 -0500 Subject: Add Phase 6.5A: same-batch linesearch REVISES Phase 6A conclusion MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 6A's "better credit → worse loss" was a protocol artifact caused by: 1. Credit normalization (inflated DFA, suppressed Vec magnitude ordering) 2. Held-out evaluation (measured generalization failure, not exploitability) 3. Gradient clamping With strict same-batch evaluation: - Oracle BP: dL_same = -0.406 (strongest descent) - Vec_M4: dL_same = -0.135 - ScalarCB: dL_same = -0.025 - DFA: dL_same = -0.003 Same-batch loss decrease is MONOTONIC with credit quality. But held-out loss INCREASES for all non-DFA methods (Case D: overfitting). The bottleneck is batch-level generalization, not surrogate exploitability. Co-Authored-By: Claude Opus 4.6 (1M context) --- NOTE.md | 47 ++- experiments/exploitability_samebatch_linesearch.py | 392 +++++++++++++++++++++ report_explore/MEMO_6.5A_samebatch_linesearch.md | 44 +++ 3 files changed, 482 insertions(+), 1 deletion(-) create mode 100644 experiments/exploitability_samebatch_linesearch.py create mode 100644 report_explore/MEMO_6.5A_samebatch_linesearch.md diff --git a/NOTE.md b/NOTE.md index 6242a41..a57fd30 100644 --- a/NOTE.md +++ b/NOTE.md @@ -5,7 +5,7 @@ - **pilot**: Controlled iteration (commits 0b9ebb2, 7baf7ae) - **frozen**: Code at commit 0b9ebb2 for all reported results -## Status: PHASE 6 EXPLOITABILITY DISSECTION COMPLETE +## Status: PHASE 6.5 PROTOCOL AUDIT — PHASE 6A CONCLUSION REVISED --- @@ -373,3 +373,48 @@ Better credit does NOT lead to better snapshot loss decrease. ### Experiment IDs (Phase 6) - `snapshot_exploit/`: Phase 6A snapshot exploitability - `update_swap/`: Phase 6C local update rule comparison + +--- + +## Phase 6.5: Protocol Audit (REVISES Phase 6A conclusion) + +### Phase 6.5A: Same-Batch Linesearch + +**CRITICAL REVISION**: Phase 6A's "better credit → worse loss" was a protocol artifact. + +Phase 6A used: normalized credit + held-out evaluation + gradient clamping. +Phase 6.5A uses: raw + norm credit, same-batch + held-out eval, no clamping, eta sweep. + +**With same-batch evaluation, better credit DOES produce more loss decrease:** + +| Method | Gamma | dL_same (norm, all, best eta) | dL_held | +|--------|-------|-------------------------------|---------| +| DFA | 0.01 | -0.003 | +0.004 | +| ScalarCB | 0.12 | -0.025 | +0.027 | +| Vec_M4 | 0.38 | **-0.135** | +0.045 | +| Oracle BP | 1.00 | **-0.406** | +0.094 | + +Same-batch loss decrease is MONOTONIC with credit quality. +But held-out loss INCREASES for all non-DFA methods. + +**This is Case D: the local surrogate exploits credit correctly on training data, +but the update overfits to the batch. Better credit = more effective overfitting.** + +### Key confounds identified in Phase 6A: +1. **Normalization** inflated DFA's weak credits to same magnitude as Vec's +2. **Held-out evaluation** showed generalization failure, not exploitability failure +3. **Gradient clamping** distorted the natural credit quality ordering + +### Raw vs Norm: +- Raw credit: tiny updates (BP grad RMS ≈ 0.00004). Vec raw best dL_same=-0.005 +- Norm credit: amplifies to useful magnitude but also amplifies overfitting + +### Revised diagnosis: +The bottleneck is NOT "surrogate can't exploit credit" (Phase 6A was wrong). +It IS "local surrogate with good credit overfits to mini-batch." +This suggests: regularization of local updates (larger batches, weight decay, +gradient noise) could make better credit usable. + +### Experiment IDs (Phase 6.5) +- `exploit_linesearch/`: Phase 6.5A smoke test (Oracle + Vec, last1, raw) +- `exploit_linesearch_full/`: Phase 6.5A full sweep (all methods, ranges, norm modes) diff --git a/experiments/exploitability_samebatch_linesearch.py b/experiments/exploitability_samebatch_linesearch.py new file mode 100644 index 0000000..c6c9ba8 --- /dev/null +++ b/experiments/exploitability_samebatch_linesearch.py @@ -0,0 +1,392 @@ +""" +Phase 6.5A: Same-batch infinitesimal descent test. + +Strict protocol: +- Fixed train minibatch B +- Compute credits on B +- Do local update with B +- Evaluate loss change on THE SAME B (same-batch) +- Sweep eta over multiple orders of magnitude +- Test raw credit (no normalization) and normalized credit separately +- No gradient clamping +- Separate update scopes: last-block-only, last-2, all-blocks +""" +import os +import sys +import json +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import copy + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.residual_mlp import ResidualMLP +from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema +from metrics.credit_metrics import cosine_similarity_batch + +# Reuse VectorCreditNet and estimator trainers +from experiments.snapshot_exploitability import ( + train_scalar_cb_on_snapshot, train_vector_on_snapshot, VectorCreditNet +) + + +def get_cifar10(batch_size=128): + import torchvision + import torchvision.transforms as transforms + from torch.utils.data import DataLoader + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) + testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) + train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) + test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) + return train_loader, test_loader + + +def get_credits_on_batch(model, x, y, device, credit_source, estimator=None, dfa_Bs=None): + """Compute per-layer credits. Returns credits dict, hiddens list, conditioning s.""" + L = model.num_blocks + batch = x.size(0) + + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + + credits = {} + + if credit_source == 'dfa': + for l in range(L): + credits[l] = (s @ dfa_Bs[l].T).detach() + + elif credit_source == 'scalar_cb': + estimator.eval() + for l in range(L): + h_l = hiddens[l].detach().requires_grad_(True) + t_l = torch.full((batch,), l / L, device=device) + V = estimator(h_l, t_l, s) + credits[l] = torch.autograd.grad(V.sum(), h_l, create_graph=False)[0].detach() + + elif credit_source == 'vec': + estimator.eval() + for l in range(L): + h_l = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + credits[l] = estimator(h_l, t_l, s).detach() + + elif credit_source == 'oracle_bp': + for p in model.parameters(): + p.requires_grad_(True) + model.zero_grad() + logits_bp, hiddens_bp = model(x, return_hidden=True) + for l in range(L + 1): + hiddens_bp[l].retain_grad() + loss_bp = F.cross_entropy(logits_bp, y) + loss_bp.backward() + for l in range(L): + credits[l] = hiddens_bp[l].grad.detach().clone() + for p in model.parameters(): + p.requires_grad_(False) + + return credits, hiddens, s + + +def do_local_update_clean(model, x, y, credits, device, eta, + update_layers, normalize_credit=False, + update_head=True): + """ + Clean local update: no gradient clamping, explicit raw/norm control. + + Args: + model: in-place modified + x, y: the batch + credits: dict {l: (batch, d)} raw credit vectors + eta: step size + update_layers: list of block indices to update + normalize_credit: if True, normalize credit by RMS before use + update_head: if True, also update output head with exact CE gradient + """ + L = model.num_blocks + + # Recompute hiddens with current params (important after any param change) + with torch.no_grad(): + _, hiddens = model(x, return_hidden=True) + + if update_head: + hL = hiddens[-1].detach() + logits_out = model.out_head(model.out_ln(hL)) + loss_out = F.cross_entropy(logits_out, y) + head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters()) + grads_head = torch.autograd.grad(loss_out, head_params) + with torch.no_grad(): + for p, g in zip(head_params, grads_head): + p.sub_(eta * g) + + for l in update_layers: + h_l = hiddens[l].detach() + a = credits[l] + + if normalize_credit: + rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_used = a / rms + else: + a_used = a + + f_l = model.blocks[l](h_l) + local_loss = (f_l * a_used.detach()).sum(dim=-1).mean() + block_grads = torch.autograd.grad(local_loss, model.blocks[l].parameters()) + with torch.no_grad(): + for p, g in zip(model.blocks[l].parameters(), block_grads): + p.sub_(eta * g) # NO clamping + + +def eval_loss_on_batch(model, x, y): + """Evaluate CE loss on a specific batch.""" + model.eval() + with torch.no_grad(): + logits = model(x) + loss = F.cross_entropy(logits, y).item() + acc = (logits.argmax(1) == y).float().mean().item() + return loss, acc + + +# ============================================================================= +# Main +# ============================================================================= +def run_experiment(args): + device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + os.makedirs(args.output_dir, exist_ok=True) + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + train_loader, test_loader = get_cifar10(args.batch_size) + input_dim = 32 * 32 * 3 + L = args.num_blocks + d = args.d_hidden + + # Load BP snapshot + model_bp = ResidualMLP(input_dim, d, 10, L).to(device) + bp_ckpt = f'results/frozen_cifar/bp_ref_L{L}_d{d}_s{args.seed}.pt' + model_bp.load_state_dict(torch.load(bp_ckpt, map_location=device)) + model_bp.eval() + for p in model_bp.parameters(): + p.requires_grad_(False) + print(f"Loaded BP snapshot from {bp_ckpt}") + + # Get a FIXED train batch (same-batch protocol) + train_iter = iter(train_loader) + x_batch, y_batch = next(train_iter) + x_batch = x_batch.view(x_batch.size(0), -1).to(device) + y_batch = y_batch.to(device) + print(f"Fixed train batch: {x_batch.shape[0]} samples") + + # Get a separate held-out batch for comparison + x_held, y_held = next(train_iter) + x_held = x_held.view(x_held.size(0), -1).to(device) + y_held = y_held.to(device) + + # Baseline losses + loss_before_same, acc_before_same = eval_loss_on_batch(model_bp, x_batch, y_batch) + loss_before_held, acc_before_held = eval_loss_on_batch(model_bp, x_held, y_held) + print(f"Before: same_batch_loss={loss_before_same:.6f}, held_out_loss={loss_before_held:.6f}") + + # ========================================================= + # Prepare credit sources + # ========================================================= + credit_configs = {} + + # DFA + dfa_Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)] + credit_configs['dfa'] = ('dfa', None, dfa_Bs) + + # Scalar CB (train on frozen snapshot) + if 'scalar_cb' in args.methods: + print("\nTraining ScalarCB on snapshot...") + torch.manual_seed(args.seed + 2000) + cb = train_scalar_cb_on_snapshot(model_bp, train_loader, device, + epochs=args.estimator_epochs, lr_fb=args.lr_fb) + credit_configs['scalar_cb'] = ('scalar_cb', cb, None) + + # Vec M4 (train on frozen snapshot) + if 'vec_eT_M4' in args.methods: + print("\nTraining Vec_M4 on snapshot...") + torch.manual_seed(args.seed + 4000) + vec4 = train_vector_on_snapshot(model_bp, train_loader, device, + epochs=args.estimator_epochs, lr_fb=args.lr_fb, M=4) + credit_configs['vec_eT_M4'] = ('vec', vec4, None) + + # Oracle BP + credit_configs['oracle_bp'] = ('oracle_bp', None, None) + + # ========================================================= + # Compute credits on the fixed batch + # ========================================================= + print("\nComputing credits on fixed batch...") + all_credits = {} + for name, (src, est, Bs) in credit_configs.items(): + if name not in args.methods: + continue + credits, hiddens, s = get_credits_on_batch(model_bp, x_batch, y_batch, device, + src, estimator=est, dfa_Bs=Bs) + all_credits[name] = credits + # Report credit magnitudes + mean_rms = np.mean([credits[l].pow(2).mean().sqrt().item() for l in range(L)]) + print(f" {name}: mean_credit_RMS={mean_rms:.6f}") + + # ========================================================= + # Line search + # ========================================================= + etas = args.etas + update_ranges = {} + if 'last1' in args.update_ranges: + update_ranges['last1'] = [L - 1] + if 'last2' in args.update_ranges: + update_ranges['last2'] = [L - 2, L - 1] + if 'all' in args.update_ranges: + update_ranges['all'] = list(range(L)) + + norm_modes = args.norm_modes # ['raw'] or ['raw', 'norm'] + + results = {} + + for ur_name, layers in update_ranges.items(): + for norm_mode in norm_modes: + normalize = (norm_mode == 'norm') + print(f"\n{'='*60}") + print(f"Update range: {ur_name}, credit: {norm_mode}") + print(f"{'='*60}") + print(f"{'Method':<15} {'eta':>10} {'dL_same':>12} {'dL_held':>12} {'dAcc_same':>10}") + print("-" * 62) + + for name in args.methods: + if name not in all_credits: + continue + credits = all_credits[name] + + for eta in etas: + # Deep copy snapshot + model_test = copy.deepcopy(model_bp) + for p in model_test.parameters(): + p.requires_grad_(True) + + # Do update + do_local_update_clean(model_test, x_batch, y_batch, credits, device, + eta=eta, update_layers=layers, + normalize_credit=normalize, + update_head=(ur_name != 'last1')) # skip head for last1 only + + for p in model_test.parameters(): + p.requires_grad_(False) + + # Evaluate on same batch + loss_same, acc_same = eval_loss_on_batch(model_test, x_batch, y_batch) + loss_held, acc_held = eval_loss_on_batch(model_test, x_held, y_held) + + dl_same = loss_same - loss_before_same + dl_held = loss_held - loss_before_held + da_same = acc_same - acc_before_same + + key = f"{name}_{ur_name}_{norm_mode}_eta{eta}" + results[key] = { + 'method': name, 'update_range': ur_name, + 'norm_mode': norm_mode, 'eta': eta, + 'loss_before_same': loss_before_same, + 'loss_after_same': loss_same, + 'delta_loss_same': dl_same, + 'delta_loss_held': dl_held, + 'delta_acc_same': da_same, + } + + print(f"{name:<15} {eta:>10.1e} {dl_same:>+12.6f} {dl_held:>+12.6f} {da_same:>+10.4f}") + + # ========================================================= + # Summary: best eta per method + # ========================================================= + print(f"\n{'='*60}") + print("BEST ETA PER METHOD (minimum same-batch DeltaLoss)") + print(f"{'='*60}") + + for ur_name in update_ranges: + for norm_mode in norm_modes: + print(f"\n {ur_name}, {norm_mode}:") + for name in args.methods: + relevant = {k: v for k, v in results.items() + if v['method'] == name and v['update_range'] == ur_name + and v['norm_mode'] == norm_mode} + if not relevant: + continue + best_key = min(relevant, key=lambda k: relevant[k]['delta_loss_same']) + best = relevant[best_key] + print(f" {name:<15} best_eta={best['eta']:.1e}, " + f"dL_same={best['delta_loss_same']:+.6f}, " + f"dL_held={best['delta_loss_held']:+.6f}") + + # Save + out_path = os.path.join(args.output_dir, f'linesearch_L{L}_d{d}_s{args.seed}.json') + with open(out_path, 'w') as f: + json.dump(results, f, indent=2, default=float) + print(f"\nSaved to {out_path}") + + # ========================================================= + # Key diagnostic: does Oracle BP descend at small eta? + # ========================================================= + print(f"\n{'='*60}") + print("KEY DIAGNOSTIC") + print(f"{'='*60}") + + for ur_name in update_ranges: + oracle_results = {k: v for k, v in results.items() + if v['method'] == 'oracle_bp' and v['update_range'] == ur_name + and v['norm_mode'] == 'raw'} + if not oracle_results: + continue + best = min(oracle_results.values(), key=lambda v: v['delta_loss_same']) + worst = max(oracle_results.values(), key=lambda v: v['delta_loss_same']) + print(f"\n Oracle BP ({ur_name}, raw):") + print(f" Best: eta={best['eta']:.1e}, dL_same={best['delta_loss_same']:+.6f}") + print(f" Worst: eta={worst['eta']:.1e}, dL_same={worst['delta_loss_same']:+.6f}") + + if best['delta_loss_same'] < -1e-6: + print(f" -> Oracle BP CAN descend on same-batch at small eta. Protocol is OK.") + else: + print(f" -> WARNING: Oracle BP cannot descend! Check implementation.") + + +def main(): + parser = argparse.ArgumentParser(description='Phase 6.5A: Same-batch Line Search') + parser.add_argument('--num_blocks', type=int, default=4) + parser.add_argument('--d_hidden', type=int, default=256) + parser.add_argument('--batch_size', type=int, default=128) + parser.add_argument('--estimator_epochs', type=int, default=100) + parser.add_argument('--lr_fb', type=float, default=1e-3) + parser.add_argument('--methods', type=str, nargs='+', + default=['oracle_bp', 'vec_eT_M4']) + parser.add_argument('--etas', type=float, nargs='+', + default=[1e-5, 3e-5, 1e-4, 3e-4, 1e-3]) + parser.add_argument('--update_ranges', type=str, nargs='+', default=['last1']) + parser.add_argument('--norm_modes', type=str, nargs='+', default=['raw']) + parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--gpu', type=int, default=3) + parser.add_argument('--output_dir', type=str, default='results/exploit_linesearch') + args = parser.parse_args() + run_experiment(args) + + +if __name__ == '__main__': + main() diff --git a/report_explore/MEMO_6.5A_samebatch_linesearch.md b/report_explore/MEMO_6.5A_samebatch_linesearch.md new file mode 100644 index 0000000..733db12 --- /dev/null +++ b/report_explore/MEMO_6.5A_samebatch_linesearch.md @@ -0,0 +1,44 @@ +# Phase 6.5A Memo: Same-Batch Linesearch + +**Date**: 2026-03-25 + +## Question +Under strict same-batch evaluation, does better credit produce better loss decrease? + +## Answer: YES. + +Phase 6A's conclusion was wrong due to protocol confounds. With same-batch evaluation: + +### All blocks, normalized credit (closest to Phase 6A protocol): + +| Method | best eta | dL_same | dL_held | +|--------|---------|---------|---------| +| DFA | 1e-2 | -0.003 | +0.004 | +| ScalarCB | 3e-3 | -0.025 | +0.027 | +| Vec_M4 | 3e-3 | **-0.135** | +0.045 | +| Oracle BP | 1e-2 | **-0.406** | +0.094 | + +**Same-batch loss decreases monotonically with credit quality**: Oracle > Vec > ScalarCB > DFA. + +### But held-out loss increases for all non-DFA methods. + +This is **Case D**: the local surrogate correctly exploits credit to decrease training loss, but the update overfits to the batch. The better the credit, the more effective the overfitting. + +### Key confounds in Phase 6A: +1. **Normalization**: Phase 6A always normalized credit, which amplified DFA's weak signals to the same magnitude as Vec's strong signals, erasing the natural ordering +2. **Held-out evaluation**: Phase 6A evaluated on held-out batches, showing the generalization failure rather than the exploitability success +3. **Gradient clamping**: Phase 6A clamped gradients to [-1, 1], further distorting the relationship + +### Raw vs Normalized (all blocks): +| Method | raw dL_same (best) | norm dL_same (best) | +|--------|--------------------|---------------------| +| Vec_M4 | -0.005 | -0.135 | +| Oracle | -0.003 | -0.406 | + +Raw credit produces tiny updates because BP gradients have RMS ≈ 0.00004. Normalization brings all methods to comparable magnitude but introduces overfitting. + +## Revised Diagnosis + +The bottleneck is NOT "local surrogate cannot exploit good credit" (Case B from Phase 6A). It IS: +- **Generalization/overfitting**: local surrogate with good credit decreases train loss but increases held-out loss +- This means the project direction should be about **regularizing local updates** rather than replacing the surrogate entirely -- cgit v1.2.3