summaryrefslogtreecommitdiff
path: root/experiments/cifar_deltaL_test.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-03-24 01:44:34 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-03-24 01:44:34 -0500
commitc09ae4244033a7a2703f0c36279d598ca869a95f (patch)
treeac09c1dc29d228865df5796b2a842ca0a42add88 /experiments/cifar_deltaL_test.py
parent8f786597d1007f0ef6012f53c22958d9c4e9b81a (diff)
Add CIFAR deltaL test (failed) and pivot design memo
- CIFAR deltaL: s=grad_hL CE (dim=512) -> acc=17.2%, Gamma≈0 Confirms scalar value field has dimensionality bottleneck on CIFAR - Pivot memo: direct vector credit field a_phi(h,t,s) -> R^d Trained with perturbation-based target, avoids curvature problem Still satisfies no hidden BP anchor constraint Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments/cifar_deltaL_test.py')
-rw-r--r--experiments/cifar_deltaL_test.py393
1 files changed, 393 insertions, 0 deletions
diff --git a/experiments/cifar_deltaL_test.py b/experiments/cifar_deltaL_test.py
new file mode 100644
index 0000000..c085489
--- /dev/null
+++ b/experiments/cifar_deltaL_test.py
@@ -0,0 +1,393 @@
+"""
+Quick test: Credit Bridge on CIFAR-10 with s=deltaL conditioning.
+deltaL = grad_{h_L} CE(out_head(h_L), y) -- output-layer-local, dim=d_hidden.
+This gives 512-dim conditioning instead of 10-dim e_T.
+"""
+import os
+import sys
+import json
+import argparse
+import time
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.residual_mlp import ResidualMLP
+from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema
+from metrics.credit_metrics import (
+ cosine_similarity_batch, perturbation_correlation, nudging_test
+)
+
+
+class ValueNetLargeS(nn.Module):
+ """Value net with larger s_dim (for deltaL conditioning)."""
+ def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.time_embed = SinusoidalTimeEmbed(time_embed_dim)
+ # Compress s to a fixed dim to keep value net manageable
+ self.s_compress = nn.Linear(s_dim, 64)
+ input_dim = d_hidden + time_embed_dim + 64
+ layers = []
+ for i in range(num_layers):
+ in_d = input_dim if i == 0 else hidden_dim
+ layers.append(nn.Linear(in_d, hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, 1))
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, h, t, s):
+ h_normed = self.ln(h)
+ t_emb = self.time_embed(t)
+ s_compressed = self.s_compress(s)
+ inp = torch.cat([h_normed, t_emb, s_compressed], dim=-1)
+ return self.net(inp).squeeze(-1)
+
+
+def get_cifar10(batch_size=128):
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
+ testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
+ train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
+ test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
+ return train_loader, test_loader
+
+
+def evaluate(model, test_loader, device):
+ model.eval()
+ correct, total = 0, 0
+ with torch.no_grad():
+ for x, y in test_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ logits = model(x)
+ correct += (logits.argmax(1) == y).sum().item()
+ total += x.size(0)
+ return correct / total
+
+
+def compute_deltaL(model, hL_det, y):
+ """Compute delta_L = grad_{h_L} CE(out_head(out_ln(h_L)), y). Output-layer-local."""
+ hL_req = hL_det.clone().requires_grad_(True)
+ logits_local = model.out_head(model.out_ln(hL_req))
+ loss_local = F.cross_entropy(logits_local, y, reduction='sum')
+ delta_L = torch.autograd.grad(loss_local, hL_req, create_graph=False)[0].detach()
+ return delta_L
+
+
+def train_cb_deltaL(model, train_loader, test_loader, device, args):
+ """Credit bridge with s=deltaL conditioning."""
+ d = model.d_hidden
+ L = model.num_blocks
+ C = 10
+ warmup_epochs = max(1, args.epochs // 5)
+
+ value_net = ValueNetLargeS(d_hidden=d, s_dim=d, time_embed_dim=32,
+ hidden_dim=256, num_layers=3).to(device)
+ value_net_ema = create_ema_model(value_net)
+
+ Bs_fallback = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)]
+
+ block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd)
+ for block in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd)
+ head_opt = optim.AdamW(
+ list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=args.lr, weight_decay=args.wd
+ )
+ value_opt = optim.Adam(value_net.parameters(), lr=args.lr_fb)
+
+ all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts]
+ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)])
+
+ log = {'train_loss': [], 'train_acc': [], 'test_acc': [], 'value_loss': []}
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ value_net.train()
+ total_loss, correct, total = 0, 0, 0
+ total_vloss = 0
+
+ if epoch <= warmup_epochs:
+ credit_blend = 0.0
+ else:
+ credit_blend = min(1.0, (epoch - warmup_epochs) / max(1, warmup_epochs))
+
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batch = x.size(0)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ true_loss = F.cross_entropy(logits, y, reduction='none').detach()
+
+ hL_det = hiddens[-1].detach()
+
+ # Compute s = deltaL (output-layer-local gradient)
+ s = compute_deltaL(model, hL_det, y)
+
+ # Train value net
+ t_L = torch.ones(batch, device=device)
+ V_terminal = value_net(hL_det, t_L, s)
+ loss_term = ((V_terminal - true_loss) ** 2).mean()
+
+ # Terminal gradient matching
+ loss_tgrad = torch.tensor(0.0, device=device)
+ if args.term_grad_weight > 0:
+ hL_req = hL_det.clone().requires_grad_(True)
+ V_at_L = value_net(hL_req, t_L, s)
+ grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0]
+ # a_L_exact is just s (deltaL) itself
+ a_L_exact = s
+ loss_tgrad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean()
+
+ # Bridge consistency
+ loss_bridge = 0.0
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ t_l_next = torch.full((batch,), (l + 1) / L, device=device)
+ V_l = value_net(h_l_det, t_l, s)
+ with torch.no_grad():
+ h_next_det = hiddens[l + 1].detach()
+ log_terms = []
+ for k in range(args.K):
+ noise = args.sigma_bridge * torch.randn_like(h_next_det)
+ V_next = value_net_ema(h_next_det + noise, t_l_next, s)
+ log_terms.append(-V_next / args.lam)
+ log_stack = torch.stack(log_terms, dim=-1)
+ V_target = -args.lam * (torch.logsumexp(log_stack, dim=-1) - np.log(args.K))
+ loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean()
+ loss_bridge = loss_bridge / L
+
+ value_loss = loss_term + loss_bridge + args.term_grad_weight * loss_tgrad
+ value_opt.zero_grad()
+ value_loss.backward()
+ torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0)
+ value_opt.step()
+ update_ema(value_net, value_net_ema, args.ema_momentum)
+ total_vloss += value_loss.item() * batch
+
+ # Compute credits
+ cb_credits = []
+ for l in range(L):
+ h_l_det = hiddens[l].detach().requires_grad_(True)
+ t_l = torch.full((batch,), l / L, device=device)
+ V_l = value_net(h_l_det, t_l, s)
+ a_l = torch.autograd.grad(V_l.sum(), h_l_det, create_graph=False)[0]
+ cb_credits.append(a_l.detach())
+
+ dfa_credits = [(e_T @ Bs_fallback[l].T).detach() for l in range(L)]
+
+ credits = []
+ for l in range(L):
+ if credit_blend >= 1.0:
+ a = cb_credits[l]
+ elif credit_blend <= 0.0:
+ a = dfa_credits[l]
+ else:
+ cb_rms = (cb_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ dfa_rms = (dfa_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a = credit_blend * (cb_credits[l] / cb_rms) + (1 - credit_blend) * (dfa_credits[l] / dfa_rms)
+ credits.append(a)
+
+ # Update output head
+ logits_out = model.out_head(model.out_ln(hL_det))
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad()
+ loss_out.backward()
+ head_opt.step()
+
+ # Update blocks
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a = credits[l]
+ rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a / rms
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a_norm).sum(dim=-1).mean()
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+
+ # Update embedding
+ a_0 = credits[0]
+ rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_0_norm = a_0 / rms_0
+ h0 = model.embed(x)
+ embed_loss = (h0 * a_0_norm).sum(dim=-1).mean()
+ embed_opt.zero_grad()
+ embed_loss.backward()
+ embed_opt.step()
+
+ total_loss += loss_val.item() * batch
+ correct += (logits.argmax(1) == y).sum().item()
+ total += batch
+
+ for sch in all_schedulers:
+ sch.step()
+
+ log['train_loss'].append(total_loss / total)
+ log['train_acc'].append(correct / total)
+ log['test_acc'].append(evaluate(model, test_loader, device))
+ log['value_loss'].append(total_vloss / total)
+ if epoch % 10 == 0 or epoch == 1:
+ phase = "warmup" if epoch <= warmup_epochs else f"blend={credit_blend:.2f}"
+ print(f" [CB-deltaL] Ep {epoch} ({phase}): loss={log['train_loss'][-1]:.4f} "
+ f"train={log['train_acc'][-1]:.4f} test={log['test_acc'][-1]:.4f} "
+ f"vloss={log['value_loss'][-1]:.6f}")
+ return log, value_net
+
+
+def compute_diagnostics(model, value_net, test_loader, device, args):
+ model.eval()
+ value_net.eval()
+ d = model.d_hidden
+ L = model.num_blocks
+
+ for x, y in test_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ break
+
+ batch = x.size(0)
+
+ # BP gradients
+ logits_bp, hiddens_bp = model(x, return_hidden=True)
+ for l in range(L + 1):
+ hiddens_bp[l].retain_grad()
+ loss_bp = F.cross_entropy(logits_bp, y)
+ loss_bp.backward()
+ bp_grads = {l: hiddens_bp[l].grad.detach().clone() for l in range(L + 1)}
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+
+ hL_det = hiddens[-1].detach()
+ s = compute_deltaL(model, hL_det, y)
+
+ results = {'bp_cosine': [], 'perturbation_rho': [], 'nudging': {'0.01': []}}
+
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+
+ h_l_req = h_l.clone().requires_grad_(True)
+ V_l = value_net(h_l_req, t_l, s)
+ a_l = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0].detach()
+
+ bp_cos = cosine_similarity_batch(a_l, bp_grads[l])
+ results['bp_cosine'].append(bp_cos)
+
+ def make_fwd_fn(start_l):
+ def fwd_fn(h):
+ with torch.no_grad():
+ curr = h
+ for i in range(start_l, L):
+ curr = curr + model.blocks[i](curr)
+ out = model.out_head(model.out_ln(curr))
+ return F.cross_entropy(out, y, reduction='none')
+ return fwd_fn
+
+ fwd_fn = make_fwd_fn(l)
+ rho = perturbation_correlation(h_l, a_l, fwd_fn, epsilon=1e-3, M=16)
+ results['perturbation_rho'].append(rho)
+
+ nud = nudging_test(h_l, a_l, fwd_fn, eta=0.01)
+ results['nudging']['0.01'].append(nud)
+
+ return results
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--d_hidden', type=int, default=512)
+ parser.add_argument('--num_blocks', type=int, default=4)
+ parser.add_argument('--batch_size', type=int, default=128)
+ parser.add_argument('--epochs', type=int, default=100)
+ parser.add_argument('--lr', type=float, default=1e-3)
+ parser.add_argument('--lr_fb', type=float, default=1e-3)
+ parser.add_argument('--wd', type=float, default=0.01)
+ parser.add_argument('--lam', type=float, default=0.1)
+ parser.add_argument('--K', type=int, default=4)
+ parser.add_argument('--sigma_bridge', type=float, default=0.05)
+ parser.add_argument('--ema_momentum', type=float, default=0.995)
+ parser.add_argument('--term_grad_weight', type=float, default=1.0)
+ parser.add_argument('--seed', type=int, default=42)
+ parser.add_argument('--gpu', type=int, default=1)
+ parser.add_argument('--output_dir', type=str, default='results/cifar_deltaL')
+ args = parser.parse_args()
+
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ print(f"Device: {device}")
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ torch.manual_seed(args.seed)
+ np.random.seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+
+ train_loader, test_loader = get_cifar10(args.batch_size)
+ input_dim = 32 * 32 * 3
+
+ model = ResidualMLP(input_dim, args.d_hidden, 10, args.num_blocks).to(device)
+ print(f"Model: d={args.d_hidden}, L={args.num_blocks}")
+ print(f"Conditioning: s=deltaL (dim={args.d_hidden})")
+
+ t0 = time.time()
+ log, vnet = train_cb_deltaL(model, train_loader, test_loader, device, args)
+ elapsed = time.time() - t0
+
+ diag = compute_diagnostics(model, vnet, test_loader, device, args)
+
+ mean_gamma = np.mean(diag['bp_cosine'])
+ mean_rho = np.mean(diag['perturbation_rho'])
+ mean_nudge = np.mean(diag['nudging']['0.01'])
+
+ print(f"\nDone in {elapsed:.0f}s")
+ print(f"Test acc: {log['test_acc'][-1]:.4f}")
+ print(f"Mean Gamma: {mean_gamma:.4f}")
+ print(f"Mean rho: {mean_rho:.4f}")
+ print(f"Mean nudge: {mean_nudge:.6f}")
+ print(f"Gamma per layer: {[round(g, 4) for g in diag['bp_cosine']]}")
+ print(f"rho per layer: {[round(r, 4) for r in diag['perturbation_rho']]}")
+
+ result = {
+ 'test_acc': log['test_acc'][-1],
+ 'mean_gamma': float(mean_gamma),
+ 'mean_rho': float(mean_rho),
+ 'mean_nudge': float(mean_nudge),
+ 'gamma_per_layer': [float(g) for g in diag['bp_cosine']],
+ 'rho_per_layer': [float(r) for r in diag['perturbation_rho']],
+ 'log': log,
+ }
+
+ out_path = os.path.join(args.output_dir, f'cb_deltaL_d{args.d_hidden}_L{args.num_blocks}_s{args.seed}.json')
+ with open(out_path, 'w') as f:
+ json.dump(result, f, indent=2)
+ print(f"Saved to {out_path}")
+
+
+if __name__ == '__main__':
+ main()