diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-03-24 20:07:03 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-03-24 20:07:03 -0500 |
| commit | 825d973428450cb24d8cccc8c2604235ef974b7c (patch) | |
| tree | 865bf6f7cc5eabbdbbccfb5c14c927584dd1a4f8 /experiments/local_update_swap.py | |
| parent | 5550e2cac45758e579810ae36bf716a0b819cebc (diff) | |
Add Phase 6: snapshot exploitability reveals local update rule is the bottleneck
Phase 6A: Better credit is ANTI-CORRELATED with loss decrease on fixed snapshot.
DFA (Gamma=0.01) → dL=-0.0001 (only method that decreases loss)
Vec_M4 (Gamma=0.38) → dL=+0.057 (increases loss most)
Oracle BP (Gamma=1.0) → dL=+0.011 (still increases loss)
Phase 6C: Target-shift rule reduces damage but cannot make non-DFA credits productive.
The inner-product surrogate <F_l(h), a_l> is fundamentally mismatched with directional credit.
Conclusion: Case B — the primary bottleneck is the local update paradigm itself,
not the credit estimator quality or tracking/co-adaptation.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments/local_update_swap.py')
| -rw-r--r-- | experiments/local_update_swap.py | 427 |
1 files changed, 427 insertions, 0 deletions
diff --git a/experiments/local_update_swap.py b/experiments/local_update_swap.py new file mode 100644 index 0000000..207560a --- /dev/null +++ b/experiments/local_update_swap.py @@ -0,0 +1,427 @@ +""" +Phase 6C: Local Update Rule Swap. + +Compare different local update rules using the same credit signals on a fixed snapshot. + +Rule 1 (baseline): Inner-product surrogate + L_inner = <F_l(h_l), a_{l+1}> + +Rule 2: Target-shift local regression + h_{l+1}^target = h_{l+1} - eta_target * a_{l+1}^norm + L_shift = 0.5 * || h_l + F_l(h_l) - sg(h_{l+1}^target) ||^2 + +Rule 3: Cosine-target update + L_cos = - cos(F_l(h_l), a_{l+1}) +""" +import os +import sys +import json +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision +import torchvision.transforms as transforms +import copy + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.residual_mlp import ResidualMLP +from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema +from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation, nudging_test + + +class VectorCreditNet(nn.Module): + def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3): + super().__init__() + self.ln = nn.LayerNorm(d_hidden) + self.time_embed = SinusoidalTimeEmbed(time_embed_dim) + input_dim = d_hidden + time_embed_dim + s_dim + layers = [] + for i in range(num_layers): + in_d = input_dim if i == 0 else hidden_dim + layers.append(nn.Linear(in_d, hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, d_hidden)) + self.net = nn.Sequential(*layers) + + def forward(self, h, t, s): + h_normed = self.ln(h) + t_emb = self.time_embed(t) + inp = torch.cat([h_normed, t_emb, s], dim=-1) + return self.net(inp) + + +def get_cifar10(batch_size=128): + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) + testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) + train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) + test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) + return train_loader, test_loader + + +def get_credits(model, x, y, device, credit_source, estimator=None, dfa_Bs=None): + L = model.num_blocks + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + credits = {} + if credit_source == 'dfa': + for l in range(L): + credits[l] = (s @ dfa_Bs[l].T).detach() + elif credit_source == 'scalar_cb': + estimator.eval() + for l in range(L): + h_l = hiddens[l].detach().requires_grad_(True) + t_l = torch.full((batch,), l / L, device=device) + V = estimator(h_l, t_l, s) + credits[l] = torch.autograd.grad(V.sum(), h_l, create_graph=False)[0].detach() + elif credit_source == 'vec': + estimator.eval() + for l in range(L): + h_l = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + credits[l] = estimator(h_l, t_l, s).detach() + elif credit_source == 'oracle_bp': + for p in model.parameters(): + p.requires_grad_(True) + model.zero_grad() + logits_bp, hiddens_bp = model(x, return_hidden=True) + for l in range(L + 1): + hiddens_bp[l].retain_grad() + F.cross_entropy(logits_bp, y).backward() + for l in range(L): + credits[l] = hiddens_bp[l].grad.detach().clone() + for p in model.parameters(): + p.requires_grad_(False) + return credits, hiddens, s + + +# ============================================================================= +# Local update rules +# ============================================================================= +def update_inner_product(model, x, y, credits, hiddens, device, lr): + """Rule 1: L_inner = <F_l(h_l), a_{l+1}>""" + L = model.num_blocks + # Head + hL = hiddens[-1].detach() + logits_out = model.out_head(model.out_ln(hL)) + loss_out = F.cross_entropy(logits_out, y) + head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters()) + grads_head = torch.autograd.grad(loss_out, head_params) + with torch.no_grad(): + for p, g in zip(head_params, grads_head): + p.sub_(lr * g) + # Blocks + for l in range(L): + h_l = hiddens[l].detach() + a = credits[l] + rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_norm = a / rms + f_l = model.blocks[l](h_l) + local_loss = (f_l * a_norm).sum(dim=-1).mean() + block_grads = torch.autograd.grad(local_loss, model.blocks[l].parameters()) + with torch.no_grad(): + for p, g in zip(model.blocks[l].parameters(), block_grads): + p.sub_(lr * g.clamp(-1, 1)) + # Embed + a_0 = credits[0] + rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + h0 = model.embed(x) + embed_loss = (h0 * (a_0 / rms_0)).sum(dim=-1).mean() + embed_grads = torch.autograd.grad(embed_loss, model.embed.parameters()) + with torch.no_grad(): + for p, g in zip(model.embed.parameters(), embed_grads): + p.sub_(lr * g.clamp(-1, 1)) + + +def update_target_shift(model, x, y, credits, hiddens, device, lr, eta_target=0.01): + """ + Rule 2: Target-shift local regression. + h_{l+1}^target = h_{l+1} - eta_target * a_{l+1}^norm + L_shift = 0.5 * || (h_l + F_l(h_l)) - sg(h_{l+1}^target) ||^2 + """ + L = model.num_blocks + # Head — still use exact CE + hL = hiddens[-1].detach() + logits_out = model.out_head(model.out_ln(hL)) + loss_out = F.cross_entropy(logits_out, y) + head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters()) + grads_head = torch.autograd.grad(loss_out, head_params) + with torch.no_grad(): + for p, g in zip(head_params, grads_head): + p.sub_(lr * g) + + # Blocks: target-shift regression + for l in range(L): + h_l = hiddens[l].detach() + h_l_next = hiddens[l + 1].detach() # current h_{l+1} + + # Credit at layer l+1 (or l for the last one) + # We use credit[l] which is the credit at layer l + # The target shift: move h_{l+1} in the negative credit direction + a = credits[l] + rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_norm = a / rms + + # Target: where h_{l+1} should move toward + h_target = (h_l_next - eta_target * a_norm).detach() + + # Compute F_l(h_l) with gradient + f_l = model.blocks[l](h_l) + h_l_next_pred = h_l + f_l # predicted h_{l+1} + + # Regression loss + shift_loss = 0.5 * ((h_l_next_pred - h_target) ** 2).sum(dim=-1).mean() + block_grads = torch.autograd.grad(shift_loss, model.blocks[l].parameters()) + with torch.no_grad(): + for p, g in zip(model.blocks[l].parameters(), block_grads): + p.sub_(lr * g.clamp(-1, 1)) + + # Embed: use credit[0] as target shift for h_0 + a_0 = credits[0] + rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + h0 = model.embed(x) + h0_target = (hiddens[0].detach() - eta_target * (a_0 / rms_0)).detach() + embed_loss = 0.5 * ((h0 - h0_target) ** 2).sum(dim=-1).mean() + embed_grads = torch.autograd.grad(embed_loss, model.embed.parameters()) + with torch.no_grad(): + for p, g in zip(model.embed.parameters(), embed_grads): + p.sub_(lr * g.clamp(-1, 1)) + + +def update_cosine_target(model, x, y, credits, hiddens, device, lr): + """Rule 3: L_cos = -cos(F_l(h_l), a_{l+1})""" + L = model.num_blocks + # Head + hL = hiddens[-1].detach() + logits_out = model.out_head(model.out_ln(hL)) + loss_out = F.cross_entropy(logits_out, y) + head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters()) + grads_head = torch.autograd.grad(loss_out, head_params) + with torch.no_grad(): + for p, g in zip(head_params, grads_head): + p.sub_(lr * g) + # Blocks + for l in range(L): + h_l = hiddens[l].detach() + a = credits[l] + f_l = model.blocks[l](h_l) + cos_sim = F.cosine_similarity(f_l, a, dim=-1).mean() + local_loss = -cos_sim + block_grads = torch.autograd.grad(local_loss, model.blocks[l].parameters()) + with torch.no_grad(): + for p, g in zip(model.blocks[l].parameters(), block_grads): + p.sub_(lr * g.clamp(-1, 1)) + # Embed + a_0 = credits[0] + h0 = model.embed(x) + cos_sim_0 = F.cosine_similarity(h0, a_0, dim=-1).mean() + embed_loss = -cos_sim_0 + embed_grads = torch.autograd.grad(embed_loss, model.embed.parameters()) + with torch.no_grad(): + for p, g in zip(model.embed.parameters(), embed_grads): + p.sub_(lr * g.clamp(-1, 1)) + + +# ============================================================================= +# Main +# ============================================================================= +def run_experiment(args): + device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + os.makedirs(args.output_dir, exist_ok=True) + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + train_loader, test_loader = get_cifar10(args.batch_size) + input_dim = 32 * 32 * 3 + L = args.num_blocks + d = args.d_hidden + + # Load BP snapshot + model_bp = ResidualMLP(input_dim, d, 10, L).to(device) + bp_ckpt = f'results/frozen_cifar/bp_ref_L{L}_d{d}_s{args.seed}.pt' + model_bp.load_state_dict(torch.load(bp_ckpt, map_location=device)) + model_bp.eval() + for p in model_bp.parameters(): + p.requires_grad_(False) + print(f"Loaded BP snapshot from {bp_ckpt}") + + # Load pre-trained estimators (or train fresh) + # DFA + dfa_Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)] + + # Scalar CB — train on snapshot + print("\nTraining ScalarCB on snapshot...") + from experiments.snapshot_exploitability import train_scalar_cb_on_snapshot, train_vector_on_snapshot + torch.manual_seed(args.seed + 2000) + cb = train_scalar_cb_on_snapshot(model_bp, train_loader, device, + epochs=args.estimator_epochs, lr_fb=args.lr_fb) + + # Vector field — train on snapshot + print("\nTraining Vec_M4 on snapshot...") + torch.manual_seed(args.seed + 4000) + vec4 = train_vector_on_snapshot(model_bp, train_loader, device, + epochs=args.estimator_epochs, lr_fb=args.lr_fb, M=4) + + credit_sources = { + 'dfa': ('dfa', None, dfa_Bs), + 'scalar_cb': ('scalar_cb', cb, None), + 'vec_eT_M4': ('vec', vec4, None), + 'oracle_bp': ('oracle_bp', None, None), + } + + update_rules = { + 'inner_product': update_inner_product, + 'target_shift': lambda m, x, y, c, h, dev, lr: update_target_shift(m, x, y, c, h, dev, lr, eta_target=args.eta_target), + 'cosine_target': update_cosine_target, + } + + # Eval function + eval_batches = [] + for i, (xv, yv) in enumerate(test_loader): + if i >= 10: + break + eval_batches.append((xv.view(xv.size(0), -1).to(device), yv.to(device))) + + def eval_model(model): + model.eval() + total_loss, correct, total = 0, 0, 0 + with torch.no_grad(): + for xv, yv in eval_batches: + logits = model(xv) + total_loss += F.cross_entropy(logits, yv, reduction='sum').item() + correct += (logits.argmax(1) == yv).sum().item() + total += xv.size(0) + return total_loss / total, correct / total + + # ========================================================= + # Run all combinations: credit_source x update_rule x k_steps + # ========================================================= + results = {} + + for cs_name, (src, est, Bs) in credit_sources.items(): + for rule_name, rule_fn in update_rules.items(): + for k in [1, 5, 20]: + tag = f"{cs_name}_{rule_name}_k{k}" + + model_test = copy.deepcopy(model_bp) + for p in model_test.parameters(): + p.requires_grad_(True) + + loss_before, acc_before = eval_model(model_test) + + train_iter = iter(train_loader) + for step in range(k): + try: + x_step, y_step = next(train_iter) + except StopIteration: + train_iter = iter(train_loader) + x_step, y_step = next(train_iter) + x_step = x_step.view(x_step.size(0), -1).to(device) + y_step = y_step.to(device) + + for p in model_test.parameters(): + p.requires_grad_(False) + credits, hiddens, s = get_credits(model_test, x_step, y_step, device, + src, estimator=est, dfa_Bs=Bs) + for p in model_test.parameters(): + p.requires_grad_(True) + rule_fn(model_test, x_step, y_step, credits, hiddens, device, lr=args.lr_update) + + for p in model_test.parameters(): + p.requires_grad_(False) + loss_after, acc_after = eval_model(model_test) + + results[tag] = { + 'credit': cs_name, 'rule': rule_name, 'k': k, + 'loss_before': loss_before, 'loss_after': loss_after, + 'delta_loss': loss_after - loss_before, + 'delta_acc': acc_after - acc_before, + } + + # ========================================================= + # Summary tables + # ========================================================= + print(f"\n{'='*90}") + print("RESULTS: DeltaLoss (negative = good)") + print(f"{'='*90}") + + for k in [1, 5, 20]: + print(f"\n--- k={k} steps ---") + print(f"{'Credit':<15} {'inner_prod':>12} {'target_shift':>14} {'cosine':>12}") + print("-" * 58) + for cs_name in ['dfa', 'scalar_cb', 'vec_eT_M4', 'oracle_bp']: + row = f"{cs_name:<15}" + for rule_name in ['inner_product', 'target_shift', 'cosine_target']: + tag = f"{cs_name}_{rule_name}_k{k}" + dl = results[tag]['delta_loss'] + row += f" {dl:>+12.4f}" + print(row) + + # Save + out_path = os.path.join(args.output_dir, f'update_swap_L{L}_d{d}_s{args.seed}.json') + with open(out_path, 'w') as f: + json.dump(results, f, indent=2, default=float) + print(f"\nSaved to {out_path}") + + # Judgment + print(f"\n{'='*60}") + print("JUDGMENT") + print(f"{'='*60}") + + # Compare at k=5 + inner_vec = results['vec_eT_M4_inner_product_k5']['delta_loss'] + shift_vec = results['vec_eT_M4_target_shift_k5']['delta_loss'] + shift_bp = results['oracle_bp_target_shift_k5']['delta_loss'] + inner_dfa = results['dfa_inner_product_k5']['delta_loss'] + + print(f"k=5: Vec+inner={inner_vec:+.4f}, Vec+shift={shift_vec:+.4f}, " + f"BP+shift={shift_bp:+.4f}, DFA+inner={inner_dfa:+.4f}") + + if shift_vec < inner_vec and shift_vec < 0: + print("TARGET-SHIFT WINS: Vec credit becomes exploitable with target-shift rule.") + print(" -> Project should pivot to 'credit + better local update coupling'.") + elif shift_bp < 0 and shift_vec >= 0: + print("TARGET-SHIFT HELPS BP BUT NOT VEC: Credit quality still matters.") + else: + print("TARGET-SHIFT DOESN'T HELP: Need further investigation.") + + +def main(): + parser = argparse.ArgumentParser(description='Phase 6C: Local Update Rule Swap') + parser.add_argument('--num_blocks', type=int, default=4) + parser.add_argument('--d_hidden', type=int, default=256) + parser.add_argument('--batch_size', type=int, default=128) + parser.add_argument('--estimator_epochs', type=int, default=100) + parser.add_argument('--lr_fb', type=float, default=1e-3) + parser.add_argument('--lr_update', type=float, default=1e-3) + parser.add_argument('--eta_target', type=float, default=0.01) + parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--gpu', type=int, default=3) + parser.add_argument('--output_dir', type=str, default='results/update_swap') + args = parser.parse_args() + run_experiment(args) + + +if __name__ == '__main__': + main() |
