diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-03-23 18:21:26 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-03-23 18:21:26 -0500 |
| commit | 6ed4fa50ddfa4c7957aaa909aaf72f0d7d317712 (patch) | |
| tree | d7c63adcd19c4f5d46c8a937e5047fece55dea62 /models/state_bridge.py | |
Initial implementation: all models, methods, toy and CIFAR experiments
Debug phase. Toy LQ experiments (3 seeds) complete with terminal gradient matching.
Credit bridge matches state bridge on linear system (~0.94 cosine).
CIFAR experiments in progress.
Diffstat (limited to 'models/state_bridge.py')
| -rw-r--r-- | models/state_bridge.py | 35 |
1 files changed, 35 insertions, 0 deletions
diff --git a/models/state_bridge.py b/models/state_bridge.py new file mode 100644 index 0000000..0a0e7aa --- /dev/null +++ b/models/state_bridge.py @@ -0,0 +1,35 @@ +""" +State Bridge predictor G_psi(h_l, t_l, s) -> predicted h_L. +Used by the State Bridge method. +""" +import torch +import torch.nn as nn +from .value_net import SinusoidalTimeEmbed + + +class StateBridgeNet(nn.Module): + """ + State predictor G_psi(h_l, t_l, s) -> predicted terminal state h_L. + """ + + def __init__(self, d_hidden: int, s_dim: int, time_embed_dim: int = 32, + hidden_dim: int = 256, num_layers: int = 3): + super().__init__() + self.ln = nn.LayerNorm(d_hidden) + self.time_embed = SinusoidalTimeEmbed(time_embed_dim) + + input_dim = d_hidden + time_embed_dim + s_dim + layers = [] + for i in range(num_layers): + in_d = input_dim if i == 0 else hidden_dim + layers.append(nn.Linear(in_d, hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, d_hidden)) + self.net = nn.Sequential(*layers) + + def forward(self, h: torch.Tensor, t: torch.Tensor, s: torch.Tensor) -> torch.Tensor: + """Returns predicted h_L as (batch, d_hidden).""" + h_normed = self.ln(h) + t_emb = self.time_embed(t) + inp = torch.cat([h_normed, t_emb, s], dim=-1) + return self.net(inp) |
