summaryrefslogtreecommitdiff
path: root/experiments/toy_lq.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-03-23 18:21:26 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-03-23 18:21:26 -0500
commit6ed4fa50ddfa4c7957aaa909aaf72f0d7d317712 (patch)
treed7c63adcd19c4f5d46c8a937e5047fece55dea62 /experiments/toy_lq.py
Initial implementation: all models, methods, toy and CIFAR experiments
Debug phase. Toy LQ experiments (3 seeds) complete with terminal gradient matching. Credit bridge matches state bridge on linear system (~0.94 cosine). CIFAR experiments in progress.
Diffstat (limited to 'experiments/toy_lq.py')
-rw-r--r--experiments/toy_lq.py395
1 files changed, 395 insertions, 0 deletions
diff --git a/experiments/toy_lq.py b/experiments/toy_lq.py
new file mode 100644
index 0000000..4fd8919
--- /dev/null
+++ b/experiments/toy_lq.py
@@ -0,0 +1,395 @@
+"""
+Phase A: Linear-Quadratic Residual Sanity Check.
+
+Fixed forward dynamics (no forward net training).
+Only train feedback/bridge models.
+Compare DFA, State Bridge, Credit Bridge against exact costate.
+
+System:
+ h_{l+1} = M_l h_l + sigma * xi_l, xi_l ~ N(0, I)
+ Phi(h_L, y) = 0.5 * ||C h_L - y||^2
+ Exact costate: a_L = C^T (C h_L - y), a_l = M_l^T a_{l+1}
+"""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from datetime import datetime
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.value_net import ValueNet, create_ema_model, update_ema
+from models.state_bridge import StateBridgeNet
+from metrics.credit_metrics import (
+ cosine_similarity_batch, perturbation_correlation, nudging_test, bridge_residual
+)
+
+
+def generate_stable_dynamics(d, L, spectral_max=0.05, seed=42):
+ """Generate stable linear maps M_l = I + A_l with ||A_l||_2 <= spectral_max."""
+ rng = np.random.RandomState(seed)
+ Ms = []
+ for _ in range(L):
+ A = rng.randn(d, d).astype(np.float32)
+ # Scale to desired spectral norm
+ u, s, v = np.linalg.svd(A, full_matrices=False)
+ A = A * (spectral_max / s[0])
+ M = np.eye(d, dtype=np.float32) + A
+ Ms.append(torch.from_numpy(M))
+ return Ms # list of (d, d)
+
+
+def rollout_forward(h0, Ms, sigma, L, device):
+ """Roll out forward dynamics: h_{l+1} = M_l h_l + sigma * xi_l."""
+ batch = h0.shape[0]
+ d = h0.shape[1]
+ hiddens = [h0]
+ h = h0
+ for l in range(L):
+ M = Ms[l].to(device)
+ noise = sigma * torch.randn(batch, d, device=device)
+ h = h @ M.T + noise
+ hiddens.append(h)
+ return hiddens # [h_0, ..., h_L]
+
+
+def terminal_loss(hL, C, y):
+ """Phi(hL, y) = 0.5 * ||C hL - y||^2, returns per-sample loss."""
+ diff = hL @ C.T - y # (batch, m)
+ return 0.5 * (diff ** 2).sum(dim=-1) # (batch,)
+
+
+def exact_costate(hiddens, Ms, C, y, device):
+ """Compute exact costate a_l for all layers."""
+ L = len(hiddens) - 1
+ hL = hiddens[L]
+ # Terminal: a_L = C^T (C h_L - y)
+ diff = hL @ C.T - y # (batch, m)
+ a_L = diff @ C # (batch, d)
+
+ costates = [None] * (L + 1)
+ costates[L] = a_L
+ for l in range(L - 1, -1, -1):
+ M = Ms[l].to(device)
+ costates[l] = costates[l + 1] @ M # a_l = M_l^T a_{l+1} -> a_{l+1} @ M
+ return costates
+
+
+def make_forward_fn_from_layer(hiddens, Ms, C, y, sigma, start_layer, device):
+ """Create a function that rolls forward from layer start_layer and returns per-sample loss."""
+ L = len(Ms)
+
+ def forward_fn(h):
+ current = h
+ for l in range(start_layer, L):
+ M = Ms[l].to(device)
+ # No noise for perturbation test (deterministic rollout)
+ current = current @ M.T
+ return terminal_loss(current, C, y)
+
+ return forward_fn
+
+
+def run_experiment(args):
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ torch.manual_seed(args.seed)
+ np.random.seed(args.seed)
+
+ # Hyperparams
+ d = args.d_hidden # 64
+ m = args.output_dim # 10
+ L = args.num_layers # 12
+ sigma = args.sigma # 0.03
+ batch_size = args.batch_size # 256
+ num_steps = args.num_steps # 5000
+ lr_fb = args.lr_fb # 1e-3
+ lam = args.lam # 0.1
+ K = args.K # 8
+ ema_momentum = args.ema_momentum # 0.995
+ sigma_bridge = args.sigma_bridge # 0.03
+
+ print(f"=== Toy LQ Experiment ===")
+ print(f"d={d}, m={m}, L={L}, sigma={sigma}, seed={args.seed}")
+ print(f"device={device}")
+
+ # Generate fixed dynamics
+ Ms = generate_stable_dynamics(d, L, spectral_max=0.05, seed=args.seed)
+ C = torch.randn(m, d, device=device) / np.sqrt(d)
+
+ # DFA random feedback matrices
+ Bs_dfa = []
+ for l in range(L + 1):
+ B = torch.randn(d, m, device=device) / np.sqrt(m)
+ Bs_dfa.append(B)
+
+ # State Bridge model
+ state_bridge = StateBridgeNet(d_hidden=d, s_dim=m, time_embed_dim=16,
+ hidden_dim=128, num_layers=2).to(device)
+ opt_state = optim.Adam(state_bridge.parameters(), lr=lr_fb)
+
+ # Credit Bridge value net
+ value_net = ValueNet(d_hidden=d, s_dim=m, time_embed_dim=16,
+ hidden_dim=128, num_layers=2).to(device)
+ value_net_ema = create_ema_model(value_net)
+ opt_value = optim.Adam(value_net.parameters(), lr=lr_fb)
+
+ # Training logs
+ log = {
+ 'steps': [],
+ 'state_bridge_loss': [],
+ 'credit_bridge_loss': [],
+ 'dfa_costate_cos': [],
+ 'state_costate_cos': [],
+ 'credit_costate_cos': [],
+ 'dfa_rho': [],
+ 'state_rho': [],
+ 'credit_rho': [],
+ 'dfa_nudge': [],
+ 'state_nudge': [],
+ 'credit_nudge': [],
+ 'bridge_residual': [],
+ }
+
+ for step in range(1, num_steps + 1):
+ # Generate data
+ h0 = torch.randn(batch_size, d, device=device)
+ y = torch.randn(batch_size, m, device=device)
+
+ # Forward rollout
+ hiddens = rollout_forward(h0, Ms, sigma, L, device)
+ hL = hiddens[L]
+
+ # Terminal error
+ e_T = (hL @ C.T - y) # (batch, m) - gradient of Phi w.r.t. prediction
+
+ # Terminal modulation code s = e_T (P=I)
+ s = e_T.detach()
+
+ # ---- Train State Bridge ----
+ state_loss = 0.0
+ hL_detached = hL.detach()
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch_size,), l / L, device=device)
+ pred_hL = state_bridge(h_l_det, t_l, s)
+ state_loss = state_loss + ((pred_hL - hL_detached) ** 2).sum(dim=-1).mean()
+ state_loss = state_loss / L
+
+ opt_state.zero_grad()
+ state_loss.backward()
+ opt_state.step()
+
+ # ---- Train Credit Bridge (value net) ----
+ # Terminal boundary: V(h_L, 1, s) should equal Phi(h_L, y)
+ hL_det = hL.detach().requires_grad_(False)
+ t_L = torch.ones(batch_size, device=device)
+ true_loss = terminal_loss(hL_det, C, y).detach()
+
+ V_terminal = value_net(hL_det, t_L, s)
+ loss_term = ((V_terminal - true_loss) ** 2).mean()
+
+ # Bridge consistency
+ loss_bridge = 0.0
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch_size,), l / L, device=device)
+ t_l_next = torch.full((batch_size,), (l + 1) / L, device=device)
+
+ V_l = value_net(h_l_det, t_l, s)
+
+ # Generate noisy next states
+ with torch.no_grad():
+ M = Ms[l].to(device)
+ h_next_det = hiddens[l + 1].detach()
+
+ log_terms = []
+ for k in range(K):
+ noise = sigma_bridge * torch.randn(batch_size, d, device=device)
+ h_next_noisy = h_next_det + noise
+ V_next = value_net_ema(h_next_noisy, t_l_next, s)
+ log_terms.append(-V_next / lam)
+
+ log_terms_stack = torch.stack(log_terms, dim=-1) # (batch, K)
+ V_target = -lam * (torch.logsumexp(log_terms_stack, dim=-1) - np.log(K))
+
+ loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean()
+
+ loss_bridge = loss_bridge / L
+ loss_value = loss_term + loss_bridge
+
+ opt_value.zero_grad()
+ loss_value.backward()
+ opt_value.step()
+ update_ema(value_net, value_net_ema, ema_momentum)
+
+ # ---- Evaluation ----
+ if step % args.eval_every == 0 or step == 1:
+ with torch.no_grad():
+ eval_batch = min(batch_size, 128)
+ h0_eval = torch.randn(eval_batch, d, device=device)
+ y_eval = torch.randn(eval_batch, m, device=device)
+ hiddens_eval = rollout_forward(h0_eval, Ms, sigma, L, device)
+ hL_eval = hiddens_eval[L]
+ e_T_eval = hL_eval @ C.T - y_eval
+ s_eval = e_T_eval.detach()
+
+ # Exact costate
+ costates_exact = exact_costate(hiddens_eval, Ms, C, y_eval, device)
+
+ # Compute credits for each method at each layer
+ dfa_cos_layers = []
+ state_cos_layers = []
+ credit_cos_layers = []
+ dfa_rho_layers = []
+ state_rho_layers = []
+ credit_rho_layers = []
+ dfa_nudge_layers = []
+ state_nudge_layers = []
+ credit_nudge_layers = []
+ bridge_res_layers = []
+
+ for l in range(L + 1):
+ h_l = hiddens_eval[l].detach()
+ a_exact = costates_exact[l].detach()
+ t_l = torch.full((eval_batch,), l / L, device=device)
+
+ # DFA credit
+ a_dfa = e_T_eval @ Bs_dfa[l].T # (batch, d)
+
+ # State bridge credit
+ h_l_req = h_l.clone().requires_grad_(True)
+ pred_hL = state_bridge(h_l_req, t_l, s_eval)
+ # Loss through state bridge prediction
+ pred_out = pred_hL @ C.T # Use C as output projection for consistency
+ pred_loss = 0.5 * ((pred_out - y_eval) ** 2).sum(dim=-1)
+ a_state = torch.autograd.grad(pred_loss.sum(), h_l_req, create_graph=False)[0]
+
+ # Credit bridge credit
+ h_l_req2 = h_l.clone().requires_grad_(True)
+ V_l = value_net(h_l_req2, t_l, s_eval)
+ a_credit = torch.autograd.grad(V_l.sum(), h_l_req2, create_graph=False)[0]
+
+ # Costate cosine
+ dfa_cos_layers.append(cosine_similarity_batch(a_dfa, a_exact))
+ state_cos_layers.append(cosine_similarity_batch(a_state, a_exact))
+ credit_cos_layers.append(cosine_similarity_batch(a_credit, a_exact))
+
+ # Perturbation correlation and nudging (skip terminal layer for forward_fn)
+ if l < L:
+ fwd_fn = make_forward_fn_from_layer(hiddens_eval, Ms, C, y_eval, sigma, l, device)
+
+ dfa_rho = perturbation_correlation(h_l, a_dfa, fwd_fn, epsilon=1e-3, M=16)
+ state_rho = perturbation_correlation(h_l, a_state.detach(), fwd_fn, epsilon=1e-3, M=16)
+ credit_rho = perturbation_correlation(h_l, a_credit.detach(), fwd_fn, epsilon=1e-3, M=16)
+ dfa_rho_layers.append(dfa_rho)
+ state_rho_layers.append(state_rho)
+ credit_rho_layers.append(credit_rho)
+
+ dfa_nud = nudging_test(h_l, a_dfa, fwd_fn, eta=0.01)
+ state_nud = nudging_test(h_l, a_state.detach(), fwd_fn, eta=0.01)
+ credit_nud = nudging_test(h_l, a_credit.detach(), fwd_fn, eta=0.01)
+ dfa_nudge_layers.append(dfa_nud)
+ state_nudge_layers.append(state_nud)
+ credit_nudge_layers.append(credit_nud)
+
+ # Bridge residual for credit bridge
+ if l < L:
+ t_l_next = torch.full((eval_batch,), (l + 1) / L, device=device)
+ h_next = hiddens_eval[l + 1].detach()
+ noisy_list = [h_next + sigma_bridge * torch.randn_like(h_next) for _ in range(K)]
+ br = bridge_residual(value_net, value_net_ema, h_l, t_l, s_eval,
+ noisy_list, t_l_next, lam)
+ bridge_res_layers.append(br)
+
+ # Average across layers
+ avg_dfa_cos = np.mean(dfa_cos_layers)
+ avg_state_cos = np.mean(state_cos_layers)
+ avg_credit_cos = np.mean(credit_cos_layers)
+ avg_dfa_rho = np.mean(dfa_rho_layers)
+ avg_state_rho = np.mean(state_rho_layers)
+ avg_credit_rho = np.mean(credit_rho_layers)
+ avg_dfa_nudge = np.mean(dfa_nudge_layers)
+ avg_state_nudge = np.mean(state_nudge_layers)
+ avg_credit_nudge = np.mean(credit_nudge_layers)
+ avg_bridge_res = np.mean(bridge_res_layers) if bridge_res_layers else 0.0
+
+ log['steps'].append(step)
+ log['dfa_costate_cos'].append(avg_dfa_cos)
+ log['state_costate_cos'].append(avg_state_cos)
+ log['credit_costate_cos'].append(avg_credit_cos)
+ log['dfa_rho'].append(avg_dfa_rho)
+ log['state_rho'].append(avg_state_rho)
+ log['credit_rho'].append(avg_credit_rho)
+ log['dfa_nudge'].append(avg_dfa_nudge)
+ log['state_nudge'].append(avg_state_nudge)
+ log['credit_nudge'].append(avg_credit_nudge)
+ log['bridge_residual'].append(avg_bridge_res)
+ log['state_bridge_loss'].append(state_loss.item())
+ log['credit_bridge_loss'].append(loss_value.item())
+
+ print(f"Step {step}/{num_steps}")
+ print(f" Costate cos - DFA: {avg_dfa_cos:.4f}, State: {avg_state_cos:.4f}, Credit: {avg_credit_cos:.4f}")
+ print(f" Perturb rho - DFA: {avg_dfa_rho:.4f}, State: {avg_state_rho:.4f}, Credit: {avg_credit_rho:.4f}")
+ print(f" Nudging - DFA: {avg_dfa_nudge:.4f}, State: {avg_state_nudge:.4f}, Credit: {avg_credit_nudge:.4f}")
+ print(f" Bridge res - {avg_bridge_res:.4f}")
+ print(f" Losses - State: {state_loss.item():.4f}, Credit: {loss_value.item():.4f}")
+ print(f" Per-layer costate cos (credit): {['%.3f' % x for x in credit_cos_layers]}")
+
+ # Save results
+ os.makedirs(args.output_dir, exist_ok=True)
+ results = {
+ 'config': vars(args),
+ 'log': log,
+ 'final_per_layer': {
+ 'dfa_costate_cos': dfa_cos_layers,
+ 'state_costate_cos': state_cos_layers,
+ 'credit_costate_cos': credit_cos_layers,
+ 'dfa_rho': dfa_rho_layers,
+ 'state_rho': state_rho_layers,
+ 'credit_rho': credit_rho_layers,
+ 'dfa_nudge': dfa_nudge_layers,
+ 'state_nudge': state_nudge_layers,
+ 'credit_nudge': credit_nudge_layers,
+ 'bridge_residual': bridge_res_layers,
+ }
+ }
+
+ out_path = os.path.join(args.output_dir, f'toy_lq_seed{args.seed}.json')
+ with open(out_path, 'w') as f:
+ json.dump(results, f, indent=2)
+ print(f"\nResults saved to {out_path}")
+
+ # Also save models
+ torch.save(value_net.state_dict(), os.path.join(args.output_dir, f'value_net_seed{args.seed}.pt'))
+ torch.save(state_bridge.state_dict(), os.path.join(args.output_dir, f'state_bridge_seed{args.seed}.pt'))
+
+ return results
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Toy LQ Sanity Check')
+ parser.add_argument('--d_hidden', type=int, default=64)
+ parser.add_argument('--output_dim', type=int, default=10)
+ parser.add_argument('--num_layers', type=int, default=12)
+ parser.add_argument('--sigma', type=float, default=0.03)
+ parser.add_argument('--batch_size', type=int, default=256)
+ parser.add_argument('--num_steps', type=int, default=5000)
+ parser.add_argument('--lr_fb', type=float, default=1e-3)
+ parser.add_argument('--lam', type=float, default=0.1)
+ parser.add_argument('--K', type=int, default=8)
+ parser.add_argument('--ema_momentum', type=float, default=0.995)
+ parser.add_argument('--sigma_bridge', type=float, default=0.03)
+ parser.add_argument('--eval_every', type=int, default=200)
+ parser.add_argument('--seed', type=int, default=42)
+ parser.add_argument('--gpu', type=int, default=1)
+ parser.add_argument('--output_dir', type=str, default='results/toy_lq')
+ args = parser.parse_args()
+ run_experiment(args)
+
+
+if __name__ == '__main__':
+ main()