""" Phase A: Linear-Quadratic Residual Sanity Check. Fixed forward dynamics (no forward net training). Only train feedback/bridge models. Compare DFA, State Bridge, Credit Bridge against exact costate. System: h_{l+1} = M_l h_l + sigma * xi_l, xi_l ~ N(0, I) Phi(h_L, y) = 0.5 * ||C h_L - y||^2 Exact costate: a_L = C^T (C h_L - y), a_l = M_l^T a_{l+1} """ import os import sys import json import argparse import numpy as np import torch import torch.nn as nn import torch.optim as optim from datetime import datetime sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.value_net import ValueNet, create_ema_model, update_ema from models.state_bridge import StateBridgeNet from metrics.credit_metrics import ( cosine_similarity_batch, perturbation_correlation, nudging_test, bridge_residual ) def generate_stable_dynamics(d, L, spectral_max=0.05, seed=42): """Generate stable linear maps M_l = I + A_l with ||A_l||_2 <= spectral_max.""" rng = np.random.RandomState(seed) Ms = [] for _ in range(L): A = rng.randn(d, d).astype(np.float32) # Scale to desired spectral norm u, s, v = np.linalg.svd(A, full_matrices=False) A = A * (spectral_max / s[0]) M = np.eye(d, dtype=np.float32) + A Ms.append(torch.from_numpy(M)) return Ms # list of (d, d) def rollout_forward(h0, Ms, sigma, L, device): """Roll out forward dynamics: h_{l+1} = M_l h_l + sigma * xi_l.""" batch = h0.shape[0] d = h0.shape[1] hiddens = [h0] h = h0 for l in range(L): M = Ms[l].to(device) noise = sigma * torch.randn(batch, d, device=device) h = h @ M.T + noise hiddens.append(h) return hiddens # [h_0, ..., h_L] def terminal_loss(hL, C, y): """Phi(hL, y) = 0.5 * ||C hL - y||^2, returns per-sample loss.""" diff = hL @ C.T - y # (batch, m) return 0.5 * (diff ** 2).sum(dim=-1) # (batch,) def exact_costate(hiddens, Ms, C, y, device): """Compute exact costate a_l for all layers.""" L = len(hiddens) - 1 hL = hiddens[L] # Terminal: a_L = C^T (C h_L - y) diff = hL @ C.T - y # (batch, m) a_L = diff @ C # (batch, d) costates = [None] * (L + 1) costates[L] = a_L for l in range(L - 1, -1, -1): M = Ms[l].to(device) costates[l] = costates[l + 1] @ M # a_l = M_l^T a_{l+1} -> a_{l+1} @ M return costates def make_forward_fn_from_layer(hiddens, Ms, C, y, sigma, start_layer, device): """Create a function that rolls forward from layer start_layer and returns per-sample loss.""" L = len(Ms) def forward_fn(h): current = h for l in range(start_layer, L): M = Ms[l].to(device) # No noise for perturbation test (deterministic rollout) current = current @ M.T return terminal_loss(current, C, y) return forward_fn def run_experiment(args): device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') torch.manual_seed(args.seed) np.random.seed(args.seed) # Hyperparams d = args.d_hidden # 64 m = args.output_dim # 10 L = args.num_layers # 12 sigma = args.sigma # 0.03 batch_size = args.batch_size # 256 num_steps = args.num_steps # 5000 lr_fb = args.lr_fb # 1e-3 lam = args.lam # 0.1 K = args.K # 8 ema_momentum = args.ema_momentum # 0.995 sigma_bridge = args.sigma_bridge # 0.03 print(f"=== Toy LQ Experiment ===") print(f"d={d}, m={m}, L={L}, sigma={sigma}, seed={args.seed}") print(f"device={device}") # Generate fixed dynamics Ms = generate_stable_dynamics(d, L, spectral_max=0.05, seed=args.seed) C = torch.randn(m, d, device=device) / np.sqrt(d) # DFA random feedback matrices Bs_dfa = [] for l in range(L + 1): B = torch.randn(d, m, device=device) / np.sqrt(m) Bs_dfa.append(B) # State Bridge model state_bridge = StateBridgeNet(d_hidden=d, s_dim=m, time_embed_dim=16, hidden_dim=128, num_layers=2).to(device) opt_state = optim.Adam(state_bridge.parameters(), lr=lr_fb) # Credit Bridge value net value_net = ValueNet(d_hidden=d, s_dim=m, time_embed_dim=16, hidden_dim=128, num_layers=2).to(device) value_net_ema = create_ema_model(value_net) opt_value = optim.Adam(value_net.parameters(), lr=lr_fb) # Training logs log = { 'steps': [], 'state_bridge_loss': [], 'credit_bridge_loss': [], 'dfa_costate_cos': [], 'state_costate_cos': [], 'credit_costate_cos': [], 'dfa_rho': [], 'state_rho': [], 'credit_rho': [], 'dfa_nudge': [], 'state_nudge': [], 'credit_nudge': [], 'bridge_residual': [], } for step in range(1, num_steps + 1): # Generate data h0 = torch.randn(batch_size, d, device=device) y = torch.randn(batch_size, m, device=device) # Forward rollout hiddens = rollout_forward(h0, Ms, sigma, L, device) hL = hiddens[L] # Terminal error e_T = (hL @ C.T - y) # (batch, m) - gradient of Phi w.r.t. prediction # Terminal modulation code s = e_T (P=I) s = e_T.detach() # ---- Train State Bridge ---- state_loss = 0.0 hL_detached = hL.detach() for l in range(L): h_l_det = hiddens[l].detach() t_l = torch.full((batch_size,), l / L, device=device) pred_hL = state_bridge(h_l_det, t_l, s) state_loss = state_loss + ((pred_hL - hL_detached) ** 2).sum(dim=-1).mean() state_loss = state_loss / L opt_state.zero_grad() state_loss.backward() opt_state.step() # ---- Train Credit Bridge (value net) ---- # Terminal boundary: V(h_L, 1, s) should equal Phi(h_L, y) hL_det = hL.detach().requires_grad_(False) t_L = torch.ones(batch_size, device=device) true_loss = terminal_loss(hL_det, C, y).detach() V_terminal = value_net(hL_det, t_L, s) loss_term = ((V_terminal - true_loss) ** 2).mean() # Bridge consistency loss_bridge = 0.0 for l in range(L): h_l_det = hiddens[l].detach() t_l = torch.full((batch_size,), l / L, device=device) t_l_next = torch.full((batch_size,), (l + 1) / L, device=device) V_l = value_net(h_l_det, t_l, s) # Generate noisy next states with torch.no_grad(): M = Ms[l].to(device) h_next_det = hiddens[l + 1].detach() log_terms = [] for k in range(K): noise = sigma_bridge * torch.randn(batch_size, d, device=device) h_next_noisy = h_next_det + noise V_next = value_net_ema(h_next_noisy, t_l_next, s) log_terms.append(-V_next / lam) log_terms_stack = torch.stack(log_terms, dim=-1) # (batch, K) V_target = -lam * (torch.logsumexp(log_terms_stack, dim=-1) - np.log(K)) loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean() loss_bridge = loss_bridge / L loss_value = loss_term + loss_bridge opt_value.zero_grad() loss_value.backward() opt_value.step() update_ema(value_net, value_net_ema, ema_momentum) # ---- Evaluation ---- if step % args.eval_every == 0 or step == 1: with torch.no_grad(): eval_batch = min(batch_size, 128) h0_eval = torch.randn(eval_batch, d, device=device) y_eval = torch.randn(eval_batch, m, device=device) hiddens_eval = rollout_forward(h0_eval, Ms, sigma, L, device) hL_eval = hiddens_eval[L] e_T_eval = hL_eval @ C.T - y_eval s_eval = e_T_eval.detach() # Exact costate costates_exact = exact_costate(hiddens_eval, Ms, C, y_eval, device) # Compute credits for each method at each layer dfa_cos_layers = [] state_cos_layers = [] credit_cos_layers = [] dfa_rho_layers = [] state_rho_layers = [] credit_rho_layers = [] dfa_nudge_layers = [] state_nudge_layers = [] credit_nudge_layers = [] bridge_res_layers = [] for l in range(L + 1): h_l = hiddens_eval[l].detach() a_exact = costates_exact[l].detach() t_l = torch.full((eval_batch,), l / L, device=device) # DFA credit a_dfa = e_T_eval @ Bs_dfa[l].T # (batch, d) # State bridge credit h_l_req = h_l.clone().requires_grad_(True) pred_hL = state_bridge(h_l_req, t_l, s_eval) # Loss through state bridge prediction pred_out = pred_hL @ C.T # Use C as output projection for consistency pred_loss = 0.5 * ((pred_out - y_eval) ** 2).sum(dim=-1) a_state = torch.autograd.grad(pred_loss.sum(), h_l_req, create_graph=False)[0] # Credit bridge credit h_l_req2 = h_l.clone().requires_grad_(True) V_l = value_net(h_l_req2, t_l, s_eval) a_credit = torch.autograd.grad(V_l.sum(), h_l_req2, create_graph=False)[0] # Costate cosine dfa_cos_layers.append(cosine_similarity_batch(a_dfa, a_exact)) state_cos_layers.append(cosine_similarity_batch(a_state, a_exact)) credit_cos_layers.append(cosine_similarity_batch(a_credit, a_exact)) # Perturbation correlation and nudging (skip terminal layer for forward_fn) if l < L: fwd_fn = make_forward_fn_from_layer(hiddens_eval, Ms, C, y_eval, sigma, l, device) dfa_rho = perturbation_correlation(h_l, a_dfa, fwd_fn, epsilon=1e-3, M=16) state_rho = perturbation_correlation(h_l, a_state.detach(), fwd_fn, epsilon=1e-3, M=16) credit_rho = perturbation_correlation(h_l, a_credit.detach(), fwd_fn, epsilon=1e-3, M=16) dfa_rho_layers.append(dfa_rho) state_rho_layers.append(state_rho) credit_rho_layers.append(credit_rho) dfa_nud = nudging_test(h_l, a_dfa, fwd_fn, eta=0.01) state_nud = nudging_test(h_l, a_state.detach(), fwd_fn, eta=0.01) credit_nud = nudging_test(h_l, a_credit.detach(), fwd_fn, eta=0.01) dfa_nudge_layers.append(dfa_nud) state_nudge_layers.append(state_nud) credit_nudge_layers.append(credit_nud) # Bridge residual for credit bridge if l < L: t_l_next = torch.full((eval_batch,), (l + 1) / L, device=device) h_next = hiddens_eval[l + 1].detach() noisy_list = [h_next + sigma_bridge * torch.randn_like(h_next) for _ in range(K)] br = bridge_residual(value_net, value_net_ema, h_l, t_l, s_eval, noisy_list, t_l_next, lam) bridge_res_layers.append(br) # Average across layers avg_dfa_cos = np.mean(dfa_cos_layers) avg_state_cos = np.mean(state_cos_layers) avg_credit_cos = np.mean(credit_cos_layers) avg_dfa_rho = np.mean(dfa_rho_layers) avg_state_rho = np.mean(state_rho_layers) avg_credit_rho = np.mean(credit_rho_layers) avg_dfa_nudge = np.mean(dfa_nudge_layers) avg_state_nudge = np.mean(state_nudge_layers) avg_credit_nudge = np.mean(credit_nudge_layers) avg_bridge_res = np.mean(bridge_res_layers) if bridge_res_layers else 0.0 log['steps'].append(step) log['dfa_costate_cos'].append(avg_dfa_cos) log['state_costate_cos'].append(avg_state_cos) log['credit_costate_cos'].append(avg_credit_cos) log['dfa_rho'].append(avg_dfa_rho) log['state_rho'].append(avg_state_rho) log['credit_rho'].append(avg_credit_rho) log['dfa_nudge'].append(avg_dfa_nudge) log['state_nudge'].append(avg_state_nudge) log['credit_nudge'].append(avg_credit_nudge) log['bridge_residual'].append(avg_bridge_res) log['state_bridge_loss'].append(state_loss.item()) log['credit_bridge_loss'].append(loss_value.item()) print(f"Step {step}/{num_steps}") print(f" Costate cos - DFA: {avg_dfa_cos:.4f}, State: {avg_state_cos:.4f}, Credit: {avg_credit_cos:.4f}") print(f" Perturb rho - DFA: {avg_dfa_rho:.4f}, State: {avg_state_rho:.4f}, Credit: {avg_credit_rho:.4f}") print(f" Nudging - DFA: {avg_dfa_nudge:.4f}, State: {avg_state_nudge:.4f}, Credit: {avg_credit_nudge:.4f}") print(f" Bridge res - {avg_bridge_res:.4f}") print(f" Losses - State: {state_loss.item():.4f}, Credit: {loss_value.item():.4f}") print(f" Per-layer costate cos (credit): {['%.3f' % x for x in credit_cos_layers]}") # Save results os.makedirs(args.output_dir, exist_ok=True) results = { 'config': vars(args), 'log': log, 'final_per_layer': { 'dfa_costate_cos': dfa_cos_layers, 'state_costate_cos': state_cos_layers, 'credit_costate_cos': credit_cos_layers, 'dfa_rho': dfa_rho_layers, 'state_rho': state_rho_layers, 'credit_rho': credit_rho_layers, 'dfa_nudge': dfa_nudge_layers, 'state_nudge': state_nudge_layers, 'credit_nudge': credit_nudge_layers, 'bridge_residual': bridge_res_layers, } } out_path = os.path.join(args.output_dir, f'toy_lq_seed{args.seed}.json') with open(out_path, 'w') as f: json.dump(results, f, indent=2) print(f"\nResults saved to {out_path}") # Also save models torch.save(value_net.state_dict(), os.path.join(args.output_dir, f'value_net_seed{args.seed}.pt')) torch.save(state_bridge.state_dict(), os.path.join(args.output_dir, f'state_bridge_seed{args.seed}.pt')) return results def main(): parser = argparse.ArgumentParser(description='Toy LQ Sanity Check') parser.add_argument('--d_hidden', type=int, default=64) parser.add_argument('--output_dim', type=int, default=10) parser.add_argument('--num_layers', type=int, default=12) parser.add_argument('--sigma', type=float, default=0.03) parser.add_argument('--batch_size', type=int, default=256) parser.add_argument('--num_steps', type=int, default=5000) parser.add_argument('--lr_fb', type=float, default=1e-3) parser.add_argument('--lam', type=float, default=0.1) parser.add_argument('--K', type=int, default=8) parser.add_argument('--ema_momentum', type=float, default=0.995) parser.add_argument('--sigma_bridge', type=float, default=0.03) parser.add_argument('--eval_every', type=int, default=200) parser.add_argument('--seed', type=int, default=42) parser.add_argument('--gpu', type=int, default=1) parser.add_argument('--output_dir', type=str, default='results/toy_lq') args = parser.parse_args() run_experiment(args) if __name__ == '__main__': main()