diff options
Diffstat (limited to 'experiments')
| -rw-r--r-- | experiments/toy_lq_v2.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/experiments/toy_lq_v2.py b/experiments/toy_lq_v2.py index ab766b6..4d1b9a2 100644 --- a/experiments/toy_lq_v2.py +++ b/experiments/toy_lq_v2.py @@ -92,13 +92,15 @@ def run_experiment(args): s = e_T.detach() # ---- Train State Bridge ---- + # Use normalized MSE (consistent with CIFAR experiment) state_loss = 0.0 hL_det = hL.detach() for l in range(L): h_l_det = hiddens[l].detach() t_l = torch.full((batch_size,), l / L, device=device) pred_hL = state_bridge(h_l_det, t_l, s) - state_loss = state_loss + ((pred_hL - hL_det) ** 2).sum(dim=-1).mean() + target_norm = hL_det.norm(dim=-1, keepdim=True).clamp(min=1.0) + state_loss = state_loss + (((pred_hL - hL_det) / target_norm) ** 2).sum(dim=-1).mean() state_loss = state_loss / L opt_state.zero_grad() |
