summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
Diffstat (limited to 'experiments')
-rw-r--r--experiments/__init__.py0
-rw-r--r--experiments/__pycache__/__init__.cpython-313.pycbin0 -> 137 bytes
-rw-r--r--experiments/__pycache__/toy_lq.cpython-313.pycbin0 -> 19620 bytes
-rw-r--r--experiments/cifar_resmlp.py775
-rw-r--r--experiments/plot_results.py327
-rw-r--r--experiments/plot_toy_final.py183
-rw-r--r--experiments/toy_lq.py395
-rw-r--r--experiments/toy_lq_sweep.py243
-rw-r--r--experiments/toy_lq_v2.py327
9 files changed, 2250 insertions, 0 deletions
diff --git a/experiments/__init__.py b/experiments/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/experiments/__init__.py
diff --git a/experiments/__pycache__/__init__.cpython-313.pyc b/experiments/__pycache__/__init__.cpython-313.pyc
new file mode 100644
index 0000000..5966841
--- /dev/null
+++ b/experiments/__pycache__/__init__.cpython-313.pyc
Binary files differ
diff --git a/experiments/__pycache__/toy_lq.cpython-313.pyc b/experiments/__pycache__/toy_lq.cpython-313.pyc
new file mode 100644
index 0000000..d8710a8
--- /dev/null
+++ b/experiments/__pycache__/toy_lq.cpython-313.pyc
Binary files differ
diff --git a/experiments/cifar_resmlp.py b/experiments/cifar_resmlp.py
new file mode 100644
index 0000000..1582f6d
--- /dev/null
+++ b/experiments/cifar_resmlp.py
@@ -0,0 +1,775 @@
+"""
+Phase B: Deep Residual MLP on CIFAR-10.
+Compare BP, DFA, State Bridge, Credit Bridge.
+
+CRITICAL CONSTRAINT: No hidden BP anchor for non-BP methods.
+All block updates use detached hidden states and local surrogates.
+"""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+import copy
+import time
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.residual_mlp import ResidualMLP
+from models.value_net import ValueNet, create_ema_model, update_ema
+from models.state_bridge import StateBridgeNet
+from metrics.credit_metrics import (
+ cosine_similarity_batch, perturbation_correlation, nudging_test,
+ offline_bp_cosine, feature_drift
+)
+
+
+def get_data(dataset='cifar10', batch_size=128):
+ if dataset == 'cifar10':
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
+ testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
+ input_dim = 32 * 32 * 3
+ num_classes = 10
+ elif dataset == 'fashionmnist':
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(28, padding=2),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.2860,), (0.3530,)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.2860,), (0.3530,)),
+ ])
+ trainset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform_train)
+ testset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform_test)
+ input_dim = 28 * 28
+ num_classes = 10
+ else:
+ raise ValueError(f"Unknown dataset: {dataset}")
+
+ train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
+ test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
+ return train_loader, test_loader, input_dim, num_classes
+
+
+def evaluate(model, test_loader, device):
+ model.eval()
+ correct, total = 0, 0
+ with torch.no_grad():
+ for x, y in test_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ logits = model(x)
+ correct += (logits.argmax(1) == y).sum().item()
+ total += x.size(0)
+ return correct / total
+
+
+# =============================================================================
+# BP Baseline
+# =============================================================================
+def train_bp(model, train_loader, test_loader, device, args):
+ """Standard end-to-end backprop training."""
+ optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd)
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
+
+ log = {'train_loss': [], 'train_acc': [], 'test_acc': []}
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ total_loss, correct, total = 0, 0, 0
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ logits = model(x)
+ loss = F.cross_entropy(logits, y)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ total_loss += loss.item() * x.size(0)
+ correct += (logits.argmax(1) == y).sum().item()
+ total += x.size(0)
+
+ scheduler.step()
+ train_loss = total_loss / total
+ train_acc = correct / total
+ test_acc = evaluate(model, test_loader, device)
+ log['train_loss'].append(train_loss)
+ log['train_acc'].append(train_acc)
+ log['test_acc'].append(test_acc)
+ if epoch % 10 == 0 or epoch == 1:
+ print(f" [BP] Epoch {epoch}: loss={train_loss:.4f}, train={train_acc:.4f}, test={test_acc:.4f}")
+
+ return log
+
+
+# =============================================================================
+# DFA Baseline
+# =============================================================================
+def train_dfa(model, train_loader, test_loader, device, args):
+ """
+ DFA training with fixed random feedback matrices.
+ Each block updated with local surrogate: L_l = <F_l(h_l), sg[a_{l+1}^DFA]>.
+ Output head updated with exact CE gradient (h_L detached).
+ Embedding updated via DFA credit at h_0.
+ """
+ d = model.d_hidden
+ num_classes = args.num_classes
+ L = model.num_blocks
+
+ # Fixed random feedback matrices, one per block
+ Bs = [torch.randn(d, num_classes, device=device) / np.sqrt(num_classes) for _ in range(L)]
+
+ # Separate optimizers
+ block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd)
+ for block in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd)
+ head_opt = optim.AdamW(
+ list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=args.lr, weight_decay=args.wd
+ )
+
+ all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts]
+ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)])
+
+ log = {'train_loss': [], 'train_acc': [], 'test_acc': []}
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ total_loss, correct, total = 0, 0, 0
+
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batch = x.size(0)
+
+ # Forward pass (no grad for hidden states)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ # e_T = softmax(logits) - one_hot(y)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1 # (batch, num_classes)
+
+ # 1. Update output head: exact CE gradient, h_L detached
+ hL_det = hiddens[-1].detach()
+ logits_out = model.out_head(model.out_ln(hL_det))
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad()
+ loss_out.backward()
+ head_opt.step()
+
+ # 2. Update each block with DFA local surrogate
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ # DFA credit: a_{l+1} = B_l @ e_T^T -> (d, batch) -> transpose
+ a_dfa = (e_T @ Bs[l].T).detach() # (batch, d) = (batch, C) @ (C, d)
+ # Normalize
+ rms = (a_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_dfa_norm = a_dfa / rms
+ # Local surrogate
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a_dfa_norm).sum(dim=-1).mean()
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ block_opts[l].step()
+
+ # 3. Update embedding with DFA credit at h_0
+ a_0_dfa = (e_T @ Bs[0].T).detach()
+ rms_0 = (a_0_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_0_norm = a_0_dfa / rms_0
+ h0 = model.embed(x)
+ embed_loss = (h0 * a_0_norm).sum(dim=-1).mean()
+ embed_opt.zero_grad()
+ embed_loss.backward()
+ embed_opt.step()
+
+ total_loss += loss_val.item() * batch
+ correct += (logits.argmax(1) == y).sum().item()
+ total += batch
+
+ for s in all_schedulers:
+ s.step()
+
+ train_loss = total_loss / total
+ train_acc = correct / total
+ test_acc = evaluate(model, test_loader, device)
+ log['train_loss'].append(train_loss)
+ log['train_acc'].append(train_acc)
+ log['test_acc'].append(test_acc)
+ if epoch % 10 == 0 or epoch == 1:
+ print(f" [DFA] Epoch {epoch}: loss={train_loss:.4f}, train={train_acc:.4f}, test={test_acc:.4f}")
+
+ return log, Bs
+
+
+# =============================================================================
+# State Bridge
+# =============================================================================
+def train_state_bridge(model, train_loader, test_loader, device, args):
+ """
+ State Bridge: predict terminal h_L from (h_l, t_l, s), derive credit as
+ a_l = grad_{h_l} CE(W_out * LN(G_psi(h_l, t_l, s)), y).
+ """
+ d = model.d_hidden
+ num_classes = args.num_classes
+ L = model.num_blocks
+
+ state_pred = StateBridgeNet(
+ d_hidden=d, s_dim=num_classes, time_embed_dim=32, hidden_dim=256, num_layers=3
+ ).to(device)
+
+ block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd)
+ for block in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd)
+ head_opt = optim.AdamW(
+ list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=args.lr, weight_decay=args.wd
+ )
+ state_opt = optim.Adam(state_pred.parameters(), lr=args.lr_fb)
+
+ all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts]
+ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)])
+
+ log = {'train_loss': [], 'train_acc': [], 'test_acc': [], 'state_pred_error': []}
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ state_pred.train()
+ total_loss, correct, total = 0, 0, 0
+ total_se = 0
+
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batch = x.size(0)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+
+ hL_det = hiddens[-1].detach()
+
+ # Train state predictor: G_psi(h_l, t_l, s) -> h_L
+ # Predict the *residual* from h_l to h_L for numerical stability
+ state_loss = 0.0
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ pred_hL = state_pred(h_l_det, t_l, s)
+ # Target: h_L (use normalized MSE for stability)
+ target = hL_det
+ target_norm = target.norm(dim=-1, keepdim=True).clamp(min=1.0)
+ state_loss = state_loss + (((pred_hL - target) / target_norm) ** 2).sum(dim=-1).mean()
+ state_loss = state_loss / L
+ state_opt.zero_grad()
+ state_loss.backward()
+ state_opt.step()
+ total_se += state_loss.item() * batch
+
+ # Compute credits: a_l = grad_{h_l} CE(out_head(LN(G(h_l, t_l, s))), y)
+ credits = []
+ for l in range(L):
+ h_l_det = hiddens[l].detach().requires_grad_(True)
+ t_l = torch.full((batch,), l / L, device=device)
+ pred_hL = state_pred(h_l_det, t_l, s)
+ pred_logits = model.out_head(model.out_ln(pred_hL))
+ pred_loss = F.cross_entropy(pred_logits, y, reduction='sum')
+ a_l = torch.autograd.grad(pred_loss, h_l_det, create_graph=False)[0]
+ credits.append(a_l.detach())
+
+ # Update output head
+ logits_out = model.out_head(model.out_ln(hL_det))
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad()
+ loss_out.backward()
+ head_opt.step()
+
+ # Update blocks
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a = credits[l]
+ rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a / rms
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a_norm).sum(dim=-1).mean()
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ block_opts[l].step()
+
+ # Update embedding with credit at layer 0
+ a_0 = credits[0]
+ rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_0_norm = a_0 / rms_0
+ h0 = model.embed(x)
+ embed_loss = (h0 * a_0_norm).sum(dim=-1).mean()
+ embed_opt.zero_grad()
+ embed_loss.backward()
+ embed_opt.step()
+
+ total_loss += loss_val.item() * batch
+ correct += (logits.argmax(1) == y).sum().item()
+ total += batch
+
+ for sch in all_schedulers:
+ sch.step()
+
+ train_loss = total_loss / total
+ train_acc = correct / total
+ test_acc = evaluate(model, test_loader, device)
+ se = total_se / total
+ log['train_loss'].append(train_loss)
+ log['train_acc'].append(train_acc)
+ log['test_acc'].append(test_acc)
+ log['state_pred_error'].append(se)
+ if epoch % 10 == 0 or epoch == 1:
+ print(f" [SB] Epoch {epoch}: loss={train_loss:.4f}, train={train_acc:.4f}, "
+ f"test={test_acc:.4f}, state_err={se:.4f}")
+
+ return log, state_pred
+
+
+# =============================================================================
+# Credit Bridge
+# =============================================================================
+def train_credit_bridge(model, train_loader, test_loader, device, args):
+ """
+ Credit Bridge: learn V_phi(h_l, t_l, s) -> scalar value.
+ Credit: a_l = grad_{h_l} V_phi.
+ Training: terminal boundary + bridge consistency + terminal gradient matching.
+ The terminal gradient is local (output layer only), NOT hidden BP.
+
+ Uses a warmup phase: first warmup_epochs, only train value net + output head,
+ then start using credit bridge signals to update blocks.
+ During warmup, blocks get DFA-style updates as a fallback.
+ """
+ d = model.d_hidden
+ num_classes = args.num_classes
+ L = model.num_blocks
+ warmup_epochs = max(1, args.epochs // 5) # 20% warmup
+
+ value_net = ValueNet(
+ d_hidden=d, s_dim=num_classes, time_embed_dim=32, hidden_dim=256, num_layers=3
+ ).to(device)
+ value_net_ema = create_ema_model(value_net)
+
+ # DFA fallback matrices for warmup
+ Bs_fallback = [torch.randn(d, num_classes, device=device) / np.sqrt(num_classes)
+ for _ in range(L)]
+
+ block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd)
+ for block in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd)
+ head_opt = optim.AdamW(
+ list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=args.lr, weight_decay=args.wd
+ )
+ value_opt = optim.Adam(value_net.parameters(), lr=args.lr_fb)
+
+ all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts]
+ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)])
+
+ lam = args.lam
+ K_samples = args.K
+ sigma_bridge = args.sigma_bridge
+ ema_momentum = args.ema_momentum
+ term_grad_weight = args.term_grad_weight
+
+ log = {'train_loss': [], 'train_acc': [], 'test_acc': [], 'value_loss': []}
+
+ print(f" [CB] Warmup phase: {warmup_epochs} epochs (DFA fallback + value net training)")
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ value_net.train()
+ total_loss, correct, total = 0, 0, 0
+ total_vloss = 0
+
+ # Blend factor: 0 during warmup, linearly increases to 1 after warmup
+ if epoch <= warmup_epochs:
+ credit_blend = 0.0
+ else:
+ credit_blend = min(1.0, (epoch - warmup_epochs) / max(1, warmup_epochs))
+
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batch = x.size(0)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+ true_loss = F.cross_entropy(logits, y, reduction='none').detach()
+
+ hL_det = hiddens[-1].detach()
+
+ # ---- Train value net (always) ----
+ t_L = torch.ones(batch, device=device)
+ V_terminal = value_net(hL_det, t_L, s)
+ loss_term = ((V_terminal - true_loss) ** 2).mean()
+
+ # Terminal gradient matching
+ loss_tgrad = torch.tensor(0.0, device=device)
+ if term_grad_weight > 0:
+ hL_req = hL_det.clone().requires_grad_(True)
+ V_at_L = value_net(hL_req, t_L, s)
+ grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0]
+ hL_req2 = hL_det.clone().requires_grad_(True)
+ logits_tgt = model.out_head(model.out_ln(hL_req2))
+ ce_loss = F.cross_entropy(logits_tgt, y, reduction='sum')
+ a_L_exact = torch.autograd.grad(ce_loss, hL_req2, create_graph=False)[0].detach()
+ loss_tgrad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean()
+
+ # Bridge consistency
+ loss_bridge = 0.0
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ t_l_next = torch.full((batch,), (l + 1) / L, device=device)
+ V_l = value_net(h_l_det, t_l, s)
+
+ with torch.no_grad():
+ h_next_det = hiddens[l + 1].detach()
+ log_terms = []
+ for k in range(K_samples):
+ noise = sigma_bridge * torch.randn_like(h_next_det)
+ V_next = value_net_ema(h_next_det + noise, t_l_next, s)
+ log_terms.append(-V_next / lam)
+ log_stack = torch.stack(log_terms, dim=-1)
+ V_target = -lam * (torch.logsumexp(log_stack, dim=-1) - np.log(K_samples))
+
+ loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean()
+ loss_bridge = loss_bridge / L
+
+ value_loss = loss_term + loss_bridge + term_grad_weight * loss_tgrad
+
+ value_opt.zero_grad()
+ value_loss.backward()
+ torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0)
+ value_opt.step()
+ update_ema(value_net, value_net_ema, ema_momentum)
+ total_vloss += value_loss.item() * batch
+
+ # ---- Compute credits ----
+ # Credit bridge credits
+ cb_credits = []
+ for l in range(L):
+ h_l_det = hiddens[l].detach().requires_grad_(True)
+ t_l = torch.full((batch,), l / L, device=device)
+ V_l = value_net(h_l_det, t_l, s)
+ a_l = torch.autograd.grad(V_l.sum(), h_l_det, create_graph=False)[0]
+ cb_credits.append(a_l.detach())
+
+ # DFA fallback credits
+ dfa_credits = [(e_T @ Bs_fallback[l].T).detach() for l in range(L)]
+
+ # Blend credits
+ credits = []
+ for l in range(L):
+ if credit_blend >= 1.0:
+ a = cb_credits[l]
+ elif credit_blend <= 0.0:
+ a = dfa_credits[l]
+ else:
+ # Normalize both before blending
+ cb_rms = (cb_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ dfa_rms = (dfa_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a = credit_blend * (cb_credits[l] / cb_rms) + (1 - credit_blend) * (dfa_credits[l] / dfa_rms)
+ credits.append(a)
+
+ # ---- Update output head ----
+ logits_out = model.out_head(model.out_ln(hL_det))
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad()
+ loss_out.backward()
+ head_opt.step()
+
+ # ---- Update blocks ----
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a = credits[l]
+ rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a / rms
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a_norm).sum(dim=-1).mean()
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ block_opts[l].step()
+
+ # ---- Update embedding ----
+ a_0 = credits[0]
+ rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_0_norm = a_0 / rms_0
+ h0 = model.embed(x)
+ embed_loss = (h0 * a_0_norm).sum(dim=-1).mean()
+ embed_opt.zero_grad()
+ embed_loss.backward()
+ embed_opt.step()
+
+ total_loss += loss_val.item() * batch
+ correct += (logits.argmax(1) == y).sum().item()
+ total += batch
+
+ for sch in all_schedulers:
+ sch.step()
+
+ train_loss = total_loss / total
+ train_acc = correct / total
+ test_acc = evaluate(model, test_loader, device)
+ vloss = total_vloss / total
+ log['train_loss'].append(train_loss)
+ log['train_acc'].append(train_acc)
+ log['test_acc'].append(test_acc)
+ log['value_loss'].append(vloss)
+ if epoch % 10 == 0 or epoch == 1:
+ phase = "warmup" if epoch <= warmup_epochs else f"blend={credit_blend:.2f}"
+ print(f" [CB] Epoch {epoch} ({phase}): loss={train_loss:.4f}, train={train_acc:.4f}, "
+ f"test={test_acc:.4f}, vloss={vloss:.6f}")
+
+ return log, value_net, value_net_ema
+
+
+# =============================================================================
+# Diagnostics
+# =============================================================================
+def compute_diagnostics(model, method_name, test_loader, device, args,
+ value_net=None, state_predictor=None, dfa_Bs=None):
+ """Compute all diagnostic metrics for a trained model."""
+ model.eval()
+ if value_net is not None:
+ value_net.eval()
+ if state_predictor is not None:
+ state_predictor.eval()
+
+ d = model.d_hidden
+ L = model.num_blocks
+ num_classes = args.num_classes
+
+ # Get one batch for diagnostics
+ for x, y in test_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ break
+
+ batch = x.size(0)
+
+ # Forward with hidden states, need grad for BP cosine
+ logits_bp, hiddens_bp = model(x, return_hidden=True)
+ for l in range(L + 1):
+ hiddens_bp[l].retain_grad()
+ loss_bp = F.cross_entropy(logits_bp, y)
+ loss_bp.backward()
+ bp_grads = {l: hiddens_bp[l].grad.detach().clone() for l in range(L + 1)}
+
+ # Forward again without grad for clean hidden states
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+
+ results = {
+ 'bp_cosine': [],
+ 'perturbation_rho': [],
+ 'nudging': {'0.001': [], '0.003': [], '0.01': []},
+ }
+
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+
+ # Get credit
+ if method_name == 'bp':
+ a_l = bp_grads[l]
+ elif method_name == 'dfa':
+ a_l = (e_T @ dfa_Bs[l].T).detach()
+ elif method_name == 'state_bridge':
+ h_l_req = h_l.clone().requires_grad_(True)
+ pred_hL = state_predictor(h_l_req, t_l, s)
+ pred_logits = model.out_head(model.out_ln(pred_hL))
+ pred_loss = F.cross_entropy(pred_logits, y, reduction='sum')
+ a_l = torch.autograd.grad(pred_loss, h_l_req, create_graph=False)[0].detach()
+ elif method_name == 'credit_bridge':
+ h_l_req = h_l.clone().requires_grad_(True)
+ V_l = value_net(h_l_req, t_l, s)
+ a_l = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0].detach()
+ else:
+ raise ValueError(f"Unknown method: {method_name}")
+
+ # BP cosine
+ bp_cos = cosine_similarity_batch(a_l, bp_grads[l])
+ results['bp_cosine'].append(bp_cos)
+
+ # Forward function for perturbation and nudging
+ def make_fwd_fn(start_l):
+ def fwd_fn(h):
+ with torch.no_grad():
+ curr = h
+ for i in range(start_l, L):
+ curr = curr + model.blocks[i](curr)
+ out = model.out_head(model.out_ln(curr))
+ return F.cross_entropy(out, y, reduction='none')
+ return fwd_fn
+
+ fwd_fn = make_fwd_fn(l)
+ rho = perturbation_correlation(h_l, a_l, fwd_fn, epsilon=1e-3, M=16)
+ results['perturbation_rho'].append(rho)
+
+ for eta in [0.001, 0.003, 0.01]:
+ nud = nudging_test(h_l, a_l, fwd_fn, eta=eta)
+ results['nudging'][str(eta)].append(nud)
+
+ return results
+
+
+# =============================================================================
+# Main
+# =============================================================================
+def run_experiment(args):
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ print(f"Using device: {device}")
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ all_results = {}
+
+ for seed in args.seeds:
+ print(f"\n{'='*60}")
+ print(f"Seed {seed}")
+ print(f"{'='*60}")
+
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+ train_loader, test_loader, input_dim, num_classes = get_data(args.dataset, args.batch_size)
+ args.num_classes = num_classes
+
+ seed_results = {}
+
+ # ---- BP ----
+ print("\n--- BP ---")
+ model_bp = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device)
+ init_bp = {n: p.clone().detach() for n, p in model_bp.named_parameters()}
+ bp_log = train_bp(model_bp, train_loader, test_loader, device, args)
+ bp_diag = compute_diagnostics(model_bp, 'bp', test_loader, device, args)
+ bp_drift = feature_drift(init_bp, {n: p.detach() for n, p in model_bp.named_parameters()})
+ seed_results['bp'] = {'log': bp_log, 'diagnostics': bp_diag, 'drift': bp_drift}
+ print(f" Final test acc: {bp_log['test_acc'][-1]:.4f}")
+
+ # ---- DFA ----
+ print("\n--- DFA ---")
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ model_dfa = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device)
+ init_dfa = {n: p.clone().detach() for n, p in model_dfa.named_parameters()}
+ dfa_log, dfa_Bs = train_dfa(model_dfa, train_loader, test_loader, device, args)
+ dfa_diag = compute_diagnostics(model_dfa, 'dfa', test_loader, device, args, dfa_Bs=dfa_Bs)
+ dfa_drift = feature_drift(init_dfa, {n: p.detach() for n, p in model_dfa.named_parameters()})
+ seed_results['dfa'] = {'log': dfa_log, 'diagnostics': dfa_diag, 'drift': dfa_drift}
+ print(f" Final test acc: {dfa_log['test_acc'][-1]:.4f}")
+
+ # ---- State Bridge ----
+ print("\n--- State Bridge ---")
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ model_sb = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device)
+ init_sb = {n: p.clone().detach() for n, p in model_sb.named_parameters()}
+ sb_log, state_pred = train_state_bridge(model_sb, train_loader, test_loader, device, args)
+ sb_diag = compute_diagnostics(model_sb, 'state_bridge', test_loader, device, args,
+ state_predictor=state_pred)
+ sb_drift = feature_drift(init_sb, {n: p.detach() for n, p in model_sb.named_parameters()})
+ seed_results['state_bridge'] = {'log': sb_log, 'diagnostics': sb_diag, 'drift': sb_drift}
+ print(f" Final test acc: {sb_log['test_acc'][-1]:.4f}")
+
+ # ---- Credit Bridge ----
+ print("\n--- Credit Bridge ---")
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ model_cb = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device)
+ init_cb = {n: p.clone().detach() for n, p in model_cb.named_parameters()}
+ cb_log, vnet, vnet_ema = train_credit_bridge(model_cb, train_loader, test_loader, device, args)
+ cb_diag = compute_diagnostics(model_cb, 'credit_bridge', test_loader, device, args,
+ value_net=vnet)
+ cb_drift = feature_drift(init_cb, {n: p.detach() for n, p in model_cb.named_parameters()})
+ seed_results['credit_bridge'] = {'log': cb_log, 'diagnostics': cb_diag, 'drift': cb_drift}
+ print(f" Final test acc: {cb_log['test_acc'][-1]:.4f}")
+
+ all_results[seed] = seed_results
+
+ # Save
+ def serialize(obj):
+ if isinstance(obj, dict):
+ return {str(k): serialize(v) for k, v in obj.items()}
+ elif isinstance(obj, list):
+ return [serialize(v) for v in obj]
+ elif isinstance(obj, (np.floating, np.integer)):
+ return float(obj)
+ elif isinstance(obj, np.ndarray):
+ return obj.tolist()
+ elif isinstance(obj, torch.Tensor):
+ return obj.cpu().numpy().tolist()
+ return obj
+
+ save_data = serialize(all_results)
+ save_data['config'] = serialize(vars(args))
+ out_path = os.path.join(args.output_dir, f'results_{args.dataset}.json')
+ with open(out_path, 'w') as f:
+ json.dump(save_data, f, indent=2)
+ print(f"\nAll results saved to {out_path}")
+ return all_results
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--dataset', type=str, default='cifar10')
+ parser.add_argument('--d_hidden', type=int, default=512)
+ parser.add_argument('--num_blocks', type=int, default=12)
+ parser.add_argument('--batch_size', type=int, default=128)
+ parser.add_argument('--epochs', type=int, default=100)
+ parser.add_argument('--lr', type=float, default=1e-3)
+ parser.add_argument('--lr_fb', type=float, default=1e-3)
+ parser.add_argument('--wd', type=float, default=0.01)
+ parser.add_argument('--lam', type=float, default=0.1)
+ parser.add_argument('--K', type=int, default=4)
+ parser.add_argument('--sigma_bridge', type=float, default=0.05)
+ parser.add_argument('--ema_momentum', type=float, default=0.995)
+ parser.add_argument('--term_grad_weight', type=float, default=1.0)
+ parser.add_argument('--seeds', type=int, nargs='+', default=[42, 123, 456])
+ parser.add_argument('--gpu', type=int, default=1)
+ parser.add_argument('--output_dir', type=str, default='results/cifar10')
+ args = parser.parse_args()
+ run_experiment(args)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/experiments/plot_results.py b/experiments/plot_results.py
new file mode 100644
index 0000000..e3e2754
--- /dev/null
+++ b/experiments/plot_results.py
@@ -0,0 +1,327 @@
+"""Generate plots for toy LQ and CIFAR-10 experiments."""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+
+
+def plot_toy_results(results_dir='results/toy_lq', output_dir='report'):
+ """Plot toy LQ experiment results."""
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Collect results across seeds
+ files = [f for f in os.listdir(results_dir) if f.startswith('toy_lq_seed') and f.endswith('.json')]
+ if not files:
+ print(f"No toy results found in {results_dir}")
+ return
+
+ all_data = []
+ for f in sorted(files):
+ with open(os.path.join(results_dir, f)) as fp:
+ all_data.append(json.load(fp))
+
+ # Use the last result for per-layer plots (or average if multiple seeds)
+ data = all_data[-1]
+ per_layer = data['final_per_layer']
+ log_data = data['log']
+
+ num_layers = len(per_layer['dfa_costate_cos'])
+ layers = list(range(num_layers))
+
+ # 1. Per-layer costate cosine
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+ ax.plot(layers, per_layer['dfa_costate_cos'], 'o-', label='DFA', color='blue')
+ ax.plot(layers, per_layer['state_costate_cos'], 's-', label='State Bridge', color='orange')
+ ax.plot(layers, per_layer['credit_costate_cos'], '^-', label='Credit Bridge', color='green')
+ ax.set_xlabel('Layer')
+ ax.set_ylabel('Cosine Similarity with Exact Costate')
+ ax.set_title('Exact Costate Cosine (Toy LQ)')
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+ ax.set_ylim(-0.2, 1.05)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'toy_costate_cosine.png'), dpi=150)
+ plt.close(fig)
+
+ # 2. Per-layer perturbation correlation
+ num_rho_layers = len(per_layer['dfa_rho'])
+ rho_layers = list(range(num_rho_layers))
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+ ax.plot(rho_layers, per_layer['dfa_rho'], 'o-', label='DFA', color='blue')
+ ax.plot(rho_layers, per_layer['state_rho'], 's-', label='State Bridge', color='orange')
+ ax.plot(rho_layers, per_layer['credit_rho'], '^-', label='Credit Bridge', color='green')
+ ax.set_xlabel('Layer')
+ ax.set_ylabel('Perturbation Correlation (rho)')
+ ax.set_title('Local Perturbation Correlation (Toy LQ)')
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+ ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'toy_perturbation_rho.png'), dpi=150)
+ plt.close(fig)
+
+ # 3. Per-layer nudging test
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+ ax.plot(rho_layers, per_layer['dfa_nudge'], 'o-', label='DFA', color='blue')
+ ax.plot(rho_layers, per_layer['state_nudge'], 's-', label='State Bridge', color='orange')
+ ax.plot(rho_layers, per_layer['credit_nudge'], '^-', label='Credit Bridge', color='green')
+ ax.set_xlabel('Layer')
+ ax.set_ylabel('Nudge Delta (negative = good)')
+ ax.set_title('Nudging Test (Toy LQ)')
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+ ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'toy_nudging.png'), dpi=150)
+ plt.close(fig)
+
+ # 4. Bridge residual over training
+ if log_data['bridge_residual']:
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+ ax.plot(log_data['steps'], log_data['bridge_residual'], '-', color='green')
+ ax.set_xlabel('Training Step')
+ ax.set_ylabel('Bridge Residual')
+ ax.set_title('Bridge Residual Over Training (Toy LQ)')
+ ax.grid(True, alpha=0.3)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'toy_bridge_residual.png'), dpi=150)
+ plt.close(fig)
+
+ # 5. Training curves (costate cosine over time)
+ fig, axes = plt.subplots(1, 3, figsize=(18, 5))
+ for ax, key, title in zip(axes,
+ ['dfa_costate_cos', 'state_costate_cos', 'credit_costate_cos'],
+ ['DFA', 'State Bridge', 'Credit Bridge']):
+ ax.plot(log_data['steps'], log_data[key], '-')
+ ax.set_xlabel('Training Step')
+ ax.set_ylabel('Avg Costate Cosine')
+ ax.set_title(f'{title} - Costate Cosine Over Training')
+ ax.grid(True, alpha=0.3)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'toy_cosine_training.png'), dpi=150)
+ plt.close(fig)
+
+ # 6. Per-layer bridge residual
+ if per_layer.get('bridge_residual'):
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+ br_layers = list(range(len(per_layer['bridge_residual'])))
+ ax.plot(br_layers, per_layer['bridge_residual'], '^-', color='green')
+ ax.set_xlabel('Layer')
+ ax.set_ylabel('Bridge Residual')
+ ax.set_title('Per-Layer Bridge Residual (Toy LQ)')
+ ax.grid(True, alpha=0.3)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'toy_bridge_residual_per_layer.png'), dpi=150)
+ plt.close(fig)
+
+ print(f"Toy LQ plots saved to {output_dir}/")
+
+
+def plot_cifar_results(results_path='results/cifar10/cifar_results_cifar10.json', output_dir='report'):
+ """Plot CIFAR-10 experiment results."""
+ os.makedirs(output_dir, exist_ok=True)
+
+ if not os.path.exists(results_path):
+ print(f"No CIFAR results found at {results_path}")
+ return
+
+ with open(results_path) as f:
+ data = json.load(f)
+
+ config = data.pop('config', {})
+ methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge']
+ colors = {'bp': 'red', 'dfa': 'blue', 'state_bridge': 'orange', 'credit_bridge': 'green'}
+ labels = {'bp': 'BP', 'dfa': 'DFA', 'state_bridge': 'State Bridge', 'credit_bridge': 'Credit Bridge'}
+
+ seeds = [k for k in data.keys() if k != 'config']
+
+ # 1. Accuracy curves (mean ± std across seeds)
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5))
+ for method in methods:
+ train_accs = []
+ test_accs = []
+ for seed in seeds:
+ if method in data[seed]:
+ log = data[seed][method]['log']
+ train_accs.append(log['train_acc'])
+ test_accs.append(log['test_acc'])
+
+ if train_accs:
+ train_arr = np.array(train_accs)
+ test_arr = np.array(test_accs)
+ epochs = np.arange(1, train_arr.shape[1] + 1)
+
+ mean_train = train_arr.mean(0)
+ std_train = train_arr.std(0)
+ mean_test = test_arr.mean(0)
+ std_test = test_arr.std(0)
+
+ axes[0].plot(epochs, mean_train, '-', color=colors[method], label=labels[method])
+ axes[0].fill_between(epochs, mean_train - std_train, mean_train + std_train,
+ alpha=0.15, color=colors[method])
+ axes[1].plot(epochs, mean_test, '-', color=colors[method], label=labels[method])
+ axes[1].fill_between(epochs, mean_test - std_test, mean_test + std_test,
+ alpha=0.15, color=colors[method])
+
+ axes[0].set_xlabel('Epoch')
+ axes[0].set_ylabel('Train Accuracy')
+ axes[0].set_title('Train Accuracy')
+ axes[0].legend()
+ axes[0].grid(True, alpha=0.3)
+ axes[1].set_xlabel('Epoch')
+ axes[1].set_ylabel('Test Accuracy')
+ axes[1].set_title('Test Accuracy')
+ axes[1].legend()
+ axes[1].grid(True, alpha=0.3)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'cifar_accuracy.png'), dpi=150)
+ plt.close(fig)
+
+ # 2. Per-layer diagnostics (from last seed)
+ last_seed = seeds[-1]
+
+ # BP cosine per layer
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+ for method in methods:
+ if method in data[last_seed] and 'diagnostics' in data[last_seed][method]:
+ diag = data[last_seed][method]['diagnostics']
+ if 'bp_cosine' in diag:
+ layers = list(range(len(diag['bp_cosine'])))
+ ax.plot(layers, diag['bp_cosine'], 'o-', color=colors[method], label=labels[method])
+ ax.set_xlabel('Layer')
+ ax.set_ylabel('Cosine with BP Gradient')
+ ax.set_title('Offline BP Cosine (CIFAR-10)')
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'cifar_bp_cosine.png'), dpi=150)
+ plt.close(fig)
+
+ # Perturbation rho per layer
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+ for method in methods:
+ if method in data[last_seed] and 'diagnostics' in data[last_seed][method]:
+ diag = data[last_seed][method]['diagnostics']
+ if 'perturbation_rho' in diag:
+ layers = list(range(len(diag['perturbation_rho'])))
+ ax.plot(layers, diag['perturbation_rho'], 'o-', color=colors[method], label=labels[method])
+ ax.set_xlabel('Layer')
+ ax.set_ylabel('Perturbation Correlation (rho)')
+ ax.set_title('Local Perturbation Correlation (CIFAR-10)')
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+ ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'cifar_perturbation_rho.png'), dpi=150)
+ plt.close(fig)
+
+ # Nudging test per layer (eta=0.01)
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+ for method in methods:
+ if method in data[last_seed] and 'diagnostics' in data[last_seed][method]:
+ diag = data[last_seed][method]['diagnostics']
+ if 'nudging' in diag and '0.01' in diag['nudging']:
+ nud = diag['nudging']['0.01']
+ layers = list(range(len(nud)))
+ ax.plot(layers, nud, 'o-', color=colors[method], label=labels[method])
+ ax.set_xlabel('Layer')
+ ax.set_ylabel('Nudge Delta (negative = good)')
+ ax.set_title('Nudging Test eta=0.01 (CIFAR-10)')
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+ ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'cifar_nudging.png'), dpi=150)
+ plt.close(fig)
+
+ # Feature drift per layer
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+ for method in methods:
+ if method in data[last_seed] and 'drift' in data[last_seed][method]:
+ drift = data[last_seed][method]['drift']
+ # Extract per-block drift (only block weights)
+ block_drifts = []
+ for l in range(12):
+ key = f'blocks.{l}.w1.weight'
+ if key in drift:
+ block_drifts.append(drift[key])
+ if block_drifts:
+ ax.plot(range(len(block_drifts)), block_drifts, 'o-', color=colors[method], label=labels[method])
+ ax.set_xlabel('Block')
+ ax.set_ylabel('Feature Drift (||W_final - W_init||/||W_init||)')
+ ax.set_title('Feature Drift (CIFAR-10)')
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'cifar_feature_drift.png'), dpi=150)
+ plt.close(fig)
+
+ print(f"CIFAR-10 plots saved to {output_dir}/")
+
+
+def print_summary_table(results_path='results/cifar10/cifar_results_cifar10.json'):
+ """Print summary table of results."""
+ if not os.path.exists(results_path):
+ print(f"No results at {results_path}")
+ return
+
+ with open(results_path) as f:
+ data = json.load(f)
+
+ config = data.pop('config', {})
+ methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge']
+ labels = {'bp': 'BP', 'dfa': 'DFA', 'state_bridge': 'State Bridge', 'credit_bridge': 'Credit Bridge'}
+
+ seeds = [k for k in data.keys() if k != 'config']
+
+ print("\n" + "="*80)
+ print("SUMMARY TABLE")
+ print("="*80)
+ print(f"{'Method':<20} {'Test Acc':<15} {'Avg rho':<15} {'Avg Nudge(0.01)':<15} {'Avg BP Cos':<15}")
+ print("-"*80)
+
+ for method in methods:
+ test_accs = []
+ avg_rhos = []
+ avg_nudges = []
+ avg_bp_cos = []
+
+ for seed in seeds:
+ if method in data[seed]:
+ log = data[seed][method]['log']
+ test_accs.append(log['test_acc'][-1])
+
+ if 'diagnostics' in data[seed][method]:
+ diag = data[seed][method]['diagnostics']
+ if 'perturbation_rho' in diag:
+ avg_rhos.append(np.mean(diag['perturbation_rho']))
+ if 'nudging' in diag and '0.01' in diag['nudging']:
+ avg_nudges.append(np.mean(diag['nudging']['0.01']))
+ if 'bp_cosine' in diag:
+ avg_bp_cos.append(np.mean(diag['bp_cosine']))
+
+ ta = f"{np.mean(test_accs):.4f}±{np.std(test_accs):.4f}" if test_accs else "N/A"
+ rho = f"{np.mean(avg_rhos):.4f}" if avg_rhos else "N/A"
+ nud = f"{np.mean(avg_nudges):.4f}" if avg_nudges else "N/A"
+ bpc = f"{np.mean(avg_bp_cos):.4f}" if avg_bp_cos else "N/A"
+
+ print(f"{labels[method]:<20} {ta:<15} {rho:<15} {nud:<15} {bpc:<15}")
+
+ print("="*80)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--toy_dir', type=str, default='results/toy_lq')
+ parser.add_argument('--cifar_path', type=str, default='results/cifar10/cifar_results_cifar10.json')
+ parser.add_argument('--output_dir', type=str, default='report')
+ args = parser.parse_args()
+
+ plot_toy_results(args.toy_dir, args.output_dir)
+ plot_cifar_results(args.cifar_path, args.output_dir)
+ print_summary_table(args.cifar_path)
diff --git a/experiments/plot_toy_final.py b/experiments/plot_toy_final.py
new file mode 100644
index 0000000..2f7c109
--- /dev/null
+++ b/experiments/plot_toy_final.py
@@ -0,0 +1,183 @@
+"""Generate final toy LQ experiment plots from v2 results across 3 seeds."""
+import os
+import json
+import numpy as np
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+
+output_dir = 'report'
+os.makedirs(output_dir, exist_ok=True)
+
+# Load all v2 results with term_grad_weight=1.0, fm=0.0
+seeds = [42, 123, 456]
+all_data = []
+for seed in seeds:
+ path = f'results/toy_lq/toy_lq_v2_seed{seed}_lam0.1_sig0.1_tgw1.0_fm0.0.json'
+ if os.path.exists(path):
+ with open(path) as f:
+ all_data.append(json.load(f))
+
+if not all_data:
+ print("No results found!")
+ exit()
+
+# Also load v1 baseline (no term_grad) for comparison
+v1_path = 'results/toy_lq/toy_lq_seed42.json'
+v1_data = None
+if os.path.exists(v1_path):
+ with open(v1_path) as f:
+ v1_data = json.load(f)
+
+# Aggregate final per-layer results across seeds
+methods = ['dfa', 'state', 'credit']
+colors = {'dfa': '#2196F3', 'state': '#FF9800', 'credit': '#4CAF50'}
+labels = {'dfa': 'DFA', 'state': 'State Bridge', 'credit': 'Credit Bridge'}
+
+# Per-layer costate cosine
+fig, axes = plt.subplots(1, 3, figsize=(18, 5))
+
+for ax, metric, title, ylabel in zip(
+ axes,
+ ['costate_cos', 'rho', 'nudge'],
+ ['Exact Costate Cosine', 'Perturbation Correlation (ρ)', 'Nudging Test'],
+ ['Cosine Similarity', 'Pearson Correlation', 'Loss Change (negative=good)']
+):
+ for method in methods:
+ key = f'{method}_{metric}'
+ values_per_seed = []
+ for data in all_data:
+ pl = data['final_per_layer']
+ if key in pl:
+ values_per_seed.append(pl[key])
+
+ if values_per_seed:
+ arr = np.array(values_per_seed)
+ mean = arr.mean(axis=0)
+ std = arr.std(axis=0)
+ layers = np.arange(len(mean))
+ ax.plot(layers, mean, 'o-', color=colors[method], label=labels[method], markersize=5)
+ ax.fill_between(layers, mean - std, mean + std, alpha=0.15, color=colors[method])
+
+ ax.set_xlabel('Layer', fontsize=12)
+ ax.set_ylabel(ylabel, fontsize=12)
+ ax.set_title(title, fontsize=13)
+ ax.legend(fontsize=11)
+ ax.grid(True, alpha=0.3)
+ if metric == 'costate_cos':
+ ax.set_ylim(-0.15, 1.05)
+ elif metric == 'rho':
+ ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+ elif metric == 'nudge':
+ ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+
+fig.suptitle('Toy LQ Sanity Check: Per-Layer Diagnostics (3 seeds)', fontsize=14, y=1.02)
+fig.tight_layout()
+fig.savefig(os.path.join(output_dir, 'toy_per_layer_diagnostics.png'), dpi=150, bbox_inches='tight')
+plt.close(fig)
+print("Saved toy_per_layer_diagnostics.png")
+
+# Training curves
+fig, axes = plt.subplots(1, 3, figsize=(18, 5))
+metric_keys = [
+ ('costate_cos', 'Avg Costate Cosine', 'Cosine Similarity'),
+ ('rho', 'Avg Perturbation ρ', 'Pearson Correlation'),
+ ('nudge', 'Avg Nudging', 'Loss Change'),
+]
+
+for ax, (metric, title, ylabel) in zip(axes, metric_keys):
+ for method in methods:
+ key = f'{method}_{metric}'
+ all_curves = []
+ for data in all_data:
+ log = data['log']
+ full_key = f'{method}_costate_cos' if metric == 'costate_cos' else f'{method}_{metric}'
+ if full_key in log:
+ all_curves.append(np.array(log[full_key]))
+
+ if all_curves:
+ # All should have same length, use shortest
+ min_len = min(len(c) for c in all_curves)
+ arr = np.array([c[:min_len] for c in all_curves])
+ steps = np.array(all_data[0]['log']['steps'][:min_len])
+ mean = arr.mean(axis=0)
+ std = arr.std(axis=0)
+ ax.plot(steps, mean, '-', color=colors[method], label=labels[method])
+ ax.fill_between(steps, mean - std, mean + std, alpha=0.15, color=colors[method])
+
+ ax.set_xlabel('Training Step', fontsize=12)
+ ax.set_ylabel(ylabel, fontsize=12)
+ ax.set_title(title, fontsize=13)
+ ax.legend(fontsize=11)
+ ax.grid(True, alpha=0.3)
+
+fig.suptitle('Toy LQ: Training Curves (3 seeds)', fontsize=14, y=1.02)
+fig.tight_layout()
+fig.savefig(os.path.join(output_dir, 'toy_training_curves.png'), dpi=150, bbox_inches='tight')
+plt.close(fig)
+print("Saved toy_training_curves.png")
+
+# Compare v1 (no term grad) vs v2 (with term grad) for credit bridge
+if v1_data:
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+
+ # v1 credit bridge (no term grad matching)
+ v1_log = v1_data['log']
+ ax.plot(v1_log['steps'], v1_log['credit_costate_cos'],
+ '--', color='red', label='Credit Bridge (w/o terminal grad)', alpha=0.8)
+
+ # v2 credit bridge (with term grad)
+ v2_log = all_data[0]['log'] # seed 42
+ ax.plot(v2_log['steps'], v2_log['credit_costate_cos'],
+ '-', color='green', label='Credit Bridge (w/ terminal grad)')
+
+ # State bridge for reference
+ ax.plot(v2_log['steps'], v2_log['state_costate_cos'],
+ '-', color='orange', label='State Bridge')
+
+ ax.set_xlabel('Training Step', fontsize=12)
+ ax.set_ylabel('Avg Costate Cosine', fontsize=12)
+ ax.set_title('Effect of Terminal Gradient Matching', fontsize=13)
+ ax.legend(fontsize=11)
+ ax.grid(True, alpha=0.3)
+ ax.set_ylim(-0.1, 1.05)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'toy_term_grad_effect.png'), dpi=150)
+ plt.close(fig)
+ print("Saved toy_term_grad_effect.png")
+
+# Bridge residual (from v1 which has it)
+if v1_data and v1_data['log'].get('bridge_residual'):
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+ ax.plot(v1_data['log']['steps'], v1_data['log']['bridge_residual'], '-', color='green')
+ ax.set_xlabel('Training Step', fontsize=12)
+ ax.set_ylabel('Bridge Residual', fontsize=12)
+ ax.set_title('Credit Bridge: Bridge Residual Over Training', fontsize=13)
+ ax.grid(True, alpha=0.3)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'toy_bridge_residual.png'), dpi=150)
+ plt.close(fig)
+ print("Saved toy_bridge_residual.png")
+
+# Print summary table
+print("\n" + "="*80)
+print("TOY LQ FINAL RESULTS (3 seeds, 8000 steps)")
+print("="*80)
+
+for method in methods:
+ cos_vals = []
+ rho_vals = []
+ nudge_vals = []
+ for data in all_data:
+ pl = data['final_per_layer']
+ cos_vals.append(np.mean(pl[f'{method}_costate_cos']))
+ rho_vals.append(np.mean(pl[f'{method}_rho']))
+ nudge_vals.append(np.mean(pl[f'{method}_nudge']))
+
+ cos_mean, cos_std = np.mean(cos_vals), np.std(cos_vals)
+ rho_mean, rho_std = np.mean(rho_vals), np.std(rho_vals)
+ nudge_mean, nudge_std = np.mean(nudge_vals), np.std(nudge_vals)
+
+ print(f"{labels[method]:<20} Cosine: {cos_mean:.4f}±{cos_std:.4f} "
+ f"ρ: {rho_mean:.4f}±{rho_std:.4f} "
+ f"Nudge: {nudge_mean:.4f}±{nudge_std:.4f}")
diff --git a/experiments/toy_lq.py b/experiments/toy_lq.py
new file mode 100644
index 0000000..4fd8919
--- /dev/null
+++ b/experiments/toy_lq.py
@@ -0,0 +1,395 @@
+"""
+Phase A: Linear-Quadratic Residual Sanity Check.
+
+Fixed forward dynamics (no forward net training).
+Only train feedback/bridge models.
+Compare DFA, State Bridge, Credit Bridge against exact costate.
+
+System:
+ h_{l+1} = M_l h_l + sigma * xi_l, xi_l ~ N(0, I)
+ Phi(h_L, y) = 0.5 * ||C h_L - y||^2
+ Exact costate: a_L = C^T (C h_L - y), a_l = M_l^T a_{l+1}
+"""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from datetime import datetime
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.value_net import ValueNet, create_ema_model, update_ema
+from models.state_bridge import StateBridgeNet
+from metrics.credit_metrics import (
+ cosine_similarity_batch, perturbation_correlation, nudging_test, bridge_residual
+)
+
+
+def generate_stable_dynamics(d, L, spectral_max=0.05, seed=42):
+ """Generate stable linear maps M_l = I + A_l with ||A_l||_2 <= spectral_max."""
+ rng = np.random.RandomState(seed)
+ Ms = []
+ for _ in range(L):
+ A = rng.randn(d, d).astype(np.float32)
+ # Scale to desired spectral norm
+ u, s, v = np.linalg.svd(A, full_matrices=False)
+ A = A * (spectral_max / s[0])
+ M = np.eye(d, dtype=np.float32) + A
+ Ms.append(torch.from_numpy(M))
+ return Ms # list of (d, d)
+
+
+def rollout_forward(h0, Ms, sigma, L, device):
+ """Roll out forward dynamics: h_{l+1} = M_l h_l + sigma * xi_l."""
+ batch = h0.shape[0]
+ d = h0.shape[1]
+ hiddens = [h0]
+ h = h0
+ for l in range(L):
+ M = Ms[l].to(device)
+ noise = sigma * torch.randn(batch, d, device=device)
+ h = h @ M.T + noise
+ hiddens.append(h)
+ return hiddens # [h_0, ..., h_L]
+
+
+def terminal_loss(hL, C, y):
+ """Phi(hL, y) = 0.5 * ||C hL - y||^2, returns per-sample loss."""
+ diff = hL @ C.T - y # (batch, m)
+ return 0.5 * (diff ** 2).sum(dim=-1) # (batch,)
+
+
+def exact_costate(hiddens, Ms, C, y, device):
+ """Compute exact costate a_l for all layers."""
+ L = len(hiddens) - 1
+ hL = hiddens[L]
+ # Terminal: a_L = C^T (C h_L - y)
+ diff = hL @ C.T - y # (batch, m)
+ a_L = diff @ C # (batch, d)
+
+ costates = [None] * (L + 1)
+ costates[L] = a_L
+ for l in range(L - 1, -1, -1):
+ M = Ms[l].to(device)
+ costates[l] = costates[l + 1] @ M # a_l = M_l^T a_{l+1} -> a_{l+1} @ M
+ return costates
+
+
+def make_forward_fn_from_layer(hiddens, Ms, C, y, sigma, start_layer, device):
+ """Create a function that rolls forward from layer start_layer and returns per-sample loss."""
+ L = len(Ms)
+
+ def forward_fn(h):
+ current = h
+ for l in range(start_layer, L):
+ M = Ms[l].to(device)
+ # No noise for perturbation test (deterministic rollout)
+ current = current @ M.T
+ return terminal_loss(current, C, y)
+
+ return forward_fn
+
+
+def run_experiment(args):
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ torch.manual_seed(args.seed)
+ np.random.seed(args.seed)
+
+ # Hyperparams
+ d = args.d_hidden # 64
+ m = args.output_dim # 10
+ L = args.num_layers # 12
+ sigma = args.sigma # 0.03
+ batch_size = args.batch_size # 256
+ num_steps = args.num_steps # 5000
+ lr_fb = args.lr_fb # 1e-3
+ lam = args.lam # 0.1
+ K = args.K # 8
+ ema_momentum = args.ema_momentum # 0.995
+ sigma_bridge = args.sigma_bridge # 0.03
+
+ print(f"=== Toy LQ Experiment ===")
+ print(f"d={d}, m={m}, L={L}, sigma={sigma}, seed={args.seed}")
+ print(f"device={device}")
+
+ # Generate fixed dynamics
+ Ms = generate_stable_dynamics(d, L, spectral_max=0.05, seed=args.seed)
+ C = torch.randn(m, d, device=device) / np.sqrt(d)
+
+ # DFA random feedback matrices
+ Bs_dfa = []
+ for l in range(L + 1):
+ B = torch.randn(d, m, device=device) / np.sqrt(m)
+ Bs_dfa.append(B)
+
+ # State Bridge model
+ state_bridge = StateBridgeNet(d_hidden=d, s_dim=m, time_embed_dim=16,
+ hidden_dim=128, num_layers=2).to(device)
+ opt_state = optim.Adam(state_bridge.parameters(), lr=lr_fb)
+
+ # Credit Bridge value net
+ value_net = ValueNet(d_hidden=d, s_dim=m, time_embed_dim=16,
+ hidden_dim=128, num_layers=2).to(device)
+ value_net_ema = create_ema_model(value_net)
+ opt_value = optim.Adam(value_net.parameters(), lr=lr_fb)
+
+ # Training logs
+ log = {
+ 'steps': [],
+ 'state_bridge_loss': [],
+ 'credit_bridge_loss': [],
+ 'dfa_costate_cos': [],
+ 'state_costate_cos': [],
+ 'credit_costate_cos': [],
+ 'dfa_rho': [],
+ 'state_rho': [],
+ 'credit_rho': [],
+ 'dfa_nudge': [],
+ 'state_nudge': [],
+ 'credit_nudge': [],
+ 'bridge_residual': [],
+ }
+
+ for step in range(1, num_steps + 1):
+ # Generate data
+ h0 = torch.randn(batch_size, d, device=device)
+ y = torch.randn(batch_size, m, device=device)
+
+ # Forward rollout
+ hiddens = rollout_forward(h0, Ms, sigma, L, device)
+ hL = hiddens[L]
+
+ # Terminal error
+ e_T = (hL @ C.T - y) # (batch, m) - gradient of Phi w.r.t. prediction
+
+ # Terminal modulation code s = e_T (P=I)
+ s = e_T.detach()
+
+ # ---- Train State Bridge ----
+ state_loss = 0.0
+ hL_detached = hL.detach()
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch_size,), l / L, device=device)
+ pred_hL = state_bridge(h_l_det, t_l, s)
+ state_loss = state_loss + ((pred_hL - hL_detached) ** 2).sum(dim=-1).mean()
+ state_loss = state_loss / L
+
+ opt_state.zero_grad()
+ state_loss.backward()
+ opt_state.step()
+
+ # ---- Train Credit Bridge (value net) ----
+ # Terminal boundary: V(h_L, 1, s) should equal Phi(h_L, y)
+ hL_det = hL.detach().requires_grad_(False)
+ t_L = torch.ones(batch_size, device=device)
+ true_loss = terminal_loss(hL_det, C, y).detach()
+
+ V_terminal = value_net(hL_det, t_L, s)
+ loss_term = ((V_terminal - true_loss) ** 2).mean()
+
+ # Bridge consistency
+ loss_bridge = 0.0
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch_size,), l / L, device=device)
+ t_l_next = torch.full((batch_size,), (l + 1) / L, device=device)
+
+ V_l = value_net(h_l_det, t_l, s)
+
+ # Generate noisy next states
+ with torch.no_grad():
+ M = Ms[l].to(device)
+ h_next_det = hiddens[l + 1].detach()
+
+ log_terms = []
+ for k in range(K):
+ noise = sigma_bridge * torch.randn(batch_size, d, device=device)
+ h_next_noisy = h_next_det + noise
+ V_next = value_net_ema(h_next_noisy, t_l_next, s)
+ log_terms.append(-V_next / lam)
+
+ log_terms_stack = torch.stack(log_terms, dim=-1) # (batch, K)
+ V_target = -lam * (torch.logsumexp(log_terms_stack, dim=-1) - np.log(K))
+
+ loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean()
+
+ loss_bridge = loss_bridge / L
+ loss_value = loss_term + loss_bridge
+
+ opt_value.zero_grad()
+ loss_value.backward()
+ opt_value.step()
+ update_ema(value_net, value_net_ema, ema_momentum)
+
+ # ---- Evaluation ----
+ if step % args.eval_every == 0 or step == 1:
+ with torch.no_grad():
+ eval_batch = min(batch_size, 128)
+ h0_eval = torch.randn(eval_batch, d, device=device)
+ y_eval = torch.randn(eval_batch, m, device=device)
+ hiddens_eval = rollout_forward(h0_eval, Ms, sigma, L, device)
+ hL_eval = hiddens_eval[L]
+ e_T_eval = hL_eval @ C.T - y_eval
+ s_eval = e_T_eval.detach()
+
+ # Exact costate
+ costates_exact = exact_costate(hiddens_eval, Ms, C, y_eval, device)
+
+ # Compute credits for each method at each layer
+ dfa_cos_layers = []
+ state_cos_layers = []
+ credit_cos_layers = []
+ dfa_rho_layers = []
+ state_rho_layers = []
+ credit_rho_layers = []
+ dfa_nudge_layers = []
+ state_nudge_layers = []
+ credit_nudge_layers = []
+ bridge_res_layers = []
+
+ for l in range(L + 1):
+ h_l = hiddens_eval[l].detach()
+ a_exact = costates_exact[l].detach()
+ t_l = torch.full((eval_batch,), l / L, device=device)
+
+ # DFA credit
+ a_dfa = e_T_eval @ Bs_dfa[l].T # (batch, d)
+
+ # State bridge credit
+ h_l_req = h_l.clone().requires_grad_(True)
+ pred_hL = state_bridge(h_l_req, t_l, s_eval)
+ # Loss through state bridge prediction
+ pred_out = pred_hL @ C.T # Use C as output projection for consistency
+ pred_loss = 0.5 * ((pred_out - y_eval) ** 2).sum(dim=-1)
+ a_state = torch.autograd.grad(pred_loss.sum(), h_l_req, create_graph=False)[0]
+
+ # Credit bridge credit
+ h_l_req2 = h_l.clone().requires_grad_(True)
+ V_l = value_net(h_l_req2, t_l, s_eval)
+ a_credit = torch.autograd.grad(V_l.sum(), h_l_req2, create_graph=False)[0]
+
+ # Costate cosine
+ dfa_cos_layers.append(cosine_similarity_batch(a_dfa, a_exact))
+ state_cos_layers.append(cosine_similarity_batch(a_state, a_exact))
+ credit_cos_layers.append(cosine_similarity_batch(a_credit, a_exact))
+
+ # Perturbation correlation and nudging (skip terminal layer for forward_fn)
+ if l < L:
+ fwd_fn = make_forward_fn_from_layer(hiddens_eval, Ms, C, y_eval, sigma, l, device)
+
+ dfa_rho = perturbation_correlation(h_l, a_dfa, fwd_fn, epsilon=1e-3, M=16)
+ state_rho = perturbation_correlation(h_l, a_state.detach(), fwd_fn, epsilon=1e-3, M=16)
+ credit_rho = perturbation_correlation(h_l, a_credit.detach(), fwd_fn, epsilon=1e-3, M=16)
+ dfa_rho_layers.append(dfa_rho)
+ state_rho_layers.append(state_rho)
+ credit_rho_layers.append(credit_rho)
+
+ dfa_nud = nudging_test(h_l, a_dfa, fwd_fn, eta=0.01)
+ state_nud = nudging_test(h_l, a_state.detach(), fwd_fn, eta=0.01)
+ credit_nud = nudging_test(h_l, a_credit.detach(), fwd_fn, eta=0.01)
+ dfa_nudge_layers.append(dfa_nud)
+ state_nudge_layers.append(state_nud)
+ credit_nudge_layers.append(credit_nud)
+
+ # Bridge residual for credit bridge
+ if l < L:
+ t_l_next = torch.full((eval_batch,), (l + 1) / L, device=device)
+ h_next = hiddens_eval[l + 1].detach()
+ noisy_list = [h_next + sigma_bridge * torch.randn_like(h_next) for _ in range(K)]
+ br = bridge_residual(value_net, value_net_ema, h_l, t_l, s_eval,
+ noisy_list, t_l_next, lam)
+ bridge_res_layers.append(br)
+
+ # Average across layers
+ avg_dfa_cos = np.mean(dfa_cos_layers)
+ avg_state_cos = np.mean(state_cos_layers)
+ avg_credit_cos = np.mean(credit_cos_layers)
+ avg_dfa_rho = np.mean(dfa_rho_layers)
+ avg_state_rho = np.mean(state_rho_layers)
+ avg_credit_rho = np.mean(credit_rho_layers)
+ avg_dfa_nudge = np.mean(dfa_nudge_layers)
+ avg_state_nudge = np.mean(state_nudge_layers)
+ avg_credit_nudge = np.mean(credit_nudge_layers)
+ avg_bridge_res = np.mean(bridge_res_layers) if bridge_res_layers else 0.0
+
+ log['steps'].append(step)
+ log['dfa_costate_cos'].append(avg_dfa_cos)
+ log['state_costate_cos'].append(avg_state_cos)
+ log['credit_costate_cos'].append(avg_credit_cos)
+ log['dfa_rho'].append(avg_dfa_rho)
+ log['state_rho'].append(avg_state_rho)
+ log['credit_rho'].append(avg_credit_rho)
+ log['dfa_nudge'].append(avg_dfa_nudge)
+ log['state_nudge'].append(avg_state_nudge)
+ log['credit_nudge'].append(avg_credit_nudge)
+ log['bridge_residual'].append(avg_bridge_res)
+ log['state_bridge_loss'].append(state_loss.item())
+ log['credit_bridge_loss'].append(loss_value.item())
+
+ print(f"Step {step}/{num_steps}")
+ print(f" Costate cos - DFA: {avg_dfa_cos:.4f}, State: {avg_state_cos:.4f}, Credit: {avg_credit_cos:.4f}")
+ print(f" Perturb rho - DFA: {avg_dfa_rho:.4f}, State: {avg_state_rho:.4f}, Credit: {avg_credit_rho:.4f}")
+ print(f" Nudging - DFA: {avg_dfa_nudge:.4f}, State: {avg_state_nudge:.4f}, Credit: {avg_credit_nudge:.4f}")
+ print(f" Bridge res - {avg_bridge_res:.4f}")
+ print(f" Losses - State: {state_loss.item():.4f}, Credit: {loss_value.item():.4f}")
+ print(f" Per-layer costate cos (credit): {['%.3f' % x for x in credit_cos_layers]}")
+
+ # Save results
+ os.makedirs(args.output_dir, exist_ok=True)
+ results = {
+ 'config': vars(args),
+ 'log': log,
+ 'final_per_layer': {
+ 'dfa_costate_cos': dfa_cos_layers,
+ 'state_costate_cos': state_cos_layers,
+ 'credit_costate_cos': credit_cos_layers,
+ 'dfa_rho': dfa_rho_layers,
+ 'state_rho': state_rho_layers,
+ 'credit_rho': credit_rho_layers,
+ 'dfa_nudge': dfa_nudge_layers,
+ 'state_nudge': state_nudge_layers,
+ 'credit_nudge': credit_nudge_layers,
+ 'bridge_residual': bridge_res_layers,
+ }
+ }
+
+ out_path = os.path.join(args.output_dir, f'toy_lq_seed{args.seed}.json')
+ with open(out_path, 'w') as f:
+ json.dump(results, f, indent=2)
+ print(f"\nResults saved to {out_path}")
+
+ # Also save models
+ torch.save(value_net.state_dict(), os.path.join(args.output_dir, f'value_net_seed{args.seed}.pt'))
+ torch.save(state_bridge.state_dict(), os.path.join(args.output_dir, f'state_bridge_seed{args.seed}.pt'))
+
+ return results
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Toy LQ Sanity Check')
+ parser.add_argument('--d_hidden', type=int, default=64)
+ parser.add_argument('--output_dim', type=int, default=10)
+ parser.add_argument('--num_layers', type=int, default=12)
+ parser.add_argument('--sigma', type=float, default=0.03)
+ parser.add_argument('--batch_size', type=int, default=256)
+ parser.add_argument('--num_steps', type=int, default=5000)
+ parser.add_argument('--lr_fb', type=float, default=1e-3)
+ parser.add_argument('--lam', type=float, default=0.1)
+ parser.add_argument('--K', type=int, default=8)
+ parser.add_argument('--ema_momentum', type=float, default=0.995)
+ parser.add_argument('--sigma_bridge', type=float, default=0.03)
+ parser.add_argument('--eval_every', type=int, default=200)
+ parser.add_argument('--seed', type=int, default=42)
+ parser.add_argument('--gpu', type=int, default=1)
+ parser.add_argument('--output_dir', type=str, default='results/toy_lq')
+ args = parser.parse_args()
+ run_experiment(args)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/experiments/toy_lq_sweep.py b/experiments/toy_lq_sweep.py
new file mode 100644
index 0000000..ae82ef0
--- /dev/null
+++ b/experiments/toy_lq_sweep.py
@@ -0,0 +1,243 @@
+"""
+Sweep over credit bridge hyperparameters to find a configuration
+where the value field gradient actually aligns with the costate.
+
+Key hypothesis: the credit bridge needs sufficient noise (sigma_bridge)
+and temperature (lambda) to make V_phi sensitive to cost-relevant directions.
+"""
+import os
+import sys
+import json
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from itertools import product
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.value_net import ValueNet, create_ema_model, update_ema
+from models.state_bridge import StateBridgeNet
+from experiments.toy_lq import (
+ generate_stable_dynamics, rollout_forward, terminal_loss,
+ exact_costate, make_forward_fn_from_layer
+)
+from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation, nudging_test
+
+
+def run_credit_bridge_config(config, device):
+ """Run credit bridge with specific hyperparameters and return final metrics."""
+ d = 64
+ m = 10
+ L = 12
+ sigma = 0.03
+ batch_size = 256
+ num_steps = config['num_steps']
+ lr = config['lr']
+ lam = config['lam']
+ K = config['K']
+ ema_momentum = config['ema_momentum']
+ sigma_bridge = config['sigma_bridge']
+ hidden_dim = config.get('hidden_dim', 128)
+ use_ln = config.get('use_ln', True)
+
+ torch.manual_seed(42)
+ np.random.seed(42)
+
+ Ms = generate_stable_dynamics(d, L, spectral_max=0.05, seed=42)
+ C = torch.randn(m, d, device=device) / np.sqrt(d)
+
+ # Value net - optionally without LayerNorm
+ value_net = ValueNet(d_hidden=d, s_dim=m, time_embed_dim=16,
+ hidden_dim=hidden_dim, num_layers=2).to(device)
+ if not use_ln:
+ value_net.ln = nn.Identity()
+
+ value_net_ema = create_ema_model(value_net)
+ opt_value = optim.Adam(value_net.parameters(), lr=lr)
+
+ best_cos = -1.0
+ best_step = 0
+ history = []
+
+ for step in range(1, num_steps + 1):
+ h0 = torch.randn(batch_size, d, device=device)
+ y = torch.randn(batch_size, m, device=device)
+ hiddens = rollout_forward(h0, Ms, sigma, L, device)
+ hL = hiddens[L]
+ e_T = hL @ C.T - y
+ s = e_T.detach()
+ true_loss = terminal_loss(hL.detach(), C, y).detach()
+
+ # Terminal boundary
+ hL_det = hL.detach()
+ t_L = torch.ones(batch_size, device=device)
+ V_terminal = value_net(hL_det, t_L, s)
+ loss_term = ((V_terminal - true_loss) ** 2).mean()
+
+ # Bridge consistency
+ loss_bridge = 0.0
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch_size,), l / L, device=device)
+ t_l_next = torch.full((batch_size,), (l + 1) / L, device=device)
+
+ V_l = value_net(h_l_det, t_l, s)
+
+ with torch.no_grad():
+ h_next_det = hiddens[l + 1].detach()
+ log_terms = []
+ for k in range(K):
+ noise = sigma_bridge * torch.randn(batch_size, d, device=device)
+ h_noisy = h_next_det + noise
+ V_next = value_net_ema(h_noisy, t_l_next, s)
+ log_terms.append(-V_next / lam)
+
+ log_terms_stack = torch.stack(log_terms, dim=-1)
+ V_target = -lam * (torch.logsumexp(log_terms_stack, dim=-1) - np.log(K))
+
+ loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean()
+
+ loss_bridge = loss_bridge / L
+ total_loss = loss_term + loss_bridge
+
+ opt_value.zero_grad()
+ total_loss.backward()
+ torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0)
+ opt_value.step()
+ update_ema(value_net, value_net_ema, ema_momentum)
+
+ # Quick evaluation
+ if step % 500 == 0 or step == num_steps:
+ with torch.no_grad():
+ eval_batch = 128
+ h0_e = torch.randn(eval_batch, d, device=device)
+ y_e = torch.randn(eval_batch, m, device=device)
+ hiddens_e = rollout_forward(h0_e, Ms, sigma, L, device)
+ hL_e = hiddens_e[L]
+ e_T_e = hL_e @ C.T - y_e
+ s_e = e_T_e.detach()
+ costates = exact_costate(hiddens_e, Ms, C, y_e, device)
+
+ cos_list = []
+ rho_list = []
+ nudge_list = []
+ for l in range(L):
+ h_l = hiddens_e[l].detach()
+ t_l = torch.full((eval_batch,), l / L, device=device)
+ a_exact = costates[l].detach()
+
+ h_l_req = h_l.clone().requires_grad_(True)
+ V_l = value_net(h_l_req, t_l, s_e)
+ a_credit = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0]
+
+ cos_list.append(cosine_similarity_batch(a_credit, a_exact))
+
+ fwd_fn = make_forward_fn_from_layer(hiddens_e, Ms, C, y_e, sigma, l, device)
+ rho = perturbation_correlation(h_l, a_credit.detach(), fwd_fn, epsilon=1e-3, M=16)
+ rho_list.append(rho)
+ nud = nudging_test(h_l, a_credit.detach(), fwd_fn, eta=0.01)
+ nudge_list.append(nud)
+
+ avg_cos = np.mean(cos_list)
+ avg_rho = np.mean(rho_list)
+ avg_nudge = np.mean(nudge_list)
+
+ if avg_cos > best_cos:
+ best_cos = avg_cos
+ best_step = step
+
+ history.append({
+ 'step': step,
+ 'avg_cos': avg_cos,
+ 'avg_rho': avg_rho,
+ 'avg_nudge': avg_nudge,
+ 'loss_term': loss_term.item(),
+ 'loss_bridge': loss_bridge.item(),
+ })
+
+ return {
+ 'best_cos': best_cos,
+ 'best_step': best_step,
+ 'final_cos': history[-1]['avg_cos'],
+ 'final_rho': history[-1]['avg_rho'],
+ 'final_nudge': history[-1]['avg_nudge'],
+ 'history': history,
+ }
+
+
+def main():
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
+ print(f"Device: {device}")
+
+ # Sweep configurations
+ configs = [
+ # Baseline (original)
+ {'name': 'base', 'lam': 0.1, 'sigma_bridge': 0.03, 'K': 8, 'lr': 1e-3,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True},
+ # Larger noise
+ {'name': 'noise_0.1', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True},
+ # Much larger noise
+ {'name': 'noise_0.3', 'lam': 0.1, 'sigma_bridge': 0.3, 'K': 8, 'lr': 1e-3,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True},
+ # Larger lambda
+ {'name': 'lam_1.0', 'lam': 1.0, 'sigma_bridge': 0.03, 'K': 8, 'lr': 1e-3,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True},
+ # Large noise + large lambda
+ {'name': 'noise_lam', 'lam': 1.0, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True},
+ # No LayerNorm
+ {'name': 'no_ln', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': False},
+ # Larger value net
+ {'name': 'big_vnet', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 256, 'use_ln': True},
+ # Slower EMA
+ {'name': 'ema_0.999', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3,
+ 'ema_momentum': 0.999, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True},
+ # More K samples
+ {'name': 'K16', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 16, 'lr': 1e-3,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True},
+ # Larger noise + large lambda + no LN
+ {'name': 'best_combo', 'lam': 1.0, 'sigma_bridge': 0.3, 'K': 8, 'lr': 1e-3,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': False},
+ # Very large sigma
+ {'name': 'noise_1.0', 'lam': 1.0, 'sigma_bridge': 1.0, 'K': 8, 'lr': 1e-3,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True},
+ # Lower lr
+ {'name': 'lr_3e-4', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 3e-4,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True},
+ ]
+
+ results = {}
+ for cfg in configs:
+ name = cfg.pop('name')
+ print(f"\n{'='*50}")
+ print(f"Config: {name}")
+ print(f" {cfg}")
+ res = run_credit_bridge_config(cfg, device)
+ results[name] = res
+ print(f" Best cos: {res['best_cos']:.4f} (step {res['best_step']})")
+ print(f" Final cos: {res['final_cos']:.4f}, rho: {res['final_rho']:.4f}, nudge: {res['final_nudge']:.4f}")
+ cfg['name'] = name # restore
+
+ # Print summary
+ print("\n" + "="*80)
+ print("SWEEP SUMMARY")
+ print("="*80)
+ print(f"{'Config':<20} {'Best Cos':<12} {'Final Cos':<12} {'Final Rho':<12} {'Final Nudge':<12}")
+ print("-"*68)
+ for name, res in results.items():
+ print(f"{name:<20} {res['best_cos']:<12.4f} {res['final_cos']:<12.4f} "
+ f"{res['final_rho']:<12.4f} {res['final_nudge']:<12.4f}")
+
+ # Save
+ os.makedirs('results/toy_lq', exist_ok=True)
+ with open('results/toy_lq/sweep_results.json', 'w') as f:
+ json.dump(results, f, indent=2)
+ print("\nSaved to results/toy_lq/sweep_results.json")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/experiments/toy_lq_v2.py b/experiments/toy_lq_v2.py
new file mode 100644
index 0000000..ab766b6
--- /dev/null
+++ b/experiments/toy_lq_v2.py
@@ -0,0 +1,327 @@
+"""
+Phase A v2: Enhanced toy LQ experiment.
+
+Key improvements over v1:
+1. Terminal gradient matching: V_phi at terminal layer should have grad_h V matching
+ the exact terminal gradient (this is LOCAL info, no hidden BP needed).
+2. Larger noise sweep integrated.
+3. Optional FM auxiliary for gradient smoothness.
+4. Better diagnostics.
+
+The terminal gradient a_L = C^T(C h_L - y) is computed from output layer only,
+so using it is allowed under the "no hidden BP anchor" constraint.
+"""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.value_net import ValueNet, create_ema_model, update_ema
+from models.state_bridge import StateBridgeNet
+from experiments.toy_lq import (
+ generate_stable_dynamics, rollout_forward, terminal_loss,
+ exact_costate, make_forward_fn_from_layer
+)
+from metrics.credit_metrics import (
+ cosine_similarity_batch, perturbation_correlation, nudging_test
+)
+
+
+def run_experiment(args):
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ torch.manual_seed(args.seed)
+ np.random.seed(args.seed)
+
+ d = args.d_hidden
+ m = args.output_dim
+ L = args.num_layers
+ sigma = args.sigma
+ batch_size = args.batch_size
+ num_steps = args.num_steps
+ lr = args.lr_fb
+ lam = args.lam
+ K = args.K
+ ema_momentum = args.ema_momentum
+ sigma_bridge = args.sigma_bridge
+
+ print(f"=== Toy LQ v2 Experiment ===")
+ print(f"d={d}, m={m}, L={L}, sigma={sigma}, seed={args.seed}")
+ print(f"lam={lam}, sigma_bridge={sigma_bridge}, K={K}")
+ print(f"terminal_grad_weight={args.term_grad_weight}")
+ print(f"fm_weight={args.fm_weight}")
+ print(f"device={device}")
+
+ Ms = generate_stable_dynamics(d, L, spectral_max=0.05, seed=args.seed)
+ C = torch.randn(m, d, device=device) / np.sqrt(d)
+
+ # DFA
+ Bs_dfa = [torch.randn(d, m, device=device) / np.sqrt(m) for _ in range(L + 1)]
+
+ # State Bridge
+ state_bridge = StateBridgeNet(d_hidden=d, s_dim=m, time_embed_dim=16,
+ hidden_dim=128, num_layers=2).to(device)
+ opt_state = optim.Adam(state_bridge.parameters(), lr=lr)
+
+ # Credit Bridge
+ value_net = ValueNet(d_hidden=d, s_dim=m, time_embed_dim=16,
+ hidden_dim=args.vnet_hidden, num_layers=args.vnet_layers).to(device)
+ value_net_ema = create_ema_model(value_net)
+ opt_value = optim.Adam(value_net.parameters(), lr=lr)
+
+ log = {key: [] for key in [
+ 'steps',
+ 'dfa_costate_cos', 'state_costate_cos', 'credit_costate_cos',
+ 'dfa_rho', 'state_rho', 'credit_rho',
+ 'dfa_nudge', 'state_nudge', 'credit_nudge',
+ 'bridge_residual', 'state_bridge_loss', 'credit_bridge_loss',
+ 'term_loss', 'bridge_loss', 'term_grad_loss', 'fm_loss',
+ ]}
+
+ for step in range(1, num_steps + 1):
+ h0 = torch.randn(batch_size, d, device=device)
+ y = torch.randn(batch_size, m, device=device)
+ hiddens = rollout_forward(h0, Ms, sigma, L, device)
+ hL = hiddens[L]
+ e_T = hL @ C.T - y
+ s = e_T.detach()
+
+ # ---- Train State Bridge ----
+ state_loss = 0.0
+ hL_det = hL.detach()
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch_size,), l / L, device=device)
+ pred_hL = state_bridge(h_l_det, t_l, s)
+ state_loss = state_loss + ((pred_hL - hL_det) ** 2).sum(dim=-1).mean()
+ state_loss = state_loss / L
+
+ opt_state.zero_grad()
+ state_loss.backward()
+ opt_state.step()
+
+ # ---- Train Credit Bridge ----
+ # 1. Terminal boundary: V(h_L, 1, s) ≈ Phi(h_L, y)
+ hL_det = hL.detach()
+ t_L = torch.ones(batch_size, device=device)
+ true_loss = terminal_loss(hL_det, C, y).detach()
+ V_terminal = value_net(hL_det, t_L, s)
+ loss_term = ((V_terminal - true_loss) ** 2).mean()
+
+ # 2. Terminal gradient matching: grad_h V(h_L, 1, s) ≈ a_L^exact
+ # This uses only terminal-local information (no hidden BP)
+ loss_term_grad = torch.tensor(0.0, device=device)
+ if args.term_grad_weight > 0:
+ hL_req = hL.detach().requires_grad_(True)
+ V_at_L = value_net(hL_req, t_L, s)
+ grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0]
+ # Exact terminal gradient: C^T (C h_L - y)
+ a_L_exact = (e_T @ C).detach() # (batch, d) -- stop grad on target
+ loss_term_grad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean()
+
+ # 3. Bridge consistency
+ loss_bridge = 0.0
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch_size,), l / L, device=device)
+ t_l_next = torch.full((batch_size,), (l + 1) / L, device=device)
+
+ V_l = value_net(h_l_det, t_l, s)
+
+ with torch.no_grad():
+ h_next_det = hiddens[l + 1].detach()
+ log_terms = []
+ for k in range(K):
+ noise = sigma_bridge * torch.randn(batch_size, d, device=device)
+ h_noisy = h_next_det + noise
+ V_next = value_net_ema(h_noisy, t_l_next, s)
+ log_terms.append(-V_next / lam)
+ log_terms_stack = torch.stack(log_terms, dim=-1)
+ V_target = -lam * (torch.logsumexp(log_terms_stack, dim=-1) - np.log(K))
+
+ loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean()
+ loss_bridge = loss_bridge / L
+
+ # 4. FM auxiliary (optional): enforce gradient smoothness
+ loss_fm = torch.tensor(0.0, device=device)
+ if args.fm_weight > 0:
+ for l in range(L):
+ tau = torch.rand(batch_size, 1, device=device)
+ h_l_det = hiddens[l].detach()
+ h_next_det = hiddens[l + 1].detach()
+ f_l = h_next_det - h_l_det # residual
+
+ eps = torch.randn(batch_size, d, device=device)
+ h_mid = h_l_det + tau * f_l + (tau * (1 - tau)).sqrt() * sigma_bridge * eps
+ h_mid.requires_grad_(True)
+
+ t_mid = torch.full((batch_size, 1), 0, device=device)
+ t_mid = (l + tau) / L
+ t_mid_flat = t_mid.squeeze(-1)
+
+ V_mid = value_net(h_mid, t_mid_flat, s)
+ grad_V_mid = torch.autograd.grad(V_mid.sum(), h_mid, create_graph=True)[0]
+
+ # Interpolated target gradient
+ # Get a_l and a_{l+1} from current value net (no create_graph for targets)
+ h_l_r = h_l_det.clone().requires_grad_(True)
+ t_l_v = torch.full((batch_size,), l / L, device=device)
+ V_l_ = value_net(h_l_r, t_l_v, s)
+ a_l = torch.autograd.grad(V_l_.sum(), h_l_r, create_graph=False)[0].detach()
+
+ h_next_r = h_next_det.clone().requires_grad_(True)
+ t_next_v = torch.full((batch_size,), (l + 1) / L, device=device)
+ V_next_ = value_net(h_next_r, t_next_v, s)
+ a_next = torch.autograd.grad(V_next_.sum(), h_next_r, create_graph=False)[0].detach()
+
+ target_grad = ((1 - tau) * a_l + tau * a_next).detach()
+ loss_fm = loss_fm + ((grad_V_mid - target_grad) ** 2).sum(dim=-1).mean()
+ loss_fm = loss_fm / L
+
+ total_loss = (loss_term
+ + loss_bridge
+ + args.term_grad_weight * loss_term_grad
+ + args.fm_weight * loss_fm)
+
+ opt_value.zero_grad()
+ total_loss.backward()
+ torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0)
+ opt_value.step()
+ update_ema(value_net, value_net_ema, ema_momentum)
+
+ # ---- Evaluation ----
+ if step % args.eval_every == 0 or step == 1:
+ with torch.no_grad():
+ eval_batch = 128
+ h0_e = torch.randn(eval_batch, d, device=device)
+ y_e = torch.randn(eval_batch, m, device=device)
+ hiddens_e = rollout_forward(h0_e, Ms, sigma, L, device)
+ hL_e = hiddens_e[L]
+ e_T_e = hL_e @ C.T - y_e
+ s_e = e_T_e.detach()
+ costates = exact_costate(hiddens_e, Ms, C, y_e, device)
+
+ dfa_cos, state_cos, credit_cos = [], [], []
+ dfa_rho, state_rho, credit_rho = [], [], []
+ dfa_nudge, state_nudge, credit_nudge = [], [], []
+ bridge_res_list = []
+
+ for l in range(L):
+ h_l = hiddens_e[l].detach()
+ a_exact = costates[l].detach()
+ t_l = torch.full((eval_batch,), l / L, device=device)
+
+ # DFA
+ a_dfa = e_T_e @ Bs_dfa[l].T
+ # State bridge
+ h_l_r1 = h_l.clone().requires_grad_(True)
+ pred_hL = state_bridge(h_l_r1, t_l, s_e)
+ pred_out = pred_hL @ C.T
+ pred_loss = 0.5 * ((pred_out - y_e) ** 2).sum(dim=-1)
+ a_state = torch.autograd.grad(pred_loss.sum(), h_l_r1, create_graph=False)[0]
+ # Credit bridge
+ h_l_r2 = h_l.clone().requires_grad_(True)
+ V_l = value_net(h_l_r2, t_l, s_e)
+ a_credit = torch.autograd.grad(V_l.sum(), h_l_r2, create_graph=False)[0]
+
+ dfa_cos.append(cosine_similarity_batch(a_dfa, a_exact))
+ state_cos.append(cosine_similarity_batch(a_state, a_exact))
+ credit_cos.append(cosine_similarity_batch(a_credit, a_exact))
+
+ fwd_fn = make_forward_fn_from_layer(hiddens_e, Ms, C, y_e, sigma, l, device)
+
+ dfa_rho.append(perturbation_correlation(h_l, a_dfa, fwd_fn, epsilon=1e-3, M=16))
+ state_rho.append(perturbation_correlation(h_l, a_state.detach(), fwd_fn, epsilon=1e-3, M=16))
+ credit_rho.append(perturbation_correlation(h_l, a_credit.detach(), fwd_fn, epsilon=1e-3, M=16))
+
+ dfa_nudge.append(nudging_test(h_l, a_dfa, fwd_fn, eta=0.01))
+ state_nudge.append(nudging_test(h_l, a_state.detach(), fwd_fn, eta=0.01))
+ credit_nudge.append(nudging_test(h_l, a_credit.detach(), fwd_fn, eta=0.01))
+
+ avg = lambda x: float(np.mean(x))
+ log['steps'].append(step)
+ log['dfa_costate_cos'].append(avg(dfa_cos))
+ log['state_costate_cos'].append(avg(state_cos))
+ log['credit_costate_cos'].append(avg(credit_cos))
+ log['dfa_rho'].append(avg(dfa_rho))
+ log['state_rho'].append(avg(state_rho))
+ log['credit_rho'].append(avg(credit_rho))
+ log['dfa_nudge'].append(avg(dfa_nudge))
+ log['state_nudge'].append(avg(state_nudge))
+ log['credit_nudge'].append(avg(credit_nudge))
+ log['state_bridge_loss'].append(state_loss.item())
+ log['credit_bridge_loss'].append(total_loss.item())
+ log['term_loss'].append(loss_term.item())
+ log['bridge_loss'].append(loss_bridge.item())
+ log['term_grad_loss'].append(loss_term_grad.item() if isinstance(loss_term_grad, torch.Tensor) else loss_term_grad)
+ log['fm_loss'].append(loss_fm.item() if isinstance(loss_fm, torch.Tensor) else loss_fm)
+
+ print(f"Step {step}/{num_steps}")
+ print(f" Costate cos - DFA: {avg(dfa_cos):.4f}, State: {avg(state_cos):.4f}, Credit: {avg(credit_cos):.4f}")
+ print(f" Perturb rho - DFA: {avg(dfa_rho):.4f}, State: {avg(state_rho):.4f}, Credit: {avg(credit_rho):.4f}")
+ print(f" Nudging - DFA: {avg(dfa_nudge):.4f}, State: {avg(state_nudge):.4f}, Credit: {avg(credit_nudge):.4f}")
+ print(f" Losses - term: {loss_term.item():.4f}, bridge: {loss_bridge.item():.4f}, "
+ f"tgrad: {loss_term_grad.item() if isinstance(loss_term_grad, torch.Tensor) else 0:.4f}, "
+ f"fm: {loss_fm.item() if isinstance(loss_fm, torch.Tensor) else 0:.4f}")
+ print(f" Per-layer credit cos: {['%.3f' % x for x in credit_cos]}")
+
+ # Save
+ os.makedirs(args.output_dir, exist_ok=True)
+ results = {
+ 'config': vars(args),
+ 'log': log,
+ 'final_per_layer': {
+ 'dfa_costate_cos': dfa_cos,
+ 'state_costate_cos': state_cos,
+ 'credit_costate_cos': credit_cos,
+ 'dfa_rho': dfa_rho,
+ 'state_rho': state_rho,
+ 'credit_rho': credit_rho,
+ 'dfa_nudge': dfa_nudge,
+ 'state_nudge': state_nudge,
+ 'credit_nudge': credit_nudge,
+ }
+ }
+ tag = f"seed{args.seed}_lam{args.lam}_sig{args.sigma_bridge}_tgw{args.term_grad_weight}_fm{args.fm_weight}"
+ out_path = os.path.join(args.output_dir, f'toy_lq_v2_{tag}.json')
+ with open(out_path, 'w') as f:
+ json.dump(results, f, indent=2)
+ print(f"\nResults saved to {out_path}")
+ return results
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Toy LQ v2')
+ parser.add_argument('--d_hidden', type=int, default=64)
+ parser.add_argument('--output_dim', type=int, default=10)
+ parser.add_argument('--num_layers', type=int, default=12)
+ parser.add_argument('--sigma', type=float, default=0.03)
+ parser.add_argument('--batch_size', type=int, default=256)
+ parser.add_argument('--num_steps', type=int, default=8000)
+ parser.add_argument('--lr_fb', type=float, default=1e-3)
+ parser.add_argument('--lam', type=float, default=0.1)
+ parser.add_argument('--K', type=int, default=8)
+ parser.add_argument('--ema_momentum', type=float, default=0.995)
+ parser.add_argument('--sigma_bridge', type=float, default=0.1)
+ parser.add_argument('--eval_every', type=int, default=500)
+ parser.add_argument('--seed', type=int, default=42)
+ parser.add_argument('--gpu', type=int, default=1)
+ parser.add_argument('--output_dir', type=str, default='results/toy_lq')
+ parser.add_argument('--vnet_hidden', type=int, default=256)
+ parser.add_argument('--vnet_layers', type=int, default=3)
+ # Key new options
+ parser.add_argument('--term_grad_weight', type=float, default=1.0,
+ help='Weight for terminal gradient matching loss')
+ parser.add_argument('--fm_weight', type=float, default=0.0,
+ help='Weight for FM gradient smoothness auxiliary')
+ args = parser.parse_args()
+ run_experiment(args)
+
+
+if __name__ == '__main__':
+ main()