diff options
Diffstat (limited to 'experiments/toy_lq_v2.py')
| -rw-r--r-- | experiments/toy_lq_v2.py | 327 |
1 files changed, 327 insertions, 0 deletions
diff --git a/experiments/toy_lq_v2.py b/experiments/toy_lq_v2.py new file mode 100644 index 0000000..ab766b6 --- /dev/null +++ b/experiments/toy_lq_v2.py @@ -0,0 +1,327 @@ +""" +Phase A v2: Enhanced toy LQ experiment. + +Key improvements over v1: +1. Terminal gradient matching: V_phi at terminal layer should have grad_h V matching + the exact terminal gradient (this is LOCAL info, no hidden BP needed). +2. Larger noise sweep integrated. +3. Optional FM auxiliary for gradient smoothness. +4. Better diagnostics. + +The terminal gradient a_L = C^T(C h_L - y) is computed from output layer only, +so using it is allowed under the "no hidden BP anchor" constraint. +""" +import os +import sys +import json +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.value_net import ValueNet, create_ema_model, update_ema +from models.state_bridge import StateBridgeNet +from experiments.toy_lq import ( + generate_stable_dynamics, rollout_forward, terminal_loss, + exact_costate, make_forward_fn_from_layer +) +from metrics.credit_metrics import ( + cosine_similarity_batch, perturbation_correlation, nudging_test +) + + +def run_experiment(args): + device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + torch.manual_seed(args.seed) + np.random.seed(args.seed) + + d = args.d_hidden + m = args.output_dim + L = args.num_layers + sigma = args.sigma + batch_size = args.batch_size + num_steps = args.num_steps + lr = args.lr_fb + lam = args.lam + K = args.K + ema_momentum = args.ema_momentum + sigma_bridge = args.sigma_bridge + + print(f"=== Toy LQ v2 Experiment ===") + print(f"d={d}, m={m}, L={L}, sigma={sigma}, seed={args.seed}") + print(f"lam={lam}, sigma_bridge={sigma_bridge}, K={K}") + print(f"terminal_grad_weight={args.term_grad_weight}") + print(f"fm_weight={args.fm_weight}") + print(f"device={device}") + + Ms = generate_stable_dynamics(d, L, spectral_max=0.05, seed=args.seed) + C = torch.randn(m, d, device=device) / np.sqrt(d) + + # DFA + Bs_dfa = [torch.randn(d, m, device=device) / np.sqrt(m) for _ in range(L + 1)] + + # State Bridge + state_bridge = StateBridgeNet(d_hidden=d, s_dim=m, time_embed_dim=16, + hidden_dim=128, num_layers=2).to(device) + opt_state = optim.Adam(state_bridge.parameters(), lr=lr) + + # Credit Bridge + value_net = ValueNet(d_hidden=d, s_dim=m, time_embed_dim=16, + hidden_dim=args.vnet_hidden, num_layers=args.vnet_layers).to(device) + value_net_ema = create_ema_model(value_net) + opt_value = optim.Adam(value_net.parameters(), lr=lr) + + log = {key: [] for key in [ + 'steps', + 'dfa_costate_cos', 'state_costate_cos', 'credit_costate_cos', + 'dfa_rho', 'state_rho', 'credit_rho', + 'dfa_nudge', 'state_nudge', 'credit_nudge', + 'bridge_residual', 'state_bridge_loss', 'credit_bridge_loss', + 'term_loss', 'bridge_loss', 'term_grad_loss', 'fm_loss', + ]} + + for step in range(1, num_steps + 1): + h0 = torch.randn(batch_size, d, device=device) + y = torch.randn(batch_size, m, device=device) + hiddens = rollout_forward(h0, Ms, sigma, L, device) + hL = hiddens[L] + e_T = hL @ C.T - y + s = e_T.detach() + + # ---- Train State Bridge ---- + state_loss = 0.0 + hL_det = hL.detach() + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch_size,), l / L, device=device) + pred_hL = state_bridge(h_l_det, t_l, s) + state_loss = state_loss + ((pred_hL - hL_det) ** 2).sum(dim=-1).mean() + state_loss = state_loss / L + + opt_state.zero_grad() + state_loss.backward() + opt_state.step() + + # ---- Train Credit Bridge ---- + # 1. Terminal boundary: V(h_L, 1, s) ≈ Phi(h_L, y) + hL_det = hL.detach() + t_L = torch.ones(batch_size, device=device) + true_loss = terminal_loss(hL_det, C, y).detach() + V_terminal = value_net(hL_det, t_L, s) + loss_term = ((V_terminal - true_loss) ** 2).mean() + + # 2. Terminal gradient matching: grad_h V(h_L, 1, s) ≈ a_L^exact + # This uses only terminal-local information (no hidden BP) + loss_term_grad = torch.tensor(0.0, device=device) + if args.term_grad_weight > 0: + hL_req = hL.detach().requires_grad_(True) + V_at_L = value_net(hL_req, t_L, s) + grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0] + # Exact terminal gradient: C^T (C h_L - y) + a_L_exact = (e_T @ C).detach() # (batch, d) -- stop grad on target + loss_term_grad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean() + + # 3. Bridge consistency + loss_bridge = 0.0 + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch_size,), l / L, device=device) + t_l_next = torch.full((batch_size,), (l + 1) / L, device=device) + + V_l = value_net(h_l_det, t_l, s) + + with torch.no_grad(): + h_next_det = hiddens[l + 1].detach() + log_terms = [] + for k in range(K): + noise = sigma_bridge * torch.randn(batch_size, d, device=device) + h_noisy = h_next_det + noise + V_next = value_net_ema(h_noisy, t_l_next, s) + log_terms.append(-V_next / lam) + log_terms_stack = torch.stack(log_terms, dim=-1) + V_target = -lam * (torch.logsumexp(log_terms_stack, dim=-1) - np.log(K)) + + loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean() + loss_bridge = loss_bridge / L + + # 4. FM auxiliary (optional): enforce gradient smoothness + loss_fm = torch.tensor(0.0, device=device) + if args.fm_weight > 0: + for l in range(L): + tau = torch.rand(batch_size, 1, device=device) + h_l_det = hiddens[l].detach() + h_next_det = hiddens[l + 1].detach() + f_l = h_next_det - h_l_det # residual + + eps = torch.randn(batch_size, d, device=device) + h_mid = h_l_det + tau * f_l + (tau * (1 - tau)).sqrt() * sigma_bridge * eps + h_mid.requires_grad_(True) + + t_mid = torch.full((batch_size, 1), 0, device=device) + t_mid = (l + tau) / L + t_mid_flat = t_mid.squeeze(-1) + + V_mid = value_net(h_mid, t_mid_flat, s) + grad_V_mid = torch.autograd.grad(V_mid.sum(), h_mid, create_graph=True)[0] + + # Interpolated target gradient + # Get a_l and a_{l+1} from current value net (no create_graph for targets) + h_l_r = h_l_det.clone().requires_grad_(True) + t_l_v = torch.full((batch_size,), l / L, device=device) + V_l_ = value_net(h_l_r, t_l_v, s) + a_l = torch.autograd.grad(V_l_.sum(), h_l_r, create_graph=False)[0].detach() + + h_next_r = h_next_det.clone().requires_grad_(True) + t_next_v = torch.full((batch_size,), (l + 1) / L, device=device) + V_next_ = value_net(h_next_r, t_next_v, s) + a_next = torch.autograd.grad(V_next_.sum(), h_next_r, create_graph=False)[0].detach() + + target_grad = ((1 - tau) * a_l + tau * a_next).detach() + loss_fm = loss_fm + ((grad_V_mid - target_grad) ** 2).sum(dim=-1).mean() + loss_fm = loss_fm / L + + total_loss = (loss_term + + loss_bridge + + args.term_grad_weight * loss_term_grad + + args.fm_weight * loss_fm) + + opt_value.zero_grad() + total_loss.backward() + torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0) + opt_value.step() + update_ema(value_net, value_net_ema, ema_momentum) + + # ---- Evaluation ---- + if step % args.eval_every == 0 or step == 1: + with torch.no_grad(): + eval_batch = 128 + h0_e = torch.randn(eval_batch, d, device=device) + y_e = torch.randn(eval_batch, m, device=device) + hiddens_e = rollout_forward(h0_e, Ms, sigma, L, device) + hL_e = hiddens_e[L] + e_T_e = hL_e @ C.T - y_e + s_e = e_T_e.detach() + costates = exact_costate(hiddens_e, Ms, C, y_e, device) + + dfa_cos, state_cos, credit_cos = [], [], [] + dfa_rho, state_rho, credit_rho = [], [], [] + dfa_nudge, state_nudge, credit_nudge = [], [], [] + bridge_res_list = [] + + for l in range(L): + h_l = hiddens_e[l].detach() + a_exact = costates[l].detach() + t_l = torch.full((eval_batch,), l / L, device=device) + + # DFA + a_dfa = e_T_e @ Bs_dfa[l].T + # State bridge + h_l_r1 = h_l.clone().requires_grad_(True) + pred_hL = state_bridge(h_l_r1, t_l, s_e) + pred_out = pred_hL @ C.T + pred_loss = 0.5 * ((pred_out - y_e) ** 2).sum(dim=-1) + a_state = torch.autograd.grad(pred_loss.sum(), h_l_r1, create_graph=False)[0] + # Credit bridge + h_l_r2 = h_l.clone().requires_grad_(True) + V_l = value_net(h_l_r2, t_l, s_e) + a_credit = torch.autograd.grad(V_l.sum(), h_l_r2, create_graph=False)[0] + + dfa_cos.append(cosine_similarity_batch(a_dfa, a_exact)) + state_cos.append(cosine_similarity_batch(a_state, a_exact)) + credit_cos.append(cosine_similarity_batch(a_credit, a_exact)) + + fwd_fn = make_forward_fn_from_layer(hiddens_e, Ms, C, y_e, sigma, l, device) + + dfa_rho.append(perturbation_correlation(h_l, a_dfa, fwd_fn, epsilon=1e-3, M=16)) + state_rho.append(perturbation_correlation(h_l, a_state.detach(), fwd_fn, epsilon=1e-3, M=16)) + credit_rho.append(perturbation_correlation(h_l, a_credit.detach(), fwd_fn, epsilon=1e-3, M=16)) + + dfa_nudge.append(nudging_test(h_l, a_dfa, fwd_fn, eta=0.01)) + state_nudge.append(nudging_test(h_l, a_state.detach(), fwd_fn, eta=0.01)) + credit_nudge.append(nudging_test(h_l, a_credit.detach(), fwd_fn, eta=0.01)) + + avg = lambda x: float(np.mean(x)) + log['steps'].append(step) + log['dfa_costate_cos'].append(avg(dfa_cos)) + log['state_costate_cos'].append(avg(state_cos)) + log['credit_costate_cos'].append(avg(credit_cos)) + log['dfa_rho'].append(avg(dfa_rho)) + log['state_rho'].append(avg(state_rho)) + log['credit_rho'].append(avg(credit_rho)) + log['dfa_nudge'].append(avg(dfa_nudge)) + log['state_nudge'].append(avg(state_nudge)) + log['credit_nudge'].append(avg(credit_nudge)) + log['state_bridge_loss'].append(state_loss.item()) + log['credit_bridge_loss'].append(total_loss.item()) + log['term_loss'].append(loss_term.item()) + log['bridge_loss'].append(loss_bridge.item()) + log['term_grad_loss'].append(loss_term_grad.item() if isinstance(loss_term_grad, torch.Tensor) else loss_term_grad) + log['fm_loss'].append(loss_fm.item() if isinstance(loss_fm, torch.Tensor) else loss_fm) + + print(f"Step {step}/{num_steps}") + print(f" Costate cos - DFA: {avg(dfa_cos):.4f}, State: {avg(state_cos):.4f}, Credit: {avg(credit_cos):.4f}") + print(f" Perturb rho - DFA: {avg(dfa_rho):.4f}, State: {avg(state_rho):.4f}, Credit: {avg(credit_rho):.4f}") + print(f" Nudging - DFA: {avg(dfa_nudge):.4f}, State: {avg(state_nudge):.4f}, Credit: {avg(credit_nudge):.4f}") + print(f" Losses - term: {loss_term.item():.4f}, bridge: {loss_bridge.item():.4f}, " + f"tgrad: {loss_term_grad.item() if isinstance(loss_term_grad, torch.Tensor) else 0:.4f}, " + f"fm: {loss_fm.item() if isinstance(loss_fm, torch.Tensor) else 0:.4f}") + print(f" Per-layer credit cos: {['%.3f' % x for x in credit_cos]}") + + # Save + os.makedirs(args.output_dir, exist_ok=True) + results = { + 'config': vars(args), + 'log': log, + 'final_per_layer': { + 'dfa_costate_cos': dfa_cos, + 'state_costate_cos': state_cos, + 'credit_costate_cos': credit_cos, + 'dfa_rho': dfa_rho, + 'state_rho': state_rho, + 'credit_rho': credit_rho, + 'dfa_nudge': dfa_nudge, + 'state_nudge': state_nudge, + 'credit_nudge': credit_nudge, + } + } + tag = f"seed{args.seed}_lam{args.lam}_sig{args.sigma_bridge}_tgw{args.term_grad_weight}_fm{args.fm_weight}" + out_path = os.path.join(args.output_dir, f'toy_lq_v2_{tag}.json') + with open(out_path, 'w') as f: + json.dump(results, f, indent=2) + print(f"\nResults saved to {out_path}") + return results + + +def main(): + parser = argparse.ArgumentParser(description='Toy LQ v2') + parser.add_argument('--d_hidden', type=int, default=64) + parser.add_argument('--output_dim', type=int, default=10) + parser.add_argument('--num_layers', type=int, default=12) + parser.add_argument('--sigma', type=float, default=0.03) + parser.add_argument('--batch_size', type=int, default=256) + parser.add_argument('--num_steps', type=int, default=8000) + parser.add_argument('--lr_fb', type=float, default=1e-3) + parser.add_argument('--lam', type=float, default=0.1) + parser.add_argument('--K', type=int, default=8) + parser.add_argument('--ema_momentum', type=float, default=0.995) + parser.add_argument('--sigma_bridge', type=float, default=0.1) + parser.add_argument('--eval_every', type=int, default=500) + parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--gpu', type=int, default=1) + parser.add_argument('--output_dir', type=str, default='results/toy_lq') + parser.add_argument('--vnet_hidden', type=int, default=256) + parser.add_argument('--vnet_layers', type=int, default=3) + # Key new options + parser.add_argument('--term_grad_weight', type=float, default=1.0, + help='Weight for terminal gradient matching loss') + parser.add_argument('--fm_weight', type=float, default=0.0, + help='Weight for FM gradient smoothness auxiliary') + args = parser.parse_args() + run_experiment(args) + + +if __name__ == '__main__': + main() |
