summaryrefslogtreecommitdiff
path: root/experiments/toy_lq_v2.py
diff options
context:
space:
mode:
Diffstat (limited to 'experiments/toy_lq_v2.py')
-rw-r--r--experiments/toy_lq_v2.py327
1 files changed, 327 insertions, 0 deletions
diff --git a/experiments/toy_lq_v2.py b/experiments/toy_lq_v2.py
new file mode 100644
index 0000000..ab766b6
--- /dev/null
+++ b/experiments/toy_lq_v2.py
@@ -0,0 +1,327 @@
+"""
+Phase A v2: Enhanced toy LQ experiment.
+
+Key improvements over v1:
+1. Terminal gradient matching: V_phi at terminal layer should have grad_h V matching
+ the exact terminal gradient (this is LOCAL info, no hidden BP needed).
+2. Larger noise sweep integrated.
+3. Optional FM auxiliary for gradient smoothness.
+4. Better diagnostics.
+
+The terminal gradient a_L = C^T(C h_L - y) is computed from output layer only,
+so using it is allowed under the "no hidden BP anchor" constraint.
+"""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.value_net import ValueNet, create_ema_model, update_ema
+from models.state_bridge import StateBridgeNet
+from experiments.toy_lq import (
+ generate_stable_dynamics, rollout_forward, terminal_loss,
+ exact_costate, make_forward_fn_from_layer
+)
+from metrics.credit_metrics import (
+ cosine_similarity_batch, perturbation_correlation, nudging_test
+)
+
+
+def run_experiment(args):
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ torch.manual_seed(args.seed)
+ np.random.seed(args.seed)
+
+ d = args.d_hidden
+ m = args.output_dim
+ L = args.num_layers
+ sigma = args.sigma
+ batch_size = args.batch_size
+ num_steps = args.num_steps
+ lr = args.lr_fb
+ lam = args.lam
+ K = args.K
+ ema_momentum = args.ema_momentum
+ sigma_bridge = args.sigma_bridge
+
+ print(f"=== Toy LQ v2 Experiment ===")
+ print(f"d={d}, m={m}, L={L}, sigma={sigma}, seed={args.seed}")
+ print(f"lam={lam}, sigma_bridge={sigma_bridge}, K={K}")
+ print(f"terminal_grad_weight={args.term_grad_weight}")
+ print(f"fm_weight={args.fm_weight}")
+ print(f"device={device}")
+
+ Ms = generate_stable_dynamics(d, L, spectral_max=0.05, seed=args.seed)
+ C = torch.randn(m, d, device=device) / np.sqrt(d)
+
+ # DFA
+ Bs_dfa = [torch.randn(d, m, device=device) / np.sqrt(m) for _ in range(L + 1)]
+
+ # State Bridge
+ state_bridge = StateBridgeNet(d_hidden=d, s_dim=m, time_embed_dim=16,
+ hidden_dim=128, num_layers=2).to(device)
+ opt_state = optim.Adam(state_bridge.parameters(), lr=lr)
+
+ # Credit Bridge
+ value_net = ValueNet(d_hidden=d, s_dim=m, time_embed_dim=16,
+ hidden_dim=args.vnet_hidden, num_layers=args.vnet_layers).to(device)
+ value_net_ema = create_ema_model(value_net)
+ opt_value = optim.Adam(value_net.parameters(), lr=lr)
+
+ log = {key: [] for key in [
+ 'steps',
+ 'dfa_costate_cos', 'state_costate_cos', 'credit_costate_cos',
+ 'dfa_rho', 'state_rho', 'credit_rho',
+ 'dfa_nudge', 'state_nudge', 'credit_nudge',
+ 'bridge_residual', 'state_bridge_loss', 'credit_bridge_loss',
+ 'term_loss', 'bridge_loss', 'term_grad_loss', 'fm_loss',
+ ]}
+
+ for step in range(1, num_steps + 1):
+ h0 = torch.randn(batch_size, d, device=device)
+ y = torch.randn(batch_size, m, device=device)
+ hiddens = rollout_forward(h0, Ms, sigma, L, device)
+ hL = hiddens[L]
+ e_T = hL @ C.T - y
+ s = e_T.detach()
+
+ # ---- Train State Bridge ----
+ state_loss = 0.0
+ hL_det = hL.detach()
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch_size,), l / L, device=device)
+ pred_hL = state_bridge(h_l_det, t_l, s)
+ state_loss = state_loss + ((pred_hL - hL_det) ** 2).sum(dim=-1).mean()
+ state_loss = state_loss / L
+
+ opt_state.zero_grad()
+ state_loss.backward()
+ opt_state.step()
+
+ # ---- Train Credit Bridge ----
+ # 1. Terminal boundary: V(h_L, 1, s) ≈ Phi(h_L, y)
+ hL_det = hL.detach()
+ t_L = torch.ones(batch_size, device=device)
+ true_loss = terminal_loss(hL_det, C, y).detach()
+ V_terminal = value_net(hL_det, t_L, s)
+ loss_term = ((V_terminal - true_loss) ** 2).mean()
+
+ # 2. Terminal gradient matching: grad_h V(h_L, 1, s) ≈ a_L^exact
+ # This uses only terminal-local information (no hidden BP)
+ loss_term_grad = torch.tensor(0.0, device=device)
+ if args.term_grad_weight > 0:
+ hL_req = hL.detach().requires_grad_(True)
+ V_at_L = value_net(hL_req, t_L, s)
+ grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0]
+ # Exact terminal gradient: C^T (C h_L - y)
+ a_L_exact = (e_T @ C).detach() # (batch, d) -- stop grad on target
+ loss_term_grad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean()
+
+ # 3. Bridge consistency
+ loss_bridge = 0.0
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch_size,), l / L, device=device)
+ t_l_next = torch.full((batch_size,), (l + 1) / L, device=device)
+
+ V_l = value_net(h_l_det, t_l, s)
+
+ with torch.no_grad():
+ h_next_det = hiddens[l + 1].detach()
+ log_terms = []
+ for k in range(K):
+ noise = sigma_bridge * torch.randn(batch_size, d, device=device)
+ h_noisy = h_next_det + noise
+ V_next = value_net_ema(h_noisy, t_l_next, s)
+ log_terms.append(-V_next / lam)
+ log_terms_stack = torch.stack(log_terms, dim=-1)
+ V_target = -lam * (torch.logsumexp(log_terms_stack, dim=-1) - np.log(K))
+
+ loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean()
+ loss_bridge = loss_bridge / L
+
+ # 4. FM auxiliary (optional): enforce gradient smoothness
+ loss_fm = torch.tensor(0.0, device=device)
+ if args.fm_weight > 0:
+ for l in range(L):
+ tau = torch.rand(batch_size, 1, device=device)
+ h_l_det = hiddens[l].detach()
+ h_next_det = hiddens[l + 1].detach()
+ f_l = h_next_det - h_l_det # residual
+
+ eps = torch.randn(batch_size, d, device=device)
+ h_mid = h_l_det + tau * f_l + (tau * (1 - tau)).sqrt() * sigma_bridge * eps
+ h_mid.requires_grad_(True)
+
+ t_mid = torch.full((batch_size, 1), 0, device=device)
+ t_mid = (l + tau) / L
+ t_mid_flat = t_mid.squeeze(-1)
+
+ V_mid = value_net(h_mid, t_mid_flat, s)
+ grad_V_mid = torch.autograd.grad(V_mid.sum(), h_mid, create_graph=True)[0]
+
+ # Interpolated target gradient
+ # Get a_l and a_{l+1} from current value net (no create_graph for targets)
+ h_l_r = h_l_det.clone().requires_grad_(True)
+ t_l_v = torch.full((batch_size,), l / L, device=device)
+ V_l_ = value_net(h_l_r, t_l_v, s)
+ a_l = torch.autograd.grad(V_l_.sum(), h_l_r, create_graph=False)[0].detach()
+
+ h_next_r = h_next_det.clone().requires_grad_(True)
+ t_next_v = torch.full((batch_size,), (l + 1) / L, device=device)
+ V_next_ = value_net(h_next_r, t_next_v, s)
+ a_next = torch.autograd.grad(V_next_.sum(), h_next_r, create_graph=False)[0].detach()
+
+ target_grad = ((1 - tau) * a_l + tau * a_next).detach()
+ loss_fm = loss_fm + ((grad_V_mid - target_grad) ** 2).sum(dim=-1).mean()
+ loss_fm = loss_fm / L
+
+ total_loss = (loss_term
+ + loss_bridge
+ + args.term_grad_weight * loss_term_grad
+ + args.fm_weight * loss_fm)
+
+ opt_value.zero_grad()
+ total_loss.backward()
+ torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0)
+ opt_value.step()
+ update_ema(value_net, value_net_ema, ema_momentum)
+
+ # ---- Evaluation ----
+ if step % args.eval_every == 0 or step == 1:
+ with torch.no_grad():
+ eval_batch = 128
+ h0_e = torch.randn(eval_batch, d, device=device)
+ y_e = torch.randn(eval_batch, m, device=device)
+ hiddens_e = rollout_forward(h0_e, Ms, sigma, L, device)
+ hL_e = hiddens_e[L]
+ e_T_e = hL_e @ C.T - y_e
+ s_e = e_T_e.detach()
+ costates = exact_costate(hiddens_e, Ms, C, y_e, device)
+
+ dfa_cos, state_cos, credit_cos = [], [], []
+ dfa_rho, state_rho, credit_rho = [], [], []
+ dfa_nudge, state_nudge, credit_nudge = [], [], []
+ bridge_res_list = []
+
+ for l in range(L):
+ h_l = hiddens_e[l].detach()
+ a_exact = costates[l].detach()
+ t_l = torch.full((eval_batch,), l / L, device=device)
+
+ # DFA
+ a_dfa = e_T_e @ Bs_dfa[l].T
+ # State bridge
+ h_l_r1 = h_l.clone().requires_grad_(True)
+ pred_hL = state_bridge(h_l_r1, t_l, s_e)
+ pred_out = pred_hL @ C.T
+ pred_loss = 0.5 * ((pred_out - y_e) ** 2).sum(dim=-1)
+ a_state = torch.autograd.grad(pred_loss.sum(), h_l_r1, create_graph=False)[0]
+ # Credit bridge
+ h_l_r2 = h_l.clone().requires_grad_(True)
+ V_l = value_net(h_l_r2, t_l, s_e)
+ a_credit = torch.autograd.grad(V_l.sum(), h_l_r2, create_graph=False)[0]
+
+ dfa_cos.append(cosine_similarity_batch(a_dfa, a_exact))
+ state_cos.append(cosine_similarity_batch(a_state, a_exact))
+ credit_cos.append(cosine_similarity_batch(a_credit, a_exact))
+
+ fwd_fn = make_forward_fn_from_layer(hiddens_e, Ms, C, y_e, sigma, l, device)
+
+ dfa_rho.append(perturbation_correlation(h_l, a_dfa, fwd_fn, epsilon=1e-3, M=16))
+ state_rho.append(perturbation_correlation(h_l, a_state.detach(), fwd_fn, epsilon=1e-3, M=16))
+ credit_rho.append(perturbation_correlation(h_l, a_credit.detach(), fwd_fn, epsilon=1e-3, M=16))
+
+ dfa_nudge.append(nudging_test(h_l, a_dfa, fwd_fn, eta=0.01))
+ state_nudge.append(nudging_test(h_l, a_state.detach(), fwd_fn, eta=0.01))
+ credit_nudge.append(nudging_test(h_l, a_credit.detach(), fwd_fn, eta=0.01))
+
+ avg = lambda x: float(np.mean(x))
+ log['steps'].append(step)
+ log['dfa_costate_cos'].append(avg(dfa_cos))
+ log['state_costate_cos'].append(avg(state_cos))
+ log['credit_costate_cos'].append(avg(credit_cos))
+ log['dfa_rho'].append(avg(dfa_rho))
+ log['state_rho'].append(avg(state_rho))
+ log['credit_rho'].append(avg(credit_rho))
+ log['dfa_nudge'].append(avg(dfa_nudge))
+ log['state_nudge'].append(avg(state_nudge))
+ log['credit_nudge'].append(avg(credit_nudge))
+ log['state_bridge_loss'].append(state_loss.item())
+ log['credit_bridge_loss'].append(total_loss.item())
+ log['term_loss'].append(loss_term.item())
+ log['bridge_loss'].append(loss_bridge.item())
+ log['term_grad_loss'].append(loss_term_grad.item() if isinstance(loss_term_grad, torch.Tensor) else loss_term_grad)
+ log['fm_loss'].append(loss_fm.item() if isinstance(loss_fm, torch.Tensor) else loss_fm)
+
+ print(f"Step {step}/{num_steps}")
+ print(f" Costate cos - DFA: {avg(dfa_cos):.4f}, State: {avg(state_cos):.4f}, Credit: {avg(credit_cos):.4f}")
+ print(f" Perturb rho - DFA: {avg(dfa_rho):.4f}, State: {avg(state_rho):.4f}, Credit: {avg(credit_rho):.4f}")
+ print(f" Nudging - DFA: {avg(dfa_nudge):.4f}, State: {avg(state_nudge):.4f}, Credit: {avg(credit_nudge):.4f}")
+ print(f" Losses - term: {loss_term.item():.4f}, bridge: {loss_bridge.item():.4f}, "
+ f"tgrad: {loss_term_grad.item() if isinstance(loss_term_grad, torch.Tensor) else 0:.4f}, "
+ f"fm: {loss_fm.item() if isinstance(loss_fm, torch.Tensor) else 0:.4f}")
+ print(f" Per-layer credit cos: {['%.3f' % x for x in credit_cos]}")
+
+ # Save
+ os.makedirs(args.output_dir, exist_ok=True)
+ results = {
+ 'config': vars(args),
+ 'log': log,
+ 'final_per_layer': {
+ 'dfa_costate_cos': dfa_cos,
+ 'state_costate_cos': state_cos,
+ 'credit_costate_cos': credit_cos,
+ 'dfa_rho': dfa_rho,
+ 'state_rho': state_rho,
+ 'credit_rho': credit_rho,
+ 'dfa_nudge': dfa_nudge,
+ 'state_nudge': state_nudge,
+ 'credit_nudge': credit_nudge,
+ }
+ }
+ tag = f"seed{args.seed}_lam{args.lam}_sig{args.sigma_bridge}_tgw{args.term_grad_weight}_fm{args.fm_weight}"
+ out_path = os.path.join(args.output_dir, f'toy_lq_v2_{tag}.json')
+ with open(out_path, 'w') as f:
+ json.dump(results, f, indent=2)
+ print(f"\nResults saved to {out_path}")
+ return results
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Toy LQ v2')
+ parser.add_argument('--d_hidden', type=int, default=64)
+ parser.add_argument('--output_dim', type=int, default=10)
+ parser.add_argument('--num_layers', type=int, default=12)
+ parser.add_argument('--sigma', type=float, default=0.03)
+ parser.add_argument('--batch_size', type=int, default=256)
+ parser.add_argument('--num_steps', type=int, default=8000)
+ parser.add_argument('--lr_fb', type=float, default=1e-3)
+ parser.add_argument('--lam', type=float, default=0.1)
+ parser.add_argument('--K', type=int, default=8)
+ parser.add_argument('--ema_momentum', type=float, default=0.995)
+ parser.add_argument('--sigma_bridge', type=float, default=0.1)
+ parser.add_argument('--eval_every', type=int, default=500)
+ parser.add_argument('--seed', type=int, default=42)
+ parser.add_argument('--gpu', type=int, default=1)
+ parser.add_argument('--output_dir', type=str, default='results/toy_lq')
+ parser.add_argument('--vnet_hidden', type=int, default=256)
+ parser.add_argument('--vnet_layers', type=int, default=3)
+ # Key new options
+ parser.add_argument('--term_grad_weight', type=float, default=1.0,
+ help='Weight for terminal gradient matching loss')
+ parser.add_argument('--fm_weight', type=float, default=0.0,
+ help='Weight for FM gradient smoothness auxiliary')
+ args = parser.parse_args()
+ run_experiment(args)
+
+
+if __name__ == '__main__':
+ main()