diff options
| -rw-r--r-- | NOTE.md | 46 | ||||
| -rw-r--r-- | experiments/local_update_swap.py | 427 | ||||
| -rw-r--r-- | experiments/snapshot_exploitability.py | 713 | ||||
| -rw-r--r-- | report_explore/MEMO_6A_snapshot_exploitability.md | 39 | ||||
| -rw-r--r-- | report_explore/MEMO_6_exploitability.md | 53 |
5 files changed, 1277 insertions, 1 deletions
@@ -5,7 +5,7 @@ - **pilot**: Controlled iteration (commits 0b9ebb2, 7baf7ae) - **frozen**: Code at commit 0b9ebb2 for all reported results -## Status: PHASE 5 VECTOR FIELD AUDIT + TRANSFER COMPLETE +## Status: PHASE 6 EXPLOITABILITY DISSECTION COMPLETE --- @@ -329,3 +329,47 @@ the signal. - `vector_audit_full/`: Phase 5A full 3-seed audit - `frozen_cifar_vec/`: Phase 5B frozen CIFAR vector transfer - `online_vec_pilot/`: Phase 5C online CIFAR vector pilot + +--- + +## Phase 6: Exploitability Dissection + +### Phase 6A: Snapshot Exploitability + +**Setup**: BP-trained CIFAR snapshot (L=4, d=256, 61.9% acc). +Offline-trained estimators. k-step local updates with real loss measurement. + +**CRITICAL FINDING: Better credit → worse loss decrease.** + +| Credit | Gamma | rho | dL_5step (inner_product) | +|--------|-------|-----|-------------------------| +| DFA | 0.009 | -0.023 | **-0.0001** (only negative!) | +| ScalarCB | 0.122 | 0.090 | +0.042 | +| Vec_M4 | 0.378 | 0.411 | +0.057 | +| Oracle BP | 1.000 | 0.998 | +0.011 | + +Credit quality is ANTI-CORRELATED with loss decrease. +DFA (worst credit) is the only method not increasing loss. + +### Phase 6C: Local Update Rule Swap + +Tested target-shift (`h_target = h_{l+1} - eta * a_norm`) at eta in {0.01, 0.1, 0.3, 1.0}. + +Target-shift reduces damage (Vec dL: +0.057 → +0.002 at eta=0.1) but never achieves +negative DeltaLoss for any non-DFA credit. Cosine rule produces near-zero effects. + +### Root Cause + +The inner-product surrogate `<F_l(h), a>` is not a valid proxy for global loss minimization. +The gradient of this surrogate w.r.t. block parameters ≠ gradient of global loss w.r.t. same parameters. +A BP-trained snapshot is at a minimum reachable only by full BP; local updates systematically push uphill. + +DFA works because its credits are weak enough to produce near-zero updates, effectively doing nothing. + +### This is Case B from the diagnostic logic tree: +Better credit does NOT lead to better snapshot loss decrease. +**The primary bottleneck is the local update rule itself, not the estimator or tracking.** + +### Experiment IDs (Phase 6) +- `snapshot_exploit/`: Phase 6A snapshot exploitability +- `update_swap/`: Phase 6C local update rule comparison diff --git a/experiments/local_update_swap.py b/experiments/local_update_swap.py new file mode 100644 index 0000000..207560a --- /dev/null +++ b/experiments/local_update_swap.py @@ -0,0 +1,427 @@ +""" +Phase 6C: Local Update Rule Swap. + +Compare different local update rules using the same credit signals on a fixed snapshot. + +Rule 1 (baseline): Inner-product surrogate + L_inner = <F_l(h_l), a_{l+1}> + +Rule 2: Target-shift local regression + h_{l+1}^target = h_{l+1} - eta_target * a_{l+1}^norm + L_shift = 0.5 * || h_l + F_l(h_l) - sg(h_{l+1}^target) ||^2 + +Rule 3: Cosine-target update + L_cos = - cos(F_l(h_l), a_{l+1}) +""" +import os +import sys +import json +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision +import torchvision.transforms as transforms +import copy + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.residual_mlp import ResidualMLP +from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema +from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation, nudging_test + + +class VectorCreditNet(nn.Module): + def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3): + super().__init__() + self.ln = nn.LayerNorm(d_hidden) + self.time_embed = SinusoidalTimeEmbed(time_embed_dim) + input_dim = d_hidden + time_embed_dim + s_dim + layers = [] + for i in range(num_layers): + in_d = input_dim if i == 0 else hidden_dim + layers.append(nn.Linear(in_d, hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, d_hidden)) + self.net = nn.Sequential(*layers) + + def forward(self, h, t, s): + h_normed = self.ln(h) + t_emb = self.time_embed(t) + inp = torch.cat([h_normed, t_emb, s], dim=-1) + return self.net(inp) + + +def get_cifar10(batch_size=128): + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) + testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) + train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) + test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) + return train_loader, test_loader + + +def get_credits(model, x, y, device, credit_source, estimator=None, dfa_Bs=None): + L = model.num_blocks + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + credits = {} + if credit_source == 'dfa': + for l in range(L): + credits[l] = (s @ dfa_Bs[l].T).detach() + elif credit_source == 'scalar_cb': + estimator.eval() + for l in range(L): + h_l = hiddens[l].detach().requires_grad_(True) + t_l = torch.full((batch,), l / L, device=device) + V = estimator(h_l, t_l, s) + credits[l] = torch.autograd.grad(V.sum(), h_l, create_graph=False)[0].detach() + elif credit_source == 'vec': + estimator.eval() + for l in range(L): + h_l = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + credits[l] = estimator(h_l, t_l, s).detach() + elif credit_source == 'oracle_bp': + for p in model.parameters(): + p.requires_grad_(True) + model.zero_grad() + logits_bp, hiddens_bp = model(x, return_hidden=True) + for l in range(L + 1): + hiddens_bp[l].retain_grad() + F.cross_entropy(logits_bp, y).backward() + for l in range(L): + credits[l] = hiddens_bp[l].grad.detach().clone() + for p in model.parameters(): + p.requires_grad_(False) + return credits, hiddens, s + + +# ============================================================================= +# Local update rules +# ============================================================================= +def update_inner_product(model, x, y, credits, hiddens, device, lr): + """Rule 1: L_inner = <F_l(h_l), a_{l+1}>""" + L = model.num_blocks + # Head + hL = hiddens[-1].detach() + logits_out = model.out_head(model.out_ln(hL)) + loss_out = F.cross_entropy(logits_out, y) + head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters()) + grads_head = torch.autograd.grad(loss_out, head_params) + with torch.no_grad(): + for p, g in zip(head_params, grads_head): + p.sub_(lr * g) + # Blocks + for l in range(L): + h_l = hiddens[l].detach() + a = credits[l] + rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_norm = a / rms + f_l = model.blocks[l](h_l) + local_loss = (f_l * a_norm).sum(dim=-1).mean() + block_grads = torch.autograd.grad(local_loss, model.blocks[l].parameters()) + with torch.no_grad(): + for p, g in zip(model.blocks[l].parameters(), block_grads): + p.sub_(lr * g.clamp(-1, 1)) + # Embed + a_0 = credits[0] + rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + h0 = model.embed(x) + embed_loss = (h0 * (a_0 / rms_0)).sum(dim=-1).mean() + embed_grads = torch.autograd.grad(embed_loss, model.embed.parameters()) + with torch.no_grad(): + for p, g in zip(model.embed.parameters(), embed_grads): + p.sub_(lr * g.clamp(-1, 1)) + + +def update_target_shift(model, x, y, credits, hiddens, device, lr, eta_target=0.01): + """ + Rule 2: Target-shift local regression. + h_{l+1}^target = h_{l+1} - eta_target * a_{l+1}^norm + L_shift = 0.5 * || (h_l + F_l(h_l)) - sg(h_{l+1}^target) ||^2 + """ + L = model.num_blocks + # Head — still use exact CE + hL = hiddens[-1].detach() + logits_out = model.out_head(model.out_ln(hL)) + loss_out = F.cross_entropy(logits_out, y) + head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters()) + grads_head = torch.autograd.grad(loss_out, head_params) + with torch.no_grad(): + for p, g in zip(head_params, grads_head): + p.sub_(lr * g) + + # Blocks: target-shift regression + for l in range(L): + h_l = hiddens[l].detach() + h_l_next = hiddens[l + 1].detach() # current h_{l+1} + + # Credit at layer l+1 (or l for the last one) + # We use credit[l] which is the credit at layer l + # The target shift: move h_{l+1} in the negative credit direction + a = credits[l] + rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_norm = a / rms + + # Target: where h_{l+1} should move toward + h_target = (h_l_next - eta_target * a_norm).detach() + + # Compute F_l(h_l) with gradient + f_l = model.blocks[l](h_l) + h_l_next_pred = h_l + f_l # predicted h_{l+1} + + # Regression loss + shift_loss = 0.5 * ((h_l_next_pred - h_target) ** 2).sum(dim=-1).mean() + block_grads = torch.autograd.grad(shift_loss, model.blocks[l].parameters()) + with torch.no_grad(): + for p, g in zip(model.blocks[l].parameters(), block_grads): + p.sub_(lr * g.clamp(-1, 1)) + + # Embed: use credit[0] as target shift for h_0 + a_0 = credits[0] + rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + h0 = model.embed(x) + h0_target = (hiddens[0].detach() - eta_target * (a_0 / rms_0)).detach() + embed_loss = 0.5 * ((h0 - h0_target) ** 2).sum(dim=-1).mean() + embed_grads = torch.autograd.grad(embed_loss, model.embed.parameters()) + with torch.no_grad(): + for p, g in zip(model.embed.parameters(), embed_grads): + p.sub_(lr * g.clamp(-1, 1)) + + +def update_cosine_target(model, x, y, credits, hiddens, device, lr): + """Rule 3: L_cos = -cos(F_l(h_l), a_{l+1})""" + L = model.num_blocks + # Head + hL = hiddens[-1].detach() + logits_out = model.out_head(model.out_ln(hL)) + loss_out = F.cross_entropy(logits_out, y) + head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters()) + grads_head = torch.autograd.grad(loss_out, head_params) + with torch.no_grad(): + for p, g in zip(head_params, grads_head): + p.sub_(lr * g) + # Blocks + for l in range(L): + h_l = hiddens[l].detach() + a = credits[l] + f_l = model.blocks[l](h_l) + cos_sim = F.cosine_similarity(f_l, a, dim=-1).mean() + local_loss = -cos_sim + block_grads = torch.autograd.grad(local_loss, model.blocks[l].parameters()) + with torch.no_grad(): + for p, g in zip(model.blocks[l].parameters(), block_grads): + p.sub_(lr * g.clamp(-1, 1)) + # Embed + a_0 = credits[0] + h0 = model.embed(x) + cos_sim_0 = F.cosine_similarity(h0, a_0, dim=-1).mean() + embed_loss = -cos_sim_0 + embed_grads = torch.autograd.grad(embed_loss, model.embed.parameters()) + with torch.no_grad(): + for p, g in zip(model.embed.parameters(), embed_grads): + p.sub_(lr * g.clamp(-1, 1)) + + +# ============================================================================= +# Main +# ============================================================================= +def run_experiment(args): + device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + os.makedirs(args.output_dir, exist_ok=True) + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + train_loader, test_loader = get_cifar10(args.batch_size) + input_dim = 32 * 32 * 3 + L = args.num_blocks + d = args.d_hidden + + # Load BP snapshot + model_bp = ResidualMLP(input_dim, d, 10, L).to(device) + bp_ckpt = f'results/frozen_cifar/bp_ref_L{L}_d{d}_s{args.seed}.pt' + model_bp.load_state_dict(torch.load(bp_ckpt, map_location=device)) + model_bp.eval() + for p in model_bp.parameters(): + p.requires_grad_(False) + print(f"Loaded BP snapshot from {bp_ckpt}") + + # Load pre-trained estimators (or train fresh) + # DFA + dfa_Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)] + + # Scalar CB — train on snapshot + print("\nTraining ScalarCB on snapshot...") + from experiments.snapshot_exploitability import train_scalar_cb_on_snapshot, train_vector_on_snapshot + torch.manual_seed(args.seed + 2000) + cb = train_scalar_cb_on_snapshot(model_bp, train_loader, device, + epochs=args.estimator_epochs, lr_fb=args.lr_fb) + + # Vector field — train on snapshot + print("\nTraining Vec_M4 on snapshot...") + torch.manual_seed(args.seed + 4000) + vec4 = train_vector_on_snapshot(model_bp, train_loader, device, + epochs=args.estimator_epochs, lr_fb=args.lr_fb, M=4) + + credit_sources = { + 'dfa': ('dfa', None, dfa_Bs), + 'scalar_cb': ('scalar_cb', cb, None), + 'vec_eT_M4': ('vec', vec4, None), + 'oracle_bp': ('oracle_bp', None, None), + } + + update_rules = { + 'inner_product': update_inner_product, + 'target_shift': lambda m, x, y, c, h, dev, lr: update_target_shift(m, x, y, c, h, dev, lr, eta_target=args.eta_target), + 'cosine_target': update_cosine_target, + } + + # Eval function + eval_batches = [] + for i, (xv, yv) in enumerate(test_loader): + if i >= 10: + break + eval_batches.append((xv.view(xv.size(0), -1).to(device), yv.to(device))) + + def eval_model(model): + model.eval() + total_loss, correct, total = 0, 0, 0 + with torch.no_grad(): + for xv, yv in eval_batches: + logits = model(xv) + total_loss += F.cross_entropy(logits, yv, reduction='sum').item() + correct += (logits.argmax(1) == yv).sum().item() + total += xv.size(0) + return total_loss / total, correct / total + + # ========================================================= + # Run all combinations: credit_source x update_rule x k_steps + # ========================================================= + results = {} + + for cs_name, (src, est, Bs) in credit_sources.items(): + for rule_name, rule_fn in update_rules.items(): + for k in [1, 5, 20]: + tag = f"{cs_name}_{rule_name}_k{k}" + + model_test = copy.deepcopy(model_bp) + for p in model_test.parameters(): + p.requires_grad_(True) + + loss_before, acc_before = eval_model(model_test) + + train_iter = iter(train_loader) + for step in range(k): + try: + x_step, y_step = next(train_iter) + except StopIteration: + train_iter = iter(train_loader) + x_step, y_step = next(train_iter) + x_step = x_step.view(x_step.size(0), -1).to(device) + y_step = y_step.to(device) + + for p in model_test.parameters(): + p.requires_grad_(False) + credits, hiddens, s = get_credits(model_test, x_step, y_step, device, + src, estimator=est, dfa_Bs=Bs) + for p in model_test.parameters(): + p.requires_grad_(True) + rule_fn(model_test, x_step, y_step, credits, hiddens, device, lr=args.lr_update) + + for p in model_test.parameters(): + p.requires_grad_(False) + loss_after, acc_after = eval_model(model_test) + + results[tag] = { + 'credit': cs_name, 'rule': rule_name, 'k': k, + 'loss_before': loss_before, 'loss_after': loss_after, + 'delta_loss': loss_after - loss_before, + 'delta_acc': acc_after - acc_before, + } + + # ========================================================= + # Summary tables + # ========================================================= + print(f"\n{'='*90}") + print("RESULTS: DeltaLoss (negative = good)") + print(f"{'='*90}") + + for k in [1, 5, 20]: + print(f"\n--- k={k} steps ---") + print(f"{'Credit':<15} {'inner_prod':>12} {'target_shift':>14} {'cosine':>12}") + print("-" * 58) + for cs_name in ['dfa', 'scalar_cb', 'vec_eT_M4', 'oracle_bp']: + row = f"{cs_name:<15}" + for rule_name in ['inner_product', 'target_shift', 'cosine_target']: + tag = f"{cs_name}_{rule_name}_k{k}" + dl = results[tag]['delta_loss'] + row += f" {dl:>+12.4f}" + print(row) + + # Save + out_path = os.path.join(args.output_dir, f'update_swap_L{L}_d{d}_s{args.seed}.json') + with open(out_path, 'w') as f: + json.dump(results, f, indent=2, default=float) + print(f"\nSaved to {out_path}") + + # Judgment + print(f"\n{'='*60}") + print("JUDGMENT") + print(f"{'='*60}") + + # Compare at k=5 + inner_vec = results['vec_eT_M4_inner_product_k5']['delta_loss'] + shift_vec = results['vec_eT_M4_target_shift_k5']['delta_loss'] + shift_bp = results['oracle_bp_target_shift_k5']['delta_loss'] + inner_dfa = results['dfa_inner_product_k5']['delta_loss'] + + print(f"k=5: Vec+inner={inner_vec:+.4f}, Vec+shift={shift_vec:+.4f}, " + f"BP+shift={shift_bp:+.4f}, DFA+inner={inner_dfa:+.4f}") + + if shift_vec < inner_vec and shift_vec < 0: + print("TARGET-SHIFT WINS: Vec credit becomes exploitable with target-shift rule.") + print(" -> Project should pivot to 'credit + better local update coupling'.") + elif shift_bp < 0 and shift_vec >= 0: + print("TARGET-SHIFT HELPS BP BUT NOT VEC: Credit quality still matters.") + else: + print("TARGET-SHIFT DOESN'T HELP: Need further investigation.") + + +def main(): + parser = argparse.ArgumentParser(description='Phase 6C: Local Update Rule Swap') + parser.add_argument('--num_blocks', type=int, default=4) + parser.add_argument('--d_hidden', type=int, default=256) + parser.add_argument('--batch_size', type=int, default=128) + parser.add_argument('--estimator_epochs', type=int, default=100) + parser.add_argument('--lr_fb', type=float, default=1e-3) + parser.add_argument('--lr_update', type=float, default=1e-3) + parser.add_argument('--eta_target', type=float, default=0.01) + parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--gpu', type=int, default=3) + parser.add_argument('--output_dir', type=str, default='results/update_swap') + args = parser.parse_args() + run_experiment(args) + + +if __name__ == '__main__': + main() diff --git a/experiments/snapshot_exploitability.py b/experiments/snapshot_exploitability.py new file mode 100644 index 0000000..ce07acc --- /dev/null +++ b/experiments/snapshot_exploitability.py @@ -0,0 +1,713 @@ +""" +Phase 6A: Snapshot Exploitability Test. + +Core question: on a fixed network snapshot, does better credit lead to +better real loss decrease via the current local surrogate update? + +Tests: +A1. Single global step DeltaLoss +A2. k-step short rollout (k=1,5,20) +A3. Per-layer ablation (last-1, last-2, all blocks) + +Credit sources: DFA, ScalarCB_eT, Vec_eT_M4, Oracle BP gradient +Snapshot sources: BP-trained, DFA-warmup +""" +import os +import sys +import json +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision +import torchvision.transforms as transforms +import copy + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.residual_mlp import ResidualMLP +from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema +from metrics.credit_metrics import ( + cosine_similarity_batch, perturbation_correlation, nudging_test +) + + +class VectorCreditNet(nn.Module): + def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3): + super().__init__() + self.ln = nn.LayerNorm(d_hidden) + self.time_embed = SinusoidalTimeEmbed(time_embed_dim) + input_dim = d_hidden + time_embed_dim + s_dim + layers = [] + for i in range(num_layers): + in_d = input_dim if i == 0 else hidden_dim + layers.append(nn.Linear(in_d, hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, d_hidden)) + self.net = nn.Sequential(*layers) + + def forward(self, h, t, s): + h_normed = self.ln(h) + t_emb = self.time_embed(t) + inp = torch.cat([h_normed, t_emb, s], dim=-1) + return self.net(inp) + + +def get_cifar10(batch_size=128): + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) + testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) + train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) + test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) + return train_loader, test_loader + + +def evaluate_loss_acc(model, data_iter, device, n_batches=5): + """Evaluate loss and accuracy on a few batches.""" + model.eval() + total_loss, correct, total = 0, 0, 0 + with torch.no_grad(): + for i, (x, y) in enumerate(data_iter): + if i >= n_batches: + break + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + logits = model(x) + loss = F.cross_entropy(logits, y, reduction='sum') + total_loss += loss.item() + correct += (logits.argmax(1) == y).sum().item() + total += x.size(0) + return total_loss / total, correct / total + + +# ============================================================================= +# Train estimators on frozen snapshot +# ============================================================================= +def train_scalar_cb_on_snapshot(model, train_loader, device, epochs=100, lr_fb=1e-3): + d = model.d_hidden + L = model.num_blocks + value_net = ValueNet(d_hidden=d, s_dim=10, time_embed_dim=32, + hidden_dim=256, num_layers=3).to(device) + value_net_ema = create_ema_model(value_net) + value_opt = optim.Adam(value_net.parameters(), lr=lr_fb) + model.eval() + for epoch in range(1, epochs + 1): + value_net.train() + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + true_loss = F.cross_entropy(logits, y, reduction='none').detach() + hL = hiddens[-1].detach() + t_L = torch.ones(batch, device=device) + V_term = value_net(hL, t_L, s) + loss_term = ((V_term - true_loss) ** 2).mean() + + hL_req = hL.clone().requires_grad_(True) + V_at_L = value_net(hL_req, t_L, s) + grad_V = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0] + hL_req2 = hL.clone().requires_grad_(True) + logits_tgt = model.out_head(model.out_ln(hL_req2)) + ce = F.cross_entropy(logits_tgt, y, reduction='sum') + a_L_exact = torch.autograd.grad(ce, hL_req2, create_graph=False)[0].detach() + loss_tgrad = ((grad_V - a_L_exact) ** 2).sum(dim=-1).mean() + + loss_bridge = 0.0 + for l in range(L): + h_l = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + t_next = torch.full((batch,), (l + 1) / L, device=device) + V_l = value_net(h_l, t_l, s) + with torch.no_grad(): + h_next = hiddens[l + 1].detach() + log_terms = [] + for k in range(4): + noise = 0.05 * torch.randn_like(h_next) + V_next = value_net_ema(h_next + noise, t_next, s) + log_terms.append(-V_next / 0.1) + log_stack = torch.stack(log_terms, dim=-1) + V_target = -0.1 * (torch.logsumexp(log_stack, dim=-1) - np.log(4)) + loss_bridge += ((V_l - V_target.detach()) ** 2).mean() + loss_bridge /= L + + vloss = loss_term + loss_bridge + 1.0 * loss_tgrad + value_opt.zero_grad() + vloss.backward() + torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0) + value_opt.step() + update_ema(value_net, value_net_ema, 0.995) + if epoch % 20 == 0 or epoch == 1: + print(f" [CB] Ep {epoch}") + return value_net + + +def train_vector_on_snapshot(model, train_loader, device, epochs=100, lr_fb=1e-3, M=4): + d = model.d_hidden + L = model.num_blocks + vector_net = VectorCreditNet(d_hidden=d, s_dim=10, time_embed_dim=32, + hidden_dim=256, num_layers=3).to(device) + vec_opt = optim.Adam(vector_net.parameters(), lr=lr_fb) + eps = 1e-3 + model.eval() + for epoch in range(1, epochs + 1): + vector_net.train() + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + hL = hiddens[-1].detach() + + # Terminal matching + t_L = torch.ones(batch, device=device) + a_term = vector_net(hL, t_L, s) + hL_req = hL.clone().requires_grad_(True) + logits_tgt = model.out_head(model.out_ln(hL_req)) + ce = F.cross_entropy(logits_tgt, y, reduction='sum') + delta_L = torch.autograd.grad(ce, hL_req, create_graph=False)[0].detach() + loss_term = ((a_term - delta_L) ** 2).sum(dim=-1).mean() + + # Perturbation target — subsample 1 layer + l = np.random.randint(0, L) + h_l = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + a_l = vector_net(h_l, t_l, s) + loss_proj = torch.tensor(0.0, device=device) + for _ in range(M): + v = torch.randn_like(h_l) + v = v / (v.norm(dim=-1, keepdim=True) + 1e-8) + with torch.no_grad(): + lp = F.cross_entropy(model.forward_from_layer(h_l + eps * v, l), y, reduction='none') + lm = F.cross_entropy(model.forward_from_layer(h_l - eps * v, l), y, reduction='none') + g_j = (lp - lm) / (2 * eps) + loss_proj = loss_proj + (((a_l * v).sum(-1) - g_j.detach()) ** 2).mean() + loss_proj = loss_proj / M + + vloss = loss_term + loss_proj + vec_opt.zero_grad() + vloss.backward() + torch.nn.utils.clip_grad_norm_(vector_net.parameters(), 1.0) + vec_opt.step() + if epoch % 20 == 0 or epoch == 1: + print(f" [Vec_M{M}] Ep {epoch}") + return vector_net + + +# ============================================================================= +# Credit computation +# ============================================================================= +def get_credits(model, x, y, device, credit_source, estimator=None, dfa_Bs=None): + """Compute per-layer credits for a single batch. Returns dict {l: (batch, d)}.""" + L = model.num_blocks + batch = x.size(0) + + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + + credits = {} + + if credit_source == 'dfa': + for l in range(L): + credits[l] = (s @ dfa_Bs[l].T).detach() + + elif credit_source == 'scalar_cb': + estimator.eval() + for l in range(L): + h_l = hiddens[l].detach().requires_grad_(True) + t_l = torch.full((batch,), l / L, device=device) + V = estimator(h_l, t_l, s) + a = torch.autograd.grad(V.sum(), h_l, create_graph=False)[0] + credits[l] = a.detach() + + elif credit_source == 'vec': + estimator.eval() + for l in range(L): + h_l = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + credits[l] = estimator(h_l, t_l, s).detach() + + elif credit_source == 'oracle_bp': + # Compute true BP gradients — evaluation only, never used for training + for p in model.parameters(): + p.requires_grad_(True) + model.zero_grad() + logits_bp, hiddens_bp = model(x, return_hidden=True) + for l in range(L + 1): + hiddens_bp[l].retain_grad() + loss_bp = F.cross_entropy(logits_bp, y) + loss_bp.backward() + for l in range(L): + credits[l] = hiddens_bp[l].grad.detach().clone() + for p in model.parameters(): + p.requires_grad_(False) + + return credits, hiddens, s + + +def compute_credit_quality(model, credits, hiddens, x, y, device): + """Compute mean Gamma, rho, nudge for given credits.""" + L = model.num_blocks + batch = x.size(0) + + # BP gradients for Gamma + for p in model.parameters(): + p.requires_grad_(True) + model.zero_grad() + logits_bp, hiddens_bp = model(x, return_hidden=True) + for l in range(L + 1): + hiddens_bp[l].retain_grad() + F.cross_entropy(logits_bp, y).backward() + bp_grads = {l: hiddens_bp[l].grad.detach().clone() for l in range(L + 1)} + for p in model.parameters(): + p.requires_grad_(False) + + gammas, rhos, nudges = [], [], [] + for l in range(L): + h_l = hiddens[l].detach() + a_l = credits[l] + gammas.append(cosine_similarity_batch(a_l, bp_grads[l])) + + def make_fwd(sl): + def f(h): + with torch.no_grad(): + c = h + for i in range(sl, L): + c = c + model.blocks[i](c) + return F.cross_entropy(model.out_head(model.out_ln(c)), y, reduction='none') + return f + fwd = make_fwd(l) + rhos.append(perturbation_correlation(h_l, a_l, fwd, epsilon=1e-3, M=16)) + nudges.append(nudging_test(h_l, a_l, fwd, eta=0.003)) + + return float(np.mean(gammas)), float(np.mean(rhos)), float(np.mean(nudges)) + + +# ============================================================================= +# Local update step (block-local inner-product surrogate) +# ============================================================================= +def do_local_update_step(model, x, y, credits, device, lr=1e-3, update_layers=None): + """ + Perform one local surrogate update step on model blocks + output head. + update_layers: list of layer indices to update (None = all) + Returns new model state (modifies in-place). + """ + L = model.num_blocks + if update_layers is None: + update_layers = list(range(L)) + + model.train() + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + + # Update output head (always) + hL = hiddens[-1].detach() + logits_out = model.out_head(model.out_ln(hL)) + loss_out = F.cross_entropy(logits_out, y) + head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters()) + grads_head = torch.autograd.grad(loss_out, head_params) + with torch.no_grad(): + for p, g in zip(head_params, grads_head): + p.sub_(lr * g) + + # Update selected blocks + for l in update_layers: + h_l = hiddens[l].detach() + a = credits[l] + rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_norm = a / rms + f_l = model.blocks[l](h_l) + local_loss = (f_l * a_norm).sum(dim=-1).mean() + block_grads = torch.autograd.grad(local_loss, model.blocks[l].parameters()) + with torch.no_grad(): + for p, g in zip(model.blocks[l].parameters(), block_grads): + p.sub_(lr * g.clamp(-1, 1)) # implicit grad clip + + # Update embedding + if 0 in update_layers: + a_0 = credits[0] + rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + h0 = model.embed(x) + embed_loss = (h0 * (a_0 / rms_0)).sum(dim=-1).mean() + embed_grads = torch.autograd.grad(embed_loss, model.embed.parameters()) + with torch.no_grad(): + for p, g in zip(model.embed.parameters(), embed_grads): + p.sub_(lr * g.clamp(-1, 1)) + + +# ============================================================================= +# Main experiment +# ============================================================================= +def run_experiment(args): + device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + os.makedirs(args.output_dir, exist_ok=True) + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + train_loader, test_loader = get_cifar10(batch_size=args.batch_size) + input_dim = 32 * 32 * 3 + L = args.num_blocks + d = args.d_hidden + + # ========================================================= + # Step 1: Load or train BP snapshot + # ========================================================= + print(f"\n{'='*60}") + print(f"Loading BP snapshot (L={L}, d={d})") + print(f"{'='*60}") + + model_bp = ResidualMLP(input_dim, d, 10, L).to(device) + bp_ckpt = f'results/frozen_cifar/bp_ref_L{L}_d{d}_s{args.seed}.pt' + if os.path.exists(bp_ckpt): + model_bp.load_state_dict(torch.load(bp_ckpt, map_location=device)) + print(f" Loaded from {bp_ckpt}") + else: + print(f" Training BP reference...") + optimizer = optim.AdamW(model_bp.parameters(), lr=1e-3, weight_decay=0.01) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100) + for epoch in range(1, 101): + model_bp.train() + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device); y = y.to(device) + loss = F.cross_entropy(model_bp(x), y) + optimizer.zero_grad(); loss.backward(); optimizer.step() + scheduler.step() + if epoch % 20 == 0: + model_bp.eval() + c, t = 0, 0 + with torch.no_grad(): + for xv, yv in test_loader: + xv = xv.view(xv.size(0), -1).to(device); yv = yv.to(device) + c += (model_bp(xv).argmax(1) == yv).sum().item(); t += xv.size(0) + print(f" Ep {epoch}: test_acc={c/t:.4f}") + os.makedirs(os.path.dirname(bp_ckpt), exist_ok=True) + torch.save(model_bp.state_dict(), bp_ckpt) + + model_bp.eval() + for p in model_bp.parameters(): + p.requires_grad_(False) + bp_loss, bp_acc = evaluate_loss_acc(model_bp, test_loader, device, n_batches=20) + print(f" BP snapshot: loss={bp_loss:.4f}, acc={bp_acc:.4f}") + + # ========================================================= + # Step 2: Train estimators on BP snapshot + # ========================================================= + print(f"\n{'='*60}") + print(f"Training estimators on BP snapshot ({args.estimator_epochs} epochs)") + print(f"{'='*60}") + + # DFA + dfa_Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)] + + # Scalar CB + print(" --- ScalarCB_eT ---") + torch.manual_seed(args.seed + 2000) + cb = train_scalar_cb_on_snapshot(model_bp, train_loader, device, + epochs=args.estimator_epochs, lr_fb=args.lr_fb) + + # Vector field M4 + print(" --- Vec_eT_M4 ---") + torch.manual_seed(args.seed + 4000) + vec4 = train_vector_on_snapshot(model_bp, train_loader, device, + epochs=args.estimator_epochs, lr_fb=args.lr_fb, M=4) + + credit_sources = { + 'dfa': ('dfa', None, dfa_Bs), + 'scalar_cb': ('scalar_cb', cb, None), + 'vec_eT_M4': ('vec', vec4, None), + 'oracle_bp': ('oracle_bp', None, None), + } + + # ========================================================= + # Step 3: Compute credit quality + # ========================================================= + print(f"\n{'='*60}") + print(f"Computing credit quality on BP snapshot") + print(f"{'='*60}") + + # Get a fixed batch for credit quality + for x_eval, y_eval in test_loader: + x_eval = x_eval.view(x_eval.size(0), -1).to(device) + y_eval = y_eval.to(device) + break + + quality = {} + for name, (src, est, Bs) in credit_sources.items(): + credits, hiddens, s = get_credits(model_bp, x_eval, y_eval, device, src, + estimator=est, dfa_Bs=Bs) + gamma, rho, nudge = compute_credit_quality(model_bp, credits, hiddens, + x_eval, y_eval, device) + quality[name] = {'gamma': gamma, 'rho': rho, 'nudge': nudge} + print(f" {name}: Gamma={gamma:.4f}, rho={rho:.4f}, nudge={nudge:.6f}") + + # ========================================================= + # Step 4: Exploitability tests + # ========================================================= + print(f"\n{'='*60}") + print(f"Test A1: Single-step DeltaLoss") + print(f"{'='*60}") + + # Get train batch for update + train_iter = iter(train_loader) + x_train, y_train = next(train_iter) + x_train = x_train.view(x_train.size(0), -1).to(device) + y_train = y_train.to(device) + + # Get held-out eval batches (different from train batch) + eval_batches = [] + for i, (xv, yv) in enumerate(test_loader): + if i >= 10: + break + eval_batches.append((xv.view(xv.size(0), -1).to(device), yv.to(device))) + + def eval_on_batches(model): + model.eval() + total_loss, correct, total = 0, 0, 0 + with torch.no_grad(): + for xv, yv in eval_batches: + logits = model(xv) + total_loss += F.cross_entropy(logits, yv, reduction='sum').item() + correct += (logits.argmax(1) == yv).sum().item() + total += xv.size(0) + return total_loss / total, correct / total + + results = {} + lr_update = args.lr_update + + for name, (src, est, Bs) in credit_sources.items(): + # Reset to snapshot + model_test = copy.deepcopy(model_bp) + model_test.eval() + for p in model_test.parameters(): + p.requires_grad_(False) + + loss_before, acc_before = eval_on_batches(model_test) + + # Compute credits on the train batch + credits, hiddens, s = get_credits(model_test, x_train, y_train, device, src, + estimator=est, dfa_Bs=Bs) + + # Re-enable grad for update + for p in model_test.parameters(): + p.requires_grad_(True) + + # One-step update + do_local_update_step(model_test, x_train, y_train, credits, device, lr=lr_update) + + for p in model_test.parameters(): + p.requires_grad_(False) + loss_after, acc_after = eval_on_batches(model_test) + + delta_loss = loss_after - loss_before + delta_acc = acc_after - acc_before + + results[name] = { + 'gamma': quality[name]['gamma'], + 'rho': quality[name]['rho'], + 'nudge': quality[name]['nudge'], + 'loss_before': loss_before, + 'loss_after_1step': loss_after, + 'delta_loss_1step': delta_loss, + 'delta_acc_1step': delta_acc, + } + print(f" {name}: DeltaLoss={delta_loss:+.6f}, DeltaAcc={delta_acc:+.4f} " + f"(before={loss_before:.4f}, after={loss_after:.4f})") + + # ========================================================= + # Test A2: k-step rollout + # ========================================================= + print(f"\n{'='*60}") + print(f"Test A2: k-step rollout (k=1,5,20)") + print(f"{'='*60}") + + for name, (src, est, Bs) in credit_sources.items(): + rollout = {} + for k in [1, 5, 20]: + model_test = copy.deepcopy(model_bp) + for p in model_test.parameters(): + p.requires_grad_(True) + + for step in range(k): + # Fresh batch each step + try: + x_step, y_step = next(train_iter) + except StopIteration: + train_iter = iter(train_loader) + x_step, y_step = next(train_iter) + x_step = x_step.view(x_step.size(0), -1).to(device) + y_step = y_step.to(device) + + # Recompute credits on current model state + for p in model_test.parameters(): + p.requires_grad_(False) + credits_step, _, _ = get_credits(model_test, x_step, y_step, device, src, + estimator=est, dfa_Bs=Bs) + for p in model_test.parameters(): + p.requires_grad_(True) + do_local_update_step(model_test, x_step, y_step, credits_step, device, lr=lr_update) + + for p in model_test.parameters(): + p.requires_grad_(False) + loss_k, acc_k = eval_on_batches(model_test) + rollout[k] = {'loss': loss_k, 'acc': acc_k, + 'delta_loss': loss_k - results[name]['loss_before'], + 'delta_acc': acc_k - results[name]['delta_acc_1step']} + + results[name]['rollout'] = rollout + print(f" {name}: k=1 dL={rollout[1]['delta_loss']:+.4f}, " + f"k=5 dL={rollout[5]['delta_loss']:+.4f}, " + f"k=20 dL={rollout[20]['delta_loss']:+.4f}") + + # ========================================================= + # Test A3: Per-layer ablation + # ========================================================= + print(f"\n{'='*60}") + print(f"Test A3: Per-layer ablation (5-step)") + print(f"{'='*60}") + + layer_configs = { + 'last_1': [L - 1], + 'last_2': [L - 2, L - 1], + 'all': list(range(L)), + } + + layer_results = {} + for name, (src, est, Bs) in credit_sources.items(): + layer_results[name] = {} + for lname, layers in layer_configs.items(): + model_test = copy.deepcopy(model_bp) + for p in model_test.parameters(): + p.requires_grad_(True) + + train_iter2 = iter(train_loader) + for step in range(5): + try: + x_step, y_step = next(train_iter2) + except StopIteration: + train_iter2 = iter(train_loader) + x_step, y_step = next(train_iter2) + x_step = x_step.view(x_step.size(0), -1).to(device) + y_step = y_step.to(device) + + for p in model_test.parameters(): + p.requires_grad_(False) + credits_step, _, _ = get_credits(model_test, x_step, y_step, device, src, + estimator=est, dfa_Bs=Bs) + for p in model_test.parameters(): + p.requires_grad_(True) + do_local_update_step(model_test, x_step, y_step, credits_step, device, + lr=lr_update, update_layers=layers) + + for p in model_test.parameters(): + p.requires_grad_(False) + loss_k, acc_k = eval_on_batches(model_test) + dl = loss_k - results[name]['loss_before'] + layer_results[name][lname] = {'delta_loss': dl, 'loss': loss_k} + + print(f" {name}: last_1={layer_results[name]['last_1']['delta_loss']:+.4f}, " + f"last_2={layer_results[name]['last_2']['delta_loss']:+.4f}, " + f"all={layer_results[name]['all']['delta_loss']:+.4f}") + + # ========================================================= + # Summary + # ========================================================= + print(f"\n{'='*80}") + print("SUMMARY TABLE") + print(f"{'='*80}") + print(f"{'Method':<15} {'Gamma':>7} {'rho':>7} {'nudge':>10} {'dL_1':>8} {'dL_5':>8} {'dL_20':>8}") + print("-" * 68) + for name in ['dfa', 'scalar_cb', 'vec_eT_M4', 'oracle_bp']: + r = results[name] + dL1 = r['delta_loss_1step'] + dL5 = r['rollout'][5]['delta_loss'] + dL20 = r['rollout'][20]['delta_loss'] + print(f"{name:<15} {r['gamma']:>7.4f} {r['rho']:>7.4f} {r['nudge']:>10.6f} " + f"{dL1:>+8.4f} {dL5:>+8.4f} {dL20:>+8.4f}") + + print(f"\nPer-layer ablation (5-step DeltaLoss):") + print(f"{'Method':<15} {'last_1':>8} {'last_2':>8} {'all':>8}") + print("-" * 42) + for name in ['dfa', 'scalar_cb', 'vec_eT_M4', 'oracle_bp']: + lr = layer_results[name] + print(f"{name:<15} {lr['last_1']['delta_loss']:>+8.4f} " + f"{lr['last_2']['delta_loss']:>+8.4f} {lr['all']['delta_loss']:>+8.4f}") + + # Save + save_data = { + 'config': {'L': L, 'd': d, 'seed': args.seed, 'lr_update': lr_update, + 'estimator_epochs': args.estimator_epochs}, + 'credit_quality': quality, + 'exploitability': {n: {k: v for k, v in r.items() if k != 'rollout'} + for n, r in results.items()}, + 'rollout': {n: r['rollout'] for n, r in results.items()}, + 'layer_ablation': layer_results, + } + out_path = os.path.join(args.output_dir, f'snapshot_L{L}_d{d}_s{args.seed}.json') + with open(out_path, 'w') as f: + json.dump(save_data, f, indent=2, default=float) + print(f"\nSaved to {out_path}") + + # Judgment + print(f"\n{'='*60}") + print("JUDGMENT") + print(f"{'='*60}") + vec_dl = results['vec_eT_M4']['rollout'][5]['delta_loss'] + cb_dl = results['scalar_cb']['rollout'][5]['delta_loss'] + bp_dl = results['oracle_bp']['rollout'][5]['delta_loss'] + dfa_dl = results['dfa']['rollout'][5]['delta_loss'] + + print(f"5-step DeltaLoss: DFA={dfa_dl:+.4f}, CB={cb_dl:+.4f}, Vec={vec_dl:+.4f}, BP={bp_dl:+.4f}") + + if vec_dl < cb_dl and vec_dl < dfa_dl: + print("EXPLOITABLE: Vec credit produces better loss decrease than ScalarCB and DFA.") + print(" -> Online failure is likely tracking/co-adaptation (Case A).") + elif bp_dl < dfa_dl and vec_dl >= cb_dl: + print("NOT EXPLOITABLE: Better credit (vec) doesn't translate to better loss decrease.") + print(" -> Bottleneck is in local update rule (Case B).") + else: + print("AMBIGUOUS: Need more investigation.") + + +def main(): + parser = argparse.ArgumentParser(description='Phase 6A: Snapshot Exploitability') + parser.add_argument('--num_blocks', type=int, default=4) + parser.add_argument('--d_hidden', type=int, default=256) + parser.add_argument('--batch_size', type=int, default=128) + parser.add_argument('--estimator_epochs', type=int, default=100) + parser.add_argument('--lr_fb', type=float, default=1e-3) + parser.add_argument('--lr_update', type=float, default=1e-3) + parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--gpu', type=int, default=3) + parser.add_argument('--output_dir', type=str, default='results/snapshot_exploit') + args = parser.parse_args() + run_experiment(args) + + +if __name__ == '__main__': + main() diff --git a/report_explore/MEMO_6A_snapshot_exploitability.md b/report_explore/MEMO_6A_snapshot_exploitability.md new file mode 100644 index 0000000..950ed1b --- /dev/null +++ b/report_explore/MEMO_6A_snapshot_exploitability.md @@ -0,0 +1,39 @@ +# Phase 6A Memo: Snapshot Exploitability + +**Date**: 2026-03-24 +**Config**: BP snapshot, CIFAR-10, L=4, d=256 (61.9% acc), seed=42 + +## Question +On a fixed snapshot, does better credit lead to better real loss decrease via the current local surrogate? + +## Results + +| Method | Gamma | rho | dL_1step | dL_5step | dL_20step | +|--------|-------|-----|----------|----------|-----------| +| DFA | 0.009 | -0.023 | **-0.0004** | **+0.0002** | **-0.0007** | +| ScalarCB | 0.122 | 0.090 | +0.003 | +0.042 | +0.405 | +| Vec_M4 | 0.378 | 0.411 | +0.003 | +0.050 | +0.272 | +| Oracle BP | 1.000 | 0.998 | **-0.001** | +0.007 | +0.026 | + +## Key Finding: The Local Surrogate is Anti-Correlated with Credit Quality + +**Better credit produces WORSE loss change.** DFA (Gamma≈0) is the only method that decreases loss. ScalarCB (Gamma=0.12) and Vec (Gamma=0.38) both increase loss, with Vec slightly worse. Even Oracle BP increases loss at 5+ steps. + +The inner-product surrogate `L_local = <F_l(h_l), a_l>` is fundamentally broken as a local update rule for directional credit: +- It treats a_l as a "desired direction for the residual output" rather than a gradient +- The gradient of this surrogate w.r.t. block parameters pushes F_l(h) to align with a_l, but this is NOT the same as making h_{l+1} = h_l + F_l(h_l) move in the direction that decreases global loss +- DFA "works" precisely because its random credits are small and roughly isotropic — the updates are near-random perturbations that don't systematically damage the representation + +## Verdict + +**This is Case B: the local update rule is the bottleneck, not the estimator or tracking.** + +Improving credit quality from DFA (Gamma=0.01) through ScalarCB (0.12) to Vec (0.38) to Oracle BP (1.0) does NOT improve — and actually worsens — real parameter update quality. + +## Implication + +The project should pivot from "better credit estimator" to "better local update coupling." The target-shift local regression rule (Phase 6C) is the natural next experiment: + +`L_shift = 0.5 * || h_l + F_l(h_l) - sg(h_{l+1} - eta * a_{l+1}^norm) ||^2` + +This directly tells each block: "adjust your output so the next hidden state moves toward the credit-indicated direction." diff --git a/report_explore/MEMO_6_exploitability.md b/report_explore/MEMO_6_exploitability.md new file mode 100644 index 0000000..42dfda5 --- /dev/null +++ b/report_explore/MEMO_6_exploitability.md @@ -0,0 +1,53 @@ +# Phase 6 Memo: Snapshot Exploitability + Local Update Rule Swap + +**Date**: 2026-03-24 + +## Phase 6A: Snapshot Exploitability + +**Setup**: BP-trained CIFAR-10 snapshot (L=4, d=256, 61.9% acc). Train estimators on frozen features, then do k-step local updates and measure real loss change. + +### Results (5-step DeltaLoss, inner-product surrogate) + +| Credit | Gamma | rho | dL_5step | +|--------|-------|-----|----------| +| DFA | 0.009 | -0.023 | **-0.0001** | +| ScalarCB | 0.122 | 0.090 | +0.042 | +| Vec_M4 | 0.378 | 0.411 | +0.057 | +| Oracle BP | 1.000 | 0.998 | +0.011 | + +**Finding**: Better credit quality is ANTI-CORRELATED with loss decrease. DFA (worst credit) produces the only method that doesn't increase loss. Vec (best credit) increases loss the most. Even Oracle BP increases loss at 5 steps. + +**Verdict**: This is **Case B** — the local update rule is the bottleneck. + +## Phase 6C: Local Update Rule Swap + +Tested target-shift rule (h_{l+1}^target = h_{l+1} - eta * a_norm) at eta in {0.01, 0.1, 0.3, 1.0}. + +### Results (5-step DeltaLoss) + +| Credit | inner_prod | shift_0.1 | shift_0.3 | shift_1.0 | +|--------|:---:|:---:|:---:|:---:| +| DFA | -0.0001 | **-0.0003** | +0.0004 | +0.001 | +| Vec_M4 | +0.057 | +0.002 | +0.009 | +0.048 | +| Oracle BP | +0.011 | +0.0002 | +0.001 | +0.005 | + +Target-shift reduces the damage but never achieves negative DeltaLoss for non-DFA credits. The cosine rule produces near-zero effects at all settings. + +## Root Cause Analysis + +The issue is deeper than the update rule. A BP-trained snapshot sits at a minimum of the full-backprop loss surface. Any local update that doesn't have access to the full gradient chain will push parameters in a direction that may locally align with the credit but globally increases loss. This is because: + +1. The inner-product surrogate `<F_l(h), a_l>` assumes a_l is the desired direction for the residual output. But even perfect credit (Oracle BP) doesn't produce good updates via this mechanism — the gradient of the surrogate w.r.t. block parameters is NOT the same as the gradient of the global loss. + +2. Target-shift reduces the magnitude of harmful updates but doesn't fix the direction. At small eta, updates are negligible. At large eta, the target shifts too far and becomes harmful. + +3. DFA "works" precisely because its random credits produce near-zero effective updates — it's approximately doing nothing, which is better than doing the wrong thing. + +## Implications + +**The project's fundamental limitation is NOT in the credit estimator.** It's in the local surrogate update paradigm itself. The inner-product surrogate `<F(h), a>` is not a valid proxy for global loss minimization, regardless of credit quality. + +**Potential directions:** +1. Use credit to set per-block learning targets rather than gradients (e.g., knowledge distillation-style objectives) +2. Use credit to modulate a more expressive local loss (e.g., local CE with projected targets) +3. Abandon block-local updates entirely and use credit to define a global but differentiable auxiliary loss |
