""" State Bridge predictor G_psi(h_l, t_l, s) -> predicted h_L. Used by the State Bridge method. """ import torch import torch.nn as nn from .value_net import SinusoidalTimeEmbed class StateBridgeNet(nn.Module): """ State predictor G_psi(h_l, t_l, s) -> predicted terminal state h_L. """ def __init__(self, d_hidden: int, s_dim: int, time_embed_dim: int = 32, hidden_dim: int = 256, num_layers: int = 3): super().__init__() self.ln = nn.LayerNorm(d_hidden) self.time_embed = SinusoidalTimeEmbed(time_embed_dim) input_dim = d_hidden + time_embed_dim + s_dim layers = [] for i in range(num_layers): in_d = input_dim if i == 0 else hidden_dim layers.append(nn.Linear(in_d, hidden_dim)) layers.append(nn.GELU()) layers.append(nn.Linear(hidden_dim, d_hidden)) self.net = nn.Sequential(*layers) def forward(self, h: torch.Tensor, t: torch.Tensor, s: torch.Tensor) -> torch.Tensor: """Returns predicted h_L as (batch, d_hidden).""" h_normed = self.ln(h) t_emb = self.time_embed(t) inp = torch.cat([h_normed, t_emb, s], dim=-1) return self.net(inp)