summaryrefslogtreecommitdiff
path: root/experiments/local_update_swap.py
diff options
context:
space:
mode:
Diffstat (limited to 'experiments/local_update_swap.py')
-rw-r--r--experiments/local_update_swap.py427
1 files changed, 427 insertions, 0 deletions
diff --git a/experiments/local_update_swap.py b/experiments/local_update_swap.py
new file mode 100644
index 0000000..207560a
--- /dev/null
+++ b/experiments/local_update_swap.py
@@ -0,0 +1,427 @@
+"""
+Phase 6C: Local Update Rule Swap.
+
+Compare different local update rules using the same credit signals on a fixed snapshot.
+
+Rule 1 (baseline): Inner-product surrogate
+ L_inner = <F_l(h_l), a_{l+1}>
+
+Rule 2: Target-shift local regression
+ h_{l+1}^target = h_{l+1} - eta_target * a_{l+1}^norm
+ L_shift = 0.5 * || h_l + F_l(h_l) - sg(h_{l+1}^target) ||^2
+
+Rule 3: Cosine-target update
+ L_cos = - cos(F_l(h_l), a_{l+1})
+"""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+import copy
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.residual_mlp import ResidualMLP
+from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema
+from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation, nudging_test
+
+
+class VectorCreditNet(nn.Module):
+ def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.time_embed = SinusoidalTimeEmbed(time_embed_dim)
+ input_dim = d_hidden + time_embed_dim + s_dim
+ layers = []
+ for i in range(num_layers):
+ in_d = input_dim if i == 0 else hidden_dim
+ layers.append(nn.Linear(in_d, hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, d_hidden))
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, h, t, s):
+ h_normed = self.ln(h)
+ t_emb = self.time_embed(t)
+ inp = torch.cat([h_normed, t_emb, s], dim=-1)
+ return self.net(inp)
+
+
+def get_cifar10(batch_size=128):
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
+ testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
+ train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
+ test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
+ return train_loader, test_loader
+
+
+def get_credits(model, x, y, device, credit_source, estimator=None, dfa_Bs=None):
+ L = model.num_blocks
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+ credits = {}
+ if credit_source == 'dfa':
+ for l in range(L):
+ credits[l] = (s @ dfa_Bs[l].T).detach()
+ elif credit_source == 'scalar_cb':
+ estimator.eval()
+ for l in range(L):
+ h_l = hiddens[l].detach().requires_grad_(True)
+ t_l = torch.full((batch,), l / L, device=device)
+ V = estimator(h_l, t_l, s)
+ credits[l] = torch.autograd.grad(V.sum(), h_l, create_graph=False)[0].detach()
+ elif credit_source == 'vec':
+ estimator.eval()
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ credits[l] = estimator(h_l, t_l, s).detach()
+ elif credit_source == 'oracle_bp':
+ for p in model.parameters():
+ p.requires_grad_(True)
+ model.zero_grad()
+ logits_bp, hiddens_bp = model(x, return_hidden=True)
+ for l in range(L + 1):
+ hiddens_bp[l].retain_grad()
+ F.cross_entropy(logits_bp, y).backward()
+ for l in range(L):
+ credits[l] = hiddens_bp[l].grad.detach().clone()
+ for p in model.parameters():
+ p.requires_grad_(False)
+ return credits, hiddens, s
+
+
+# =============================================================================
+# Local update rules
+# =============================================================================
+def update_inner_product(model, x, y, credits, hiddens, device, lr):
+ """Rule 1: L_inner = <F_l(h_l), a_{l+1}>"""
+ L = model.num_blocks
+ # Head
+ hL = hiddens[-1].detach()
+ logits_out = model.out_head(model.out_ln(hL))
+ loss_out = F.cross_entropy(logits_out, y)
+ head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters())
+ grads_head = torch.autograd.grad(loss_out, head_params)
+ with torch.no_grad():
+ for p, g in zip(head_params, grads_head):
+ p.sub_(lr * g)
+ # Blocks
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a = credits[l]
+ rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a / rms
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a_norm).sum(dim=-1).mean()
+ block_grads = torch.autograd.grad(local_loss, model.blocks[l].parameters())
+ with torch.no_grad():
+ for p, g in zip(model.blocks[l].parameters(), block_grads):
+ p.sub_(lr * g.clamp(-1, 1))
+ # Embed
+ a_0 = credits[0]
+ rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ h0 = model.embed(x)
+ embed_loss = (h0 * (a_0 / rms_0)).sum(dim=-1).mean()
+ embed_grads = torch.autograd.grad(embed_loss, model.embed.parameters())
+ with torch.no_grad():
+ for p, g in zip(model.embed.parameters(), embed_grads):
+ p.sub_(lr * g.clamp(-1, 1))
+
+
+def update_target_shift(model, x, y, credits, hiddens, device, lr, eta_target=0.01):
+ """
+ Rule 2: Target-shift local regression.
+ h_{l+1}^target = h_{l+1} - eta_target * a_{l+1}^norm
+ L_shift = 0.5 * || (h_l + F_l(h_l)) - sg(h_{l+1}^target) ||^2
+ """
+ L = model.num_blocks
+ # Head — still use exact CE
+ hL = hiddens[-1].detach()
+ logits_out = model.out_head(model.out_ln(hL))
+ loss_out = F.cross_entropy(logits_out, y)
+ head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters())
+ grads_head = torch.autograd.grad(loss_out, head_params)
+ with torch.no_grad():
+ for p, g in zip(head_params, grads_head):
+ p.sub_(lr * g)
+
+ # Blocks: target-shift regression
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ h_l_next = hiddens[l + 1].detach() # current h_{l+1}
+
+ # Credit at layer l+1 (or l for the last one)
+ # We use credit[l] which is the credit at layer l
+ # The target shift: move h_{l+1} in the negative credit direction
+ a = credits[l]
+ rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a / rms
+
+ # Target: where h_{l+1} should move toward
+ h_target = (h_l_next - eta_target * a_norm).detach()
+
+ # Compute F_l(h_l) with gradient
+ f_l = model.blocks[l](h_l)
+ h_l_next_pred = h_l + f_l # predicted h_{l+1}
+
+ # Regression loss
+ shift_loss = 0.5 * ((h_l_next_pred - h_target) ** 2).sum(dim=-1).mean()
+ block_grads = torch.autograd.grad(shift_loss, model.blocks[l].parameters())
+ with torch.no_grad():
+ for p, g in zip(model.blocks[l].parameters(), block_grads):
+ p.sub_(lr * g.clamp(-1, 1))
+
+ # Embed: use credit[0] as target shift for h_0
+ a_0 = credits[0]
+ rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ h0 = model.embed(x)
+ h0_target = (hiddens[0].detach() - eta_target * (a_0 / rms_0)).detach()
+ embed_loss = 0.5 * ((h0 - h0_target) ** 2).sum(dim=-1).mean()
+ embed_grads = torch.autograd.grad(embed_loss, model.embed.parameters())
+ with torch.no_grad():
+ for p, g in zip(model.embed.parameters(), embed_grads):
+ p.sub_(lr * g.clamp(-1, 1))
+
+
+def update_cosine_target(model, x, y, credits, hiddens, device, lr):
+ """Rule 3: L_cos = -cos(F_l(h_l), a_{l+1})"""
+ L = model.num_blocks
+ # Head
+ hL = hiddens[-1].detach()
+ logits_out = model.out_head(model.out_ln(hL))
+ loss_out = F.cross_entropy(logits_out, y)
+ head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters())
+ grads_head = torch.autograd.grad(loss_out, head_params)
+ with torch.no_grad():
+ for p, g in zip(head_params, grads_head):
+ p.sub_(lr * g)
+ # Blocks
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a = credits[l]
+ f_l = model.blocks[l](h_l)
+ cos_sim = F.cosine_similarity(f_l, a, dim=-1).mean()
+ local_loss = -cos_sim
+ block_grads = torch.autograd.grad(local_loss, model.blocks[l].parameters())
+ with torch.no_grad():
+ for p, g in zip(model.blocks[l].parameters(), block_grads):
+ p.sub_(lr * g.clamp(-1, 1))
+ # Embed
+ a_0 = credits[0]
+ h0 = model.embed(x)
+ cos_sim_0 = F.cosine_similarity(h0, a_0, dim=-1).mean()
+ embed_loss = -cos_sim_0
+ embed_grads = torch.autograd.grad(embed_loss, model.embed.parameters())
+ with torch.no_grad():
+ for p, g in zip(model.embed.parameters(), embed_grads):
+ p.sub_(lr * g.clamp(-1, 1))
+
+
+# =============================================================================
+# Main
+# =============================================================================
+def run_experiment(args):
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ print(f"Using device: {device}")
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ torch.manual_seed(args.seed)
+ np.random.seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+
+ train_loader, test_loader = get_cifar10(args.batch_size)
+ input_dim = 32 * 32 * 3
+ L = args.num_blocks
+ d = args.d_hidden
+
+ # Load BP snapshot
+ model_bp = ResidualMLP(input_dim, d, 10, L).to(device)
+ bp_ckpt = f'results/frozen_cifar/bp_ref_L{L}_d{d}_s{args.seed}.pt'
+ model_bp.load_state_dict(torch.load(bp_ckpt, map_location=device))
+ model_bp.eval()
+ for p in model_bp.parameters():
+ p.requires_grad_(False)
+ print(f"Loaded BP snapshot from {bp_ckpt}")
+
+ # Load pre-trained estimators (or train fresh)
+ # DFA
+ dfa_Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)]
+
+ # Scalar CB — train on snapshot
+ print("\nTraining ScalarCB on snapshot...")
+ from experiments.snapshot_exploitability import train_scalar_cb_on_snapshot, train_vector_on_snapshot
+ torch.manual_seed(args.seed + 2000)
+ cb = train_scalar_cb_on_snapshot(model_bp, train_loader, device,
+ epochs=args.estimator_epochs, lr_fb=args.lr_fb)
+
+ # Vector field — train on snapshot
+ print("\nTraining Vec_M4 on snapshot...")
+ torch.manual_seed(args.seed + 4000)
+ vec4 = train_vector_on_snapshot(model_bp, train_loader, device,
+ epochs=args.estimator_epochs, lr_fb=args.lr_fb, M=4)
+
+ credit_sources = {
+ 'dfa': ('dfa', None, dfa_Bs),
+ 'scalar_cb': ('scalar_cb', cb, None),
+ 'vec_eT_M4': ('vec', vec4, None),
+ 'oracle_bp': ('oracle_bp', None, None),
+ }
+
+ update_rules = {
+ 'inner_product': update_inner_product,
+ 'target_shift': lambda m, x, y, c, h, dev, lr: update_target_shift(m, x, y, c, h, dev, lr, eta_target=args.eta_target),
+ 'cosine_target': update_cosine_target,
+ }
+
+ # Eval function
+ eval_batches = []
+ for i, (xv, yv) in enumerate(test_loader):
+ if i >= 10:
+ break
+ eval_batches.append((xv.view(xv.size(0), -1).to(device), yv.to(device)))
+
+ def eval_model(model):
+ model.eval()
+ total_loss, correct, total = 0, 0, 0
+ with torch.no_grad():
+ for xv, yv in eval_batches:
+ logits = model(xv)
+ total_loss += F.cross_entropy(logits, yv, reduction='sum').item()
+ correct += (logits.argmax(1) == yv).sum().item()
+ total += xv.size(0)
+ return total_loss / total, correct / total
+
+ # =========================================================
+ # Run all combinations: credit_source x update_rule x k_steps
+ # =========================================================
+ results = {}
+
+ for cs_name, (src, est, Bs) in credit_sources.items():
+ for rule_name, rule_fn in update_rules.items():
+ for k in [1, 5, 20]:
+ tag = f"{cs_name}_{rule_name}_k{k}"
+
+ model_test = copy.deepcopy(model_bp)
+ for p in model_test.parameters():
+ p.requires_grad_(True)
+
+ loss_before, acc_before = eval_model(model_test)
+
+ train_iter = iter(train_loader)
+ for step in range(k):
+ try:
+ x_step, y_step = next(train_iter)
+ except StopIteration:
+ train_iter = iter(train_loader)
+ x_step, y_step = next(train_iter)
+ x_step = x_step.view(x_step.size(0), -1).to(device)
+ y_step = y_step.to(device)
+
+ for p in model_test.parameters():
+ p.requires_grad_(False)
+ credits, hiddens, s = get_credits(model_test, x_step, y_step, device,
+ src, estimator=est, dfa_Bs=Bs)
+ for p in model_test.parameters():
+ p.requires_grad_(True)
+ rule_fn(model_test, x_step, y_step, credits, hiddens, device, lr=args.lr_update)
+
+ for p in model_test.parameters():
+ p.requires_grad_(False)
+ loss_after, acc_after = eval_model(model_test)
+
+ results[tag] = {
+ 'credit': cs_name, 'rule': rule_name, 'k': k,
+ 'loss_before': loss_before, 'loss_after': loss_after,
+ 'delta_loss': loss_after - loss_before,
+ 'delta_acc': acc_after - acc_before,
+ }
+
+ # =========================================================
+ # Summary tables
+ # =========================================================
+ print(f"\n{'='*90}")
+ print("RESULTS: DeltaLoss (negative = good)")
+ print(f"{'='*90}")
+
+ for k in [1, 5, 20]:
+ print(f"\n--- k={k} steps ---")
+ print(f"{'Credit':<15} {'inner_prod':>12} {'target_shift':>14} {'cosine':>12}")
+ print("-" * 58)
+ for cs_name in ['dfa', 'scalar_cb', 'vec_eT_M4', 'oracle_bp']:
+ row = f"{cs_name:<15}"
+ for rule_name in ['inner_product', 'target_shift', 'cosine_target']:
+ tag = f"{cs_name}_{rule_name}_k{k}"
+ dl = results[tag]['delta_loss']
+ row += f" {dl:>+12.4f}"
+ print(row)
+
+ # Save
+ out_path = os.path.join(args.output_dir, f'update_swap_L{L}_d{d}_s{args.seed}.json')
+ with open(out_path, 'w') as f:
+ json.dump(results, f, indent=2, default=float)
+ print(f"\nSaved to {out_path}")
+
+ # Judgment
+ print(f"\n{'='*60}")
+ print("JUDGMENT")
+ print(f"{'='*60}")
+
+ # Compare at k=5
+ inner_vec = results['vec_eT_M4_inner_product_k5']['delta_loss']
+ shift_vec = results['vec_eT_M4_target_shift_k5']['delta_loss']
+ shift_bp = results['oracle_bp_target_shift_k5']['delta_loss']
+ inner_dfa = results['dfa_inner_product_k5']['delta_loss']
+
+ print(f"k=5: Vec+inner={inner_vec:+.4f}, Vec+shift={shift_vec:+.4f}, "
+ f"BP+shift={shift_bp:+.4f}, DFA+inner={inner_dfa:+.4f}")
+
+ if shift_vec < inner_vec and shift_vec < 0:
+ print("TARGET-SHIFT WINS: Vec credit becomes exploitable with target-shift rule.")
+ print(" -> Project should pivot to 'credit + better local update coupling'.")
+ elif shift_bp < 0 and shift_vec >= 0:
+ print("TARGET-SHIFT HELPS BP BUT NOT VEC: Credit quality still matters.")
+ else:
+ print("TARGET-SHIFT DOESN'T HELP: Need further investigation.")
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Phase 6C: Local Update Rule Swap')
+ parser.add_argument('--num_blocks', type=int, default=4)
+ parser.add_argument('--d_hidden', type=int, default=256)
+ parser.add_argument('--batch_size', type=int, default=128)
+ parser.add_argument('--estimator_epochs', type=int, default=100)
+ parser.add_argument('--lr_fb', type=float, default=1e-3)
+ parser.add_argument('--lr_update', type=float, default=1e-3)
+ parser.add_argument('--eta_target', type=float, default=0.01)
+ parser.add_argument('--seed', type=int, default=42)
+ parser.add_argument('--gpu', type=int, default=3)
+ parser.add_argument('--output_dir', type=str, default='results/update_swap')
+ args = parser.parse_args()
+ run_experiment(args)
+
+
+if __name__ == '__main__':
+ main()