""" Phase A v2: Enhanced toy LQ experiment. Key improvements over v1: 1. Terminal gradient matching: V_phi at terminal layer should have grad_h V matching the exact terminal gradient (this is LOCAL info, no hidden BP needed). 2. Larger noise sweep integrated. 3. Optional FM auxiliary for gradient smoothness. 4. Better diagnostics. The terminal gradient a_L = C^T(C h_L - y) is computed from output layer only, so using it is allowed under the "no hidden BP anchor" constraint. """ import os import sys import json import argparse import numpy as np import torch import torch.nn as nn import torch.optim as optim sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.value_net import ValueNet, create_ema_model, update_ema from models.state_bridge import StateBridgeNet from experiments.toy_lq import ( generate_stable_dynamics, rollout_forward, terminal_loss, exact_costate, make_forward_fn_from_layer ) from metrics.credit_metrics import ( cosine_similarity_batch, perturbation_correlation, nudging_test ) def run_experiment(args): device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') torch.manual_seed(args.seed) np.random.seed(args.seed) d = args.d_hidden m = args.output_dim L = args.num_layers sigma = args.sigma batch_size = args.batch_size num_steps = args.num_steps lr = args.lr_fb lam = args.lam K = args.K ema_momentum = args.ema_momentum sigma_bridge = args.sigma_bridge print(f"=== Toy LQ v2 Experiment ===") print(f"d={d}, m={m}, L={L}, sigma={sigma}, seed={args.seed}") print(f"lam={lam}, sigma_bridge={sigma_bridge}, K={K}") print(f"terminal_grad_weight={args.term_grad_weight}") print(f"fm_weight={args.fm_weight}") print(f"device={device}") Ms = generate_stable_dynamics(d, L, spectral_max=0.05, seed=args.seed) C = torch.randn(m, d, device=device) / np.sqrt(d) # DFA Bs_dfa = [torch.randn(d, m, device=device) / np.sqrt(m) for _ in range(L + 1)] # State Bridge state_bridge = StateBridgeNet(d_hidden=d, s_dim=m, time_embed_dim=16, hidden_dim=128, num_layers=2).to(device) opt_state = optim.Adam(state_bridge.parameters(), lr=lr) # Credit Bridge value_net = ValueNet(d_hidden=d, s_dim=m, time_embed_dim=16, hidden_dim=args.vnet_hidden, num_layers=args.vnet_layers).to(device) value_net_ema = create_ema_model(value_net) opt_value = optim.Adam(value_net.parameters(), lr=lr) log = {key: [] for key in [ 'steps', 'dfa_costate_cos', 'state_costate_cos', 'credit_costate_cos', 'dfa_rho', 'state_rho', 'credit_rho', 'dfa_nudge', 'state_nudge', 'credit_nudge', 'bridge_residual', 'state_bridge_loss', 'credit_bridge_loss', 'term_loss', 'bridge_loss', 'term_grad_loss', 'fm_loss', ]} for step in range(1, num_steps + 1): h0 = torch.randn(batch_size, d, device=device) y = torch.randn(batch_size, m, device=device) hiddens = rollout_forward(h0, Ms, sigma, L, device) hL = hiddens[L] e_T = hL @ C.T - y s = e_T.detach() # ---- Train State Bridge ---- # Use normalized MSE (consistent with CIFAR experiment) state_loss = 0.0 hL_det = hL.detach() for l in range(L): h_l_det = hiddens[l].detach() t_l = torch.full((batch_size,), l / L, device=device) pred_hL = state_bridge(h_l_det, t_l, s) target_norm = hL_det.norm(dim=-1, keepdim=True).clamp(min=1.0) state_loss = state_loss + (((pred_hL - hL_det) / target_norm) ** 2).sum(dim=-1).mean() state_loss = state_loss / L opt_state.zero_grad() state_loss.backward() opt_state.step() # ---- Train Credit Bridge ---- # 1. Terminal boundary: V(h_L, 1, s) ≈ Phi(h_L, y) hL_det = hL.detach() t_L = torch.ones(batch_size, device=device) true_loss = terminal_loss(hL_det, C, y).detach() V_terminal = value_net(hL_det, t_L, s) loss_term = ((V_terminal - true_loss) ** 2).mean() # 2. Terminal gradient matching: grad_h V(h_L, 1, s) ≈ a_L^exact # This uses only terminal-local information (no hidden BP) loss_term_grad = torch.tensor(0.0, device=device) if args.term_grad_weight > 0: hL_req = hL.detach().requires_grad_(True) V_at_L = value_net(hL_req, t_L, s) grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0] # Exact terminal gradient: C^T (C h_L - y) a_L_exact = (e_T @ C).detach() # (batch, d) -- stop grad on target loss_term_grad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean() # 3. Bridge consistency loss_bridge = 0.0 for l in range(L): h_l_det = hiddens[l].detach() t_l = torch.full((batch_size,), l / L, device=device) t_l_next = torch.full((batch_size,), (l + 1) / L, device=device) V_l = value_net(h_l_det, t_l, s) with torch.no_grad(): h_next_det = hiddens[l + 1].detach() log_terms = [] for k in range(K): noise = sigma_bridge * torch.randn(batch_size, d, device=device) h_noisy = h_next_det + noise V_next = value_net_ema(h_noisy, t_l_next, s) log_terms.append(-V_next / lam) log_terms_stack = torch.stack(log_terms, dim=-1) V_target = -lam * (torch.logsumexp(log_terms_stack, dim=-1) - np.log(K)) loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean() loss_bridge = loss_bridge / L # 4. FM auxiliary (optional): enforce gradient smoothness loss_fm = torch.tensor(0.0, device=device) if args.fm_weight > 0: for l in range(L): tau = torch.rand(batch_size, 1, device=device) h_l_det = hiddens[l].detach() h_next_det = hiddens[l + 1].detach() f_l = h_next_det - h_l_det # residual eps = torch.randn(batch_size, d, device=device) h_mid = h_l_det + tau * f_l + (tau * (1 - tau)).sqrt() * sigma_bridge * eps h_mid.requires_grad_(True) t_mid = torch.full((batch_size, 1), 0, device=device) t_mid = (l + tau) / L t_mid_flat = t_mid.squeeze(-1) V_mid = value_net(h_mid, t_mid_flat, s) grad_V_mid = torch.autograd.grad(V_mid.sum(), h_mid, create_graph=True)[0] # Interpolated target gradient # Get a_l and a_{l+1} from current value net (no create_graph for targets) h_l_r = h_l_det.clone().requires_grad_(True) t_l_v = torch.full((batch_size,), l / L, device=device) V_l_ = value_net(h_l_r, t_l_v, s) a_l = torch.autograd.grad(V_l_.sum(), h_l_r, create_graph=False)[0].detach() h_next_r = h_next_det.clone().requires_grad_(True) t_next_v = torch.full((batch_size,), (l + 1) / L, device=device) V_next_ = value_net(h_next_r, t_next_v, s) a_next = torch.autograd.grad(V_next_.sum(), h_next_r, create_graph=False)[0].detach() target_grad = ((1 - tau) * a_l + tau * a_next).detach() loss_fm = loss_fm + ((grad_V_mid - target_grad) ** 2).sum(dim=-1).mean() loss_fm = loss_fm / L total_loss = (loss_term + loss_bridge + args.term_grad_weight * loss_term_grad + args.fm_weight * loss_fm) opt_value.zero_grad() total_loss.backward() torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0) opt_value.step() update_ema(value_net, value_net_ema, ema_momentum) # ---- Evaluation ---- if step % args.eval_every == 0 or step == 1: with torch.no_grad(): eval_batch = 128 h0_e = torch.randn(eval_batch, d, device=device) y_e = torch.randn(eval_batch, m, device=device) hiddens_e = rollout_forward(h0_e, Ms, sigma, L, device) hL_e = hiddens_e[L] e_T_e = hL_e @ C.T - y_e s_e = e_T_e.detach() costates = exact_costate(hiddens_e, Ms, C, y_e, device) dfa_cos, state_cos, credit_cos = [], [], [] dfa_rho, state_rho, credit_rho = [], [], [] dfa_nudge, state_nudge, credit_nudge = [], [], [] bridge_res_list = [] for l in range(L): h_l = hiddens_e[l].detach() a_exact = costates[l].detach() t_l = torch.full((eval_batch,), l / L, device=device) # DFA a_dfa = e_T_e @ Bs_dfa[l].T # State bridge h_l_r1 = h_l.clone().requires_grad_(True) pred_hL = state_bridge(h_l_r1, t_l, s_e) pred_out = pred_hL @ C.T pred_loss = 0.5 * ((pred_out - y_e) ** 2).sum(dim=-1) a_state = torch.autograd.grad(pred_loss.sum(), h_l_r1, create_graph=False)[0] # Credit bridge h_l_r2 = h_l.clone().requires_grad_(True) V_l = value_net(h_l_r2, t_l, s_e) a_credit = torch.autograd.grad(V_l.sum(), h_l_r2, create_graph=False)[0] dfa_cos.append(cosine_similarity_batch(a_dfa, a_exact)) state_cos.append(cosine_similarity_batch(a_state, a_exact)) credit_cos.append(cosine_similarity_batch(a_credit, a_exact)) fwd_fn = make_forward_fn_from_layer(hiddens_e, Ms, C, y_e, sigma, l, device) dfa_rho.append(perturbation_correlation(h_l, a_dfa, fwd_fn, epsilon=1e-3, M=16)) state_rho.append(perturbation_correlation(h_l, a_state.detach(), fwd_fn, epsilon=1e-3, M=16)) credit_rho.append(perturbation_correlation(h_l, a_credit.detach(), fwd_fn, epsilon=1e-3, M=16)) dfa_nudge.append(nudging_test(h_l, a_dfa, fwd_fn, eta=0.01)) state_nudge.append(nudging_test(h_l, a_state.detach(), fwd_fn, eta=0.01)) credit_nudge.append(nudging_test(h_l, a_credit.detach(), fwd_fn, eta=0.01)) avg = lambda x: float(np.mean(x)) log['steps'].append(step) log['dfa_costate_cos'].append(avg(dfa_cos)) log['state_costate_cos'].append(avg(state_cos)) log['credit_costate_cos'].append(avg(credit_cos)) log['dfa_rho'].append(avg(dfa_rho)) log['state_rho'].append(avg(state_rho)) log['credit_rho'].append(avg(credit_rho)) log['dfa_nudge'].append(avg(dfa_nudge)) log['state_nudge'].append(avg(state_nudge)) log['credit_nudge'].append(avg(credit_nudge)) log['state_bridge_loss'].append(state_loss.item()) log['credit_bridge_loss'].append(total_loss.item()) log['term_loss'].append(loss_term.item()) log['bridge_loss'].append(loss_bridge.item()) log['term_grad_loss'].append(loss_term_grad.item() if isinstance(loss_term_grad, torch.Tensor) else loss_term_grad) log['fm_loss'].append(loss_fm.item() if isinstance(loss_fm, torch.Tensor) else loss_fm) print(f"Step {step}/{num_steps}") print(f" Costate cos - DFA: {avg(dfa_cos):.4f}, State: {avg(state_cos):.4f}, Credit: {avg(credit_cos):.4f}") print(f" Perturb rho - DFA: {avg(dfa_rho):.4f}, State: {avg(state_rho):.4f}, Credit: {avg(credit_rho):.4f}") print(f" Nudging - DFA: {avg(dfa_nudge):.4f}, State: {avg(state_nudge):.4f}, Credit: {avg(credit_nudge):.4f}") print(f" Losses - term: {loss_term.item():.4f}, bridge: {loss_bridge.item():.4f}, " f"tgrad: {loss_term_grad.item() if isinstance(loss_term_grad, torch.Tensor) else 0:.4f}, " f"fm: {loss_fm.item() if isinstance(loss_fm, torch.Tensor) else 0:.4f}") print(f" Per-layer credit cos: {['%.3f' % x for x in credit_cos]}") # Save os.makedirs(args.output_dir, exist_ok=True) results = { 'config': vars(args), 'log': log, 'final_per_layer': { 'dfa_costate_cos': dfa_cos, 'state_costate_cos': state_cos, 'credit_costate_cos': credit_cos, 'dfa_rho': dfa_rho, 'state_rho': state_rho, 'credit_rho': credit_rho, 'dfa_nudge': dfa_nudge, 'state_nudge': state_nudge, 'credit_nudge': credit_nudge, } } tag = f"seed{args.seed}_lam{args.lam}_sig{args.sigma_bridge}_tgw{args.term_grad_weight}_fm{args.fm_weight}" out_path = os.path.join(args.output_dir, f'toy_lq_v2_{tag}.json') with open(out_path, 'w') as f: json.dump(results, f, indent=2) print(f"\nResults saved to {out_path}") return results def main(): parser = argparse.ArgumentParser(description='Toy LQ v2') parser.add_argument('--d_hidden', type=int, default=64) parser.add_argument('--output_dim', type=int, default=10) parser.add_argument('--num_layers', type=int, default=12) parser.add_argument('--sigma', type=float, default=0.03) parser.add_argument('--batch_size', type=int, default=256) parser.add_argument('--num_steps', type=int, default=8000) parser.add_argument('--lr_fb', type=float, default=1e-3) parser.add_argument('--lam', type=float, default=0.1) parser.add_argument('--K', type=int, default=8) parser.add_argument('--ema_momentum', type=float, default=0.995) parser.add_argument('--sigma_bridge', type=float, default=0.1) parser.add_argument('--eval_every', type=int, default=500) parser.add_argument('--seed', type=int, default=42) parser.add_argument('--gpu', type=int, default=1) parser.add_argument('--output_dir', type=str, default='results/toy_lq') parser.add_argument('--vnet_hidden', type=int, default=256) parser.add_argument('--vnet_layers', type=int, default=3) # Key new options parser.add_argument('--term_grad_weight', type=float, default=1.0, help='Weight for terminal gradient matching loss') parser.add_argument('--fm_weight', type=float, default=0.0, help='Weight for FM gradient smoothness auxiliary') args = parser.parse_args() run_experiment(args) if __name__ == '__main__': main()