diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-03-23 18:21:26 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-03-23 18:21:26 -0500 |
| commit | 6ed4fa50ddfa4c7957aaa909aaf72f0d7d317712 (patch) | |
| tree | d7c63adcd19c4f5d46c8a937e5047fece55dea62 /experiments/toy_lq.py | |
Initial implementation: all models, methods, toy and CIFAR experiments
Debug phase. Toy LQ experiments (3 seeds) complete with terminal gradient matching.
Credit bridge matches state bridge on linear system (~0.94 cosine).
CIFAR experiments in progress.
Diffstat (limited to 'experiments/toy_lq.py')
| -rw-r--r-- | experiments/toy_lq.py | 395 |
1 files changed, 395 insertions, 0 deletions
diff --git a/experiments/toy_lq.py b/experiments/toy_lq.py new file mode 100644 index 0000000..4fd8919 --- /dev/null +++ b/experiments/toy_lq.py @@ -0,0 +1,395 @@ +""" +Phase A: Linear-Quadratic Residual Sanity Check. + +Fixed forward dynamics (no forward net training). +Only train feedback/bridge models. +Compare DFA, State Bridge, Credit Bridge against exact costate. + +System: + h_{l+1} = M_l h_l + sigma * xi_l, xi_l ~ N(0, I) + Phi(h_L, y) = 0.5 * ||C h_L - y||^2 + Exact costate: a_L = C^T (C h_L - y), a_l = M_l^T a_{l+1} +""" +import os +import sys +import json +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from datetime import datetime + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.value_net import ValueNet, create_ema_model, update_ema +from models.state_bridge import StateBridgeNet +from metrics.credit_metrics import ( + cosine_similarity_batch, perturbation_correlation, nudging_test, bridge_residual +) + + +def generate_stable_dynamics(d, L, spectral_max=0.05, seed=42): + """Generate stable linear maps M_l = I + A_l with ||A_l||_2 <= spectral_max.""" + rng = np.random.RandomState(seed) + Ms = [] + for _ in range(L): + A = rng.randn(d, d).astype(np.float32) + # Scale to desired spectral norm + u, s, v = np.linalg.svd(A, full_matrices=False) + A = A * (spectral_max / s[0]) + M = np.eye(d, dtype=np.float32) + A + Ms.append(torch.from_numpy(M)) + return Ms # list of (d, d) + + +def rollout_forward(h0, Ms, sigma, L, device): + """Roll out forward dynamics: h_{l+1} = M_l h_l + sigma * xi_l.""" + batch = h0.shape[0] + d = h0.shape[1] + hiddens = [h0] + h = h0 + for l in range(L): + M = Ms[l].to(device) + noise = sigma * torch.randn(batch, d, device=device) + h = h @ M.T + noise + hiddens.append(h) + return hiddens # [h_0, ..., h_L] + + +def terminal_loss(hL, C, y): + """Phi(hL, y) = 0.5 * ||C hL - y||^2, returns per-sample loss.""" + diff = hL @ C.T - y # (batch, m) + return 0.5 * (diff ** 2).sum(dim=-1) # (batch,) + + +def exact_costate(hiddens, Ms, C, y, device): + """Compute exact costate a_l for all layers.""" + L = len(hiddens) - 1 + hL = hiddens[L] + # Terminal: a_L = C^T (C h_L - y) + diff = hL @ C.T - y # (batch, m) + a_L = diff @ C # (batch, d) + + costates = [None] * (L + 1) + costates[L] = a_L + for l in range(L - 1, -1, -1): + M = Ms[l].to(device) + costates[l] = costates[l + 1] @ M # a_l = M_l^T a_{l+1} -> a_{l+1} @ M + return costates + + +def make_forward_fn_from_layer(hiddens, Ms, C, y, sigma, start_layer, device): + """Create a function that rolls forward from layer start_layer and returns per-sample loss.""" + L = len(Ms) + + def forward_fn(h): + current = h + for l in range(start_layer, L): + M = Ms[l].to(device) + # No noise for perturbation test (deterministic rollout) + current = current @ M.T + return terminal_loss(current, C, y) + + return forward_fn + + +def run_experiment(args): + device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + torch.manual_seed(args.seed) + np.random.seed(args.seed) + + # Hyperparams + d = args.d_hidden # 64 + m = args.output_dim # 10 + L = args.num_layers # 12 + sigma = args.sigma # 0.03 + batch_size = args.batch_size # 256 + num_steps = args.num_steps # 5000 + lr_fb = args.lr_fb # 1e-3 + lam = args.lam # 0.1 + K = args.K # 8 + ema_momentum = args.ema_momentum # 0.995 + sigma_bridge = args.sigma_bridge # 0.03 + + print(f"=== Toy LQ Experiment ===") + print(f"d={d}, m={m}, L={L}, sigma={sigma}, seed={args.seed}") + print(f"device={device}") + + # Generate fixed dynamics + Ms = generate_stable_dynamics(d, L, spectral_max=0.05, seed=args.seed) + C = torch.randn(m, d, device=device) / np.sqrt(d) + + # DFA random feedback matrices + Bs_dfa = [] + for l in range(L + 1): + B = torch.randn(d, m, device=device) / np.sqrt(m) + Bs_dfa.append(B) + + # State Bridge model + state_bridge = StateBridgeNet(d_hidden=d, s_dim=m, time_embed_dim=16, + hidden_dim=128, num_layers=2).to(device) + opt_state = optim.Adam(state_bridge.parameters(), lr=lr_fb) + + # Credit Bridge value net + value_net = ValueNet(d_hidden=d, s_dim=m, time_embed_dim=16, + hidden_dim=128, num_layers=2).to(device) + value_net_ema = create_ema_model(value_net) + opt_value = optim.Adam(value_net.parameters(), lr=lr_fb) + + # Training logs + log = { + 'steps': [], + 'state_bridge_loss': [], + 'credit_bridge_loss': [], + 'dfa_costate_cos': [], + 'state_costate_cos': [], + 'credit_costate_cos': [], + 'dfa_rho': [], + 'state_rho': [], + 'credit_rho': [], + 'dfa_nudge': [], + 'state_nudge': [], + 'credit_nudge': [], + 'bridge_residual': [], + } + + for step in range(1, num_steps + 1): + # Generate data + h0 = torch.randn(batch_size, d, device=device) + y = torch.randn(batch_size, m, device=device) + + # Forward rollout + hiddens = rollout_forward(h0, Ms, sigma, L, device) + hL = hiddens[L] + + # Terminal error + e_T = (hL @ C.T - y) # (batch, m) - gradient of Phi w.r.t. prediction + + # Terminal modulation code s = e_T (P=I) + s = e_T.detach() + + # ---- Train State Bridge ---- + state_loss = 0.0 + hL_detached = hL.detach() + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch_size,), l / L, device=device) + pred_hL = state_bridge(h_l_det, t_l, s) + state_loss = state_loss + ((pred_hL - hL_detached) ** 2).sum(dim=-1).mean() + state_loss = state_loss / L + + opt_state.zero_grad() + state_loss.backward() + opt_state.step() + + # ---- Train Credit Bridge (value net) ---- + # Terminal boundary: V(h_L, 1, s) should equal Phi(h_L, y) + hL_det = hL.detach().requires_grad_(False) + t_L = torch.ones(batch_size, device=device) + true_loss = terminal_loss(hL_det, C, y).detach() + + V_terminal = value_net(hL_det, t_L, s) + loss_term = ((V_terminal - true_loss) ** 2).mean() + + # Bridge consistency + loss_bridge = 0.0 + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch_size,), l / L, device=device) + t_l_next = torch.full((batch_size,), (l + 1) / L, device=device) + + V_l = value_net(h_l_det, t_l, s) + + # Generate noisy next states + with torch.no_grad(): + M = Ms[l].to(device) + h_next_det = hiddens[l + 1].detach() + + log_terms = [] + for k in range(K): + noise = sigma_bridge * torch.randn(batch_size, d, device=device) + h_next_noisy = h_next_det + noise + V_next = value_net_ema(h_next_noisy, t_l_next, s) + log_terms.append(-V_next / lam) + + log_terms_stack = torch.stack(log_terms, dim=-1) # (batch, K) + V_target = -lam * (torch.logsumexp(log_terms_stack, dim=-1) - np.log(K)) + + loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean() + + loss_bridge = loss_bridge / L + loss_value = loss_term + loss_bridge + + opt_value.zero_grad() + loss_value.backward() + opt_value.step() + update_ema(value_net, value_net_ema, ema_momentum) + + # ---- Evaluation ---- + if step % args.eval_every == 0 or step == 1: + with torch.no_grad(): + eval_batch = min(batch_size, 128) + h0_eval = torch.randn(eval_batch, d, device=device) + y_eval = torch.randn(eval_batch, m, device=device) + hiddens_eval = rollout_forward(h0_eval, Ms, sigma, L, device) + hL_eval = hiddens_eval[L] + e_T_eval = hL_eval @ C.T - y_eval + s_eval = e_T_eval.detach() + + # Exact costate + costates_exact = exact_costate(hiddens_eval, Ms, C, y_eval, device) + + # Compute credits for each method at each layer + dfa_cos_layers = [] + state_cos_layers = [] + credit_cos_layers = [] + dfa_rho_layers = [] + state_rho_layers = [] + credit_rho_layers = [] + dfa_nudge_layers = [] + state_nudge_layers = [] + credit_nudge_layers = [] + bridge_res_layers = [] + + for l in range(L + 1): + h_l = hiddens_eval[l].detach() + a_exact = costates_exact[l].detach() + t_l = torch.full((eval_batch,), l / L, device=device) + + # DFA credit + a_dfa = e_T_eval @ Bs_dfa[l].T # (batch, d) + + # State bridge credit + h_l_req = h_l.clone().requires_grad_(True) + pred_hL = state_bridge(h_l_req, t_l, s_eval) + # Loss through state bridge prediction + pred_out = pred_hL @ C.T # Use C as output projection for consistency + pred_loss = 0.5 * ((pred_out - y_eval) ** 2).sum(dim=-1) + a_state = torch.autograd.grad(pred_loss.sum(), h_l_req, create_graph=False)[0] + + # Credit bridge credit + h_l_req2 = h_l.clone().requires_grad_(True) + V_l = value_net(h_l_req2, t_l, s_eval) + a_credit = torch.autograd.grad(V_l.sum(), h_l_req2, create_graph=False)[0] + + # Costate cosine + dfa_cos_layers.append(cosine_similarity_batch(a_dfa, a_exact)) + state_cos_layers.append(cosine_similarity_batch(a_state, a_exact)) + credit_cos_layers.append(cosine_similarity_batch(a_credit, a_exact)) + + # Perturbation correlation and nudging (skip terminal layer for forward_fn) + if l < L: + fwd_fn = make_forward_fn_from_layer(hiddens_eval, Ms, C, y_eval, sigma, l, device) + + dfa_rho = perturbation_correlation(h_l, a_dfa, fwd_fn, epsilon=1e-3, M=16) + state_rho = perturbation_correlation(h_l, a_state.detach(), fwd_fn, epsilon=1e-3, M=16) + credit_rho = perturbation_correlation(h_l, a_credit.detach(), fwd_fn, epsilon=1e-3, M=16) + dfa_rho_layers.append(dfa_rho) + state_rho_layers.append(state_rho) + credit_rho_layers.append(credit_rho) + + dfa_nud = nudging_test(h_l, a_dfa, fwd_fn, eta=0.01) + state_nud = nudging_test(h_l, a_state.detach(), fwd_fn, eta=0.01) + credit_nud = nudging_test(h_l, a_credit.detach(), fwd_fn, eta=0.01) + dfa_nudge_layers.append(dfa_nud) + state_nudge_layers.append(state_nud) + credit_nudge_layers.append(credit_nud) + + # Bridge residual for credit bridge + if l < L: + t_l_next = torch.full((eval_batch,), (l + 1) / L, device=device) + h_next = hiddens_eval[l + 1].detach() + noisy_list = [h_next + sigma_bridge * torch.randn_like(h_next) for _ in range(K)] + br = bridge_residual(value_net, value_net_ema, h_l, t_l, s_eval, + noisy_list, t_l_next, lam) + bridge_res_layers.append(br) + + # Average across layers + avg_dfa_cos = np.mean(dfa_cos_layers) + avg_state_cos = np.mean(state_cos_layers) + avg_credit_cos = np.mean(credit_cos_layers) + avg_dfa_rho = np.mean(dfa_rho_layers) + avg_state_rho = np.mean(state_rho_layers) + avg_credit_rho = np.mean(credit_rho_layers) + avg_dfa_nudge = np.mean(dfa_nudge_layers) + avg_state_nudge = np.mean(state_nudge_layers) + avg_credit_nudge = np.mean(credit_nudge_layers) + avg_bridge_res = np.mean(bridge_res_layers) if bridge_res_layers else 0.0 + + log['steps'].append(step) + log['dfa_costate_cos'].append(avg_dfa_cos) + log['state_costate_cos'].append(avg_state_cos) + log['credit_costate_cos'].append(avg_credit_cos) + log['dfa_rho'].append(avg_dfa_rho) + log['state_rho'].append(avg_state_rho) + log['credit_rho'].append(avg_credit_rho) + log['dfa_nudge'].append(avg_dfa_nudge) + log['state_nudge'].append(avg_state_nudge) + log['credit_nudge'].append(avg_credit_nudge) + log['bridge_residual'].append(avg_bridge_res) + log['state_bridge_loss'].append(state_loss.item()) + log['credit_bridge_loss'].append(loss_value.item()) + + print(f"Step {step}/{num_steps}") + print(f" Costate cos - DFA: {avg_dfa_cos:.4f}, State: {avg_state_cos:.4f}, Credit: {avg_credit_cos:.4f}") + print(f" Perturb rho - DFA: {avg_dfa_rho:.4f}, State: {avg_state_rho:.4f}, Credit: {avg_credit_rho:.4f}") + print(f" Nudging - DFA: {avg_dfa_nudge:.4f}, State: {avg_state_nudge:.4f}, Credit: {avg_credit_nudge:.4f}") + print(f" Bridge res - {avg_bridge_res:.4f}") + print(f" Losses - State: {state_loss.item():.4f}, Credit: {loss_value.item():.4f}") + print(f" Per-layer costate cos (credit): {['%.3f' % x for x in credit_cos_layers]}") + + # Save results + os.makedirs(args.output_dir, exist_ok=True) + results = { + 'config': vars(args), + 'log': log, + 'final_per_layer': { + 'dfa_costate_cos': dfa_cos_layers, + 'state_costate_cos': state_cos_layers, + 'credit_costate_cos': credit_cos_layers, + 'dfa_rho': dfa_rho_layers, + 'state_rho': state_rho_layers, + 'credit_rho': credit_rho_layers, + 'dfa_nudge': dfa_nudge_layers, + 'state_nudge': state_nudge_layers, + 'credit_nudge': credit_nudge_layers, + 'bridge_residual': bridge_res_layers, + } + } + + out_path = os.path.join(args.output_dir, f'toy_lq_seed{args.seed}.json') + with open(out_path, 'w') as f: + json.dump(results, f, indent=2) + print(f"\nResults saved to {out_path}") + + # Also save models + torch.save(value_net.state_dict(), os.path.join(args.output_dir, f'value_net_seed{args.seed}.pt')) + torch.save(state_bridge.state_dict(), os.path.join(args.output_dir, f'state_bridge_seed{args.seed}.pt')) + + return results + + +def main(): + parser = argparse.ArgumentParser(description='Toy LQ Sanity Check') + parser.add_argument('--d_hidden', type=int, default=64) + parser.add_argument('--output_dim', type=int, default=10) + parser.add_argument('--num_layers', type=int, default=12) + parser.add_argument('--sigma', type=float, default=0.03) + parser.add_argument('--batch_size', type=int, default=256) + parser.add_argument('--num_steps', type=int, default=5000) + parser.add_argument('--lr_fb', type=float, default=1e-3) + parser.add_argument('--lam', type=float, default=0.1) + parser.add_argument('--K', type=int, default=8) + parser.add_argument('--ema_momentum', type=float, default=0.995) + parser.add_argument('--sigma_bridge', type=float, default=0.03) + parser.add_argument('--eval_every', type=int, default=200) + parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--gpu', type=int, default=1) + parser.add_argument('--output_dir', type=str, default='results/toy_lq') + args = parser.parse_args() + run_experiment(args) + + +if __name__ == '__main__': + main() |
