summaryrefslogtreecommitdiff
path: root/experiments/toy_lq_v2.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-03-23 18:21:48 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-03-23 18:21:48 -0500
commit245e1174695c819642030461e3f544dffb7062fd (patch)
tree4379b4329a830a2bb29bd5ca895dca8cb864f036 /experiments/toy_lq_v2.py
parent6ed4fa50ddfa4c7957aaa909aaf72f0d7d317712 (diff)
Sync state bridge: use normalized MSE target in both toy and CIFAR
Reason: toy used raw MSE, CIFAR used normalized. They must be the same method for consistent reporting. Normalized MSE is more robust to varying h_L magnitudes.
Diffstat (limited to 'experiments/toy_lq_v2.py')
-rw-r--r--experiments/toy_lq_v2.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/experiments/toy_lq_v2.py b/experiments/toy_lq_v2.py
index ab766b6..4d1b9a2 100644
--- a/experiments/toy_lq_v2.py
+++ b/experiments/toy_lq_v2.py
@@ -92,13 +92,15 @@ def run_experiment(args):
s = e_T.detach()
# ---- Train State Bridge ----
+ # Use normalized MSE (consistent with CIFAR experiment)
state_loss = 0.0
hL_det = hL.detach()
for l in range(L):
h_l_det = hiddens[l].detach()
t_l = torch.full((batch_size,), l / L, device=device)
pred_hL = state_bridge(h_l_det, t_l, s)
- state_loss = state_loss + ((pred_hL - hL_det) ** 2).sum(dim=-1).mean()
+ target_norm = hL_det.norm(dim=-1, keepdim=True).clamp(min=1.0)
+ state_loss = state_loss + (((pred_hL - hL_det) / target_norm) ** 2).sum(dim=-1).mean()
state_loss = state_loss / L
opt_state.zero_grad()