diff options
Diffstat (limited to 'experiments/toy_lq_sweep.py')
| -rw-r--r-- | experiments/toy_lq_sweep.py | 243 |
1 files changed, 243 insertions, 0 deletions
diff --git a/experiments/toy_lq_sweep.py b/experiments/toy_lq_sweep.py new file mode 100644 index 0000000..ae82ef0 --- /dev/null +++ b/experiments/toy_lq_sweep.py @@ -0,0 +1,243 @@ +""" +Sweep over credit bridge hyperparameters to find a configuration +where the value field gradient actually aligns with the costate. + +Key hypothesis: the credit bridge needs sufficient noise (sigma_bridge) +and temperature (lambda) to make V_phi sensitive to cost-relevant directions. +""" +import os +import sys +import json +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from itertools import product + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.value_net import ValueNet, create_ema_model, update_ema +from models.state_bridge import StateBridgeNet +from experiments.toy_lq import ( + generate_stable_dynamics, rollout_forward, terminal_loss, + exact_costate, make_forward_fn_from_layer +) +from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation, nudging_test + + +def run_credit_bridge_config(config, device): + """Run credit bridge with specific hyperparameters and return final metrics.""" + d = 64 + m = 10 + L = 12 + sigma = 0.03 + batch_size = 256 + num_steps = config['num_steps'] + lr = config['lr'] + lam = config['lam'] + K = config['K'] + ema_momentum = config['ema_momentum'] + sigma_bridge = config['sigma_bridge'] + hidden_dim = config.get('hidden_dim', 128) + use_ln = config.get('use_ln', True) + + torch.manual_seed(42) + np.random.seed(42) + + Ms = generate_stable_dynamics(d, L, spectral_max=0.05, seed=42) + C = torch.randn(m, d, device=device) / np.sqrt(d) + + # Value net - optionally without LayerNorm + value_net = ValueNet(d_hidden=d, s_dim=m, time_embed_dim=16, + hidden_dim=hidden_dim, num_layers=2).to(device) + if not use_ln: + value_net.ln = nn.Identity() + + value_net_ema = create_ema_model(value_net) + opt_value = optim.Adam(value_net.parameters(), lr=lr) + + best_cos = -1.0 + best_step = 0 + history = [] + + for step in range(1, num_steps + 1): + h0 = torch.randn(batch_size, d, device=device) + y = torch.randn(batch_size, m, device=device) + hiddens = rollout_forward(h0, Ms, sigma, L, device) + hL = hiddens[L] + e_T = hL @ C.T - y + s = e_T.detach() + true_loss = terminal_loss(hL.detach(), C, y).detach() + + # Terminal boundary + hL_det = hL.detach() + t_L = torch.ones(batch_size, device=device) + V_terminal = value_net(hL_det, t_L, s) + loss_term = ((V_terminal - true_loss) ** 2).mean() + + # Bridge consistency + loss_bridge = 0.0 + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch_size,), l / L, device=device) + t_l_next = torch.full((batch_size,), (l + 1) / L, device=device) + + V_l = value_net(h_l_det, t_l, s) + + with torch.no_grad(): + h_next_det = hiddens[l + 1].detach() + log_terms = [] + for k in range(K): + noise = sigma_bridge * torch.randn(batch_size, d, device=device) + h_noisy = h_next_det + noise + V_next = value_net_ema(h_noisy, t_l_next, s) + log_terms.append(-V_next / lam) + + log_terms_stack = torch.stack(log_terms, dim=-1) + V_target = -lam * (torch.logsumexp(log_terms_stack, dim=-1) - np.log(K)) + + loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean() + + loss_bridge = loss_bridge / L + total_loss = loss_term + loss_bridge + + opt_value.zero_grad() + total_loss.backward() + torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0) + opt_value.step() + update_ema(value_net, value_net_ema, ema_momentum) + + # Quick evaluation + if step % 500 == 0 or step == num_steps: + with torch.no_grad(): + eval_batch = 128 + h0_e = torch.randn(eval_batch, d, device=device) + y_e = torch.randn(eval_batch, m, device=device) + hiddens_e = rollout_forward(h0_e, Ms, sigma, L, device) + hL_e = hiddens_e[L] + e_T_e = hL_e @ C.T - y_e + s_e = e_T_e.detach() + costates = exact_costate(hiddens_e, Ms, C, y_e, device) + + cos_list = [] + rho_list = [] + nudge_list = [] + for l in range(L): + h_l = hiddens_e[l].detach() + t_l = torch.full((eval_batch,), l / L, device=device) + a_exact = costates[l].detach() + + h_l_req = h_l.clone().requires_grad_(True) + V_l = value_net(h_l_req, t_l, s_e) + a_credit = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0] + + cos_list.append(cosine_similarity_batch(a_credit, a_exact)) + + fwd_fn = make_forward_fn_from_layer(hiddens_e, Ms, C, y_e, sigma, l, device) + rho = perturbation_correlation(h_l, a_credit.detach(), fwd_fn, epsilon=1e-3, M=16) + rho_list.append(rho) + nud = nudging_test(h_l, a_credit.detach(), fwd_fn, eta=0.01) + nudge_list.append(nud) + + avg_cos = np.mean(cos_list) + avg_rho = np.mean(rho_list) + avg_nudge = np.mean(nudge_list) + + if avg_cos > best_cos: + best_cos = avg_cos + best_step = step + + history.append({ + 'step': step, + 'avg_cos': avg_cos, + 'avg_rho': avg_rho, + 'avg_nudge': avg_nudge, + 'loss_term': loss_term.item(), + 'loss_bridge': loss_bridge.item(), + }) + + return { + 'best_cos': best_cos, + 'best_step': best_step, + 'final_cos': history[-1]['avg_cos'], + 'final_rho': history[-1]['avg_rho'], + 'final_nudge': history[-1]['avg_nudge'], + 'history': history, + } + + +def main(): + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + print(f"Device: {device}") + + # Sweep configurations + configs = [ + # Baseline (original) + {'name': 'base', 'lam': 0.1, 'sigma_bridge': 0.03, 'K': 8, 'lr': 1e-3, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, + # Larger noise + {'name': 'noise_0.1', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, + # Much larger noise + {'name': 'noise_0.3', 'lam': 0.1, 'sigma_bridge': 0.3, 'K': 8, 'lr': 1e-3, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, + # Larger lambda + {'name': 'lam_1.0', 'lam': 1.0, 'sigma_bridge': 0.03, 'K': 8, 'lr': 1e-3, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, + # Large noise + large lambda + {'name': 'noise_lam', 'lam': 1.0, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, + # No LayerNorm + {'name': 'no_ln', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': False}, + # Larger value net + {'name': 'big_vnet', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 256, 'use_ln': True}, + # Slower EMA + {'name': 'ema_0.999', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3, + 'ema_momentum': 0.999, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, + # More K samples + {'name': 'K16', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 16, 'lr': 1e-3, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, + # Larger noise + large lambda + no LN + {'name': 'best_combo', 'lam': 1.0, 'sigma_bridge': 0.3, 'K': 8, 'lr': 1e-3, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': False}, + # Very large sigma + {'name': 'noise_1.0', 'lam': 1.0, 'sigma_bridge': 1.0, 'K': 8, 'lr': 1e-3, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, + # Lower lr + {'name': 'lr_3e-4', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 3e-4, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, + ] + + results = {} + for cfg in configs: + name = cfg.pop('name') + print(f"\n{'='*50}") + print(f"Config: {name}") + print(f" {cfg}") + res = run_credit_bridge_config(cfg, device) + results[name] = res + print(f" Best cos: {res['best_cos']:.4f} (step {res['best_step']})") + print(f" Final cos: {res['final_cos']:.4f}, rho: {res['final_rho']:.4f}, nudge: {res['final_nudge']:.4f}") + cfg['name'] = name # restore + + # Print summary + print("\n" + "="*80) + print("SWEEP SUMMARY") + print("="*80) + print(f"{'Config':<20} {'Best Cos':<12} {'Final Cos':<12} {'Final Rho':<12} {'Final Nudge':<12}") + print("-"*68) + for name, res in results.items(): + print(f"{name:<20} {res['best_cos']:<12.4f} {res['final_cos']:<12.4f} " + f"{res['final_rho']:<12.4f} {res['final_nudge']:<12.4f}") + + # Save + os.makedirs('results/toy_lq', exist_ok=True) + with open('results/toy_lq/sweep_results.json', 'w') as f: + json.dump(results, f, indent=2) + print("\nSaved to results/toy_lq/sweep_results.json") + + +if __name__ == '__main__': + main() |
