summaryrefslogtreecommitdiff
path: root/models/state_bridge.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-03-23 18:21:26 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-03-23 18:21:26 -0500
commit6ed4fa50ddfa4c7957aaa909aaf72f0d7d317712 (patch)
treed7c63adcd19c4f5d46c8a937e5047fece55dea62 /models/state_bridge.py
Initial implementation: all models, methods, toy and CIFAR experiments
Debug phase. Toy LQ experiments (3 seeds) complete with terminal gradient matching. Credit bridge matches state bridge on linear system (~0.94 cosine). CIFAR experiments in progress.
Diffstat (limited to 'models/state_bridge.py')
-rw-r--r--models/state_bridge.py35
1 files changed, 35 insertions, 0 deletions
diff --git a/models/state_bridge.py b/models/state_bridge.py
new file mode 100644
index 0000000..0a0e7aa
--- /dev/null
+++ b/models/state_bridge.py
@@ -0,0 +1,35 @@
+"""
+State Bridge predictor G_psi(h_l, t_l, s) -> predicted h_L.
+Used by the State Bridge method.
+"""
+import torch
+import torch.nn as nn
+from .value_net import SinusoidalTimeEmbed
+
+
+class StateBridgeNet(nn.Module):
+ """
+ State predictor G_psi(h_l, t_l, s) -> predicted terminal state h_L.
+ """
+
+ def __init__(self, d_hidden: int, s_dim: int, time_embed_dim: int = 32,
+ hidden_dim: int = 256, num_layers: int = 3):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.time_embed = SinusoidalTimeEmbed(time_embed_dim)
+
+ input_dim = d_hidden + time_embed_dim + s_dim
+ layers = []
+ for i in range(num_layers):
+ in_d = input_dim if i == 0 else hidden_dim
+ layers.append(nn.Linear(in_d, hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, d_hidden))
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, h: torch.Tensor, t: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
+ """Returns predicted h_L as (batch, d_hidden)."""
+ h_normed = self.ln(h)
+ t_emb = self.time_embed(t)
+ inp = torch.cat([h_normed, t_emb, s], dim=-1)
+ return self.net(inp)