""" Phase 6C: Local Update Rule Swap. Compare different local update rules using the same credit signals on a fixed snapshot. Rule 1 (baseline): Inner-product surrogate L_inner = Rule 2: Target-shift local regression h_{l+1}^target = h_{l+1} - eta_target * a_{l+1}^norm L_shift = 0.5 * || h_l + F_l(h_l) - sg(h_{l+1}^target) ||^2 Rule 3: Cosine-target update L_cos = - cos(F_l(h_l), a_{l+1}) """ import os import sys import json import argparse import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader import torchvision import torchvision.transforms as transforms import copy sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.residual_mlp import ResidualMLP from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation, nudging_test class VectorCreditNet(nn.Module): def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3): super().__init__() self.ln = nn.LayerNorm(d_hidden) self.time_embed = SinusoidalTimeEmbed(time_embed_dim) input_dim = d_hidden + time_embed_dim + s_dim layers = [] for i in range(num_layers): in_d = input_dim if i == 0 else hidden_dim layers.append(nn.Linear(in_d, hidden_dim)) layers.append(nn.GELU()) layers.append(nn.Linear(hidden_dim, d_hidden)) self.net = nn.Sequential(*layers) def forward(self, h, t, s): h_normed = self.ln(h) t_emb = self.time_embed(t) inp = torch.cat([h_normed, t_emb, s], dim=-1) return self.net(inp) def get_cifar10(batch_size=128): transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) return train_loader, test_loader def get_credits(model, x, y, device, credit_source, estimator=None, dfa_Bs=None): L = model.num_blocks batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 s = e_T.detach() credits = {} if credit_source == 'dfa': for l in range(L): credits[l] = (s @ dfa_Bs[l].T).detach() elif credit_source == 'scalar_cb': estimator.eval() for l in range(L): h_l = hiddens[l].detach().requires_grad_(True) t_l = torch.full((batch,), l / L, device=device) V = estimator(h_l, t_l, s) credits[l] = torch.autograd.grad(V.sum(), h_l, create_graph=False)[0].detach() elif credit_source == 'vec': estimator.eval() for l in range(L): h_l = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) credits[l] = estimator(h_l, t_l, s).detach() elif credit_source == 'oracle_bp': for p in model.parameters(): p.requires_grad_(True) model.zero_grad() logits_bp, hiddens_bp = model(x, return_hidden=True) for l in range(L + 1): hiddens_bp[l].retain_grad() F.cross_entropy(logits_bp, y).backward() for l in range(L): credits[l] = hiddens_bp[l].grad.detach().clone() for p in model.parameters(): p.requires_grad_(False) return credits, hiddens, s # ============================================================================= # Local update rules # ============================================================================= def update_inner_product(model, x, y, credits, hiddens, device, lr): """Rule 1: L_inner = """ L = model.num_blocks # Head hL = hiddens[-1].detach() logits_out = model.out_head(model.out_ln(hL)) loss_out = F.cross_entropy(logits_out, y) head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters()) grads_head = torch.autograd.grad(loss_out, head_params) with torch.no_grad(): for p, g in zip(head_params, grads_head): p.sub_(lr * g) # Blocks for l in range(L): h_l = hiddens[l].detach() a = credits[l] rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a_norm = a / rms f_l = model.blocks[l](h_l) local_loss = (f_l * a_norm).sum(dim=-1).mean() block_grads = torch.autograd.grad(local_loss, model.blocks[l].parameters()) with torch.no_grad(): for p, g in zip(model.blocks[l].parameters(), block_grads): p.sub_(lr * g.clamp(-1, 1)) # Embed a_0 = credits[0] rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 h0 = model.embed(x) embed_loss = (h0 * (a_0 / rms_0)).sum(dim=-1).mean() embed_grads = torch.autograd.grad(embed_loss, model.embed.parameters()) with torch.no_grad(): for p, g in zip(model.embed.parameters(), embed_grads): p.sub_(lr * g.clamp(-1, 1)) def update_target_shift(model, x, y, credits, hiddens, device, lr, eta_target=0.01): """ Rule 2: Target-shift local regression. h_{l+1}^target = h_{l+1} - eta_target * a_{l+1}^norm L_shift = 0.5 * || (h_l + F_l(h_l)) - sg(h_{l+1}^target) ||^2 """ L = model.num_blocks # Head — still use exact CE hL = hiddens[-1].detach() logits_out = model.out_head(model.out_ln(hL)) loss_out = F.cross_entropy(logits_out, y) head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters()) grads_head = torch.autograd.grad(loss_out, head_params) with torch.no_grad(): for p, g in zip(head_params, grads_head): p.sub_(lr * g) # Blocks: target-shift regression for l in range(L): h_l = hiddens[l].detach() h_l_next = hiddens[l + 1].detach() # current h_{l+1} # Credit at layer l+1 (or l for the last one) # We use credit[l] which is the credit at layer l # The target shift: move h_{l+1} in the negative credit direction a = credits[l] rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a_norm = a / rms # Target: where h_{l+1} should move toward h_target = (h_l_next - eta_target * a_norm).detach() # Compute F_l(h_l) with gradient f_l = model.blocks[l](h_l) h_l_next_pred = h_l + f_l # predicted h_{l+1} # Regression loss shift_loss = 0.5 * ((h_l_next_pred - h_target) ** 2).sum(dim=-1).mean() block_grads = torch.autograd.grad(shift_loss, model.blocks[l].parameters()) with torch.no_grad(): for p, g in zip(model.blocks[l].parameters(), block_grads): p.sub_(lr * g.clamp(-1, 1)) # Embed: use credit[0] as target shift for h_0 a_0 = credits[0] rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 h0 = model.embed(x) h0_target = (hiddens[0].detach() - eta_target * (a_0 / rms_0)).detach() embed_loss = 0.5 * ((h0 - h0_target) ** 2).sum(dim=-1).mean() embed_grads = torch.autograd.grad(embed_loss, model.embed.parameters()) with torch.no_grad(): for p, g in zip(model.embed.parameters(), embed_grads): p.sub_(lr * g.clamp(-1, 1)) def update_cosine_target(model, x, y, credits, hiddens, device, lr): """Rule 3: L_cos = -cos(F_l(h_l), a_{l+1})""" L = model.num_blocks # Head hL = hiddens[-1].detach() logits_out = model.out_head(model.out_ln(hL)) loss_out = F.cross_entropy(logits_out, y) head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters()) grads_head = torch.autograd.grad(loss_out, head_params) with torch.no_grad(): for p, g in zip(head_params, grads_head): p.sub_(lr * g) # Blocks for l in range(L): h_l = hiddens[l].detach() a = credits[l] f_l = model.blocks[l](h_l) cos_sim = F.cosine_similarity(f_l, a, dim=-1).mean() local_loss = -cos_sim block_grads = torch.autograd.grad(local_loss, model.blocks[l].parameters()) with torch.no_grad(): for p, g in zip(model.blocks[l].parameters(), block_grads): p.sub_(lr * g.clamp(-1, 1)) # Embed a_0 = credits[0] h0 = model.embed(x) cos_sim_0 = F.cosine_similarity(h0, a_0, dim=-1).mean() embed_loss = -cos_sim_0 embed_grads = torch.autograd.grad(embed_loss, model.embed.parameters()) with torch.no_grad(): for p, g in zip(model.embed.parameters(), embed_grads): p.sub_(lr * g.clamp(-1, 1)) # ============================================================================= # Main # ============================================================================= def run_experiment(args): device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") os.makedirs(args.output_dir, exist_ok=True) torch.manual_seed(args.seed) np.random.seed(args.seed) torch.cuda.manual_seed_all(args.seed) train_loader, test_loader = get_cifar10(args.batch_size) input_dim = 32 * 32 * 3 L = args.num_blocks d = args.d_hidden # Load BP snapshot model_bp = ResidualMLP(input_dim, d, 10, L).to(device) bp_ckpt = f'results/frozen_cifar/bp_ref_L{L}_d{d}_s{args.seed}.pt' model_bp.load_state_dict(torch.load(bp_ckpt, map_location=device)) model_bp.eval() for p in model_bp.parameters(): p.requires_grad_(False) print(f"Loaded BP snapshot from {bp_ckpt}") # Load pre-trained estimators (or train fresh) # DFA dfa_Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)] # Scalar CB — train on snapshot print("\nTraining ScalarCB on snapshot...") from experiments.snapshot_exploitability import train_scalar_cb_on_snapshot, train_vector_on_snapshot torch.manual_seed(args.seed + 2000) cb = train_scalar_cb_on_snapshot(model_bp, train_loader, device, epochs=args.estimator_epochs, lr_fb=args.lr_fb) # Vector field — train on snapshot print("\nTraining Vec_M4 on snapshot...") torch.manual_seed(args.seed + 4000) vec4 = train_vector_on_snapshot(model_bp, train_loader, device, epochs=args.estimator_epochs, lr_fb=args.lr_fb, M=4) credit_sources = { 'dfa': ('dfa', None, dfa_Bs), 'scalar_cb': ('scalar_cb', cb, None), 'vec_eT_M4': ('vec', vec4, None), 'oracle_bp': ('oracle_bp', None, None), } update_rules = { 'inner_product': update_inner_product, 'target_shift': lambda m, x, y, c, h, dev, lr: update_target_shift(m, x, y, c, h, dev, lr, eta_target=args.eta_target), 'cosine_target': update_cosine_target, } # Eval function eval_batches = [] for i, (xv, yv) in enumerate(test_loader): if i >= 10: break eval_batches.append((xv.view(xv.size(0), -1).to(device), yv.to(device))) def eval_model(model): model.eval() total_loss, correct, total = 0, 0, 0 with torch.no_grad(): for xv, yv in eval_batches: logits = model(xv) total_loss += F.cross_entropy(logits, yv, reduction='sum').item() correct += (logits.argmax(1) == yv).sum().item() total += xv.size(0) return total_loss / total, correct / total # ========================================================= # Run all combinations: credit_source x update_rule x k_steps # ========================================================= results = {} for cs_name, (src, est, Bs) in credit_sources.items(): for rule_name, rule_fn in update_rules.items(): for k in [1, 5, 20]: tag = f"{cs_name}_{rule_name}_k{k}" model_test = copy.deepcopy(model_bp) for p in model_test.parameters(): p.requires_grad_(True) loss_before, acc_before = eval_model(model_test) train_iter = iter(train_loader) for step in range(k): try: x_step, y_step = next(train_iter) except StopIteration: train_iter = iter(train_loader) x_step, y_step = next(train_iter) x_step = x_step.view(x_step.size(0), -1).to(device) y_step = y_step.to(device) for p in model_test.parameters(): p.requires_grad_(False) credits, hiddens, s = get_credits(model_test, x_step, y_step, device, src, estimator=est, dfa_Bs=Bs) for p in model_test.parameters(): p.requires_grad_(True) rule_fn(model_test, x_step, y_step, credits, hiddens, device, lr=args.lr_update) for p in model_test.parameters(): p.requires_grad_(False) loss_after, acc_after = eval_model(model_test) results[tag] = { 'credit': cs_name, 'rule': rule_name, 'k': k, 'loss_before': loss_before, 'loss_after': loss_after, 'delta_loss': loss_after - loss_before, 'delta_acc': acc_after - acc_before, } # ========================================================= # Summary tables # ========================================================= print(f"\n{'='*90}") print("RESULTS: DeltaLoss (negative = good)") print(f"{'='*90}") for k in [1, 5, 20]: print(f"\n--- k={k} steps ---") print(f"{'Credit':<15} {'inner_prod':>12} {'target_shift':>14} {'cosine':>12}") print("-" * 58) for cs_name in ['dfa', 'scalar_cb', 'vec_eT_M4', 'oracle_bp']: row = f"{cs_name:<15}" for rule_name in ['inner_product', 'target_shift', 'cosine_target']: tag = f"{cs_name}_{rule_name}_k{k}" dl = results[tag]['delta_loss'] row += f" {dl:>+12.4f}" print(row) # Save out_path = os.path.join(args.output_dir, f'update_swap_L{L}_d{d}_s{args.seed}.json') with open(out_path, 'w') as f: json.dump(results, f, indent=2, default=float) print(f"\nSaved to {out_path}") # Judgment print(f"\n{'='*60}") print("JUDGMENT") print(f"{'='*60}") # Compare at k=5 inner_vec = results['vec_eT_M4_inner_product_k5']['delta_loss'] shift_vec = results['vec_eT_M4_target_shift_k5']['delta_loss'] shift_bp = results['oracle_bp_target_shift_k5']['delta_loss'] inner_dfa = results['dfa_inner_product_k5']['delta_loss'] print(f"k=5: Vec+inner={inner_vec:+.4f}, Vec+shift={shift_vec:+.4f}, " f"BP+shift={shift_bp:+.4f}, DFA+inner={inner_dfa:+.4f}") if shift_vec < inner_vec and shift_vec < 0: print("TARGET-SHIFT WINS: Vec credit becomes exploitable with target-shift rule.") print(" -> Project should pivot to 'credit + better local update coupling'.") elif shift_bp < 0 and shift_vec >= 0: print("TARGET-SHIFT HELPS BP BUT NOT VEC: Credit quality still matters.") else: print("TARGET-SHIFT DOESN'T HELP: Need further investigation.") def main(): parser = argparse.ArgumentParser(description='Phase 6C: Local Update Rule Swap') parser.add_argument('--num_blocks', type=int, default=4) parser.add_argument('--d_hidden', type=int, default=256) parser.add_argument('--batch_size', type=int, default=128) parser.add_argument('--estimator_epochs', type=int, default=100) parser.add_argument('--lr_fb', type=float, default=1e-3) parser.add_argument('--lr_update', type=float, default=1e-3) parser.add_argument('--eta_target', type=float, default=0.01) parser.add_argument('--seed', type=int, default=42) parser.add_argument('--gpu', type=int, default=3) parser.add_argument('--output_dir', type=str, default='results/update_swap') args = parser.parse_args() run_experiment(args) if __name__ == '__main__': main()