summaryrefslogtreecommitdiff
path: root/models/state_bridge.py
blob: 0a0e7aa7e7da8b736f8b1d2f555c1b3d5a9e1356 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
"""
State Bridge predictor G_psi(h_l, t_l, s) -> predicted h_L.
Used by the State Bridge method.
"""
import torch
import torch.nn as nn
from .value_net import SinusoidalTimeEmbed


class StateBridgeNet(nn.Module):
    """
    State predictor G_psi(h_l, t_l, s) -> predicted terminal state h_L.
    """

    def __init__(self, d_hidden: int, s_dim: int, time_embed_dim: int = 32,
                 hidden_dim: int = 256, num_layers: int = 3):
        super().__init__()
        self.ln = nn.LayerNorm(d_hidden)
        self.time_embed = SinusoidalTimeEmbed(time_embed_dim)

        input_dim = d_hidden + time_embed_dim + s_dim
        layers = []
        for i in range(num_layers):
            in_d = input_dim if i == 0 else hidden_dim
            layers.append(nn.Linear(in_d, hidden_dim))
            layers.append(nn.GELU())
        layers.append(nn.Linear(hidden_dim, d_hidden))
        self.net = nn.Sequential(*layers)

    def forward(self, h: torch.Tensor, t: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
        """Returns predicted h_L as (batch, d_hidden)."""
        h_normed = self.ln(h)
        t_emb = self.time_embed(t)
        inp = torch.cat([h_normed, t_emb, s], dim=-1)
        return self.net(inp)