""" Quick test: Credit Bridge on CIFAR-10 with s=deltaL conditioning. deltaL = grad_{h_L} CE(out_head(h_L), y) -- output-layer-local, dim=d_hidden. This gives 512-dim conditioning instead of 10-dim e_T. """ import os import sys import json import argparse import time import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader import torchvision import torchvision.transforms as transforms sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.residual_mlp import ResidualMLP from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema from metrics.credit_metrics import ( cosine_similarity_batch, perturbation_correlation, nudging_test ) class ValueNetLargeS(nn.Module): """Value net with larger s_dim (for deltaL conditioning).""" def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3): super().__init__() self.ln = nn.LayerNorm(d_hidden) self.time_embed = SinusoidalTimeEmbed(time_embed_dim) # Compress s to a fixed dim to keep value net manageable self.s_compress = nn.Linear(s_dim, 64) input_dim = d_hidden + time_embed_dim + 64 layers = [] for i in range(num_layers): in_d = input_dim if i == 0 else hidden_dim layers.append(nn.Linear(in_d, hidden_dim)) layers.append(nn.GELU()) layers.append(nn.Linear(hidden_dim, 1)) self.net = nn.Sequential(*layers) def forward(self, h, t, s): h_normed = self.ln(h) t_emb = self.time_embed(t) s_compressed = self.s_compress(s) inp = torch.cat([h_normed, t_emb, s_compressed], dim=-1) return self.net(inp).squeeze(-1) def get_cifar10(batch_size=128): transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) return train_loader, test_loader def evaluate(model, test_loader, device): model.eval() correct, total = 0, 0 with torch.no_grad(): for x, y in test_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) logits = model(x) correct += (logits.argmax(1) == y).sum().item() total += x.size(0) return correct / total def compute_deltaL(model, hL_det, y): """Compute delta_L = grad_{h_L} CE(out_head(out_ln(h_L)), y). Output-layer-local.""" hL_req = hL_det.clone().requires_grad_(True) logits_local = model.out_head(model.out_ln(hL_req)) loss_local = F.cross_entropy(logits_local, y, reduction='sum') delta_L = torch.autograd.grad(loss_local, hL_req, create_graph=False)[0].detach() return delta_L def train_cb_deltaL(model, train_loader, test_loader, device, args): """Credit bridge with s=deltaL conditioning.""" d = model.d_hidden L = model.num_blocks C = 10 warmup_epochs = max(1, args.epochs // 5) value_net = ValueNetLargeS(d_hidden=d, s_dim=d, time_embed_dim=32, hidden_dim=256, num_layers=3).to(device) value_net_ema = create_ema_model(value_net) Bs_fallback = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd) for block in model.blocks] embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd) head_opt = optim.AdamW( list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=args.lr, weight_decay=args.wd ) value_opt = optim.Adam(value_net.parameters(), lr=args.lr_fb) all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs), optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)]) log = {'train_loss': [], 'train_acc': [], 'test_acc': [], 'value_loss': []} for epoch in range(1, args.epochs + 1): model.train() value_net.train() total_loss, correct, total = 0, 0, 0 total_vloss = 0 if epoch <= warmup_epochs: credit_blend = 0.0 else: credit_blend = min(1.0, (epoch - warmup_epochs) / max(1, warmup_epochs)) for x, y in train_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) loss_val = F.cross_entropy(logits, y) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 true_loss = F.cross_entropy(logits, y, reduction='none').detach() hL_det = hiddens[-1].detach() # Compute s = deltaL (output-layer-local gradient) s = compute_deltaL(model, hL_det, y) # Train value net t_L = torch.ones(batch, device=device) V_terminal = value_net(hL_det, t_L, s) loss_term = ((V_terminal - true_loss) ** 2).mean() # Terminal gradient matching loss_tgrad = torch.tensor(0.0, device=device) if args.term_grad_weight > 0: hL_req = hL_det.clone().requires_grad_(True) V_at_L = value_net(hL_req, t_L, s) grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0] # a_L_exact is just s (deltaL) itself a_L_exact = s loss_tgrad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean() # Bridge consistency loss_bridge = 0.0 for l in range(L): h_l_det = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) t_l_next = torch.full((batch,), (l + 1) / L, device=device) V_l = value_net(h_l_det, t_l, s) with torch.no_grad(): h_next_det = hiddens[l + 1].detach() log_terms = [] for k in range(args.K): noise = args.sigma_bridge * torch.randn_like(h_next_det) V_next = value_net_ema(h_next_det + noise, t_l_next, s) log_terms.append(-V_next / args.lam) log_stack = torch.stack(log_terms, dim=-1) V_target = -args.lam * (torch.logsumexp(log_stack, dim=-1) - np.log(args.K)) loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean() loss_bridge = loss_bridge / L value_loss = loss_term + loss_bridge + args.term_grad_weight * loss_tgrad value_opt.zero_grad() value_loss.backward() torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0) value_opt.step() update_ema(value_net, value_net_ema, args.ema_momentum) total_vloss += value_loss.item() * batch # Compute credits cb_credits = [] for l in range(L): h_l_det = hiddens[l].detach().requires_grad_(True) t_l = torch.full((batch,), l / L, device=device) V_l = value_net(h_l_det, t_l, s) a_l = torch.autograd.grad(V_l.sum(), h_l_det, create_graph=False)[0] cb_credits.append(a_l.detach()) dfa_credits = [(e_T @ Bs_fallback[l].T).detach() for l in range(L)] credits = [] for l in range(L): if credit_blend >= 1.0: a = cb_credits[l] elif credit_blend <= 0.0: a = dfa_credits[l] else: cb_rms = (cb_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 dfa_rms = (dfa_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a = credit_blend * (cb_credits[l] / cb_rms) + (1 - credit_blend) * (dfa_credits[l] / dfa_rms) credits.append(a) # Update output head logits_out = model.out_head(model.out_ln(hL_det)) loss_out = F.cross_entropy(logits_out, y) head_opt.zero_grad() loss_out.backward() head_opt.step() # Update blocks for l in range(L): h_l = hiddens[l].detach() a = credits[l] rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a_norm = a / rms f_l = model.blocks[l](h_l) local_loss = (f_l * a_norm).sum(dim=-1).mean() block_opts[l].zero_grad() local_loss.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) block_opts[l].step() # Update embedding a_0 = credits[0] rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a_0_norm = a_0 / rms_0 h0 = model.embed(x) embed_loss = (h0 * a_0_norm).sum(dim=-1).mean() embed_opt.zero_grad() embed_loss.backward() embed_opt.step() total_loss += loss_val.item() * batch correct += (logits.argmax(1) == y).sum().item() total += batch for sch in all_schedulers: sch.step() log['train_loss'].append(total_loss / total) log['train_acc'].append(correct / total) log['test_acc'].append(evaluate(model, test_loader, device)) log['value_loss'].append(total_vloss / total) if epoch % 10 == 0 or epoch == 1: phase = "warmup" if epoch <= warmup_epochs else f"blend={credit_blend:.2f}" print(f" [CB-deltaL] Ep {epoch} ({phase}): loss={log['train_loss'][-1]:.4f} " f"train={log['train_acc'][-1]:.4f} test={log['test_acc'][-1]:.4f} " f"vloss={log['value_loss'][-1]:.6f}") return log, value_net def compute_diagnostics(model, value_net, test_loader, device, args): model.eval() value_net.eval() d = model.d_hidden L = model.num_blocks for x, y in test_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) break batch = x.size(0) # BP gradients logits_bp, hiddens_bp = model(x, return_hidden=True) for l in range(L + 1): hiddens_bp[l].retain_grad() loss_bp = F.cross_entropy(logits_bp, y) loss_bp.backward() bp_grads = {l: hiddens_bp[l].grad.detach().clone() for l in range(L + 1)} with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) hL_det = hiddens[-1].detach() s = compute_deltaL(model, hL_det, y) results = {'bp_cosine': [], 'perturbation_rho': [], 'nudging': {'0.01': []}} for l in range(L): h_l = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) h_l_req = h_l.clone().requires_grad_(True) V_l = value_net(h_l_req, t_l, s) a_l = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0].detach() bp_cos = cosine_similarity_batch(a_l, bp_grads[l]) results['bp_cosine'].append(bp_cos) def make_fwd_fn(start_l): def fwd_fn(h): with torch.no_grad(): curr = h for i in range(start_l, L): curr = curr + model.blocks[i](curr) out = model.out_head(model.out_ln(curr)) return F.cross_entropy(out, y, reduction='none') return fwd_fn fwd_fn = make_fwd_fn(l) rho = perturbation_correlation(h_l, a_l, fwd_fn, epsilon=1e-3, M=16) results['perturbation_rho'].append(rho) nud = nudging_test(h_l, a_l, fwd_fn, eta=0.01) results['nudging']['0.01'].append(nud) return results def main(): parser = argparse.ArgumentParser() parser.add_argument('--d_hidden', type=int, default=512) parser.add_argument('--num_blocks', type=int, default=4) parser.add_argument('--batch_size', type=int, default=128) parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--lr_fb', type=float, default=1e-3) parser.add_argument('--wd', type=float, default=0.01) parser.add_argument('--lam', type=float, default=0.1) parser.add_argument('--K', type=int, default=4) parser.add_argument('--sigma_bridge', type=float, default=0.05) parser.add_argument('--ema_momentum', type=float, default=0.995) parser.add_argument('--term_grad_weight', type=float, default=1.0) parser.add_argument('--seed', type=int, default=42) parser.add_argument('--gpu', type=int, default=1) parser.add_argument('--output_dir', type=str, default='results/cifar_deltaL') args = parser.parse_args() device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}") os.makedirs(args.output_dir, exist_ok=True) torch.manual_seed(args.seed) np.random.seed(args.seed) torch.cuda.manual_seed_all(args.seed) train_loader, test_loader = get_cifar10(args.batch_size) input_dim = 32 * 32 * 3 model = ResidualMLP(input_dim, args.d_hidden, 10, args.num_blocks).to(device) print(f"Model: d={args.d_hidden}, L={args.num_blocks}") print(f"Conditioning: s=deltaL (dim={args.d_hidden})") t0 = time.time() log, vnet = train_cb_deltaL(model, train_loader, test_loader, device, args) elapsed = time.time() - t0 diag = compute_diagnostics(model, vnet, test_loader, device, args) mean_gamma = np.mean(diag['bp_cosine']) mean_rho = np.mean(diag['perturbation_rho']) mean_nudge = np.mean(diag['nudging']['0.01']) print(f"\nDone in {elapsed:.0f}s") print(f"Test acc: {log['test_acc'][-1]:.4f}") print(f"Mean Gamma: {mean_gamma:.4f}") print(f"Mean rho: {mean_rho:.4f}") print(f"Mean nudge: {mean_nudge:.6f}") print(f"Gamma per layer: {[round(g, 4) for g in diag['bp_cosine']]}") print(f"rho per layer: {[round(r, 4) for r in diag['perturbation_rho']]}") result = { 'test_acc': log['test_acc'][-1], 'mean_gamma': float(mean_gamma), 'mean_rho': float(mean_rho), 'mean_nudge': float(mean_nudge), 'gamma_per_layer': [float(g) for g in diag['bp_cosine']], 'rho_per_layer': [float(r) for r in diag['perturbation_rho']], 'log': log, } out_path = os.path.join(args.output_dir, f'cb_deltaL_d{args.d_hidden}_L{args.num_blocks}_s{args.seed}.json') with open(out_path, 'w') as f: json.dump(result, f, indent=2) print(f"Saved to {out_path}") if __name__ == '__main__': main()