summaryrefslogtreecommitdiff
path: root/experiments/toy_lq_sweep.py
diff options
context:
space:
mode:
Diffstat (limited to 'experiments/toy_lq_sweep.py')
-rw-r--r--experiments/toy_lq_sweep.py243
1 files changed, 243 insertions, 0 deletions
diff --git a/experiments/toy_lq_sweep.py b/experiments/toy_lq_sweep.py
new file mode 100644
index 0000000..ae82ef0
--- /dev/null
+++ b/experiments/toy_lq_sweep.py
@@ -0,0 +1,243 @@
+"""
+Sweep over credit bridge hyperparameters to find a configuration
+where the value field gradient actually aligns with the costate.
+
+Key hypothesis: the credit bridge needs sufficient noise (sigma_bridge)
+and temperature (lambda) to make V_phi sensitive to cost-relevant directions.
+"""
+import os
+import sys
+import json
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from itertools import product
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.value_net import ValueNet, create_ema_model, update_ema
+from models.state_bridge import StateBridgeNet
+from experiments.toy_lq import (
+ generate_stable_dynamics, rollout_forward, terminal_loss,
+ exact_costate, make_forward_fn_from_layer
+)
+from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation, nudging_test
+
+
+def run_credit_bridge_config(config, device):
+ """Run credit bridge with specific hyperparameters and return final metrics."""
+ d = 64
+ m = 10
+ L = 12
+ sigma = 0.03
+ batch_size = 256
+ num_steps = config['num_steps']
+ lr = config['lr']
+ lam = config['lam']
+ K = config['K']
+ ema_momentum = config['ema_momentum']
+ sigma_bridge = config['sigma_bridge']
+ hidden_dim = config.get('hidden_dim', 128)
+ use_ln = config.get('use_ln', True)
+
+ torch.manual_seed(42)
+ np.random.seed(42)
+
+ Ms = generate_stable_dynamics(d, L, spectral_max=0.05, seed=42)
+ C = torch.randn(m, d, device=device) / np.sqrt(d)
+
+ # Value net - optionally without LayerNorm
+ value_net = ValueNet(d_hidden=d, s_dim=m, time_embed_dim=16,
+ hidden_dim=hidden_dim, num_layers=2).to(device)
+ if not use_ln:
+ value_net.ln = nn.Identity()
+
+ value_net_ema = create_ema_model(value_net)
+ opt_value = optim.Adam(value_net.parameters(), lr=lr)
+
+ best_cos = -1.0
+ best_step = 0
+ history = []
+
+ for step in range(1, num_steps + 1):
+ h0 = torch.randn(batch_size, d, device=device)
+ y = torch.randn(batch_size, m, device=device)
+ hiddens = rollout_forward(h0, Ms, sigma, L, device)
+ hL = hiddens[L]
+ e_T = hL @ C.T - y
+ s = e_T.detach()
+ true_loss = terminal_loss(hL.detach(), C, y).detach()
+
+ # Terminal boundary
+ hL_det = hL.detach()
+ t_L = torch.ones(batch_size, device=device)
+ V_terminal = value_net(hL_det, t_L, s)
+ loss_term = ((V_terminal - true_loss) ** 2).mean()
+
+ # Bridge consistency
+ loss_bridge = 0.0
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch_size,), l / L, device=device)
+ t_l_next = torch.full((batch_size,), (l + 1) / L, device=device)
+
+ V_l = value_net(h_l_det, t_l, s)
+
+ with torch.no_grad():
+ h_next_det = hiddens[l + 1].detach()
+ log_terms = []
+ for k in range(K):
+ noise = sigma_bridge * torch.randn(batch_size, d, device=device)
+ h_noisy = h_next_det + noise
+ V_next = value_net_ema(h_noisy, t_l_next, s)
+ log_terms.append(-V_next / lam)
+
+ log_terms_stack = torch.stack(log_terms, dim=-1)
+ V_target = -lam * (torch.logsumexp(log_terms_stack, dim=-1) - np.log(K))
+
+ loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean()
+
+ loss_bridge = loss_bridge / L
+ total_loss = loss_term + loss_bridge
+
+ opt_value.zero_grad()
+ total_loss.backward()
+ torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0)
+ opt_value.step()
+ update_ema(value_net, value_net_ema, ema_momentum)
+
+ # Quick evaluation
+ if step % 500 == 0 or step == num_steps:
+ with torch.no_grad():
+ eval_batch = 128
+ h0_e = torch.randn(eval_batch, d, device=device)
+ y_e = torch.randn(eval_batch, m, device=device)
+ hiddens_e = rollout_forward(h0_e, Ms, sigma, L, device)
+ hL_e = hiddens_e[L]
+ e_T_e = hL_e @ C.T - y_e
+ s_e = e_T_e.detach()
+ costates = exact_costate(hiddens_e, Ms, C, y_e, device)
+
+ cos_list = []
+ rho_list = []
+ nudge_list = []
+ for l in range(L):
+ h_l = hiddens_e[l].detach()
+ t_l = torch.full((eval_batch,), l / L, device=device)
+ a_exact = costates[l].detach()
+
+ h_l_req = h_l.clone().requires_grad_(True)
+ V_l = value_net(h_l_req, t_l, s_e)
+ a_credit = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0]
+
+ cos_list.append(cosine_similarity_batch(a_credit, a_exact))
+
+ fwd_fn = make_forward_fn_from_layer(hiddens_e, Ms, C, y_e, sigma, l, device)
+ rho = perturbation_correlation(h_l, a_credit.detach(), fwd_fn, epsilon=1e-3, M=16)
+ rho_list.append(rho)
+ nud = nudging_test(h_l, a_credit.detach(), fwd_fn, eta=0.01)
+ nudge_list.append(nud)
+
+ avg_cos = np.mean(cos_list)
+ avg_rho = np.mean(rho_list)
+ avg_nudge = np.mean(nudge_list)
+
+ if avg_cos > best_cos:
+ best_cos = avg_cos
+ best_step = step
+
+ history.append({
+ 'step': step,
+ 'avg_cos': avg_cos,
+ 'avg_rho': avg_rho,
+ 'avg_nudge': avg_nudge,
+ 'loss_term': loss_term.item(),
+ 'loss_bridge': loss_bridge.item(),
+ })
+
+ return {
+ 'best_cos': best_cos,
+ 'best_step': best_step,
+ 'final_cos': history[-1]['avg_cos'],
+ 'final_rho': history[-1]['avg_rho'],
+ 'final_nudge': history[-1]['avg_nudge'],
+ 'history': history,
+ }
+
+
+def main():
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
+ print(f"Device: {device}")
+
+ # Sweep configurations
+ configs = [
+ # Baseline (original)
+ {'name': 'base', 'lam': 0.1, 'sigma_bridge': 0.03, 'K': 8, 'lr': 1e-3,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True},
+ # Larger noise
+ {'name': 'noise_0.1', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True},
+ # Much larger noise
+ {'name': 'noise_0.3', 'lam': 0.1, 'sigma_bridge': 0.3, 'K': 8, 'lr': 1e-3,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True},
+ # Larger lambda
+ {'name': 'lam_1.0', 'lam': 1.0, 'sigma_bridge': 0.03, 'K': 8, 'lr': 1e-3,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True},
+ # Large noise + large lambda
+ {'name': 'noise_lam', 'lam': 1.0, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True},
+ # No LayerNorm
+ {'name': 'no_ln', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': False},
+ # Larger value net
+ {'name': 'big_vnet', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 256, 'use_ln': True},
+ # Slower EMA
+ {'name': 'ema_0.999', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3,
+ 'ema_momentum': 0.999, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True},
+ # More K samples
+ {'name': 'K16', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 16, 'lr': 1e-3,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True},
+ # Larger noise + large lambda + no LN
+ {'name': 'best_combo', 'lam': 1.0, 'sigma_bridge': 0.3, 'K': 8, 'lr': 1e-3,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': False},
+ # Very large sigma
+ {'name': 'noise_1.0', 'lam': 1.0, 'sigma_bridge': 1.0, 'K': 8, 'lr': 1e-3,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True},
+ # Lower lr
+ {'name': 'lr_3e-4', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 3e-4,
+ 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True},
+ ]
+
+ results = {}
+ for cfg in configs:
+ name = cfg.pop('name')
+ print(f"\n{'='*50}")
+ print(f"Config: {name}")
+ print(f" {cfg}")
+ res = run_credit_bridge_config(cfg, device)
+ results[name] = res
+ print(f" Best cos: {res['best_cos']:.4f} (step {res['best_step']})")
+ print(f" Final cos: {res['final_cos']:.4f}, rho: {res['final_rho']:.4f}, nudge: {res['final_nudge']:.4f}")
+ cfg['name'] = name # restore
+
+ # Print summary
+ print("\n" + "="*80)
+ print("SWEEP SUMMARY")
+ print("="*80)
+ print(f"{'Config':<20} {'Best Cos':<12} {'Final Cos':<12} {'Final Rho':<12} {'Final Nudge':<12}")
+ print("-"*68)
+ for name, res in results.items():
+ print(f"{name:<20} {res['best_cos']:<12.4f} {res['final_cos']:<12.4f} "
+ f"{res['final_rho']:<12.4f} {res['final_nudge']:<12.4f}")
+
+ # Save
+ os.makedirs('results/toy_lq', exist_ok=True)
+ with open('results/toy_lq/sweep_results.json', 'w') as f:
+ json.dump(results, f, indent=2)
+ print("\nSaved to results/toy_lq/sweep_results.json")
+
+
+if __name__ == '__main__':
+ main()