summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--NOTE.md47
-rw-r--r--experiments/exploitability_samebatch_linesearch.py392
-rw-r--r--report_explore/MEMO_6.5A_samebatch_linesearch.md44
3 files changed, 482 insertions, 1 deletions
diff --git a/NOTE.md b/NOTE.md
index 6242a41..a57fd30 100644
--- a/NOTE.md
+++ b/NOTE.md
@@ -5,7 +5,7 @@
- **pilot**: Controlled iteration (commits 0b9ebb2, 7baf7ae)
- **frozen**: Code at commit 0b9ebb2 for all reported results
-## Status: PHASE 6 EXPLOITABILITY DISSECTION COMPLETE
+## Status: PHASE 6.5 PROTOCOL AUDIT — PHASE 6A CONCLUSION REVISED
---
@@ -373,3 +373,48 @@ Better credit does NOT lead to better snapshot loss decrease.
### Experiment IDs (Phase 6)
- `snapshot_exploit/`: Phase 6A snapshot exploitability
- `update_swap/`: Phase 6C local update rule comparison
+
+---
+
+## Phase 6.5: Protocol Audit (REVISES Phase 6A conclusion)
+
+### Phase 6.5A: Same-Batch Linesearch
+
+**CRITICAL REVISION**: Phase 6A's "better credit → worse loss" was a protocol artifact.
+
+Phase 6A used: normalized credit + held-out evaluation + gradient clamping.
+Phase 6.5A uses: raw + norm credit, same-batch + held-out eval, no clamping, eta sweep.
+
+**With same-batch evaluation, better credit DOES produce more loss decrease:**
+
+| Method | Gamma | dL_same (norm, all, best eta) | dL_held |
+|--------|-------|-------------------------------|---------|
+| DFA | 0.01 | -0.003 | +0.004 |
+| ScalarCB | 0.12 | -0.025 | +0.027 |
+| Vec_M4 | 0.38 | **-0.135** | +0.045 |
+| Oracle BP | 1.00 | **-0.406** | +0.094 |
+
+Same-batch loss decrease is MONOTONIC with credit quality.
+But held-out loss INCREASES for all non-DFA methods.
+
+**This is Case D: the local surrogate exploits credit correctly on training data,
+but the update overfits to the batch. Better credit = more effective overfitting.**
+
+### Key confounds identified in Phase 6A:
+1. **Normalization** inflated DFA's weak credits to same magnitude as Vec's
+2. **Held-out evaluation** showed generalization failure, not exploitability failure
+3. **Gradient clamping** distorted the natural credit quality ordering
+
+### Raw vs Norm:
+- Raw credit: tiny updates (BP grad RMS ≈ 0.00004). Vec raw best dL_same=-0.005
+- Norm credit: amplifies to useful magnitude but also amplifies overfitting
+
+### Revised diagnosis:
+The bottleneck is NOT "surrogate can't exploit credit" (Phase 6A was wrong).
+It IS "local surrogate with good credit overfits to mini-batch."
+This suggests: regularization of local updates (larger batches, weight decay,
+gradient noise) could make better credit usable.
+
+### Experiment IDs (Phase 6.5)
+- `exploit_linesearch/`: Phase 6.5A smoke test (Oracle + Vec, last1, raw)
+- `exploit_linesearch_full/`: Phase 6.5A full sweep (all methods, ranges, norm modes)
diff --git a/experiments/exploitability_samebatch_linesearch.py b/experiments/exploitability_samebatch_linesearch.py
new file mode 100644
index 0000000..c6c9ba8
--- /dev/null
+++ b/experiments/exploitability_samebatch_linesearch.py
@@ -0,0 +1,392 @@
+"""
+Phase 6.5A: Same-batch infinitesimal descent test.
+
+Strict protocol:
+- Fixed train minibatch B
+- Compute credits on B
+- Do local update with B
+- Evaluate loss change on THE SAME B (same-batch)
+- Sweep eta over multiple orders of magnitude
+- Test raw credit (no normalization) and normalized credit separately
+- No gradient clamping
+- Separate update scopes: last-block-only, last-2, all-blocks
+"""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+import copy
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.residual_mlp import ResidualMLP
+from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema
+from metrics.credit_metrics import cosine_similarity_batch
+
+# Reuse VectorCreditNet and estimator trainers
+from experiments.snapshot_exploitability import (
+ train_scalar_cb_on_snapshot, train_vector_on_snapshot, VectorCreditNet
+)
+
+
+def get_cifar10(batch_size=128):
+ import torchvision
+ import torchvision.transforms as transforms
+ from torch.utils.data import DataLoader
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
+ testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
+ train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
+ test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
+ return train_loader, test_loader
+
+
+def get_credits_on_batch(model, x, y, device, credit_source, estimator=None, dfa_Bs=None):
+ """Compute per-layer credits. Returns credits dict, hiddens list, conditioning s."""
+ L = model.num_blocks
+ batch = x.size(0)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+
+ credits = {}
+
+ if credit_source == 'dfa':
+ for l in range(L):
+ credits[l] = (s @ dfa_Bs[l].T).detach()
+
+ elif credit_source == 'scalar_cb':
+ estimator.eval()
+ for l in range(L):
+ h_l = hiddens[l].detach().requires_grad_(True)
+ t_l = torch.full((batch,), l / L, device=device)
+ V = estimator(h_l, t_l, s)
+ credits[l] = torch.autograd.grad(V.sum(), h_l, create_graph=False)[0].detach()
+
+ elif credit_source == 'vec':
+ estimator.eval()
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ credits[l] = estimator(h_l, t_l, s).detach()
+
+ elif credit_source == 'oracle_bp':
+ for p in model.parameters():
+ p.requires_grad_(True)
+ model.zero_grad()
+ logits_bp, hiddens_bp = model(x, return_hidden=True)
+ for l in range(L + 1):
+ hiddens_bp[l].retain_grad()
+ loss_bp = F.cross_entropy(logits_bp, y)
+ loss_bp.backward()
+ for l in range(L):
+ credits[l] = hiddens_bp[l].grad.detach().clone()
+ for p in model.parameters():
+ p.requires_grad_(False)
+
+ return credits, hiddens, s
+
+
+def do_local_update_clean(model, x, y, credits, device, eta,
+ update_layers, normalize_credit=False,
+ update_head=True):
+ """
+ Clean local update: no gradient clamping, explicit raw/norm control.
+
+ Args:
+ model: in-place modified
+ x, y: the batch
+ credits: dict {l: (batch, d)} raw credit vectors
+ eta: step size
+ update_layers: list of block indices to update
+ normalize_credit: if True, normalize credit by RMS before use
+ update_head: if True, also update output head with exact CE gradient
+ """
+ L = model.num_blocks
+
+ # Recompute hiddens with current params (important after any param change)
+ with torch.no_grad():
+ _, hiddens = model(x, return_hidden=True)
+
+ if update_head:
+ hL = hiddens[-1].detach()
+ logits_out = model.out_head(model.out_ln(hL))
+ loss_out = F.cross_entropy(logits_out, y)
+ head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters())
+ grads_head = torch.autograd.grad(loss_out, head_params)
+ with torch.no_grad():
+ for p, g in zip(head_params, grads_head):
+ p.sub_(eta * g)
+
+ for l in update_layers:
+ h_l = hiddens[l].detach()
+ a = credits[l]
+
+ if normalize_credit:
+ rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_used = a / rms
+ else:
+ a_used = a
+
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a_used.detach()).sum(dim=-1).mean()
+ block_grads = torch.autograd.grad(local_loss, model.blocks[l].parameters())
+ with torch.no_grad():
+ for p, g in zip(model.blocks[l].parameters(), block_grads):
+ p.sub_(eta * g) # NO clamping
+
+
+def eval_loss_on_batch(model, x, y):
+ """Evaluate CE loss on a specific batch."""
+ model.eval()
+ with torch.no_grad():
+ logits = model(x)
+ loss = F.cross_entropy(logits, y).item()
+ acc = (logits.argmax(1) == y).float().mean().item()
+ return loss, acc
+
+
+# =============================================================================
+# Main
+# =============================================================================
+def run_experiment(args):
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ print(f"Using device: {device}")
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ torch.manual_seed(args.seed)
+ np.random.seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+
+ train_loader, test_loader = get_cifar10(args.batch_size)
+ input_dim = 32 * 32 * 3
+ L = args.num_blocks
+ d = args.d_hidden
+
+ # Load BP snapshot
+ model_bp = ResidualMLP(input_dim, d, 10, L).to(device)
+ bp_ckpt = f'results/frozen_cifar/bp_ref_L{L}_d{d}_s{args.seed}.pt'
+ model_bp.load_state_dict(torch.load(bp_ckpt, map_location=device))
+ model_bp.eval()
+ for p in model_bp.parameters():
+ p.requires_grad_(False)
+ print(f"Loaded BP snapshot from {bp_ckpt}")
+
+ # Get a FIXED train batch (same-batch protocol)
+ train_iter = iter(train_loader)
+ x_batch, y_batch = next(train_iter)
+ x_batch = x_batch.view(x_batch.size(0), -1).to(device)
+ y_batch = y_batch.to(device)
+ print(f"Fixed train batch: {x_batch.shape[0]} samples")
+
+ # Get a separate held-out batch for comparison
+ x_held, y_held = next(train_iter)
+ x_held = x_held.view(x_held.size(0), -1).to(device)
+ y_held = y_held.to(device)
+
+ # Baseline losses
+ loss_before_same, acc_before_same = eval_loss_on_batch(model_bp, x_batch, y_batch)
+ loss_before_held, acc_before_held = eval_loss_on_batch(model_bp, x_held, y_held)
+ print(f"Before: same_batch_loss={loss_before_same:.6f}, held_out_loss={loss_before_held:.6f}")
+
+ # =========================================================
+ # Prepare credit sources
+ # =========================================================
+ credit_configs = {}
+
+ # DFA
+ dfa_Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)]
+ credit_configs['dfa'] = ('dfa', None, dfa_Bs)
+
+ # Scalar CB (train on frozen snapshot)
+ if 'scalar_cb' in args.methods:
+ print("\nTraining ScalarCB on snapshot...")
+ torch.manual_seed(args.seed + 2000)
+ cb = train_scalar_cb_on_snapshot(model_bp, train_loader, device,
+ epochs=args.estimator_epochs, lr_fb=args.lr_fb)
+ credit_configs['scalar_cb'] = ('scalar_cb', cb, None)
+
+ # Vec M4 (train on frozen snapshot)
+ if 'vec_eT_M4' in args.methods:
+ print("\nTraining Vec_M4 on snapshot...")
+ torch.manual_seed(args.seed + 4000)
+ vec4 = train_vector_on_snapshot(model_bp, train_loader, device,
+ epochs=args.estimator_epochs, lr_fb=args.lr_fb, M=4)
+ credit_configs['vec_eT_M4'] = ('vec', vec4, None)
+
+ # Oracle BP
+ credit_configs['oracle_bp'] = ('oracle_bp', None, None)
+
+ # =========================================================
+ # Compute credits on the fixed batch
+ # =========================================================
+ print("\nComputing credits on fixed batch...")
+ all_credits = {}
+ for name, (src, est, Bs) in credit_configs.items():
+ if name not in args.methods:
+ continue
+ credits, hiddens, s = get_credits_on_batch(model_bp, x_batch, y_batch, device,
+ src, estimator=est, dfa_Bs=Bs)
+ all_credits[name] = credits
+ # Report credit magnitudes
+ mean_rms = np.mean([credits[l].pow(2).mean().sqrt().item() for l in range(L)])
+ print(f" {name}: mean_credit_RMS={mean_rms:.6f}")
+
+ # =========================================================
+ # Line search
+ # =========================================================
+ etas = args.etas
+ update_ranges = {}
+ if 'last1' in args.update_ranges:
+ update_ranges['last1'] = [L - 1]
+ if 'last2' in args.update_ranges:
+ update_ranges['last2'] = [L - 2, L - 1]
+ if 'all' in args.update_ranges:
+ update_ranges['all'] = list(range(L))
+
+ norm_modes = args.norm_modes # ['raw'] or ['raw', 'norm']
+
+ results = {}
+
+ for ur_name, layers in update_ranges.items():
+ for norm_mode in norm_modes:
+ normalize = (norm_mode == 'norm')
+ print(f"\n{'='*60}")
+ print(f"Update range: {ur_name}, credit: {norm_mode}")
+ print(f"{'='*60}")
+ print(f"{'Method':<15} {'eta':>10} {'dL_same':>12} {'dL_held':>12} {'dAcc_same':>10}")
+ print("-" * 62)
+
+ for name in args.methods:
+ if name not in all_credits:
+ continue
+ credits = all_credits[name]
+
+ for eta in etas:
+ # Deep copy snapshot
+ model_test = copy.deepcopy(model_bp)
+ for p in model_test.parameters():
+ p.requires_grad_(True)
+
+ # Do update
+ do_local_update_clean(model_test, x_batch, y_batch, credits, device,
+ eta=eta, update_layers=layers,
+ normalize_credit=normalize,
+ update_head=(ur_name != 'last1')) # skip head for last1 only
+
+ for p in model_test.parameters():
+ p.requires_grad_(False)
+
+ # Evaluate on same batch
+ loss_same, acc_same = eval_loss_on_batch(model_test, x_batch, y_batch)
+ loss_held, acc_held = eval_loss_on_batch(model_test, x_held, y_held)
+
+ dl_same = loss_same - loss_before_same
+ dl_held = loss_held - loss_before_held
+ da_same = acc_same - acc_before_same
+
+ key = f"{name}_{ur_name}_{norm_mode}_eta{eta}"
+ results[key] = {
+ 'method': name, 'update_range': ur_name,
+ 'norm_mode': norm_mode, 'eta': eta,
+ 'loss_before_same': loss_before_same,
+ 'loss_after_same': loss_same,
+ 'delta_loss_same': dl_same,
+ 'delta_loss_held': dl_held,
+ 'delta_acc_same': da_same,
+ }
+
+ print(f"{name:<15} {eta:>10.1e} {dl_same:>+12.6f} {dl_held:>+12.6f} {da_same:>+10.4f}")
+
+ # =========================================================
+ # Summary: best eta per method
+ # =========================================================
+ print(f"\n{'='*60}")
+ print("BEST ETA PER METHOD (minimum same-batch DeltaLoss)")
+ print(f"{'='*60}")
+
+ for ur_name in update_ranges:
+ for norm_mode in norm_modes:
+ print(f"\n {ur_name}, {norm_mode}:")
+ for name in args.methods:
+ relevant = {k: v for k, v in results.items()
+ if v['method'] == name and v['update_range'] == ur_name
+ and v['norm_mode'] == norm_mode}
+ if not relevant:
+ continue
+ best_key = min(relevant, key=lambda k: relevant[k]['delta_loss_same'])
+ best = relevant[best_key]
+ print(f" {name:<15} best_eta={best['eta']:.1e}, "
+ f"dL_same={best['delta_loss_same']:+.6f}, "
+ f"dL_held={best['delta_loss_held']:+.6f}")
+
+ # Save
+ out_path = os.path.join(args.output_dir, f'linesearch_L{L}_d{d}_s{args.seed}.json')
+ with open(out_path, 'w') as f:
+ json.dump(results, f, indent=2, default=float)
+ print(f"\nSaved to {out_path}")
+
+ # =========================================================
+ # Key diagnostic: does Oracle BP descend at small eta?
+ # =========================================================
+ print(f"\n{'='*60}")
+ print("KEY DIAGNOSTIC")
+ print(f"{'='*60}")
+
+ for ur_name in update_ranges:
+ oracle_results = {k: v for k, v in results.items()
+ if v['method'] == 'oracle_bp' and v['update_range'] == ur_name
+ and v['norm_mode'] == 'raw'}
+ if not oracle_results:
+ continue
+ best = min(oracle_results.values(), key=lambda v: v['delta_loss_same'])
+ worst = max(oracle_results.values(), key=lambda v: v['delta_loss_same'])
+ print(f"\n Oracle BP ({ur_name}, raw):")
+ print(f" Best: eta={best['eta']:.1e}, dL_same={best['delta_loss_same']:+.6f}")
+ print(f" Worst: eta={worst['eta']:.1e}, dL_same={worst['delta_loss_same']:+.6f}")
+
+ if best['delta_loss_same'] < -1e-6:
+ print(f" -> Oracle BP CAN descend on same-batch at small eta. Protocol is OK.")
+ else:
+ print(f" -> WARNING: Oracle BP cannot descend! Check implementation.")
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Phase 6.5A: Same-batch Line Search')
+ parser.add_argument('--num_blocks', type=int, default=4)
+ parser.add_argument('--d_hidden', type=int, default=256)
+ parser.add_argument('--batch_size', type=int, default=128)
+ parser.add_argument('--estimator_epochs', type=int, default=100)
+ parser.add_argument('--lr_fb', type=float, default=1e-3)
+ parser.add_argument('--methods', type=str, nargs='+',
+ default=['oracle_bp', 'vec_eT_M4'])
+ parser.add_argument('--etas', type=float, nargs='+',
+ default=[1e-5, 3e-5, 1e-4, 3e-4, 1e-3])
+ parser.add_argument('--update_ranges', type=str, nargs='+', default=['last1'])
+ parser.add_argument('--norm_modes', type=str, nargs='+', default=['raw'])
+ parser.add_argument('--seed', type=int, default=42)
+ parser.add_argument('--gpu', type=int, default=3)
+ parser.add_argument('--output_dir', type=str, default='results/exploit_linesearch')
+ args = parser.parse_args()
+ run_experiment(args)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/report_explore/MEMO_6.5A_samebatch_linesearch.md b/report_explore/MEMO_6.5A_samebatch_linesearch.md
new file mode 100644
index 0000000..733db12
--- /dev/null
+++ b/report_explore/MEMO_6.5A_samebatch_linesearch.md
@@ -0,0 +1,44 @@
+# Phase 6.5A Memo: Same-Batch Linesearch
+
+**Date**: 2026-03-25
+
+## Question
+Under strict same-batch evaluation, does better credit produce better loss decrease?
+
+## Answer: YES.
+
+Phase 6A's conclusion was wrong due to protocol confounds. With same-batch evaluation:
+
+### All blocks, normalized credit (closest to Phase 6A protocol):
+
+| Method | best eta | dL_same | dL_held |
+|--------|---------|---------|---------|
+| DFA | 1e-2 | -0.003 | +0.004 |
+| ScalarCB | 3e-3 | -0.025 | +0.027 |
+| Vec_M4 | 3e-3 | **-0.135** | +0.045 |
+| Oracle BP | 1e-2 | **-0.406** | +0.094 |
+
+**Same-batch loss decreases monotonically with credit quality**: Oracle > Vec > ScalarCB > DFA.
+
+### But held-out loss increases for all non-DFA methods.
+
+This is **Case D**: the local surrogate correctly exploits credit to decrease training loss, but the update overfits to the batch. The better the credit, the more effective the overfitting.
+
+### Key confounds in Phase 6A:
+1. **Normalization**: Phase 6A always normalized credit, which amplified DFA's weak signals to the same magnitude as Vec's strong signals, erasing the natural ordering
+2. **Held-out evaluation**: Phase 6A evaluated on held-out batches, showing the generalization failure rather than the exploitability success
+3. **Gradient clamping**: Phase 6A clamped gradients to [-1, 1], further distorting the relationship
+
+### Raw vs Normalized (all blocks):
+| Method | raw dL_same (best) | norm dL_same (best) |
+|--------|--------------------|---------------------|
+| Vec_M4 | -0.005 | -0.135 |
+| Oracle | -0.003 | -0.406 |
+
+Raw credit produces tiny updates because BP gradients have RMS ≈ 0.00004. Normalization brings all methods to comparable magnitude but introduces overfitting.
+
+## Revised Diagnosis
+
+The bottleneck is NOT "local surrogate cannot exploit good credit" (Case B from Phase 6A). It IS:
+- **Generalization/overfitting**: local surrogate with good credit decreases train loss but increases held-out loss
+- This means the project direction should be about **regularizing local updates** rather than replacing the surrogate entirely