summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
Diffstat (limited to 'experiments')
-rw-r--r--experiments/toy_lq_v2.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/experiments/toy_lq_v2.py b/experiments/toy_lq_v2.py
index ab766b6..4d1b9a2 100644
--- a/experiments/toy_lq_v2.py
+++ b/experiments/toy_lq_v2.py
@@ -92,13 +92,15 @@ def run_experiment(args):
s = e_T.detach()
# ---- Train State Bridge ----
+ # Use normalized MSE (consistent with CIFAR experiment)
state_loss = 0.0
hL_det = hL.detach()
for l in range(L):
h_l_det = hiddens[l].detach()
t_l = torch.full((batch_size,), l / L, device=device)
pred_hL = state_bridge(h_l_det, t_l, s)
- state_loss = state_loss + ((pred_hL - hL_det) ** 2).sum(dim=-1).mean()
+ target_norm = hL_det.norm(dim=-1, keepdim=True).clamp(min=1.0)
+ state_loss = state_loss + (((pred_hL - hL_det) / target_norm) ** 2).sum(dim=-1).mean()
state_loss = state_loss / L
opt_state.zero_grad()