""" Sweep over credit bridge hyperparameters to find a configuration where the value field gradient actually aligns with the costate. Key hypothesis: the credit bridge needs sufficient noise (sigma_bridge) and temperature (lambda) to make V_phi sensitive to cost-relevant directions. """ import os import sys import json import numpy as np import torch import torch.nn as nn import torch.optim as optim from itertools import product sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.value_net import ValueNet, create_ema_model, update_ema from models.state_bridge import StateBridgeNet from experiments.toy_lq import ( generate_stable_dynamics, rollout_forward, terminal_loss, exact_costate, make_forward_fn_from_layer ) from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation, nudging_test def run_credit_bridge_config(config, device): """Run credit bridge with specific hyperparameters and return final metrics.""" d = 64 m = 10 L = 12 sigma = 0.03 batch_size = 256 num_steps = config['num_steps'] lr = config['lr'] lam = config['lam'] K = config['K'] ema_momentum = config['ema_momentum'] sigma_bridge = config['sigma_bridge'] hidden_dim = config.get('hidden_dim', 128) use_ln = config.get('use_ln', True) torch.manual_seed(42) np.random.seed(42) Ms = generate_stable_dynamics(d, L, spectral_max=0.05, seed=42) C = torch.randn(m, d, device=device) / np.sqrt(d) # Value net - optionally without LayerNorm value_net = ValueNet(d_hidden=d, s_dim=m, time_embed_dim=16, hidden_dim=hidden_dim, num_layers=2).to(device) if not use_ln: value_net.ln = nn.Identity() value_net_ema = create_ema_model(value_net) opt_value = optim.Adam(value_net.parameters(), lr=lr) best_cos = -1.0 best_step = 0 history = [] for step in range(1, num_steps + 1): h0 = torch.randn(batch_size, d, device=device) y = torch.randn(batch_size, m, device=device) hiddens = rollout_forward(h0, Ms, sigma, L, device) hL = hiddens[L] e_T = hL @ C.T - y s = e_T.detach() true_loss = terminal_loss(hL.detach(), C, y).detach() # Terminal boundary hL_det = hL.detach() t_L = torch.ones(batch_size, device=device) V_terminal = value_net(hL_det, t_L, s) loss_term = ((V_terminal - true_loss) ** 2).mean() # Bridge consistency loss_bridge = 0.0 for l in range(L): h_l_det = hiddens[l].detach() t_l = torch.full((batch_size,), l / L, device=device) t_l_next = torch.full((batch_size,), (l + 1) / L, device=device) V_l = value_net(h_l_det, t_l, s) with torch.no_grad(): h_next_det = hiddens[l + 1].detach() log_terms = [] for k in range(K): noise = sigma_bridge * torch.randn(batch_size, d, device=device) h_noisy = h_next_det + noise V_next = value_net_ema(h_noisy, t_l_next, s) log_terms.append(-V_next / lam) log_terms_stack = torch.stack(log_terms, dim=-1) V_target = -lam * (torch.logsumexp(log_terms_stack, dim=-1) - np.log(K)) loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean() loss_bridge = loss_bridge / L total_loss = loss_term + loss_bridge opt_value.zero_grad() total_loss.backward() torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0) opt_value.step() update_ema(value_net, value_net_ema, ema_momentum) # Quick evaluation if step % 500 == 0 or step == num_steps: with torch.no_grad(): eval_batch = 128 h0_e = torch.randn(eval_batch, d, device=device) y_e = torch.randn(eval_batch, m, device=device) hiddens_e = rollout_forward(h0_e, Ms, sigma, L, device) hL_e = hiddens_e[L] e_T_e = hL_e @ C.T - y_e s_e = e_T_e.detach() costates = exact_costate(hiddens_e, Ms, C, y_e, device) cos_list = [] rho_list = [] nudge_list = [] for l in range(L): h_l = hiddens_e[l].detach() t_l = torch.full((eval_batch,), l / L, device=device) a_exact = costates[l].detach() h_l_req = h_l.clone().requires_grad_(True) V_l = value_net(h_l_req, t_l, s_e) a_credit = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0] cos_list.append(cosine_similarity_batch(a_credit, a_exact)) fwd_fn = make_forward_fn_from_layer(hiddens_e, Ms, C, y_e, sigma, l, device) rho = perturbation_correlation(h_l, a_credit.detach(), fwd_fn, epsilon=1e-3, M=16) rho_list.append(rho) nud = nudging_test(h_l, a_credit.detach(), fwd_fn, eta=0.01) nudge_list.append(nud) avg_cos = np.mean(cos_list) avg_rho = np.mean(rho_list) avg_nudge = np.mean(nudge_list) if avg_cos > best_cos: best_cos = avg_cos best_step = step history.append({ 'step': step, 'avg_cos': avg_cos, 'avg_rho': avg_rho, 'avg_nudge': avg_nudge, 'loss_term': loss_term.item(), 'loss_bridge': loss_bridge.item(), }) return { 'best_cos': best_cos, 'best_step': best_step, 'final_cos': history[-1]['avg_cos'], 'final_rho': history[-1]['avg_rho'], 'final_nudge': history[-1]['avg_nudge'], 'history': history, } def main(): device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}") # Sweep configurations configs = [ # Baseline (original) {'name': 'base', 'lam': 0.1, 'sigma_bridge': 0.03, 'K': 8, 'lr': 1e-3, 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, # Larger noise {'name': 'noise_0.1', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3, 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, # Much larger noise {'name': 'noise_0.3', 'lam': 0.1, 'sigma_bridge': 0.3, 'K': 8, 'lr': 1e-3, 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, # Larger lambda {'name': 'lam_1.0', 'lam': 1.0, 'sigma_bridge': 0.03, 'K': 8, 'lr': 1e-3, 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, # Large noise + large lambda {'name': 'noise_lam', 'lam': 1.0, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3, 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, # No LayerNorm {'name': 'no_ln', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3, 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': False}, # Larger value net {'name': 'big_vnet', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3, 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 256, 'use_ln': True}, # Slower EMA {'name': 'ema_0.999', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3, 'ema_momentum': 0.999, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, # More K samples {'name': 'K16', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 16, 'lr': 1e-3, 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, # Larger noise + large lambda + no LN {'name': 'best_combo', 'lam': 1.0, 'sigma_bridge': 0.3, 'K': 8, 'lr': 1e-3, 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': False}, # Very large sigma {'name': 'noise_1.0', 'lam': 1.0, 'sigma_bridge': 1.0, 'K': 8, 'lr': 1e-3, 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, # Lower lr {'name': 'lr_3e-4', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 3e-4, 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, ] results = {} for cfg in configs: name = cfg.pop('name') print(f"\n{'='*50}") print(f"Config: {name}") print(f" {cfg}") res = run_credit_bridge_config(cfg, device) results[name] = res print(f" Best cos: {res['best_cos']:.4f} (step {res['best_step']})") print(f" Final cos: {res['final_cos']:.4f}, rho: {res['final_rho']:.4f}, nudge: {res['final_nudge']:.4f}") cfg['name'] = name # restore # Print summary print("\n" + "="*80) print("SWEEP SUMMARY") print("="*80) print(f"{'Config':<20} {'Best Cos':<12} {'Final Cos':<12} {'Final Rho':<12} {'Final Nudge':<12}") print("-"*68) for name, res in results.items(): print(f"{name:<20} {res['best_cos']:<12.4f} {res['final_cos']:<12.4f} " f"{res['final_rho']:<12.4f} {res['final_nudge']:<12.4f}") # Save os.makedirs('results/toy_lq', exist_ok=True) with open('results/toy_lq/sweep_results.json', 'w') as f: json.dump(results, f, indent=2) print("\nSaved to results/toy_lq/sweep_results.json") if __name__ == '__main__': main()