diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-03-23 18:21:26 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-03-23 18:21:26 -0500 |
| commit | 6ed4fa50ddfa4c7957aaa909aaf72f0d7d317712 (patch) | |
| tree | d7c63adcd19c4f5d46c8a937e5047fece55dea62 /experiments/cifar_resmlp.py | |
Initial implementation: all models, methods, toy and CIFAR experiments
Debug phase. Toy LQ experiments (3 seeds) complete with terminal gradient matching.
Credit bridge matches state bridge on linear system (~0.94 cosine).
CIFAR experiments in progress.
Diffstat (limited to 'experiments/cifar_resmlp.py')
| -rw-r--r-- | experiments/cifar_resmlp.py | 775 |
1 files changed, 775 insertions, 0 deletions
diff --git a/experiments/cifar_resmlp.py b/experiments/cifar_resmlp.py new file mode 100644 index 0000000..1582f6d --- /dev/null +++ b/experiments/cifar_resmlp.py @@ -0,0 +1,775 @@ +""" +Phase B: Deep Residual MLP on CIFAR-10. +Compare BP, DFA, State Bridge, Credit Bridge. + +CRITICAL CONSTRAINT: No hidden BP anchor for non-BP methods. +All block updates use detached hidden states and local surrogates. +""" +import os +import sys +import json +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision +import torchvision.transforms as transforms +import copy +import time + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.residual_mlp import ResidualMLP +from models.value_net import ValueNet, create_ema_model, update_ema +from models.state_bridge import StateBridgeNet +from metrics.credit_metrics import ( + cosine_similarity_batch, perturbation_correlation, nudging_test, + offline_bp_cosine, feature_drift +) + + +def get_data(dataset='cifar10', batch_size=128): + if dataset == 'cifar10': + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) + testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) + input_dim = 32 * 32 * 3 + num_classes = 10 + elif dataset == 'fashionmnist': + transform_train = transforms.Compose([ + transforms.RandomCrop(28, padding=2), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.2860,), (0.3530,)), + ]) + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.2860,), (0.3530,)), + ]) + trainset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform_train) + testset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform_test) + input_dim = 28 * 28 + num_classes = 10 + else: + raise ValueError(f"Unknown dataset: {dataset}") + + train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) + test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) + return train_loader, test_loader, input_dim, num_classes + + +def evaluate(model, test_loader, device): + model.eval() + correct, total = 0, 0 + with torch.no_grad(): + for x, y in test_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + logits = model(x) + correct += (logits.argmax(1) == y).sum().item() + total += x.size(0) + return correct / total + + +# ============================================================================= +# BP Baseline +# ============================================================================= +def train_bp(model, train_loader, test_loader, device, args): + """Standard end-to-end backprop training.""" + optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) + + log = {'train_loss': [], 'train_acc': [], 'test_acc': []} + + for epoch in range(1, args.epochs + 1): + model.train() + total_loss, correct, total = 0, 0, 0 + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + logits = model(x) + loss = F.cross_entropy(logits, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + total_loss += loss.item() * x.size(0) + correct += (logits.argmax(1) == y).sum().item() + total += x.size(0) + + scheduler.step() + train_loss = total_loss / total + train_acc = correct / total + test_acc = evaluate(model, test_loader, device) + log['train_loss'].append(train_loss) + log['train_acc'].append(train_acc) + log['test_acc'].append(test_acc) + if epoch % 10 == 0 or epoch == 1: + print(f" [BP] Epoch {epoch}: loss={train_loss:.4f}, train={train_acc:.4f}, test={test_acc:.4f}") + + return log + + +# ============================================================================= +# DFA Baseline +# ============================================================================= +def train_dfa(model, train_loader, test_loader, device, args): + """ + DFA training with fixed random feedback matrices. + Each block updated with local surrogate: L_l = <F_l(h_l), sg[a_{l+1}^DFA]>. + Output head updated with exact CE gradient (h_L detached). + Embedding updated via DFA credit at h_0. + """ + d = model.d_hidden + num_classes = args.num_classes + L = model.num_blocks + + # Fixed random feedback matrices, one per block + Bs = [torch.randn(d, num_classes, device=device) / np.sqrt(num_classes) for _ in range(L)] + + # Separate optimizers + block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd) + for block in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd) + head_opt = optim.AdamW( + list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=args.lr, weight_decay=args.wd + ) + + all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)]) + + log = {'train_loss': [], 'train_acc': [], 'test_acc': []} + + for epoch in range(1, args.epochs + 1): + model.train() + total_loss, correct, total = 0, 0, 0 + + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + batch = x.size(0) + + # Forward pass (no grad for hidden states) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + # e_T = softmax(logits) - one_hot(y) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 # (batch, num_classes) + + # 1. Update output head: exact CE gradient, h_L detached + hL_det = hiddens[-1].detach() + logits_out = model.out_head(model.out_ln(hL_det)) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad() + loss_out.backward() + head_opt.step() + + # 2. Update each block with DFA local surrogate + for l in range(L): + h_l = hiddens[l].detach() + # DFA credit: a_{l+1} = B_l @ e_T^T -> (d, batch) -> transpose + a_dfa = (e_T @ Bs[l].T).detach() # (batch, d) = (batch, C) @ (C, d) + # Normalize + rms = (a_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_dfa_norm = a_dfa / rms + # Local surrogate + f_l = model.blocks[l](h_l) + local_loss = (f_l * a_dfa_norm).sum(dim=-1).mean() + block_opts[l].zero_grad() + local_loss.backward() + block_opts[l].step() + + # 3. Update embedding with DFA credit at h_0 + a_0_dfa = (e_T @ Bs[0].T).detach() + rms_0 = (a_0_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_0_norm = a_0_dfa / rms_0 + h0 = model.embed(x) + embed_loss = (h0 * a_0_norm).sum(dim=-1).mean() + embed_opt.zero_grad() + embed_loss.backward() + embed_opt.step() + + total_loss += loss_val.item() * batch + correct += (logits.argmax(1) == y).sum().item() + total += batch + + for s in all_schedulers: + s.step() + + train_loss = total_loss / total + train_acc = correct / total + test_acc = evaluate(model, test_loader, device) + log['train_loss'].append(train_loss) + log['train_acc'].append(train_acc) + log['test_acc'].append(test_acc) + if epoch % 10 == 0 or epoch == 1: + print(f" [DFA] Epoch {epoch}: loss={train_loss:.4f}, train={train_acc:.4f}, test={test_acc:.4f}") + + return log, Bs + + +# ============================================================================= +# State Bridge +# ============================================================================= +def train_state_bridge(model, train_loader, test_loader, device, args): + """ + State Bridge: predict terminal h_L from (h_l, t_l, s), derive credit as + a_l = grad_{h_l} CE(W_out * LN(G_psi(h_l, t_l, s)), y). + """ + d = model.d_hidden + num_classes = args.num_classes + L = model.num_blocks + + state_pred = StateBridgeNet( + d_hidden=d, s_dim=num_classes, time_embed_dim=32, hidden_dim=256, num_layers=3 + ).to(device) + + block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd) + for block in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd) + head_opt = optim.AdamW( + list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=args.lr, weight_decay=args.wd + ) + state_opt = optim.Adam(state_pred.parameters(), lr=args.lr_fb) + + all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)]) + + log = {'train_loss': [], 'train_acc': [], 'test_acc': [], 'state_pred_error': []} + + for epoch in range(1, args.epochs + 1): + model.train() + state_pred.train() + total_loss, correct, total = 0, 0, 0 + total_se = 0 + + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + batch = x.size(0) + + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + + hL_det = hiddens[-1].detach() + + # Train state predictor: G_psi(h_l, t_l, s) -> h_L + # Predict the *residual* from h_l to h_L for numerical stability + state_loss = 0.0 + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + pred_hL = state_pred(h_l_det, t_l, s) + # Target: h_L (use normalized MSE for stability) + target = hL_det + target_norm = target.norm(dim=-1, keepdim=True).clamp(min=1.0) + state_loss = state_loss + (((pred_hL - target) / target_norm) ** 2).sum(dim=-1).mean() + state_loss = state_loss / L + state_opt.zero_grad() + state_loss.backward() + state_opt.step() + total_se += state_loss.item() * batch + + # Compute credits: a_l = grad_{h_l} CE(out_head(LN(G(h_l, t_l, s))), y) + credits = [] + for l in range(L): + h_l_det = hiddens[l].detach().requires_grad_(True) + t_l = torch.full((batch,), l / L, device=device) + pred_hL = state_pred(h_l_det, t_l, s) + pred_logits = model.out_head(model.out_ln(pred_hL)) + pred_loss = F.cross_entropy(pred_logits, y, reduction='sum') + a_l = torch.autograd.grad(pred_loss, h_l_det, create_graph=False)[0] + credits.append(a_l.detach()) + + # Update output head + logits_out = model.out_head(model.out_ln(hL_det)) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad() + loss_out.backward() + head_opt.step() + + # Update blocks + for l in range(L): + h_l = hiddens[l].detach() + a = credits[l] + rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_norm = a / rms + f_l = model.blocks[l](h_l) + local_loss = (f_l * a_norm).sum(dim=-1).mean() + block_opts[l].zero_grad() + local_loss.backward() + block_opts[l].step() + + # Update embedding with credit at layer 0 + a_0 = credits[0] + rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_0_norm = a_0 / rms_0 + h0 = model.embed(x) + embed_loss = (h0 * a_0_norm).sum(dim=-1).mean() + embed_opt.zero_grad() + embed_loss.backward() + embed_opt.step() + + total_loss += loss_val.item() * batch + correct += (logits.argmax(1) == y).sum().item() + total += batch + + for sch in all_schedulers: + sch.step() + + train_loss = total_loss / total + train_acc = correct / total + test_acc = evaluate(model, test_loader, device) + se = total_se / total + log['train_loss'].append(train_loss) + log['train_acc'].append(train_acc) + log['test_acc'].append(test_acc) + log['state_pred_error'].append(se) + if epoch % 10 == 0 or epoch == 1: + print(f" [SB] Epoch {epoch}: loss={train_loss:.4f}, train={train_acc:.4f}, " + f"test={test_acc:.4f}, state_err={se:.4f}") + + return log, state_pred + + +# ============================================================================= +# Credit Bridge +# ============================================================================= +def train_credit_bridge(model, train_loader, test_loader, device, args): + """ + Credit Bridge: learn V_phi(h_l, t_l, s) -> scalar value. + Credit: a_l = grad_{h_l} V_phi. + Training: terminal boundary + bridge consistency + terminal gradient matching. + The terminal gradient is local (output layer only), NOT hidden BP. + + Uses a warmup phase: first warmup_epochs, only train value net + output head, + then start using credit bridge signals to update blocks. + During warmup, blocks get DFA-style updates as a fallback. + """ + d = model.d_hidden + num_classes = args.num_classes + L = model.num_blocks + warmup_epochs = max(1, args.epochs // 5) # 20% warmup + + value_net = ValueNet( + d_hidden=d, s_dim=num_classes, time_embed_dim=32, hidden_dim=256, num_layers=3 + ).to(device) + value_net_ema = create_ema_model(value_net) + + # DFA fallback matrices for warmup + Bs_fallback = [torch.randn(d, num_classes, device=device) / np.sqrt(num_classes) + for _ in range(L)] + + block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd) + for block in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd) + head_opt = optim.AdamW( + list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=args.lr, weight_decay=args.wd + ) + value_opt = optim.Adam(value_net.parameters(), lr=args.lr_fb) + + all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)]) + + lam = args.lam + K_samples = args.K + sigma_bridge = args.sigma_bridge + ema_momentum = args.ema_momentum + term_grad_weight = args.term_grad_weight + + log = {'train_loss': [], 'train_acc': [], 'test_acc': [], 'value_loss': []} + + print(f" [CB] Warmup phase: {warmup_epochs} epochs (DFA fallback + value net training)") + + for epoch in range(1, args.epochs + 1): + model.train() + value_net.train() + total_loss, correct, total = 0, 0, 0 + total_vloss = 0 + + # Blend factor: 0 during warmup, linearly increases to 1 after warmup + if epoch <= warmup_epochs: + credit_blend = 0.0 + else: + credit_blend = min(1.0, (epoch - warmup_epochs) / max(1, warmup_epochs)) + + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + batch = x.size(0) + + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + true_loss = F.cross_entropy(logits, y, reduction='none').detach() + + hL_det = hiddens[-1].detach() + + # ---- Train value net (always) ---- + t_L = torch.ones(batch, device=device) + V_terminal = value_net(hL_det, t_L, s) + loss_term = ((V_terminal - true_loss) ** 2).mean() + + # Terminal gradient matching + loss_tgrad = torch.tensor(0.0, device=device) + if term_grad_weight > 0: + hL_req = hL_det.clone().requires_grad_(True) + V_at_L = value_net(hL_req, t_L, s) + grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0] + hL_req2 = hL_det.clone().requires_grad_(True) + logits_tgt = model.out_head(model.out_ln(hL_req2)) + ce_loss = F.cross_entropy(logits_tgt, y, reduction='sum') + a_L_exact = torch.autograd.grad(ce_loss, hL_req2, create_graph=False)[0].detach() + loss_tgrad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean() + + # Bridge consistency + loss_bridge = 0.0 + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + t_l_next = torch.full((batch,), (l + 1) / L, device=device) + V_l = value_net(h_l_det, t_l, s) + + with torch.no_grad(): + h_next_det = hiddens[l + 1].detach() + log_terms = [] + for k in range(K_samples): + noise = sigma_bridge * torch.randn_like(h_next_det) + V_next = value_net_ema(h_next_det + noise, t_l_next, s) + log_terms.append(-V_next / lam) + log_stack = torch.stack(log_terms, dim=-1) + V_target = -lam * (torch.logsumexp(log_stack, dim=-1) - np.log(K_samples)) + + loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean() + loss_bridge = loss_bridge / L + + value_loss = loss_term + loss_bridge + term_grad_weight * loss_tgrad + + value_opt.zero_grad() + value_loss.backward() + torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0) + value_opt.step() + update_ema(value_net, value_net_ema, ema_momentum) + total_vloss += value_loss.item() * batch + + # ---- Compute credits ---- + # Credit bridge credits + cb_credits = [] + for l in range(L): + h_l_det = hiddens[l].detach().requires_grad_(True) + t_l = torch.full((batch,), l / L, device=device) + V_l = value_net(h_l_det, t_l, s) + a_l = torch.autograd.grad(V_l.sum(), h_l_det, create_graph=False)[0] + cb_credits.append(a_l.detach()) + + # DFA fallback credits + dfa_credits = [(e_T @ Bs_fallback[l].T).detach() for l in range(L)] + + # Blend credits + credits = [] + for l in range(L): + if credit_blend >= 1.0: + a = cb_credits[l] + elif credit_blend <= 0.0: + a = dfa_credits[l] + else: + # Normalize both before blending + cb_rms = (cb_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + dfa_rms = (dfa_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a = credit_blend * (cb_credits[l] / cb_rms) + (1 - credit_blend) * (dfa_credits[l] / dfa_rms) + credits.append(a) + + # ---- Update output head ---- + logits_out = model.out_head(model.out_ln(hL_det)) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad() + loss_out.backward() + head_opt.step() + + # ---- Update blocks ---- + for l in range(L): + h_l = hiddens[l].detach() + a = credits[l] + rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_norm = a / rms + f_l = model.blocks[l](h_l) + local_loss = (f_l * a_norm).sum(dim=-1).mean() + block_opts[l].zero_grad() + local_loss.backward() + block_opts[l].step() + + # ---- Update embedding ---- + a_0 = credits[0] + rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_0_norm = a_0 / rms_0 + h0 = model.embed(x) + embed_loss = (h0 * a_0_norm).sum(dim=-1).mean() + embed_opt.zero_grad() + embed_loss.backward() + embed_opt.step() + + total_loss += loss_val.item() * batch + correct += (logits.argmax(1) == y).sum().item() + total += batch + + for sch in all_schedulers: + sch.step() + + train_loss = total_loss / total + train_acc = correct / total + test_acc = evaluate(model, test_loader, device) + vloss = total_vloss / total + log['train_loss'].append(train_loss) + log['train_acc'].append(train_acc) + log['test_acc'].append(test_acc) + log['value_loss'].append(vloss) + if epoch % 10 == 0 or epoch == 1: + phase = "warmup" if epoch <= warmup_epochs else f"blend={credit_blend:.2f}" + print(f" [CB] Epoch {epoch} ({phase}): loss={train_loss:.4f}, train={train_acc:.4f}, " + f"test={test_acc:.4f}, vloss={vloss:.6f}") + + return log, value_net, value_net_ema + + +# ============================================================================= +# Diagnostics +# ============================================================================= +def compute_diagnostics(model, method_name, test_loader, device, args, + value_net=None, state_predictor=None, dfa_Bs=None): + """Compute all diagnostic metrics for a trained model.""" + model.eval() + if value_net is not None: + value_net.eval() + if state_predictor is not None: + state_predictor.eval() + + d = model.d_hidden + L = model.num_blocks + num_classes = args.num_classes + + # Get one batch for diagnostics + for x, y in test_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + break + + batch = x.size(0) + + # Forward with hidden states, need grad for BP cosine + logits_bp, hiddens_bp = model(x, return_hidden=True) + for l in range(L + 1): + hiddens_bp[l].retain_grad() + loss_bp = F.cross_entropy(logits_bp, y) + loss_bp.backward() + bp_grads = {l: hiddens_bp[l].grad.detach().clone() for l in range(L + 1)} + + # Forward again without grad for clean hidden states + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + + results = { + 'bp_cosine': [], + 'perturbation_rho': [], + 'nudging': {'0.001': [], '0.003': [], '0.01': []}, + } + + for l in range(L): + h_l = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + + # Get credit + if method_name == 'bp': + a_l = bp_grads[l] + elif method_name == 'dfa': + a_l = (e_T @ dfa_Bs[l].T).detach() + elif method_name == 'state_bridge': + h_l_req = h_l.clone().requires_grad_(True) + pred_hL = state_predictor(h_l_req, t_l, s) + pred_logits = model.out_head(model.out_ln(pred_hL)) + pred_loss = F.cross_entropy(pred_logits, y, reduction='sum') + a_l = torch.autograd.grad(pred_loss, h_l_req, create_graph=False)[0].detach() + elif method_name == 'credit_bridge': + h_l_req = h_l.clone().requires_grad_(True) + V_l = value_net(h_l_req, t_l, s) + a_l = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0].detach() + else: + raise ValueError(f"Unknown method: {method_name}") + + # BP cosine + bp_cos = cosine_similarity_batch(a_l, bp_grads[l]) + results['bp_cosine'].append(bp_cos) + + # Forward function for perturbation and nudging + def make_fwd_fn(start_l): + def fwd_fn(h): + with torch.no_grad(): + curr = h + for i in range(start_l, L): + curr = curr + model.blocks[i](curr) + out = model.out_head(model.out_ln(curr)) + return F.cross_entropy(out, y, reduction='none') + return fwd_fn + + fwd_fn = make_fwd_fn(l) + rho = perturbation_correlation(h_l, a_l, fwd_fn, epsilon=1e-3, M=16) + results['perturbation_rho'].append(rho) + + for eta in [0.001, 0.003, 0.01]: + nud = nudging_test(h_l, a_l, fwd_fn, eta=eta) + results['nudging'][str(eta)].append(nud) + + return results + + +# ============================================================================= +# Main +# ============================================================================= +def run_experiment(args): + device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + os.makedirs(args.output_dir, exist_ok=True) + + all_results = {} + + for seed in args.seeds: + print(f"\n{'='*60}") + print(f"Seed {seed}") + print(f"{'='*60}") + + torch.manual_seed(seed) + np.random.seed(seed) + torch.cuda.manual_seed_all(seed) + + train_loader, test_loader, input_dim, num_classes = get_data(args.dataset, args.batch_size) + args.num_classes = num_classes + + seed_results = {} + + # ---- BP ---- + print("\n--- BP ---") + model_bp = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device) + init_bp = {n: p.clone().detach() for n, p in model_bp.named_parameters()} + bp_log = train_bp(model_bp, train_loader, test_loader, device, args) + bp_diag = compute_diagnostics(model_bp, 'bp', test_loader, device, args) + bp_drift = feature_drift(init_bp, {n: p.detach() for n, p in model_bp.named_parameters()}) + seed_results['bp'] = {'log': bp_log, 'diagnostics': bp_diag, 'drift': bp_drift} + print(f" Final test acc: {bp_log['test_acc'][-1]:.4f}") + + # ---- DFA ---- + print("\n--- DFA ---") + torch.manual_seed(seed) + np.random.seed(seed) + torch.cuda.manual_seed_all(seed) + model_dfa = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device) + init_dfa = {n: p.clone().detach() for n, p in model_dfa.named_parameters()} + dfa_log, dfa_Bs = train_dfa(model_dfa, train_loader, test_loader, device, args) + dfa_diag = compute_diagnostics(model_dfa, 'dfa', test_loader, device, args, dfa_Bs=dfa_Bs) + dfa_drift = feature_drift(init_dfa, {n: p.detach() for n, p in model_dfa.named_parameters()}) + seed_results['dfa'] = {'log': dfa_log, 'diagnostics': dfa_diag, 'drift': dfa_drift} + print(f" Final test acc: {dfa_log['test_acc'][-1]:.4f}") + + # ---- State Bridge ---- + print("\n--- State Bridge ---") + torch.manual_seed(seed) + np.random.seed(seed) + torch.cuda.manual_seed_all(seed) + model_sb = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device) + init_sb = {n: p.clone().detach() for n, p in model_sb.named_parameters()} + sb_log, state_pred = train_state_bridge(model_sb, train_loader, test_loader, device, args) + sb_diag = compute_diagnostics(model_sb, 'state_bridge', test_loader, device, args, + state_predictor=state_pred) + sb_drift = feature_drift(init_sb, {n: p.detach() for n, p in model_sb.named_parameters()}) + seed_results['state_bridge'] = {'log': sb_log, 'diagnostics': sb_diag, 'drift': sb_drift} + print(f" Final test acc: {sb_log['test_acc'][-1]:.4f}") + + # ---- Credit Bridge ---- + print("\n--- Credit Bridge ---") + torch.manual_seed(seed) + np.random.seed(seed) + torch.cuda.manual_seed_all(seed) + model_cb = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device) + init_cb = {n: p.clone().detach() for n, p in model_cb.named_parameters()} + cb_log, vnet, vnet_ema = train_credit_bridge(model_cb, train_loader, test_loader, device, args) + cb_diag = compute_diagnostics(model_cb, 'credit_bridge', test_loader, device, args, + value_net=vnet) + cb_drift = feature_drift(init_cb, {n: p.detach() for n, p in model_cb.named_parameters()}) + seed_results['credit_bridge'] = {'log': cb_log, 'diagnostics': cb_diag, 'drift': cb_drift} + print(f" Final test acc: {cb_log['test_acc'][-1]:.4f}") + + all_results[seed] = seed_results + + # Save + def serialize(obj): + if isinstance(obj, dict): + return {str(k): serialize(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [serialize(v) for v in obj] + elif isinstance(obj, (np.floating, np.integer)): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, torch.Tensor): + return obj.cpu().numpy().tolist() + return obj + + save_data = serialize(all_results) + save_data['config'] = serialize(vars(args)) + out_path = os.path.join(args.output_dir, f'results_{args.dataset}.json') + with open(out_path, 'w') as f: + json.dump(save_data, f, indent=2) + print(f"\nAll results saved to {out_path}") + return all_results + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--dataset', type=str, default='cifar10') + parser.add_argument('--d_hidden', type=int, default=512) + parser.add_argument('--num_blocks', type=int, default=12) + parser.add_argument('--batch_size', type=int, default=128) + parser.add_argument('--epochs', type=int, default=100) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--lr_fb', type=float, default=1e-3) + parser.add_argument('--wd', type=float, default=0.01) + parser.add_argument('--lam', type=float, default=0.1) + parser.add_argument('--K', type=int, default=4) + parser.add_argument('--sigma_bridge', type=float, default=0.05) + parser.add_argument('--ema_momentum', type=float, default=0.995) + parser.add_argument('--term_grad_weight', type=float, default=1.0) + parser.add_argument('--seeds', type=int, nargs='+', default=[42, 123, 456]) + parser.add_argument('--gpu', type=int, default=1) + parser.add_argument('--output_dir', type=str, default='results/cifar10') + args = parser.parse_args() + run_experiment(args) + + +if __name__ == '__main__': + main() |
