summaryrefslogtreecommitdiff
path: root/experiments/cifar_resmlp.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-03-23 18:21:26 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-03-23 18:21:26 -0500
commit6ed4fa50ddfa4c7957aaa909aaf72f0d7d317712 (patch)
treed7c63adcd19c4f5d46c8a937e5047fece55dea62 /experiments/cifar_resmlp.py
Initial implementation: all models, methods, toy and CIFAR experiments
Debug phase. Toy LQ experiments (3 seeds) complete with terminal gradient matching. Credit bridge matches state bridge on linear system (~0.94 cosine). CIFAR experiments in progress.
Diffstat (limited to 'experiments/cifar_resmlp.py')
-rw-r--r--experiments/cifar_resmlp.py775
1 files changed, 775 insertions, 0 deletions
diff --git a/experiments/cifar_resmlp.py b/experiments/cifar_resmlp.py
new file mode 100644
index 0000000..1582f6d
--- /dev/null
+++ b/experiments/cifar_resmlp.py
@@ -0,0 +1,775 @@
+"""
+Phase B: Deep Residual MLP on CIFAR-10.
+Compare BP, DFA, State Bridge, Credit Bridge.
+
+CRITICAL CONSTRAINT: No hidden BP anchor for non-BP methods.
+All block updates use detached hidden states and local surrogates.
+"""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+import copy
+import time
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.residual_mlp import ResidualMLP
+from models.value_net import ValueNet, create_ema_model, update_ema
+from models.state_bridge import StateBridgeNet
+from metrics.credit_metrics import (
+ cosine_similarity_batch, perturbation_correlation, nudging_test,
+ offline_bp_cosine, feature_drift
+)
+
+
+def get_data(dataset='cifar10', batch_size=128):
+ if dataset == 'cifar10':
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
+ testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
+ input_dim = 32 * 32 * 3
+ num_classes = 10
+ elif dataset == 'fashionmnist':
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(28, padding=2),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.2860,), (0.3530,)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.2860,), (0.3530,)),
+ ])
+ trainset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform_train)
+ testset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform_test)
+ input_dim = 28 * 28
+ num_classes = 10
+ else:
+ raise ValueError(f"Unknown dataset: {dataset}")
+
+ train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
+ test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
+ return train_loader, test_loader, input_dim, num_classes
+
+
+def evaluate(model, test_loader, device):
+ model.eval()
+ correct, total = 0, 0
+ with torch.no_grad():
+ for x, y in test_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ logits = model(x)
+ correct += (logits.argmax(1) == y).sum().item()
+ total += x.size(0)
+ return correct / total
+
+
+# =============================================================================
+# BP Baseline
+# =============================================================================
+def train_bp(model, train_loader, test_loader, device, args):
+ """Standard end-to-end backprop training."""
+ optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd)
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
+
+ log = {'train_loss': [], 'train_acc': [], 'test_acc': []}
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ total_loss, correct, total = 0, 0, 0
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ logits = model(x)
+ loss = F.cross_entropy(logits, y)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ total_loss += loss.item() * x.size(0)
+ correct += (logits.argmax(1) == y).sum().item()
+ total += x.size(0)
+
+ scheduler.step()
+ train_loss = total_loss / total
+ train_acc = correct / total
+ test_acc = evaluate(model, test_loader, device)
+ log['train_loss'].append(train_loss)
+ log['train_acc'].append(train_acc)
+ log['test_acc'].append(test_acc)
+ if epoch % 10 == 0 or epoch == 1:
+ print(f" [BP] Epoch {epoch}: loss={train_loss:.4f}, train={train_acc:.4f}, test={test_acc:.4f}")
+
+ return log
+
+
+# =============================================================================
+# DFA Baseline
+# =============================================================================
+def train_dfa(model, train_loader, test_loader, device, args):
+ """
+ DFA training with fixed random feedback matrices.
+ Each block updated with local surrogate: L_l = <F_l(h_l), sg[a_{l+1}^DFA]>.
+ Output head updated with exact CE gradient (h_L detached).
+ Embedding updated via DFA credit at h_0.
+ """
+ d = model.d_hidden
+ num_classes = args.num_classes
+ L = model.num_blocks
+
+ # Fixed random feedback matrices, one per block
+ Bs = [torch.randn(d, num_classes, device=device) / np.sqrt(num_classes) for _ in range(L)]
+
+ # Separate optimizers
+ block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd)
+ for block in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd)
+ head_opt = optim.AdamW(
+ list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=args.lr, weight_decay=args.wd
+ )
+
+ all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts]
+ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)])
+
+ log = {'train_loss': [], 'train_acc': [], 'test_acc': []}
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ total_loss, correct, total = 0, 0, 0
+
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batch = x.size(0)
+
+ # Forward pass (no grad for hidden states)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ # e_T = softmax(logits) - one_hot(y)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1 # (batch, num_classes)
+
+ # 1. Update output head: exact CE gradient, h_L detached
+ hL_det = hiddens[-1].detach()
+ logits_out = model.out_head(model.out_ln(hL_det))
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad()
+ loss_out.backward()
+ head_opt.step()
+
+ # 2. Update each block with DFA local surrogate
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ # DFA credit: a_{l+1} = B_l @ e_T^T -> (d, batch) -> transpose
+ a_dfa = (e_T @ Bs[l].T).detach() # (batch, d) = (batch, C) @ (C, d)
+ # Normalize
+ rms = (a_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_dfa_norm = a_dfa / rms
+ # Local surrogate
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a_dfa_norm).sum(dim=-1).mean()
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ block_opts[l].step()
+
+ # 3. Update embedding with DFA credit at h_0
+ a_0_dfa = (e_T @ Bs[0].T).detach()
+ rms_0 = (a_0_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_0_norm = a_0_dfa / rms_0
+ h0 = model.embed(x)
+ embed_loss = (h0 * a_0_norm).sum(dim=-1).mean()
+ embed_opt.zero_grad()
+ embed_loss.backward()
+ embed_opt.step()
+
+ total_loss += loss_val.item() * batch
+ correct += (logits.argmax(1) == y).sum().item()
+ total += batch
+
+ for s in all_schedulers:
+ s.step()
+
+ train_loss = total_loss / total
+ train_acc = correct / total
+ test_acc = evaluate(model, test_loader, device)
+ log['train_loss'].append(train_loss)
+ log['train_acc'].append(train_acc)
+ log['test_acc'].append(test_acc)
+ if epoch % 10 == 0 or epoch == 1:
+ print(f" [DFA] Epoch {epoch}: loss={train_loss:.4f}, train={train_acc:.4f}, test={test_acc:.4f}")
+
+ return log, Bs
+
+
+# =============================================================================
+# State Bridge
+# =============================================================================
+def train_state_bridge(model, train_loader, test_loader, device, args):
+ """
+ State Bridge: predict terminal h_L from (h_l, t_l, s), derive credit as
+ a_l = grad_{h_l} CE(W_out * LN(G_psi(h_l, t_l, s)), y).
+ """
+ d = model.d_hidden
+ num_classes = args.num_classes
+ L = model.num_blocks
+
+ state_pred = StateBridgeNet(
+ d_hidden=d, s_dim=num_classes, time_embed_dim=32, hidden_dim=256, num_layers=3
+ ).to(device)
+
+ block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd)
+ for block in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd)
+ head_opt = optim.AdamW(
+ list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=args.lr, weight_decay=args.wd
+ )
+ state_opt = optim.Adam(state_pred.parameters(), lr=args.lr_fb)
+
+ all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts]
+ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)])
+
+ log = {'train_loss': [], 'train_acc': [], 'test_acc': [], 'state_pred_error': []}
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ state_pred.train()
+ total_loss, correct, total = 0, 0, 0
+ total_se = 0
+
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batch = x.size(0)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+
+ hL_det = hiddens[-1].detach()
+
+ # Train state predictor: G_psi(h_l, t_l, s) -> h_L
+ # Predict the *residual* from h_l to h_L for numerical stability
+ state_loss = 0.0
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ pred_hL = state_pred(h_l_det, t_l, s)
+ # Target: h_L (use normalized MSE for stability)
+ target = hL_det
+ target_norm = target.norm(dim=-1, keepdim=True).clamp(min=1.0)
+ state_loss = state_loss + (((pred_hL - target) / target_norm) ** 2).sum(dim=-1).mean()
+ state_loss = state_loss / L
+ state_opt.zero_grad()
+ state_loss.backward()
+ state_opt.step()
+ total_se += state_loss.item() * batch
+
+ # Compute credits: a_l = grad_{h_l} CE(out_head(LN(G(h_l, t_l, s))), y)
+ credits = []
+ for l in range(L):
+ h_l_det = hiddens[l].detach().requires_grad_(True)
+ t_l = torch.full((batch,), l / L, device=device)
+ pred_hL = state_pred(h_l_det, t_l, s)
+ pred_logits = model.out_head(model.out_ln(pred_hL))
+ pred_loss = F.cross_entropy(pred_logits, y, reduction='sum')
+ a_l = torch.autograd.grad(pred_loss, h_l_det, create_graph=False)[0]
+ credits.append(a_l.detach())
+
+ # Update output head
+ logits_out = model.out_head(model.out_ln(hL_det))
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad()
+ loss_out.backward()
+ head_opt.step()
+
+ # Update blocks
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a = credits[l]
+ rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a / rms
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a_norm).sum(dim=-1).mean()
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ block_opts[l].step()
+
+ # Update embedding with credit at layer 0
+ a_0 = credits[0]
+ rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_0_norm = a_0 / rms_0
+ h0 = model.embed(x)
+ embed_loss = (h0 * a_0_norm).sum(dim=-1).mean()
+ embed_opt.zero_grad()
+ embed_loss.backward()
+ embed_opt.step()
+
+ total_loss += loss_val.item() * batch
+ correct += (logits.argmax(1) == y).sum().item()
+ total += batch
+
+ for sch in all_schedulers:
+ sch.step()
+
+ train_loss = total_loss / total
+ train_acc = correct / total
+ test_acc = evaluate(model, test_loader, device)
+ se = total_se / total
+ log['train_loss'].append(train_loss)
+ log['train_acc'].append(train_acc)
+ log['test_acc'].append(test_acc)
+ log['state_pred_error'].append(se)
+ if epoch % 10 == 0 or epoch == 1:
+ print(f" [SB] Epoch {epoch}: loss={train_loss:.4f}, train={train_acc:.4f}, "
+ f"test={test_acc:.4f}, state_err={se:.4f}")
+
+ return log, state_pred
+
+
+# =============================================================================
+# Credit Bridge
+# =============================================================================
+def train_credit_bridge(model, train_loader, test_loader, device, args):
+ """
+ Credit Bridge: learn V_phi(h_l, t_l, s) -> scalar value.
+ Credit: a_l = grad_{h_l} V_phi.
+ Training: terminal boundary + bridge consistency + terminal gradient matching.
+ The terminal gradient is local (output layer only), NOT hidden BP.
+
+ Uses a warmup phase: first warmup_epochs, only train value net + output head,
+ then start using credit bridge signals to update blocks.
+ During warmup, blocks get DFA-style updates as a fallback.
+ """
+ d = model.d_hidden
+ num_classes = args.num_classes
+ L = model.num_blocks
+ warmup_epochs = max(1, args.epochs // 5) # 20% warmup
+
+ value_net = ValueNet(
+ d_hidden=d, s_dim=num_classes, time_embed_dim=32, hidden_dim=256, num_layers=3
+ ).to(device)
+ value_net_ema = create_ema_model(value_net)
+
+ # DFA fallback matrices for warmup
+ Bs_fallback = [torch.randn(d, num_classes, device=device) / np.sqrt(num_classes)
+ for _ in range(L)]
+
+ block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd)
+ for block in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd)
+ head_opt = optim.AdamW(
+ list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=args.lr, weight_decay=args.wd
+ )
+ value_opt = optim.Adam(value_net.parameters(), lr=args.lr_fb)
+
+ all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts]
+ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)])
+
+ lam = args.lam
+ K_samples = args.K
+ sigma_bridge = args.sigma_bridge
+ ema_momentum = args.ema_momentum
+ term_grad_weight = args.term_grad_weight
+
+ log = {'train_loss': [], 'train_acc': [], 'test_acc': [], 'value_loss': []}
+
+ print(f" [CB] Warmup phase: {warmup_epochs} epochs (DFA fallback + value net training)")
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ value_net.train()
+ total_loss, correct, total = 0, 0, 0
+ total_vloss = 0
+
+ # Blend factor: 0 during warmup, linearly increases to 1 after warmup
+ if epoch <= warmup_epochs:
+ credit_blend = 0.0
+ else:
+ credit_blend = min(1.0, (epoch - warmup_epochs) / max(1, warmup_epochs))
+
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batch = x.size(0)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+ true_loss = F.cross_entropy(logits, y, reduction='none').detach()
+
+ hL_det = hiddens[-1].detach()
+
+ # ---- Train value net (always) ----
+ t_L = torch.ones(batch, device=device)
+ V_terminal = value_net(hL_det, t_L, s)
+ loss_term = ((V_terminal - true_loss) ** 2).mean()
+
+ # Terminal gradient matching
+ loss_tgrad = torch.tensor(0.0, device=device)
+ if term_grad_weight > 0:
+ hL_req = hL_det.clone().requires_grad_(True)
+ V_at_L = value_net(hL_req, t_L, s)
+ grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0]
+ hL_req2 = hL_det.clone().requires_grad_(True)
+ logits_tgt = model.out_head(model.out_ln(hL_req2))
+ ce_loss = F.cross_entropy(logits_tgt, y, reduction='sum')
+ a_L_exact = torch.autograd.grad(ce_loss, hL_req2, create_graph=False)[0].detach()
+ loss_tgrad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean()
+
+ # Bridge consistency
+ loss_bridge = 0.0
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ t_l_next = torch.full((batch,), (l + 1) / L, device=device)
+ V_l = value_net(h_l_det, t_l, s)
+
+ with torch.no_grad():
+ h_next_det = hiddens[l + 1].detach()
+ log_terms = []
+ for k in range(K_samples):
+ noise = sigma_bridge * torch.randn_like(h_next_det)
+ V_next = value_net_ema(h_next_det + noise, t_l_next, s)
+ log_terms.append(-V_next / lam)
+ log_stack = torch.stack(log_terms, dim=-1)
+ V_target = -lam * (torch.logsumexp(log_stack, dim=-1) - np.log(K_samples))
+
+ loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean()
+ loss_bridge = loss_bridge / L
+
+ value_loss = loss_term + loss_bridge + term_grad_weight * loss_tgrad
+
+ value_opt.zero_grad()
+ value_loss.backward()
+ torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0)
+ value_opt.step()
+ update_ema(value_net, value_net_ema, ema_momentum)
+ total_vloss += value_loss.item() * batch
+
+ # ---- Compute credits ----
+ # Credit bridge credits
+ cb_credits = []
+ for l in range(L):
+ h_l_det = hiddens[l].detach().requires_grad_(True)
+ t_l = torch.full((batch,), l / L, device=device)
+ V_l = value_net(h_l_det, t_l, s)
+ a_l = torch.autograd.grad(V_l.sum(), h_l_det, create_graph=False)[0]
+ cb_credits.append(a_l.detach())
+
+ # DFA fallback credits
+ dfa_credits = [(e_T @ Bs_fallback[l].T).detach() for l in range(L)]
+
+ # Blend credits
+ credits = []
+ for l in range(L):
+ if credit_blend >= 1.0:
+ a = cb_credits[l]
+ elif credit_blend <= 0.0:
+ a = dfa_credits[l]
+ else:
+ # Normalize both before blending
+ cb_rms = (cb_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ dfa_rms = (dfa_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a = credit_blend * (cb_credits[l] / cb_rms) + (1 - credit_blend) * (dfa_credits[l] / dfa_rms)
+ credits.append(a)
+
+ # ---- Update output head ----
+ logits_out = model.out_head(model.out_ln(hL_det))
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad()
+ loss_out.backward()
+ head_opt.step()
+
+ # ---- Update blocks ----
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a = credits[l]
+ rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a / rms
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a_norm).sum(dim=-1).mean()
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ block_opts[l].step()
+
+ # ---- Update embedding ----
+ a_0 = credits[0]
+ rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_0_norm = a_0 / rms_0
+ h0 = model.embed(x)
+ embed_loss = (h0 * a_0_norm).sum(dim=-1).mean()
+ embed_opt.zero_grad()
+ embed_loss.backward()
+ embed_opt.step()
+
+ total_loss += loss_val.item() * batch
+ correct += (logits.argmax(1) == y).sum().item()
+ total += batch
+
+ for sch in all_schedulers:
+ sch.step()
+
+ train_loss = total_loss / total
+ train_acc = correct / total
+ test_acc = evaluate(model, test_loader, device)
+ vloss = total_vloss / total
+ log['train_loss'].append(train_loss)
+ log['train_acc'].append(train_acc)
+ log['test_acc'].append(test_acc)
+ log['value_loss'].append(vloss)
+ if epoch % 10 == 0 or epoch == 1:
+ phase = "warmup" if epoch <= warmup_epochs else f"blend={credit_blend:.2f}"
+ print(f" [CB] Epoch {epoch} ({phase}): loss={train_loss:.4f}, train={train_acc:.4f}, "
+ f"test={test_acc:.4f}, vloss={vloss:.6f}")
+
+ return log, value_net, value_net_ema
+
+
+# =============================================================================
+# Diagnostics
+# =============================================================================
+def compute_diagnostics(model, method_name, test_loader, device, args,
+ value_net=None, state_predictor=None, dfa_Bs=None):
+ """Compute all diagnostic metrics for a trained model."""
+ model.eval()
+ if value_net is not None:
+ value_net.eval()
+ if state_predictor is not None:
+ state_predictor.eval()
+
+ d = model.d_hidden
+ L = model.num_blocks
+ num_classes = args.num_classes
+
+ # Get one batch for diagnostics
+ for x, y in test_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ break
+
+ batch = x.size(0)
+
+ # Forward with hidden states, need grad for BP cosine
+ logits_bp, hiddens_bp = model(x, return_hidden=True)
+ for l in range(L + 1):
+ hiddens_bp[l].retain_grad()
+ loss_bp = F.cross_entropy(logits_bp, y)
+ loss_bp.backward()
+ bp_grads = {l: hiddens_bp[l].grad.detach().clone() for l in range(L + 1)}
+
+ # Forward again without grad for clean hidden states
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+
+ results = {
+ 'bp_cosine': [],
+ 'perturbation_rho': [],
+ 'nudging': {'0.001': [], '0.003': [], '0.01': []},
+ }
+
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+
+ # Get credit
+ if method_name == 'bp':
+ a_l = bp_grads[l]
+ elif method_name == 'dfa':
+ a_l = (e_T @ dfa_Bs[l].T).detach()
+ elif method_name == 'state_bridge':
+ h_l_req = h_l.clone().requires_grad_(True)
+ pred_hL = state_predictor(h_l_req, t_l, s)
+ pred_logits = model.out_head(model.out_ln(pred_hL))
+ pred_loss = F.cross_entropy(pred_logits, y, reduction='sum')
+ a_l = torch.autograd.grad(pred_loss, h_l_req, create_graph=False)[0].detach()
+ elif method_name == 'credit_bridge':
+ h_l_req = h_l.clone().requires_grad_(True)
+ V_l = value_net(h_l_req, t_l, s)
+ a_l = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0].detach()
+ else:
+ raise ValueError(f"Unknown method: {method_name}")
+
+ # BP cosine
+ bp_cos = cosine_similarity_batch(a_l, bp_grads[l])
+ results['bp_cosine'].append(bp_cos)
+
+ # Forward function for perturbation and nudging
+ def make_fwd_fn(start_l):
+ def fwd_fn(h):
+ with torch.no_grad():
+ curr = h
+ for i in range(start_l, L):
+ curr = curr + model.blocks[i](curr)
+ out = model.out_head(model.out_ln(curr))
+ return F.cross_entropy(out, y, reduction='none')
+ return fwd_fn
+
+ fwd_fn = make_fwd_fn(l)
+ rho = perturbation_correlation(h_l, a_l, fwd_fn, epsilon=1e-3, M=16)
+ results['perturbation_rho'].append(rho)
+
+ for eta in [0.001, 0.003, 0.01]:
+ nud = nudging_test(h_l, a_l, fwd_fn, eta=eta)
+ results['nudging'][str(eta)].append(nud)
+
+ return results
+
+
+# =============================================================================
+# Main
+# =============================================================================
+def run_experiment(args):
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ print(f"Using device: {device}")
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ all_results = {}
+
+ for seed in args.seeds:
+ print(f"\n{'='*60}")
+ print(f"Seed {seed}")
+ print(f"{'='*60}")
+
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+ train_loader, test_loader, input_dim, num_classes = get_data(args.dataset, args.batch_size)
+ args.num_classes = num_classes
+
+ seed_results = {}
+
+ # ---- BP ----
+ print("\n--- BP ---")
+ model_bp = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device)
+ init_bp = {n: p.clone().detach() for n, p in model_bp.named_parameters()}
+ bp_log = train_bp(model_bp, train_loader, test_loader, device, args)
+ bp_diag = compute_diagnostics(model_bp, 'bp', test_loader, device, args)
+ bp_drift = feature_drift(init_bp, {n: p.detach() for n, p in model_bp.named_parameters()})
+ seed_results['bp'] = {'log': bp_log, 'diagnostics': bp_diag, 'drift': bp_drift}
+ print(f" Final test acc: {bp_log['test_acc'][-1]:.4f}")
+
+ # ---- DFA ----
+ print("\n--- DFA ---")
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ model_dfa = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device)
+ init_dfa = {n: p.clone().detach() for n, p in model_dfa.named_parameters()}
+ dfa_log, dfa_Bs = train_dfa(model_dfa, train_loader, test_loader, device, args)
+ dfa_diag = compute_diagnostics(model_dfa, 'dfa', test_loader, device, args, dfa_Bs=dfa_Bs)
+ dfa_drift = feature_drift(init_dfa, {n: p.detach() for n, p in model_dfa.named_parameters()})
+ seed_results['dfa'] = {'log': dfa_log, 'diagnostics': dfa_diag, 'drift': dfa_drift}
+ print(f" Final test acc: {dfa_log['test_acc'][-1]:.4f}")
+
+ # ---- State Bridge ----
+ print("\n--- State Bridge ---")
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ model_sb = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device)
+ init_sb = {n: p.clone().detach() for n, p in model_sb.named_parameters()}
+ sb_log, state_pred = train_state_bridge(model_sb, train_loader, test_loader, device, args)
+ sb_diag = compute_diagnostics(model_sb, 'state_bridge', test_loader, device, args,
+ state_predictor=state_pred)
+ sb_drift = feature_drift(init_sb, {n: p.detach() for n, p in model_sb.named_parameters()})
+ seed_results['state_bridge'] = {'log': sb_log, 'diagnostics': sb_diag, 'drift': sb_drift}
+ print(f" Final test acc: {sb_log['test_acc'][-1]:.4f}")
+
+ # ---- Credit Bridge ----
+ print("\n--- Credit Bridge ---")
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ model_cb = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device)
+ init_cb = {n: p.clone().detach() for n, p in model_cb.named_parameters()}
+ cb_log, vnet, vnet_ema = train_credit_bridge(model_cb, train_loader, test_loader, device, args)
+ cb_diag = compute_diagnostics(model_cb, 'credit_bridge', test_loader, device, args,
+ value_net=vnet)
+ cb_drift = feature_drift(init_cb, {n: p.detach() for n, p in model_cb.named_parameters()})
+ seed_results['credit_bridge'] = {'log': cb_log, 'diagnostics': cb_diag, 'drift': cb_drift}
+ print(f" Final test acc: {cb_log['test_acc'][-1]:.4f}")
+
+ all_results[seed] = seed_results
+
+ # Save
+ def serialize(obj):
+ if isinstance(obj, dict):
+ return {str(k): serialize(v) for k, v in obj.items()}
+ elif isinstance(obj, list):
+ return [serialize(v) for v in obj]
+ elif isinstance(obj, (np.floating, np.integer)):
+ return float(obj)
+ elif isinstance(obj, np.ndarray):
+ return obj.tolist()
+ elif isinstance(obj, torch.Tensor):
+ return obj.cpu().numpy().tolist()
+ return obj
+
+ save_data = serialize(all_results)
+ save_data['config'] = serialize(vars(args))
+ out_path = os.path.join(args.output_dir, f'results_{args.dataset}.json')
+ with open(out_path, 'w') as f:
+ json.dump(save_data, f, indent=2)
+ print(f"\nAll results saved to {out_path}")
+ return all_results
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--dataset', type=str, default='cifar10')
+ parser.add_argument('--d_hidden', type=int, default=512)
+ parser.add_argument('--num_blocks', type=int, default=12)
+ parser.add_argument('--batch_size', type=int, default=128)
+ parser.add_argument('--epochs', type=int, default=100)
+ parser.add_argument('--lr', type=float, default=1e-3)
+ parser.add_argument('--lr_fb', type=float, default=1e-3)
+ parser.add_argument('--wd', type=float, default=0.01)
+ parser.add_argument('--lam', type=float, default=0.1)
+ parser.add_argument('--K', type=int, default=4)
+ parser.add_argument('--sigma_bridge', type=float, default=0.05)
+ parser.add_argument('--ema_momentum', type=float, default=0.995)
+ parser.add_argument('--term_grad_weight', type=float, default=1.0)
+ parser.add_argument('--seeds', type=int, nargs='+', default=[42, 123, 456])
+ parser.add_argument('--gpu', type=int, default=1)
+ parser.add_argument('--output_dir', type=str, default='results/cifar10')
+ args = parser.parse_args()
+ run_experiment(args)
+
+
+if __name__ == '__main__':
+ main()