diff options
Diffstat (limited to 'experiments')
| -rw-r--r-- | experiments/__init__.py | 0 | ||||
| -rw-r--r-- | experiments/__pycache__/__init__.cpython-313.pyc | bin | 0 -> 137 bytes | |||
| -rw-r--r-- | experiments/__pycache__/toy_lq.cpython-313.pyc | bin | 0 -> 19620 bytes | |||
| -rw-r--r-- | experiments/cifar_resmlp.py | 775 | ||||
| -rw-r--r-- | experiments/plot_results.py | 327 | ||||
| -rw-r--r-- | experiments/plot_toy_final.py | 183 | ||||
| -rw-r--r-- | experiments/toy_lq.py | 395 | ||||
| -rw-r--r-- | experiments/toy_lq_sweep.py | 243 | ||||
| -rw-r--r-- | experiments/toy_lq_v2.py | 327 |
9 files changed, 2250 insertions, 0 deletions
diff --git a/experiments/__init__.py b/experiments/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/experiments/__init__.py diff --git a/experiments/__pycache__/__init__.cpython-313.pyc b/experiments/__pycache__/__init__.cpython-313.pyc Binary files differnew file mode 100644 index 0000000..5966841 --- /dev/null +++ b/experiments/__pycache__/__init__.cpython-313.pyc diff --git a/experiments/__pycache__/toy_lq.cpython-313.pyc b/experiments/__pycache__/toy_lq.cpython-313.pyc Binary files differnew file mode 100644 index 0000000..d8710a8 --- /dev/null +++ b/experiments/__pycache__/toy_lq.cpython-313.pyc diff --git a/experiments/cifar_resmlp.py b/experiments/cifar_resmlp.py new file mode 100644 index 0000000..1582f6d --- /dev/null +++ b/experiments/cifar_resmlp.py @@ -0,0 +1,775 @@ +""" +Phase B: Deep Residual MLP on CIFAR-10. +Compare BP, DFA, State Bridge, Credit Bridge. + +CRITICAL CONSTRAINT: No hidden BP anchor for non-BP methods. +All block updates use detached hidden states and local surrogates. +""" +import os +import sys +import json +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision +import torchvision.transforms as transforms +import copy +import time + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.residual_mlp import ResidualMLP +from models.value_net import ValueNet, create_ema_model, update_ema +from models.state_bridge import StateBridgeNet +from metrics.credit_metrics import ( + cosine_similarity_batch, perturbation_correlation, nudging_test, + offline_bp_cosine, feature_drift +) + + +def get_data(dataset='cifar10', batch_size=128): + if dataset == 'cifar10': + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) + testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) + input_dim = 32 * 32 * 3 + num_classes = 10 + elif dataset == 'fashionmnist': + transform_train = transforms.Compose([ + transforms.RandomCrop(28, padding=2), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.2860,), (0.3530,)), + ]) + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.2860,), (0.3530,)), + ]) + trainset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform_train) + testset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform_test) + input_dim = 28 * 28 + num_classes = 10 + else: + raise ValueError(f"Unknown dataset: {dataset}") + + train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) + test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) + return train_loader, test_loader, input_dim, num_classes + + +def evaluate(model, test_loader, device): + model.eval() + correct, total = 0, 0 + with torch.no_grad(): + for x, y in test_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + logits = model(x) + correct += (logits.argmax(1) == y).sum().item() + total += x.size(0) + return correct / total + + +# ============================================================================= +# BP Baseline +# ============================================================================= +def train_bp(model, train_loader, test_loader, device, args): + """Standard end-to-end backprop training.""" + optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) + + log = {'train_loss': [], 'train_acc': [], 'test_acc': []} + + for epoch in range(1, args.epochs + 1): + model.train() + total_loss, correct, total = 0, 0, 0 + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + logits = model(x) + loss = F.cross_entropy(logits, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + total_loss += loss.item() * x.size(0) + correct += (logits.argmax(1) == y).sum().item() + total += x.size(0) + + scheduler.step() + train_loss = total_loss / total + train_acc = correct / total + test_acc = evaluate(model, test_loader, device) + log['train_loss'].append(train_loss) + log['train_acc'].append(train_acc) + log['test_acc'].append(test_acc) + if epoch % 10 == 0 or epoch == 1: + print(f" [BP] Epoch {epoch}: loss={train_loss:.4f}, train={train_acc:.4f}, test={test_acc:.4f}") + + return log + + +# ============================================================================= +# DFA Baseline +# ============================================================================= +def train_dfa(model, train_loader, test_loader, device, args): + """ + DFA training with fixed random feedback matrices. + Each block updated with local surrogate: L_l = <F_l(h_l), sg[a_{l+1}^DFA]>. + Output head updated with exact CE gradient (h_L detached). + Embedding updated via DFA credit at h_0. + """ + d = model.d_hidden + num_classes = args.num_classes + L = model.num_blocks + + # Fixed random feedback matrices, one per block + Bs = [torch.randn(d, num_classes, device=device) / np.sqrt(num_classes) for _ in range(L)] + + # Separate optimizers + block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd) + for block in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd) + head_opt = optim.AdamW( + list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=args.lr, weight_decay=args.wd + ) + + all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)]) + + log = {'train_loss': [], 'train_acc': [], 'test_acc': []} + + for epoch in range(1, args.epochs + 1): + model.train() + total_loss, correct, total = 0, 0, 0 + + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + batch = x.size(0) + + # Forward pass (no grad for hidden states) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + # e_T = softmax(logits) - one_hot(y) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 # (batch, num_classes) + + # 1. Update output head: exact CE gradient, h_L detached + hL_det = hiddens[-1].detach() + logits_out = model.out_head(model.out_ln(hL_det)) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad() + loss_out.backward() + head_opt.step() + + # 2. Update each block with DFA local surrogate + for l in range(L): + h_l = hiddens[l].detach() + # DFA credit: a_{l+1} = B_l @ e_T^T -> (d, batch) -> transpose + a_dfa = (e_T @ Bs[l].T).detach() # (batch, d) = (batch, C) @ (C, d) + # Normalize + rms = (a_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_dfa_norm = a_dfa / rms + # Local surrogate + f_l = model.blocks[l](h_l) + local_loss = (f_l * a_dfa_norm).sum(dim=-1).mean() + block_opts[l].zero_grad() + local_loss.backward() + block_opts[l].step() + + # 3. Update embedding with DFA credit at h_0 + a_0_dfa = (e_T @ Bs[0].T).detach() + rms_0 = (a_0_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_0_norm = a_0_dfa / rms_0 + h0 = model.embed(x) + embed_loss = (h0 * a_0_norm).sum(dim=-1).mean() + embed_opt.zero_grad() + embed_loss.backward() + embed_opt.step() + + total_loss += loss_val.item() * batch + correct += (logits.argmax(1) == y).sum().item() + total += batch + + for s in all_schedulers: + s.step() + + train_loss = total_loss / total + train_acc = correct / total + test_acc = evaluate(model, test_loader, device) + log['train_loss'].append(train_loss) + log['train_acc'].append(train_acc) + log['test_acc'].append(test_acc) + if epoch % 10 == 0 or epoch == 1: + print(f" [DFA] Epoch {epoch}: loss={train_loss:.4f}, train={train_acc:.4f}, test={test_acc:.4f}") + + return log, Bs + + +# ============================================================================= +# State Bridge +# ============================================================================= +def train_state_bridge(model, train_loader, test_loader, device, args): + """ + State Bridge: predict terminal h_L from (h_l, t_l, s), derive credit as + a_l = grad_{h_l} CE(W_out * LN(G_psi(h_l, t_l, s)), y). + """ + d = model.d_hidden + num_classes = args.num_classes + L = model.num_blocks + + state_pred = StateBridgeNet( + d_hidden=d, s_dim=num_classes, time_embed_dim=32, hidden_dim=256, num_layers=3 + ).to(device) + + block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd) + for block in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd) + head_opt = optim.AdamW( + list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=args.lr, weight_decay=args.wd + ) + state_opt = optim.Adam(state_pred.parameters(), lr=args.lr_fb) + + all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)]) + + log = {'train_loss': [], 'train_acc': [], 'test_acc': [], 'state_pred_error': []} + + for epoch in range(1, args.epochs + 1): + model.train() + state_pred.train() + total_loss, correct, total = 0, 0, 0 + total_se = 0 + + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + batch = x.size(0) + + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + + hL_det = hiddens[-1].detach() + + # Train state predictor: G_psi(h_l, t_l, s) -> h_L + # Predict the *residual* from h_l to h_L for numerical stability + state_loss = 0.0 + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + pred_hL = state_pred(h_l_det, t_l, s) + # Target: h_L (use normalized MSE for stability) + target = hL_det + target_norm = target.norm(dim=-1, keepdim=True).clamp(min=1.0) + state_loss = state_loss + (((pred_hL - target) / target_norm) ** 2).sum(dim=-1).mean() + state_loss = state_loss / L + state_opt.zero_grad() + state_loss.backward() + state_opt.step() + total_se += state_loss.item() * batch + + # Compute credits: a_l = grad_{h_l} CE(out_head(LN(G(h_l, t_l, s))), y) + credits = [] + for l in range(L): + h_l_det = hiddens[l].detach().requires_grad_(True) + t_l = torch.full((batch,), l / L, device=device) + pred_hL = state_pred(h_l_det, t_l, s) + pred_logits = model.out_head(model.out_ln(pred_hL)) + pred_loss = F.cross_entropy(pred_logits, y, reduction='sum') + a_l = torch.autograd.grad(pred_loss, h_l_det, create_graph=False)[0] + credits.append(a_l.detach()) + + # Update output head + logits_out = model.out_head(model.out_ln(hL_det)) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad() + loss_out.backward() + head_opt.step() + + # Update blocks + for l in range(L): + h_l = hiddens[l].detach() + a = credits[l] + rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_norm = a / rms + f_l = model.blocks[l](h_l) + local_loss = (f_l * a_norm).sum(dim=-1).mean() + block_opts[l].zero_grad() + local_loss.backward() + block_opts[l].step() + + # Update embedding with credit at layer 0 + a_0 = credits[0] + rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_0_norm = a_0 / rms_0 + h0 = model.embed(x) + embed_loss = (h0 * a_0_norm).sum(dim=-1).mean() + embed_opt.zero_grad() + embed_loss.backward() + embed_opt.step() + + total_loss += loss_val.item() * batch + correct += (logits.argmax(1) == y).sum().item() + total += batch + + for sch in all_schedulers: + sch.step() + + train_loss = total_loss / total + train_acc = correct / total + test_acc = evaluate(model, test_loader, device) + se = total_se / total + log['train_loss'].append(train_loss) + log['train_acc'].append(train_acc) + log['test_acc'].append(test_acc) + log['state_pred_error'].append(se) + if epoch % 10 == 0 or epoch == 1: + print(f" [SB] Epoch {epoch}: loss={train_loss:.4f}, train={train_acc:.4f}, " + f"test={test_acc:.4f}, state_err={se:.4f}") + + return log, state_pred + + +# ============================================================================= +# Credit Bridge +# ============================================================================= +def train_credit_bridge(model, train_loader, test_loader, device, args): + """ + Credit Bridge: learn V_phi(h_l, t_l, s) -> scalar value. + Credit: a_l = grad_{h_l} V_phi. + Training: terminal boundary + bridge consistency + terminal gradient matching. + The terminal gradient is local (output layer only), NOT hidden BP. + + Uses a warmup phase: first warmup_epochs, only train value net + output head, + then start using credit bridge signals to update blocks. + During warmup, blocks get DFA-style updates as a fallback. + """ + d = model.d_hidden + num_classes = args.num_classes + L = model.num_blocks + warmup_epochs = max(1, args.epochs // 5) # 20% warmup + + value_net = ValueNet( + d_hidden=d, s_dim=num_classes, time_embed_dim=32, hidden_dim=256, num_layers=3 + ).to(device) + value_net_ema = create_ema_model(value_net) + + # DFA fallback matrices for warmup + Bs_fallback = [torch.randn(d, num_classes, device=device) / np.sqrt(num_classes) + for _ in range(L)] + + block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd) + for block in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd) + head_opt = optim.AdamW( + list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=args.lr, weight_decay=args.wd + ) + value_opt = optim.Adam(value_net.parameters(), lr=args.lr_fb) + + all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)]) + + lam = args.lam + K_samples = args.K + sigma_bridge = args.sigma_bridge + ema_momentum = args.ema_momentum + term_grad_weight = args.term_grad_weight + + log = {'train_loss': [], 'train_acc': [], 'test_acc': [], 'value_loss': []} + + print(f" [CB] Warmup phase: {warmup_epochs} epochs (DFA fallback + value net training)") + + for epoch in range(1, args.epochs + 1): + model.train() + value_net.train() + total_loss, correct, total = 0, 0, 0 + total_vloss = 0 + + # Blend factor: 0 during warmup, linearly increases to 1 after warmup + if epoch <= warmup_epochs: + credit_blend = 0.0 + else: + credit_blend = min(1.0, (epoch - warmup_epochs) / max(1, warmup_epochs)) + + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + batch = x.size(0) + + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + true_loss = F.cross_entropy(logits, y, reduction='none').detach() + + hL_det = hiddens[-1].detach() + + # ---- Train value net (always) ---- + t_L = torch.ones(batch, device=device) + V_terminal = value_net(hL_det, t_L, s) + loss_term = ((V_terminal - true_loss) ** 2).mean() + + # Terminal gradient matching + loss_tgrad = torch.tensor(0.0, device=device) + if term_grad_weight > 0: + hL_req = hL_det.clone().requires_grad_(True) + V_at_L = value_net(hL_req, t_L, s) + grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0] + hL_req2 = hL_det.clone().requires_grad_(True) + logits_tgt = model.out_head(model.out_ln(hL_req2)) + ce_loss = F.cross_entropy(logits_tgt, y, reduction='sum') + a_L_exact = torch.autograd.grad(ce_loss, hL_req2, create_graph=False)[0].detach() + loss_tgrad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean() + + # Bridge consistency + loss_bridge = 0.0 + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + t_l_next = torch.full((batch,), (l + 1) / L, device=device) + V_l = value_net(h_l_det, t_l, s) + + with torch.no_grad(): + h_next_det = hiddens[l + 1].detach() + log_terms = [] + for k in range(K_samples): + noise = sigma_bridge * torch.randn_like(h_next_det) + V_next = value_net_ema(h_next_det + noise, t_l_next, s) + log_terms.append(-V_next / lam) + log_stack = torch.stack(log_terms, dim=-1) + V_target = -lam * (torch.logsumexp(log_stack, dim=-1) - np.log(K_samples)) + + loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean() + loss_bridge = loss_bridge / L + + value_loss = loss_term + loss_bridge + term_grad_weight * loss_tgrad + + value_opt.zero_grad() + value_loss.backward() + torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0) + value_opt.step() + update_ema(value_net, value_net_ema, ema_momentum) + total_vloss += value_loss.item() * batch + + # ---- Compute credits ---- + # Credit bridge credits + cb_credits = [] + for l in range(L): + h_l_det = hiddens[l].detach().requires_grad_(True) + t_l = torch.full((batch,), l / L, device=device) + V_l = value_net(h_l_det, t_l, s) + a_l = torch.autograd.grad(V_l.sum(), h_l_det, create_graph=False)[0] + cb_credits.append(a_l.detach()) + + # DFA fallback credits + dfa_credits = [(e_T @ Bs_fallback[l].T).detach() for l in range(L)] + + # Blend credits + credits = [] + for l in range(L): + if credit_blend >= 1.0: + a = cb_credits[l] + elif credit_blend <= 0.0: + a = dfa_credits[l] + else: + # Normalize both before blending + cb_rms = (cb_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + dfa_rms = (dfa_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a = credit_blend * (cb_credits[l] / cb_rms) + (1 - credit_blend) * (dfa_credits[l] / dfa_rms) + credits.append(a) + + # ---- Update output head ---- + logits_out = model.out_head(model.out_ln(hL_det)) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad() + loss_out.backward() + head_opt.step() + + # ---- Update blocks ---- + for l in range(L): + h_l = hiddens[l].detach() + a = credits[l] + rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_norm = a / rms + f_l = model.blocks[l](h_l) + local_loss = (f_l * a_norm).sum(dim=-1).mean() + block_opts[l].zero_grad() + local_loss.backward() + block_opts[l].step() + + # ---- Update embedding ---- + a_0 = credits[0] + rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_0_norm = a_0 / rms_0 + h0 = model.embed(x) + embed_loss = (h0 * a_0_norm).sum(dim=-1).mean() + embed_opt.zero_grad() + embed_loss.backward() + embed_opt.step() + + total_loss += loss_val.item() * batch + correct += (logits.argmax(1) == y).sum().item() + total += batch + + for sch in all_schedulers: + sch.step() + + train_loss = total_loss / total + train_acc = correct / total + test_acc = evaluate(model, test_loader, device) + vloss = total_vloss / total + log['train_loss'].append(train_loss) + log['train_acc'].append(train_acc) + log['test_acc'].append(test_acc) + log['value_loss'].append(vloss) + if epoch % 10 == 0 or epoch == 1: + phase = "warmup" if epoch <= warmup_epochs else f"blend={credit_blend:.2f}" + print(f" [CB] Epoch {epoch} ({phase}): loss={train_loss:.4f}, train={train_acc:.4f}, " + f"test={test_acc:.4f}, vloss={vloss:.6f}") + + return log, value_net, value_net_ema + + +# ============================================================================= +# Diagnostics +# ============================================================================= +def compute_diagnostics(model, method_name, test_loader, device, args, + value_net=None, state_predictor=None, dfa_Bs=None): + """Compute all diagnostic metrics for a trained model.""" + model.eval() + if value_net is not None: + value_net.eval() + if state_predictor is not None: + state_predictor.eval() + + d = model.d_hidden + L = model.num_blocks + num_classes = args.num_classes + + # Get one batch for diagnostics + for x, y in test_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + break + + batch = x.size(0) + + # Forward with hidden states, need grad for BP cosine + logits_bp, hiddens_bp = model(x, return_hidden=True) + for l in range(L + 1): + hiddens_bp[l].retain_grad() + loss_bp = F.cross_entropy(logits_bp, y) + loss_bp.backward() + bp_grads = {l: hiddens_bp[l].grad.detach().clone() for l in range(L + 1)} + + # Forward again without grad for clean hidden states + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + + results = { + 'bp_cosine': [], + 'perturbation_rho': [], + 'nudging': {'0.001': [], '0.003': [], '0.01': []}, + } + + for l in range(L): + h_l = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + + # Get credit + if method_name == 'bp': + a_l = bp_grads[l] + elif method_name == 'dfa': + a_l = (e_T @ dfa_Bs[l].T).detach() + elif method_name == 'state_bridge': + h_l_req = h_l.clone().requires_grad_(True) + pred_hL = state_predictor(h_l_req, t_l, s) + pred_logits = model.out_head(model.out_ln(pred_hL)) + pred_loss = F.cross_entropy(pred_logits, y, reduction='sum') + a_l = torch.autograd.grad(pred_loss, h_l_req, create_graph=False)[0].detach() + elif method_name == 'credit_bridge': + h_l_req = h_l.clone().requires_grad_(True) + V_l = value_net(h_l_req, t_l, s) + a_l = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0].detach() + else: + raise ValueError(f"Unknown method: {method_name}") + + # BP cosine + bp_cos = cosine_similarity_batch(a_l, bp_grads[l]) + results['bp_cosine'].append(bp_cos) + + # Forward function for perturbation and nudging + def make_fwd_fn(start_l): + def fwd_fn(h): + with torch.no_grad(): + curr = h + for i in range(start_l, L): + curr = curr + model.blocks[i](curr) + out = model.out_head(model.out_ln(curr)) + return F.cross_entropy(out, y, reduction='none') + return fwd_fn + + fwd_fn = make_fwd_fn(l) + rho = perturbation_correlation(h_l, a_l, fwd_fn, epsilon=1e-3, M=16) + results['perturbation_rho'].append(rho) + + for eta in [0.001, 0.003, 0.01]: + nud = nudging_test(h_l, a_l, fwd_fn, eta=eta) + results['nudging'][str(eta)].append(nud) + + return results + + +# ============================================================================= +# Main +# ============================================================================= +def run_experiment(args): + device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + os.makedirs(args.output_dir, exist_ok=True) + + all_results = {} + + for seed in args.seeds: + print(f"\n{'='*60}") + print(f"Seed {seed}") + print(f"{'='*60}") + + torch.manual_seed(seed) + np.random.seed(seed) + torch.cuda.manual_seed_all(seed) + + train_loader, test_loader, input_dim, num_classes = get_data(args.dataset, args.batch_size) + args.num_classes = num_classes + + seed_results = {} + + # ---- BP ---- + print("\n--- BP ---") + model_bp = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device) + init_bp = {n: p.clone().detach() for n, p in model_bp.named_parameters()} + bp_log = train_bp(model_bp, train_loader, test_loader, device, args) + bp_diag = compute_diagnostics(model_bp, 'bp', test_loader, device, args) + bp_drift = feature_drift(init_bp, {n: p.detach() for n, p in model_bp.named_parameters()}) + seed_results['bp'] = {'log': bp_log, 'diagnostics': bp_diag, 'drift': bp_drift} + print(f" Final test acc: {bp_log['test_acc'][-1]:.4f}") + + # ---- DFA ---- + print("\n--- DFA ---") + torch.manual_seed(seed) + np.random.seed(seed) + torch.cuda.manual_seed_all(seed) + model_dfa = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device) + init_dfa = {n: p.clone().detach() for n, p in model_dfa.named_parameters()} + dfa_log, dfa_Bs = train_dfa(model_dfa, train_loader, test_loader, device, args) + dfa_diag = compute_diagnostics(model_dfa, 'dfa', test_loader, device, args, dfa_Bs=dfa_Bs) + dfa_drift = feature_drift(init_dfa, {n: p.detach() for n, p in model_dfa.named_parameters()}) + seed_results['dfa'] = {'log': dfa_log, 'diagnostics': dfa_diag, 'drift': dfa_drift} + print(f" Final test acc: {dfa_log['test_acc'][-1]:.4f}") + + # ---- State Bridge ---- + print("\n--- State Bridge ---") + torch.manual_seed(seed) + np.random.seed(seed) + torch.cuda.manual_seed_all(seed) + model_sb = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device) + init_sb = {n: p.clone().detach() for n, p in model_sb.named_parameters()} + sb_log, state_pred = train_state_bridge(model_sb, train_loader, test_loader, device, args) + sb_diag = compute_diagnostics(model_sb, 'state_bridge', test_loader, device, args, + state_predictor=state_pred) + sb_drift = feature_drift(init_sb, {n: p.detach() for n, p in model_sb.named_parameters()}) + seed_results['state_bridge'] = {'log': sb_log, 'diagnostics': sb_diag, 'drift': sb_drift} + print(f" Final test acc: {sb_log['test_acc'][-1]:.4f}") + + # ---- Credit Bridge ---- + print("\n--- Credit Bridge ---") + torch.manual_seed(seed) + np.random.seed(seed) + torch.cuda.manual_seed_all(seed) + model_cb = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device) + init_cb = {n: p.clone().detach() for n, p in model_cb.named_parameters()} + cb_log, vnet, vnet_ema = train_credit_bridge(model_cb, train_loader, test_loader, device, args) + cb_diag = compute_diagnostics(model_cb, 'credit_bridge', test_loader, device, args, + value_net=vnet) + cb_drift = feature_drift(init_cb, {n: p.detach() for n, p in model_cb.named_parameters()}) + seed_results['credit_bridge'] = {'log': cb_log, 'diagnostics': cb_diag, 'drift': cb_drift} + print(f" Final test acc: {cb_log['test_acc'][-1]:.4f}") + + all_results[seed] = seed_results + + # Save + def serialize(obj): + if isinstance(obj, dict): + return {str(k): serialize(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [serialize(v) for v in obj] + elif isinstance(obj, (np.floating, np.integer)): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, torch.Tensor): + return obj.cpu().numpy().tolist() + return obj + + save_data = serialize(all_results) + save_data['config'] = serialize(vars(args)) + out_path = os.path.join(args.output_dir, f'results_{args.dataset}.json') + with open(out_path, 'w') as f: + json.dump(save_data, f, indent=2) + print(f"\nAll results saved to {out_path}") + return all_results + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--dataset', type=str, default='cifar10') + parser.add_argument('--d_hidden', type=int, default=512) + parser.add_argument('--num_blocks', type=int, default=12) + parser.add_argument('--batch_size', type=int, default=128) + parser.add_argument('--epochs', type=int, default=100) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--lr_fb', type=float, default=1e-3) + parser.add_argument('--wd', type=float, default=0.01) + parser.add_argument('--lam', type=float, default=0.1) + parser.add_argument('--K', type=int, default=4) + parser.add_argument('--sigma_bridge', type=float, default=0.05) + parser.add_argument('--ema_momentum', type=float, default=0.995) + parser.add_argument('--term_grad_weight', type=float, default=1.0) + parser.add_argument('--seeds', type=int, nargs='+', default=[42, 123, 456]) + parser.add_argument('--gpu', type=int, default=1) + parser.add_argument('--output_dir', type=str, default='results/cifar10') + args = parser.parse_args() + run_experiment(args) + + +if __name__ == '__main__': + main() diff --git a/experiments/plot_results.py b/experiments/plot_results.py new file mode 100644 index 0000000..e3e2754 --- /dev/null +++ b/experiments/plot_results.py @@ -0,0 +1,327 @@ +"""Generate plots for toy LQ and CIFAR-10 experiments.""" +import os +import sys +import json +import argparse +import numpy as np + +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + + +def plot_toy_results(results_dir='results/toy_lq', output_dir='report'): + """Plot toy LQ experiment results.""" + os.makedirs(output_dir, exist_ok=True) + + # Collect results across seeds + files = [f for f in os.listdir(results_dir) if f.startswith('toy_lq_seed') and f.endswith('.json')] + if not files: + print(f"No toy results found in {results_dir}") + return + + all_data = [] + for f in sorted(files): + with open(os.path.join(results_dir, f)) as fp: + all_data.append(json.load(fp)) + + # Use the last result for per-layer plots (or average if multiple seeds) + data = all_data[-1] + per_layer = data['final_per_layer'] + log_data = data['log'] + + num_layers = len(per_layer['dfa_costate_cos']) + layers = list(range(num_layers)) + + # 1. Per-layer costate cosine + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + ax.plot(layers, per_layer['dfa_costate_cos'], 'o-', label='DFA', color='blue') + ax.plot(layers, per_layer['state_costate_cos'], 's-', label='State Bridge', color='orange') + ax.plot(layers, per_layer['credit_costate_cos'], '^-', label='Credit Bridge', color='green') + ax.set_xlabel('Layer') + ax.set_ylabel('Cosine Similarity with Exact Costate') + ax.set_title('Exact Costate Cosine (Toy LQ)') + ax.legend() + ax.grid(True, alpha=0.3) + ax.set_ylim(-0.2, 1.05) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'toy_costate_cosine.png'), dpi=150) + plt.close(fig) + + # 2. Per-layer perturbation correlation + num_rho_layers = len(per_layer['dfa_rho']) + rho_layers = list(range(num_rho_layers)) + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + ax.plot(rho_layers, per_layer['dfa_rho'], 'o-', label='DFA', color='blue') + ax.plot(rho_layers, per_layer['state_rho'], 's-', label='State Bridge', color='orange') + ax.plot(rho_layers, per_layer['credit_rho'], '^-', label='Credit Bridge', color='green') + ax.set_xlabel('Layer') + ax.set_ylabel('Perturbation Correlation (rho)') + ax.set_title('Local Perturbation Correlation (Toy LQ)') + ax.legend() + ax.grid(True, alpha=0.3) + ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'toy_perturbation_rho.png'), dpi=150) + plt.close(fig) + + # 3. Per-layer nudging test + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + ax.plot(rho_layers, per_layer['dfa_nudge'], 'o-', label='DFA', color='blue') + ax.plot(rho_layers, per_layer['state_nudge'], 's-', label='State Bridge', color='orange') + ax.plot(rho_layers, per_layer['credit_nudge'], '^-', label='Credit Bridge', color='green') + ax.set_xlabel('Layer') + ax.set_ylabel('Nudge Delta (negative = good)') + ax.set_title('Nudging Test (Toy LQ)') + ax.legend() + ax.grid(True, alpha=0.3) + ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'toy_nudging.png'), dpi=150) + plt.close(fig) + + # 4. Bridge residual over training + if log_data['bridge_residual']: + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + ax.plot(log_data['steps'], log_data['bridge_residual'], '-', color='green') + ax.set_xlabel('Training Step') + ax.set_ylabel('Bridge Residual') + ax.set_title('Bridge Residual Over Training (Toy LQ)') + ax.grid(True, alpha=0.3) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'toy_bridge_residual.png'), dpi=150) + plt.close(fig) + + # 5. Training curves (costate cosine over time) + fig, axes = plt.subplots(1, 3, figsize=(18, 5)) + for ax, key, title in zip(axes, + ['dfa_costate_cos', 'state_costate_cos', 'credit_costate_cos'], + ['DFA', 'State Bridge', 'Credit Bridge']): + ax.plot(log_data['steps'], log_data[key], '-') + ax.set_xlabel('Training Step') + ax.set_ylabel('Avg Costate Cosine') + ax.set_title(f'{title} - Costate Cosine Over Training') + ax.grid(True, alpha=0.3) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'toy_cosine_training.png'), dpi=150) + plt.close(fig) + + # 6. Per-layer bridge residual + if per_layer.get('bridge_residual'): + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + br_layers = list(range(len(per_layer['bridge_residual']))) + ax.plot(br_layers, per_layer['bridge_residual'], '^-', color='green') + ax.set_xlabel('Layer') + ax.set_ylabel('Bridge Residual') + ax.set_title('Per-Layer Bridge Residual (Toy LQ)') + ax.grid(True, alpha=0.3) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'toy_bridge_residual_per_layer.png'), dpi=150) + plt.close(fig) + + print(f"Toy LQ plots saved to {output_dir}/") + + +def plot_cifar_results(results_path='results/cifar10/cifar_results_cifar10.json', output_dir='report'): + """Plot CIFAR-10 experiment results.""" + os.makedirs(output_dir, exist_ok=True) + + if not os.path.exists(results_path): + print(f"No CIFAR results found at {results_path}") + return + + with open(results_path) as f: + data = json.load(f) + + config = data.pop('config', {}) + methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge'] + colors = {'bp': 'red', 'dfa': 'blue', 'state_bridge': 'orange', 'credit_bridge': 'green'} + labels = {'bp': 'BP', 'dfa': 'DFA', 'state_bridge': 'State Bridge', 'credit_bridge': 'Credit Bridge'} + + seeds = [k for k in data.keys() if k != 'config'] + + # 1. Accuracy curves (mean ± std across seeds) + fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + for method in methods: + train_accs = [] + test_accs = [] + for seed in seeds: + if method in data[seed]: + log = data[seed][method]['log'] + train_accs.append(log['train_acc']) + test_accs.append(log['test_acc']) + + if train_accs: + train_arr = np.array(train_accs) + test_arr = np.array(test_accs) + epochs = np.arange(1, train_arr.shape[1] + 1) + + mean_train = train_arr.mean(0) + std_train = train_arr.std(0) + mean_test = test_arr.mean(0) + std_test = test_arr.std(0) + + axes[0].plot(epochs, mean_train, '-', color=colors[method], label=labels[method]) + axes[0].fill_between(epochs, mean_train - std_train, mean_train + std_train, + alpha=0.15, color=colors[method]) + axes[1].plot(epochs, mean_test, '-', color=colors[method], label=labels[method]) + axes[1].fill_between(epochs, mean_test - std_test, mean_test + std_test, + alpha=0.15, color=colors[method]) + + axes[0].set_xlabel('Epoch') + axes[0].set_ylabel('Train Accuracy') + axes[0].set_title('Train Accuracy') + axes[0].legend() + axes[0].grid(True, alpha=0.3) + axes[1].set_xlabel('Epoch') + axes[1].set_ylabel('Test Accuracy') + axes[1].set_title('Test Accuracy') + axes[1].legend() + axes[1].grid(True, alpha=0.3) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'cifar_accuracy.png'), dpi=150) + plt.close(fig) + + # 2. Per-layer diagnostics (from last seed) + last_seed = seeds[-1] + + # BP cosine per layer + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + for method in methods: + if method in data[last_seed] and 'diagnostics' in data[last_seed][method]: + diag = data[last_seed][method]['diagnostics'] + if 'bp_cosine' in diag: + layers = list(range(len(diag['bp_cosine']))) + ax.plot(layers, diag['bp_cosine'], 'o-', color=colors[method], label=labels[method]) + ax.set_xlabel('Layer') + ax.set_ylabel('Cosine with BP Gradient') + ax.set_title('Offline BP Cosine (CIFAR-10)') + ax.legend() + ax.grid(True, alpha=0.3) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'cifar_bp_cosine.png'), dpi=150) + plt.close(fig) + + # Perturbation rho per layer + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + for method in methods: + if method in data[last_seed] and 'diagnostics' in data[last_seed][method]: + diag = data[last_seed][method]['diagnostics'] + if 'perturbation_rho' in diag: + layers = list(range(len(diag['perturbation_rho']))) + ax.plot(layers, diag['perturbation_rho'], 'o-', color=colors[method], label=labels[method]) + ax.set_xlabel('Layer') + ax.set_ylabel('Perturbation Correlation (rho)') + ax.set_title('Local Perturbation Correlation (CIFAR-10)') + ax.legend() + ax.grid(True, alpha=0.3) + ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'cifar_perturbation_rho.png'), dpi=150) + plt.close(fig) + + # Nudging test per layer (eta=0.01) + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + for method in methods: + if method in data[last_seed] and 'diagnostics' in data[last_seed][method]: + diag = data[last_seed][method]['diagnostics'] + if 'nudging' in diag and '0.01' in diag['nudging']: + nud = diag['nudging']['0.01'] + layers = list(range(len(nud))) + ax.plot(layers, nud, 'o-', color=colors[method], label=labels[method]) + ax.set_xlabel('Layer') + ax.set_ylabel('Nudge Delta (negative = good)') + ax.set_title('Nudging Test eta=0.01 (CIFAR-10)') + ax.legend() + ax.grid(True, alpha=0.3) + ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'cifar_nudging.png'), dpi=150) + plt.close(fig) + + # Feature drift per layer + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + for method in methods: + if method in data[last_seed] and 'drift' in data[last_seed][method]: + drift = data[last_seed][method]['drift'] + # Extract per-block drift (only block weights) + block_drifts = [] + for l in range(12): + key = f'blocks.{l}.w1.weight' + if key in drift: + block_drifts.append(drift[key]) + if block_drifts: + ax.plot(range(len(block_drifts)), block_drifts, 'o-', color=colors[method], label=labels[method]) + ax.set_xlabel('Block') + ax.set_ylabel('Feature Drift (||W_final - W_init||/||W_init||)') + ax.set_title('Feature Drift (CIFAR-10)') + ax.legend() + ax.grid(True, alpha=0.3) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'cifar_feature_drift.png'), dpi=150) + plt.close(fig) + + print(f"CIFAR-10 plots saved to {output_dir}/") + + +def print_summary_table(results_path='results/cifar10/cifar_results_cifar10.json'): + """Print summary table of results.""" + if not os.path.exists(results_path): + print(f"No results at {results_path}") + return + + with open(results_path) as f: + data = json.load(f) + + config = data.pop('config', {}) + methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge'] + labels = {'bp': 'BP', 'dfa': 'DFA', 'state_bridge': 'State Bridge', 'credit_bridge': 'Credit Bridge'} + + seeds = [k for k in data.keys() if k != 'config'] + + print("\n" + "="*80) + print("SUMMARY TABLE") + print("="*80) + print(f"{'Method':<20} {'Test Acc':<15} {'Avg rho':<15} {'Avg Nudge(0.01)':<15} {'Avg BP Cos':<15}") + print("-"*80) + + for method in methods: + test_accs = [] + avg_rhos = [] + avg_nudges = [] + avg_bp_cos = [] + + for seed in seeds: + if method in data[seed]: + log = data[seed][method]['log'] + test_accs.append(log['test_acc'][-1]) + + if 'diagnostics' in data[seed][method]: + diag = data[seed][method]['diagnostics'] + if 'perturbation_rho' in diag: + avg_rhos.append(np.mean(diag['perturbation_rho'])) + if 'nudging' in diag and '0.01' in diag['nudging']: + avg_nudges.append(np.mean(diag['nudging']['0.01'])) + if 'bp_cosine' in diag: + avg_bp_cos.append(np.mean(diag['bp_cosine'])) + + ta = f"{np.mean(test_accs):.4f}±{np.std(test_accs):.4f}" if test_accs else "N/A" + rho = f"{np.mean(avg_rhos):.4f}" if avg_rhos else "N/A" + nud = f"{np.mean(avg_nudges):.4f}" if avg_nudges else "N/A" + bpc = f"{np.mean(avg_bp_cos):.4f}" if avg_bp_cos else "N/A" + + print(f"{labels[method]:<20} {ta:<15} {rho:<15} {nud:<15} {bpc:<15}") + + print("="*80) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--toy_dir', type=str, default='results/toy_lq') + parser.add_argument('--cifar_path', type=str, default='results/cifar10/cifar_results_cifar10.json') + parser.add_argument('--output_dir', type=str, default='report') + args = parser.parse_args() + + plot_toy_results(args.toy_dir, args.output_dir) + plot_cifar_results(args.cifar_path, args.output_dir) + print_summary_table(args.cifar_path) diff --git a/experiments/plot_toy_final.py b/experiments/plot_toy_final.py new file mode 100644 index 0000000..2f7c109 --- /dev/null +++ b/experiments/plot_toy_final.py @@ -0,0 +1,183 @@ +"""Generate final toy LQ experiment plots from v2 results across 3 seeds.""" +import os +import json +import numpy as np +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +output_dir = 'report' +os.makedirs(output_dir, exist_ok=True) + +# Load all v2 results with term_grad_weight=1.0, fm=0.0 +seeds = [42, 123, 456] +all_data = [] +for seed in seeds: + path = f'results/toy_lq/toy_lq_v2_seed{seed}_lam0.1_sig0.1_tgw1.0_fm0.0.json' + if os.path.exists(path): + with open(path) as f: + all_data.append(json.load(f)) + +if not all_data: + print("No results found!") + exit() + +# Also load v1 baseline (no term_grad) for comparison +v1_path = 'results/toy_lq/toy_lq_seed42.json' +v1_data = None +if os.path.exists(v1_path): + with open(v1_path) as f: + v1_data = json.load(f) + +# Aggregate final per-layer results across seeds +methods = ['dfa', 'state', 'credit'] +colors = {'dfa': '#2196F3', 'state': '#FF9800', 'credit': '#4CAF50'} +labels = {'dfa': 'DFA', 'state': 'State Bridge', 'credit': 'Credit Bridge'} + +# Per-layer costate cosine +fig, axes = plt.subplots(1, 3, figsize=(18, 5)) + +for ax, metric, title, ylabel in zip( + axes, + ['costate_cos', 'rho', 'nudge'], + ['Exact Costate Cosine', 'Perturbation Correlation (ρ)', 'Nudging Test'], + ['Cosine Similarity', 'Pearson Correlation', 'Loss Change (negative=good)'] +): + for method in methods: + key = f'{method}_{metric}' + values_per_seed = [] + for data in all_data: + pl = data['final_per_layer'] + if key in pl: + values_per_seed.append(pl[key]) + + if values_per_seed: + arr = np.array(values_per_seed) + mean = arr.mean(axis=0) + std = arr.std(axis=0) + layers = np.arange(len(mean)) + ax.plot(layers, mean, 'o-', color=colors[method], label=labels[method], markersize=5) + ax.fill_between(layers, mean - std, mean + std, alpha=0.15, color=colors[method]) + + ax.set_xlabel('Layer', fontsize=12) + ax.set_ylabel(ylabel, fontsize=12) + ax.set_title(title, fontsize=13) + ax.legend(fontsize=11) + ax.grid(True, alpha=0.3) + if metric == 'costate_cos': + ax.set_ylim(-0.15, 1.05) + elif metric == 'rho': + ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) + elif metric == 'nudge': + ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) + +fig.suptitle('Toy LQ Sanity Check: Per-Layer Diagnostics (3 seeds)', fontsize=14, y=1.02) +fig.tight_layout() +fig.savefig(os.path.join(output_dir, 'toy_per_layer_diagnostics.png'), dpi=150, bbox_inches='tight') +plt.close(fig) +print("Saved toy_per_layer_diagnostics.png") + +# Training curves +fig, axes = plt.subplots(1, 3, figsize=(18, 5)) +metric_keys = [ + ('costate_cos', 'Avg Costate Cosine', 'Cosine Similarity'), + ('rho', 'Avg Perturbation ρ', 'Pearson Correlation'), + ('nudge', 'Avg Nudging', 'Loss Change'), +] + +for ax, (metric, title, ylabel) in zip(axes, metric_keys): + for method in methods: + key = f'{method}_{metric}' + all_curves = [] + for data in all_data: + log = data['log'] + full_key = f'{method}_costate_cos' if metric == 'costate_cos' else f'{method}_{metric}' + if full_key in log: + all_curves.append(np.array(log[full_key])) + + if all_curves: + # All should have same length, use shortest + min_len = min(len(c) for c in all_curves) + arr = np.array([c[:min_len] for c in all_curves]) + steps = np.array(all_data[0]['log']['steps'][:min_len]) + mean = arr.mean(axis=0) + std = arr.std(axis=0) + ax.plot(steps, mean, '-', color=colors[method], label=labels[method]) + ax.fill_between(steps, mean - std, mean + std, alpha=0.15, color=colors[method]) + + ax.set_xlabel('Training Step', fontsize=12) + ax.set_ylabel(ylabel, fontsize=12) + ax.set_title(title, fontsize=13) + ax.legend(fontsize=11) + ax.grid(True, alpha=0.3) + +fig.suptitle('Toy LQ: Training Curves (3 seeds)', fontsize=14, y=1.02) +fig.tight_layout() +fig.savefig(os.path.join(output_dir, 'toy_training_curves.png'), dpi=150, bbox_inches='tight') +plt.close(fig) +print("Saved toy_training_curves.png") + +# Compare v1 (no term grad) vs v2 (with term grad) for credit bridge +if v1_data: + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + + # v1 credit bridge (no term grad matching) + v1_log = v1_data['log'] + ax.plot(v1_log['steps'], v1_log['credit_costate_cos'], + '--', color='red', label='Credit Bridge (w/o terminal grad)', alpha=0.8) + + # v2 credit bridge (with term grad) + v2_log = all_data[0]['log'] # seed 42 + ax.plot(v2_log['steps'], v2_log['credit_costate_cos'], + '-', color='green', label='Credit Bridge (w/ terminal grad)') + + # State bridge for reference + ax.plot(v2_log['steps'], v2_log['state_costate_cos'], + '-', color='orange', label='State Bridge') + + ax.set_xlabel('Training Step', fontsize=12) + ax.set_ylabel('Avg Costate Cosine', fontsize=12) + ax.set_title('Effect of Terminal Gradient Matching', fontsize=13) + ax.legend(fontsize=11) + ax.grid(True, alpha=0.3) + ax.set_ylim(-0.1, 1.05) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'toy_term_grad_effect.png'), dpi=150) + plt.close(fig) + print("Saved toy_term_grad_effect.png") + +# Bridge residual (from v1 which has it) +if v1_data and v1_data['log'].get('bridge_residual'): + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + ax.plot(v1_data['log']['steps'], v1_data['log']['bridge_residual'], '-', color='green') + ax.set_xlabel('Training Step', fontsize=12) + ax.set_ylabel('Bridge Residual', fontsize=12) + ax.set_title('Credit Bridge: Bridge Residual Over Training', fontsize=13) + ax.grid(True, alpha=0.3) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'toy_bridge_residual.png'), dpi=150) + plt.close(fig) + print("Saved toy_bridge_residual.png") + +# Print summary table +print("\n" + "="*80) +print("TOY LQ FINAL RESULTS (3 seeds, 8000 steps)") +print("="*80) + +for method in methods: + cos_vals = [] + rho_vals = [] + nudge_vals = [] + for data in all_data: + pl = data['final_per_layer'] + cos_vals.append(np.mean(pl[f'{method}_costate_cos'])) + rho_vals.append(np.mean(pl[f'{method}_rho'])) + nudge_vals.append(np.mean(pl[f'{method}_nudge'])) + + cos_mean, cos_std = np.mean(cos_vals), np.std(cos_vals) + rho_mean, rho_std = np.mean(rho_vals), np.std(rho_vals) + nudge_mean, nudge_std = np.mean(nudge_vals), np.std(nudge_vals) + + print(f"{labels[method]:<20} Cosine: {cos_mean:.4f}±{cos_std:.4f} " + f"ρ: {rho_mean:.4f}±{rho_std:.4f} " + f"Nudge: {nudge_mean:.4f}±{nudge_std:.4f}") diff --git a/experiments/toy_lq.py b/experiments/toy_lq.py new file mode 100644 index 0000000..4fd8919 --- /dev/null +++ b/experiments/toy_lq.py @@ -0,0 +1,395 @@ +""" +Phase A: Linear-Quadratic Residual Sanity Check. + +Fixed forward dynamics (no forward net training). +Only train feedback/bridge models. +Compare DFA, State Bridge, Credit Bridge against exact costate. + +System: + h_{l+1} = M_l h_l + sigma * xi_l, xi_l ~ N(0, I) + Phi(h_L, y) = 0.5 * ||C h_L - y||^2 + Exact costate: a_L = C^T (C h_L - y), a_l = M_l^T a_{l+1} +""" +import os +import sys +import json +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from datetime import datetime + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.value_net import ValueNet, create_ema_model, update_ema +from models.state_bridge import StateBridgeNet +from metrics.credit_metrics import ( + cosine_similarity_batch, perturbation_correlation, nudging_test, bridge_residual +) + + +def generate_stable_dynamics(d, L, spectral_max=0.05, seed=42): + """Generate stable linear maps M_l = I + A_l with ||A_l||_2 <= spectral_max.""" + rng = np.random.RandomState(seed) + Ms = [] + for _ in range(L): + A = rng.randn(d, d).astype(np.float32) + # Scale to desired spectral norm + u, s, v = np.linalg.svd(A, full_matrices=False) + A = A * (spectral_max / s[0]) + M = np.eye(d, dtype=np.float32) + A + Ms.append(torch.from_numpy(M)) + return Ms # list of (d, d) + + +def rollout_forward(h0, Ms, sigma, L, device): + """Roll out forward dynamics: h_{l+1} = M_l h_l + sigma * xi_l.""" + batch = h0.shape[0] + d = h0.shape[1] + hiddens = [h0] + h = h0 + for l in range(L): + M = Ms[l].to(device) + noise = sigma * torch.randn(batch, d, device=device) + h = h @ M.T + noise + hiddens.append(h) + return hiddens # [h_0, ..., h_L] + + +def terminal_loss(hL, C, y): + """Phi(hL, y) = 0.5 * ||C hL - y||^2, returns per-sample loss.""" + diff = hL @ C.T - y # (batch, m) + return 0.5 * (diff ** 2).sum(dim=-1) # (batch,) + + +def exact_costate(hiddens, Ms, C, y, device): + """Compute exact costate a_l for all layers.""" + L = len(hiddens) - 1 + hL = hiddens[L] + # Terminal: a_L = C^T (C h_L - y) + diff = hL @ C.T - y # (batch, m) + a_L = diff @ C # (batch, d) + + costates = [None] * (L + 1) + costates[L] = a_L + for l in range(L - 1, -1, -1): + M = Ms[l].to(device) + costates[l] = costates[l + 1] @ M # a_l = M_l^T a_{l+1} -> a_{l+1} @ M + return costates + + +def make_forward_fn_from_layer(hiddens, Ms, C, y, sigma, start_layer, device): + """Create a function that rolls forward from layer start_layer and returns per-sample loss.""" + L = len(Ms) + + def forward_fn(h): + current = h + for l in range(start_layer, L): + M = Ms[l].to(device) + # No noise for perturbation test (deterministic rollout) + current = current @ M.T + return terminal_loss(current, C, y) + + return forward_fn + + +def run_experiment(args): + device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + torch.manual_seed(args.seed) + np.random.seed(args.seed) + + # Hyperparams + d = args.d_hidden # 64 + m = args.output_dim # 10 + L = args.num_layers # 12 + sigma = args.sigma # 0.03 + batch_size = args.batch_size # 256 + num_steps = args.num_steps # 5000 + lr_fb = args.lr_fb # 1e-3 + lam = args.lam # 0.1 + K = args.K # 8 + ema_momentum = args.ema_momentum # 0.995 + sigma_bridge = args.sigma_bridge # 0.03 + + print(f"=== Toy LQ Experiment ===") + print(f"d={d}, m={m}, L={L}, sigma={sigma}, seed={args.seed}") + print(f"device={device}") + + # Generate fixed dynamics + Ms = generate_stable_dynamics(d, L, spectral_max=0.05, seed=args.seed) + C = torch.randn(m, d, device=device) / np.sqrt(d) + + # DFA random feedback matrices + Bs_dfa = [] + for l in range(L + 1): + B = torch.randn(d, m, device=device) / np.sqrt(m) + Bs_dfa.append(B) + + # State Bridge model + state_bridge = StateBridgeNet(d_hidden=d, s_dim=m, time_embed_dim=16, + hidden_dim=128, num_layers=2).to(device) + opt_state = optim.Adam(state_bridge.parameters(), lr=lr_fb) + + # Credit Bridge value net + value_net = ValueNet(d_hidden=d, s_dim=m, time_embed_dim=16, + hidden_dim=128, num_layers=2).to(device) + value_net_ema = create_ema_model(value_net) + opt_value = optim.Adam(value_net.parameters(), lr=lr_fb) + + # Training logs + log = { + 'steps': [], + 'state_bridge_loss': [], + 'credit_bridge_loss': [], + 'dfa_costate_cos': [], + 'state_costate_cos': [], + 'credit_costate_cos': [], + 'dfa_rho': [], + 'state_rho': [], + 'credit_rho': [], + 'dfa_nudge': [], + 'state_nudge': [], + 'credit_nudge': [], + 'bridge_residual': [], + } + + for step in range(1, num_steps + 1): + # Generate data + h0 = torch.randn(batch_size, d, device=device) + y = torch.randn(batch_size, m, device=device) + + # Forward rollout + hiddens = rollout_forward(h0, Ms, sigma, L, device) + hL = hiddens[L] + + # Terminal error + e_T = (hL @ C.T - y) # (batch, m) - gradient of Phi w.r.t. prediction + + # Terminal modulation code s = e_T (P=I) + s = e_T.detach() + + # ---- Train State Bridge ---- + state_loss = 0.0 + hL_detached = hL.detach() + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch_size,), l / L, device=device) + pred_hL = state_bridge(h_l_det, t_l, s) + state_loss = state_loss + ((pred_hL - hL_detached) ** 2).sum(dim=-1).mean() + state_loss = state_loss / L + + opt_state.zero_grad() + state_loss.backward() + opt_state.step() + + # ---- Train Credit Bridge (value net) ---- + # Terminal boundary: V(h_L, 1, s) should equal Phi(h_L, y) + hL_det = hL.detach().requires_grad_(False) + t_L = torch.ones(batch_size, device=device) + true_loss = terminal_loss(hL_det, C, y).detach() + + V_terminal = value_net(hL_det, t_L, s) + loss_term = ((V_terminal - true_loss) ** 2).mean() + + # Bridge consistency + loss_bridge = 0.0 + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch_size,), l / L, device=device) + t_l_next = torch.full((batch_size,), (l + 1) / L, device=device) + + V_l = value_net(h_l_det, t_l, s) + + # Generate noisy next states + with torch.no_grad(): + M = Ms[l].to(device) + h_next_det = hiddens[l + 1].detach() + + log_terms = [] + for k in range(K): + noise = sigma_bridge * torch.randn(batch_size, d, device=device) + h_next_noisy = h_next_det + noise + V_next = value_net_ema(h_next_noisy, t_l_next, s) + log_terms.append(-V_next / lam) + + log_terms_stack = torch.stack(log_terms, dim=-1) # (batch, K) + V_target = -lam * (torch.logsumexp(log_terms_stack, dim=-1) - np.log(K)) + + loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean() + + loss_bridge = loss_bridge / L + loss_value = loss_term + loss_bridge + + opt_value.zero_grad() + loss_value.backward() + opt_value.step() + update_ema(value_net, value_net_ema, ema_momentum) + + # ---- Evaluation ---- + if step % args.eval_every == 0 or step == 1: + with torch.no_grad(): + eval_batch = min(batch_size, 128) + h0_eval = torch.randn(eval_batch, d, device=device) + y_eval = torch.randn(eval_batch, m, device=device) + hiddens_eval = rollout_forward(h0_eval, Ms, sigma, L, device) + hL_eval = hiddens_eval[L] + e_T_eval = hL_eval @ C.T - y_eval + s_eval = e_T_eval.detach() + + # Exact costate + costates_exact = exact_costate(hiddens_eval, Ms, C, y_eval, device) + + # Compute credits for each method at each layer + dfa_cos_layers = [] + state_cos_layers = [] + credit_cos_layers = [] + dfa_rho_layers = [] + state_rho_layers = [] + credit_rho_layers = [] + dfa_nudge_layers = [] + state_nudge_layers = [] + credit_nudge_layers = [] + bridge_res_layers = [] + + for l in range(L + 1): + h_l = hiddens_eval[l].detach() + a_exact = costates_exact[l].detach() + t_l = torch.full((eval_batch,), l / L, device=device) + + # DFA credit + a_dfa = e_T_eval @ Bs_dfa[l].T # (batch, d) + + # State bridge credit + h_l_req = h_l.clone().requires_grad_(True) + pred_hL = state_bridge(h_l_req, t_l, s_eval) + # Loss through state bridge prediction + pred_out = pred_hL @ C.T # Use C as output projection for consistency + pred_loss = 0.5 * ((pred_out - y_eval) ** 2).sum(dim=-1) + a_state = torch.autograd.grad(pred_loss.sum(), h_l_req, create_graph=False)[0] + + # Credit bridge credit + h_l_req2 = h_l.clone().requires_grad_(True) + V_l = value_net(h_l_req2, t_l, s_eval) + a_credit = torch.autograd.grad(V_l.sum(), h_l_req2, create_graph=False)[0] + + # Costate cosine + dfa_cos_layers.append(cosine_similarity_batch(a_dfa, a_exact)) + state_cos_layers.append(cosine_similarity_batch(a_state, a_exact)) + credit_cos_layers.append(cosine_similarity_batch(a_credit, a_exact)) + + # Perturbation correlation and nudging (skip terminal layer for forward_fn) + if l < L: + fwd_fn = make_forward_fn_from_layer(hiddens_eval, Ms, C, y_eval, sigma, l, device) + + dfa_rho = perturbation_correlation(h_l, a_dfa, fwd_fn, epsilon=1e-3, M=16) + state_rho = perturbation_correlation(h_l, a_state.detach(), fwd_fn, epsilon=1e-3, M=16) + credit_rho = perturbation_correlation(h_l, a_credit.detach(), fwd_fn, epsilon=1e-3, M=16) + dfa_rho_layers.append(dfa_rho) + state_rho_layers.append(state_rho) + credit_rho_layers.append(credit_rho) + + dfa_nud = nudging_test(h_l, a_dfa, fwd_fn, eta=0.01) + state_nud = nudging_test(h_l, a_state.detach(), fwd_fn, eta=0.01) + credit_nud = nudging_test(h_l, a_credit.detach(), fwd_fn, eta=0.01) + dfa_nudge_layers.append(dfa_nud) + state_nudge_layers.append(state_nud) + credit_nudge_layers.append(credit_nud) + + # Bridge residual for credit bridge + if l < L: + t_l_next = torch.full((eval_batch,), (l + 1) / L, device=device) + h_next = hiddens_eval[l + 1].detach() + noisy_list = [h_next + sigma_bridge * torch.randn_like(h_next) for _ in range(K)] + br = bridge_residual(value_net, value_net_ema, h_l, t_l, s_eval, + noisy_list, t_l_next, lam) + bridge_res_layers.append(br) + + # Average across layers + avg_dfa_cos = np.mean(dfa_cos_layers) + avg_state_cos = np.mean(state_cos_layers) + avg_credit_cos = np.mean(credit_cos_layers) + avg_dfa_rho = np.mean(dfa_rho_layers) + avg_state_rho = np.mean(state_rho_layers) + avg_credit_rho = np.mean(credit_rho_layers) + avg_dfa_nudge = np.mean(dfa_nudge_layers) + avg_state_nudge = np.mean(state_nudge_layers) + avg_credit_nudge = np.mean(credit_nudge_layers) + avg_bridge_res = np.mean(bridge_res_layers) if bridge_res_layers else 0.0 + + log['steps'].append(step) + log['dfa_costate_cos'].append(avg_dfa_cos) + log['state_costate_cos'].append(avg_state_cos) + log['credit_costate_cos'].append(avg_credit_cos) + log['dfa_rho'].append(avg_dfa_rho) + log['state_rho'].append(avg_state_rho) + log['credit_rho'].append(avg_credit_rho) + log['dfa_nudge'].append(avg_dfa_nudge) + log['state_nudge'].append(avg_state_nudge) + log['credit_nudge'].append(avg_credit_nudge) + log['bridge_residual'].append(avg_bridge_res) + log['state_bridge_loss'].append(state_loss.item()) + log['credit_bridge_loss'].append(loss_value.item()) + + print(f"Step {step}/{num_steps}") + print(f" Costate cos - DFA: {avg_dfa_cos:.4f}, State: {avg_state_cos:.4f}, Credit: {avg_credit_cos:.4f}") + print(f" Perturb rho - DFA: {avg_dfa_rho:.4f}, State: {avg_state_rho:.4f}, Credit: {avg_credit_rho:.4f}") + print(f" Nudging - DFA: {avg_dfa_nudge:.4f}, State: {avg_state_nudge:.4f}, Credit: {avg_credit_nudge:.4f}") + print(f" Bridge res - {avg_bridge_res:.4f}") + print(f" Losses - State: {state_loss.item():.4f}, Credit: {loss_value.item():.4f}") + print(f" Per-layer costate cos (credit): {['%.3f' % x for x in credit_cos_layers]}") + + # Save results + os.makedirs(args.output_dir, exist_ok=True) + results = { + 'config': vars(args), + 'log': log, + 'final_per_layer': { + 'dfa_costate_cos': dfa_cos_layers, + 'state_costate_cos': state_cos_layers, + 'credit_costate_cos': credit_cos_layers, + 'dfa_rho': dfa_rho_layers, + 'state_rho': state_rho_layers, + 'credit_rho': credit_rho_layers, + 'dfa_nudge': dfa_nudge_layers, + 'state_nudge': state_nudge_layers, + 'credit_nudge': credit_nudge_layers, + 'bridge_residual': bridge_res_layers, + } + } + + out_path = os.path.join(args.output_dir, f'toy_lq_seed{args.seed}.json') + with open(out_path, 'w') as f: + json.dump(results, f, indent=2) + print(f"\nResults saved to {out_path}") + + # Also save models + torch.save(value_net.state_dict(), os.path.join(args.output_dir, f'value_net_seed{args.seed}.pt')) + torch.save(state_bridge.state_dict(), os.path.join(args.output_dir, f'state_bridge_seed{args.seed}.pt')) + + return results + + +def main(): + parser = argparse.ArgumentParser(description='Toy LQ Sanity Check') + parser.add_argument('--d_hidden', type=int, default=64) + parser.add_argument('--output_dim', type=int, default=10) + parser.add_argument('--num_layers', type=int, default=12) + parser.add_argument('--sigma', type=float, default=0.03) + parser.add_argument('--batch_size', type=int, default=256) + parser.add_argument('--num_steps', type=int, default=5000) + parser.add_argument('--lr_fb', type=float, default=1e-3) + parser.add_argument('--lam', type=float, default=0.1) + parser.add_argument('--K', type=int, default=8) + parser.add_argument('--ema_momentum', type=float, default=0.995) + parser.add_argument('--sigma_bridge', type=float, default=0.03) + parser.add_argument('--eval_every', type=int, default=200) + parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--gpu', type=int, default=1) + parser.add_argument('--output_dir', type=str, default='results/toy_lq') + args = parser.parse_args() + run_experiment(args) + + +if __name__ == '__main__': + main() diff --git a/experiments/toy_lq_sweep.py b/experiments/toy_lq_sweep.py new file mode 100644 index 0000000..ae82ef0 --- /dev/null +++ b/experiments/toy_lq_sweep.py @@ -0,0 +1,243 @@ +""" +Sweep over credit bridge hyperparameters to find a configuration +where the value field gradient actually aligns with the costate. + +Key hypothesis: the credit bridge needs sufficient noise (sigma_bridge) +and temperature (lambda) to make V_phi sensitive to cost-relevant directions. +""" +import os +import sys +import json +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from itertools import product + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.value_net import ValueNet, create_ema_model, update_ema +from models.state_bridge import StateBridgeNet +from experiments.toy_lq import ( + generate_stable_dynamics, rollout_forward, terminal_loss, + exact_costate, make_forward_fn_from_layer +) +from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation, nudging_test + + +def run_credit_bridge_config(config, device): + """Run credit bridge with specific hyperparameters and return final metrics.""" + d = 64 + m = 10 + L = 12 + sigma = 0.03 + batch_size = 256 + num_steps = config['num_steps'] + lr = config['lr'] + lam = config['lam'] + K = config['K'] + ema_momentum = config['ema_momentum'] + sigma_bridge = config['sigma_bridge'] + hidden_dim = config.get('hidden_dim', 128) + use_ln = config.get('use_ln', True) + + torch.manual_seed(42) + np.random.seed(42) + + Ms = generate_stable_dynamics(d, L, spectral_max=0.05, seed=42) + C = torch.randn(m, d, device=device) / np.sqrt(d) + + # Value net - optionally without LayerNorm + value_net = ValueNet(d_hidden=d, s_dim=m, time_embed_dim=16, + hidden_dim=hidden_dim, num_layers=2).to(device) + if not use_ln: + value_net.ln = nn.Identity() + + value_net_ema = create_ema_model(value_net) + opt_value = optim.Adam(value_net.parameters(), lr=lr) + + best_cos = -1.0 + best_step = 0 + history = [] + + for step in range(1, num_steps + 1): + h0 = torch.randn(batch_size, d, device=device) + y = torch.randn(batch_size, m, device=device) + hiddens = rollout_forward(h0, Ms, sigma, L, device) + hL = hiddens[L] + e_T = hL @ C.T - y + s = e_T.detach() + true_loss = terminal_loss(hL.detach(), C, y).detach() + + # Terminal boundary + hL_det = hL.detach() + t_L = torch.ones(batch_size, device=device) + V_terminal = value_net(hL_det, t_L, s) + loss_term = ((V_terminal - true_loss) ** 2).mean() + + # Bridge consistency + loss_bridge = 0.0 + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch_size,), l / L, device=device) + t_l_next = torch.full((batch_size,), (l + 1) / L, device=device) + + V_l = value_net(h_l_det, t_l, s) + + with torch.no_grad(): + h_next_det = hiddens[l + 1].detach() + log_terms = [] + for k in range(K): + noise = sigma_bridge * torch.randn(batch_size, d, device=device) + h_noisy = h_next_det + noise + V_next = value_net_ema(h_noisy, t_l_next, s) + log_terms.append(-V_next / lam) + + log_terms_stack = torch.stack(log_terms, dim=-1) + V_target = -lam * (torch.logsumexp(log_terms_stack, dim=-1) - np.log(K)) + + loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean() + + loss_bridge = loss_bridge / L + total_loss = loss_term + loss_bridge + + opt_value.zero_grad() + total_loss.backward() + torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0) + opt_value.step() + update_ema(value_net, value_net_ema, ema_momentum) + + # Quick evaluation + if step % 500 == 0 or step == num_steps: + with torch.no_grad(): + eval_batch = 128 + h0_e = torch.randn(eval_batch, d, device=device) + y_e = torch.randn(eval_batch, m, device=device) + hiddens_e = rollout_forward(h0_e, Ms, sigma, L, device) + hL_e = hiddens_e[L] + e_T_e = hL_e @ C.T - y_e + s_e = e_T_e.detach() + costates = exact_costate(hiddens_e, Ms, C, y_e, device) + + cos_list = [] + rho_list = [] + nudge_list = [] + for l in range(L): + h_l = hiddens_e[l].detach() + t_l = torch.full((eval_batch,), l / L, device=device) + a_exact = costates[l].detach() + + h_l_req = h_l.clone().requires_grad_(True) + V_l = value_net(h_l_req, t_l, s_e) + a_credit = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0] + + cos_list.append(cosine_similarity_batch(a_credit, a_exact)) + + fwd_fn = make_forward_fn_from_layer(hiddens_e, Ms, C, y_e, sigma, l, device) + rho = perturbation_correlation(h_l, a_credit.detach(), fwd_fn, epsilon=1e-3, M=16) + rho_list.append(rho) + nud = nudging_test(h_l, a_credit.detach(), fwd_fn, eta=0.01) + nudge_list.append(nud) + + avg_cos = np.mean(cos_list) + avg_rho = np.mean(rho_list) + avg_nudge = np.mean(nudge_list) + + if avg_cos > best_cos: + best_cos = avg_cos + best_step = step + + history.append({ + 'step': step, + 'avg_cos': avg_cos, + 'avg_rho': avg_rho, + 'avg_nudge': avg_nudge, + 'loss_term': loss_term.item(), + 'loss_bridge': loss_bridge.item(), + }) + + return { + 'best_cos': best_cos, + 'best_step': best_step, + 'final_cos': history[-1]['avg_cos'], + 'final_rho': history[-1]['avg_rho'], + 'final_nudge': history[-1]['avg_nudge'], + 'history': history, + } + + +def main(): + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + print(f"Device: {device}") + + # Sweep configurations + configs = [ + # Baseline (original) + {'name': 'base', 'lam': 0.1, 'sigma_bridge': 0.03, 'K': 8, 'lr': 1e-3, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, + # Larger noise + {'name': 'noise_0.1', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, + # Much larger noise + {'name': 'noise_0.3', 'lam': 0.1, 'sigma_bridge': 0.3, 'K': 8, 'lr': 1e-3, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, + # Larger lambda + {'name': 'lam_1.0', 'lam': 1.0, 'sigma_bridge': 0.03, 'K': 8, 'lr': 1e-3, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, + # Large noise + large lambda + {'name': 'noise_lam', 'lam': 1.0, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, + # No LayerNorm + {'name': 'no_ln', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': False}, + # Larger value net + {'name': 'big_vnet', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 256, 'use_ln': True}, + # Slower EMA + {'name': 'ema_0.999', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3, + 'ema_momentum': 0.999, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, + # More K samples + {'name': 'K16', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 16, 'lr': 1e-3, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, + # Larger noise + large lambda + no LN + {'name': 'best_combo', 'lam': 1.0, 'sigma_bridge': 0.3, 'K': 8, 'lr': 1e-3, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': False}, + # Very large sigma + {'name': 'noise_1.0', 'lam': 1.0, 'sigma_bridge': 1.0, 'K': 8, 'lr': 1e-3, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, + # Lower lr + {'name': 'lr_3e-4', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 3e-4, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, + ] + + results = {} + for cfg in configs: + name = cfg.pop('name') + print(f"\n{'='*50}") + print(f"Config: {name}") + print(f" {cfg}") + res = run_credit_bridge_config(cfg, device) + results[name] = res + print(f" Best cos: {res['best_cos']:.4f} (step {res['best_step']})") + print(f" Final cos: {res['final_cos']:.4f}, rho: {res['final_rho']:.4f}, nudge: {res['final_nudge']:.4f}") + cfg['name'] = name # restore + + # Print summary + print("\n" + "="*80) + print("SWEEP SUMMARY") + print("="*80) + print(f"{'Config':<20} {'Best Cos':<12} {'Final Cos':<12} {'Final Rho':<12} {'Final Nudge':<12}") + print("-"*68) + for name, res in results.items(): + print(f"{name:<20} {res['best_cos']:<12.4f} {res['final_cos']:<12.4f} " + f"{res['final_rho']:<12.4f} {res['final_nudge']:<12.4f}") + + # Save + os.makedirs('results/toy_lq', exist_ok=True) + with open('results/toy_lq/sweep_results.json', 'w') as f: + json.dump(results, f, indent=2) + print("\nSaved to results/toy_lq/sweep_results.json") + + +if __name__ == '__main__': + main() diff --git a/experiments/toy_lq_v2.py b/experiments/toy_lq_v2.py new file mode 100644 index 0000000..ab766b6 --- /dev/null +++ b/experiments/toy_lq_v2.py @@ -0,0 +1,327 @@ +""" +Phase A v2: Enhanced toy LQ experiment. + +Key improvements over v1: +1. Terminal gradient matching: V_phi at terminal layer should have grad_h V matching + the exact terminal gradient (this is LOCAL info, no hidden BP needed). +2. Larger noise sweep integrated. +3. Optional FM auxiliary for gradient smoothness. +4. Better diagnostics. + +The terminal gradient a_L = C^T(C h_L - y) is computed from output layer only, +so using it is allowed under the "no hidden BP anchor" constraint. +""" +import os +import sys +import json +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.value_net import ValueNet, create_ema_model, update_ema +from models.state_bridge import StateBridgeNet +from experiments.toy_lq import ( + generate_stable_dynamics, rollout_forward, terminal_loss, + exact_costate, make_forward_fn_from_layer +) +from metrics.credit_metrics import ( + cosine_similarity_batch, perturbation_correlation, nudging_test +) + + +def run_experiment(args): + device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + torch.manual_seed(args.seed) + np.random.seed(args.seed) + + d = args.d_hidden + m = args.output_dim + L = args.num_layers + sigma = args.sigma + batch_size = args.batch_size + num_steps = args.num_steps + lr = args.lr_fb + lam = args.lam + K = args.K + ema_momentum = args.ema_momentum + sigma_bridge = args.sigma_bridge + + print(f"=== Toy LQ v2 Experiment ===") + print(f"d={d}, m={m}, L={L}, sigma={sigma}, seed={args.seed}") + print(f"lam={lam}, sigma_bridge={sigma_bridge}, K={K}") + print(f"terminal_grad_weight={args.term_grad_weight}") + print(f"fm_weight={args.fm_weight}") + print(f"device={device}") + + Ms = generate_stable_dynamics(d, L, spectral_max=0.05, seed=args.seed) + C = torch.randn(m, d, device=device) / np.sqrt(d) + + # DFA + Bs_dfa = [torch.randn(d, m, device=device) / np.sqrt(m) for _ in range(L + 1)] + + # State Bridge + state_bridge = StateBridgeNet(d_hidden=d, s_dim=m, time_embed_dim=16, + hidden_dim=128, num_layers=2).to(device) + opt_state = optim.Adam(state_bridge.parameters(), lr=lr) + + # Credit Bridge + value_net = ValueNet(d_hidden=d, s_dim=m, time_embed_dim=16, + hidden_dim=args.vnet_hidden, num_layers=args.vnet_layers).to(device) + value_net_ema = create_ema_model(value_net) + opt_value = optim.Adam(value_net.parameters(), lr=lr) + + log = {key: [] for key in [ + 'steps', + 'dfa_costate_cos', 'state_costate_cos', 'credit_costate_cos', + 'dfa_rho', 'state_rho', 'credit_rho', + 'dfa_nudge', 'state_nudge', 'credit_nudge', + 'bridge_residual', 'state_bridge_loss', 'credit_bridge_loss', + 'term_loss', 'bridge_loss', 'term_grad_loss', 'fm_loss', + ]} + + for step in range(1, num_steps + 1): + h0 = torch.randn(batch_size, d, device=device) + y = torch.randn(batch_size, m, device=device) + hiddens = rollout_forward(h0, Ms, sigma, L, device) + hL = hiddens[L] + e_T = hL @ C.T - y + s = e_T.detach() + + # ---- Train State Bridge ---- + state_loss = 0.0 + hL_det = hL.detach() + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch_size,), l / L, device=device) + pred_hL = state_bridge(h_l_det, t_l, s) + state_loss = state_loss + ((pred_hL - hL_det) ** 2).sum(dim=-1).mean() + state_loss = state_loss / L + + opt_state.zero_grad() + state_loss.backward() + opt_state.step() + + # ---- Train Credit Bridge ---- + # 1. Terminal boundary: V(h_L, 1, s) ≈ Phi(h_L, y) + hL_det = hL.detach() + t_L = torch.ones(batch_size, device=device) + true_loss = terminal_loss(hL_det, C, y).detach() + V_terminal = value_net(hL_det, t_L, s) + loss_term = ((V_terminal - true_loss) ** 2).mean() + + # 2. Terminal gradient matching: grad_h V(h_L, 1, s) ≈ a_L^exact + # This uses only terminal-local information (no hidden BP) + loss_term_grad = torch.tensor(0.0, device=device) + if args.term_grad_weight > 0: + hL_req = hL.detach().requires_grad_(True) + V_at_L = value_net(hL_req, t_L, s) + grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0] + # Exact terminal gradient: C^T (C h_L - y) + a_L_exact = (e_T @ C).detach() # (batch, d) -- stop grad on target + loss_term_grad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean() + + # 3. Bridge consistency + loss_bridge = 0.0 + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch_size,), l / L, device=device) + t_l_next = torch.full((batch_size,), (l + 1) / L, device=device) + + V_l = value_net(h_l_det, t_l, s) + + with torch.no_grad(): + h_next_det = hiddens[l + 1].detach() + log_terms = [] + for k in range(K): + noise = sigma_bridge * torch.randn(batch_size, d, device=device) + h_noisy = h_next_det + noise + V_next = value_net_ema(h_noisy, t_l_next, s) + log_terms.append(-V_next / lam) + log_terms_stack = torch.stack(log_terms, dim=-1) + V_target = -lam * (torch.logsumexp(log_terms_stack, dim=-1) - np.log(K)) + + loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean() + loss_bridge = loss_bridge / L + + # 4. FM auxiliary (optional): enforce gradient smoothness + loss_fm = torch.tensor(0.0, device=device) + if args.fm_weight > 0: + for l in range(L): + tau = torch.rand(batch_size, 1, device=device) + h_l_det = hiddens[l].detach() + h_next_det = hiddens[l + 1].detach() + f_l = h_next_det - h_l_det # residual + + eps = torch.randn(batch_size, d, device=device) + h_mid = h_l_det + tau * f_l + (tau * (1 - tau)).sqrt() * sigma_bridge * eps + h_mid.requires_grad_(True) + + t_mid = torch.full((batch_size, 1), 0, device=device) + t_mid = (l + tau) / L + t_mid_flat = t_mid.squeeze(-1) + + V_mid = value_net(h_mid, t_mid_flat, s) + grad_V_mid = torch.autograd.grad(V_mid.sum(), h_mid, create_graph=True)[0] + + # Interpolated target gradient + # Get a_l and a_{l+1} from current value net (no create_graph for targets) + h_l_r = h_l_det.clone().requires_grad_(True) + t_l_v = torch.full((batch_size,), l / L, device=device) + V_l_ = value_net(h_l_r, t_l_v, s) + a_l = torch.autograd.grad(V_l_.sum(), h_l_r, create_graph=False)[0].detach() + + h_next_r = h_next_det.clone().requires_grad_(True) + t_next_v = torch.full((batch_size,), (l + 1) / L, device=device) + V_next_ = value_net(h_next_r, t_next_v, s) + a_next = torch.autograd.grad(V_next_.sum(), h_next_r, create_graph=False)[0].detach() + + target_grad = ((1 - tau) * a_l + tau * a_next).detach() + loss_fm = loss_fm + ((grad_V_mid - target_grad) ** 2).sum(dim=-1).mean() + loss_fm = loss_fm / L + + total_loss = (loss_term + + loss_bridge + + args.term_grad_weight * loss_term_grad + + args.fm_weight * loss_fm) + + opt_value.zero_grad() + total_loss.backward() + torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0) + opt_value.step() + update_ema(value_net, value_net_ema, ema_momentum) + + # ---- Evaluation ---- + if step % args.eval_every == 0 or step == 1: + with torch.no_grad(): + eval_batch = 128 + h0_e = torch.randn(eval_batch, d, device=device) + y_e = torch.randn(eval_batch, m, device=device) + hiddens_e = rollout_forward(h0_e, Ms, sigma, L, device) + hL_e = hiddens_e[L] + e_T_e = hL_e @ C.T - y_e + s_e = e_T_e.detach() + costates = exact_costate(hiddens_e, Ms, C, y_e, device) + + dfa_cos, state_cos, credit_cos = [], [], [] + dfa_rho, state_rho, credit_rho = [], [], [] + dfa_nudge, state_nudge, credit_nudge = [], [], [] + bridge_res_list = [] + + for l in range(L): + h_l = hiddens_e[l].detach() + a_exact = costates[l].detach() + t_l = torch.full((eval_batch,), l / L, device=device) + + # DFA + a_dfa = e_T_e @ Bs_dfa[l].T + # State bridge + h_l_r1 = h_l.clone().requires_grad_(True) + pred_hL = state_bridge(h_l_r1, t_l, s_e) + pred_out = pred_hL @ C.T + pred_loss = 0.5 * ((pred_out - y_e) ** 2).sum(dim=-1) + a_state = torch.autograd.grad(pred_loss.sum(), h_l_r1, create_graph=False)[0] + # Credit bridge + h_l_r2 = h_l.clone().requires_grad_(True) + V_l = value_net(h_l_r2, t_l, s_e) + a_credit = torch.autograd.grad(V_l.sum(), h_l_r2, create_graph=False)[0] + + dfa_cos.append(cosine_similarity_batch(a_dfa, a_exact)) + state_cos.append(cosine_similarity_batch(a_state, a_exact)) + credit_cos.append(cosine_similarity_batch(a_credit, a_exact)) + + fwd_fn = make_forward_fn_from_layer(hiddens_e, Ms, C, y_e, sigma, l, device) + + dfa_rho.append(perturbation_correlation(h_l, a_dfa, fwd_fn, epsilon=1e-3, M=16)) + state_rho.append(perturbation_correlation(h_l, a_state.detach(), fwd_fn, epsilon=1e-3, M=16)) + credit_rho.append(perturbation_correlation(h_l, a_credit.detach(), fwd_fn, epsilon=1e-3, M=16)) + + dfa_nudge.append(nudging_test(h_l, a_dfa, fwd_fn, eta=0.01)) + state_nudge.append(nudging_test(h_l, a_state.detach(), fwd_fn, eta=0.01)) + credit_nudge.append(nudging_test(h_l, a_credit.detach(), fwd_fn, eta=0.01)) + + avg = lambda x: float(np.mean(x)) + log['steps'].append(step) + log['dfa_costate_cos'].append(avg(dfa_cos)) + log['state_costate_cos'].append(avg(state_cos)) + log['credit_costate_cos'].append(avg(credit_cos)) + log['dfa_rho'].append(avg(dfa_rho)) + log['state_rho'].append(avg(state_rho)) + log['credit_rho'].append(avg(credit_rho)) + log['dfa_nudge'].append(avg(dfa_nudge)) + log['state_nudge'].append(avg(state_nudge)) + log['credit_nudge'].append(avg(credit_nudge)) + log['state_bridge_loss'].append(state_loss.item()) + log['credit_bridge_loss'].append(total_loss.item()) + log['term_loss'].append(loss_term.item()) + log['bridge_loss'].append(loss_bridge.item()) + log['term_grad_loss'].append(loss_term_grad.item() if isinstance(loss_term_grad, torch.Tensor) else loss_term_grad) + log['fm_loss'].append(loss_fm.item() if isinstance(loss_fm, torch.Tensor) else loss_fm) + + print(f"Step {step}/{num_steps}") + print(f" Costate cos - DFA: {avg(dfa_cos):.4f}, State: {avg(state_cos):.4f}, Credit: {avg(credit_cos):.4f}") + print(f" Perturb rho - DFA: {avg(dfa_rho):.4f}, State: {avg(state_rho):.4f}, Credit: {avg(credit_rho):.4f}") + print(f" Nudging - DFA: {avg(dfa_nudge):.4f}, State: {avg(state_nudge):.4f}, Credit: {avg(credit_nudge):.4f}") + print(f" Losses - term: {loss_term.item():.4f}, bridge: {loss_bridge.item():.4f}, " + f"tgrad: {loss_term_grad.item() if isinstance(loss_term_grad, torch.Tensor) else 0:.4f}, " + f"fm: {loss_fm.item() if isinstance(loss_fm, torch.Tensor) else 0:.4f}") + print(f" Per-layer credit cos: {['%.3f' % x for x in credit_cos]}") + + # Save + os.makedirs(args.output_dir, exist_ok=True) + results = { + 'config': vars(args), + 'log': log, + 'final_per_layer': { + 'dfa_costate_cos': dfa_cos, + 'state_costate_cos': state_cos, + 'credit_costate_cos': credit_cos, + 'dfa_rho': dfa_rho, + 'state_rho': state_rho, + 'credit_rho': credit_rho, + 'dfa_nudge': dfa_nudge, + 'state_nudge': state_nudge, + 'credit_nudge': credit_nudge, + } + } + tag = f"seed{args.seed}_lam{args.lam}_sig{args.sigma_bridge}_tgw{args.term_grad_weight}_fm{args.fm_weight}" + out_path = os.path.join(args.output_dir, f'toy_lq_v2_{tag}.json') + with open(out_path, 'w') as f: + json.dump(results, f, indent=2) + print(f"\nResults saved to {out_path}") + return results + + +def main(): + parser = argparse.ArgumentParser(description='Toy LQ v2') + parser.add_argument('--d_hidden', type=int, default=64) + parser.add_argument('--output_dim', type=int, default=10) + parser.add_argument('--num_layers', type=int, default=12) + parser.add_argument('--sigma', type=float, default=0.03) + parser.add_argument('--batch_size', type=int, default=256) + parser.add_argument('--num_steps', type=int, default=8000) + parser.add_argument('--lr_fb', type=float, default=1e-3) + parser.add_argument('--lam', type=float, default=0.1) + parser.add_argument('--K', type=int, default=8) + parser.add_argument('--ema_momentum', type=float, default=0.995) + parser.add_argument('--sigma_bridge', type=float, default=0.1) + parser.add_argument('--eval_every', type=int, default=500) + parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--gpu', type=int, default=1) + parser.add_argument('--output_dir', type=str, default='results/toy_lq') + parser.add_argument('--vnet_hidden', type=int, default=256) + parser.add_argument('--vnet_layers', type=int, default=3) + # Key new options + parser.add_argument('--term_grad_weight', type=float, default=1.0, + help='Weight for terminal gradient matching loss') + parser.add_argument('--fm_weight', type=float, default=0.0, + help='Weight for FM gradient smoothness auxiliary') + args = parser.parse_args() + run_experiment(args) + + +if __name__ == '__main__': + main() |
