summaryrefslogtreecommitdiff
path: root/models/value_net.py
blob: 3c72f75258810931fe641c28f273535890cdd1f5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""
Value network V_phi(h_l, t_l, s) -> scalar.
Used by the Credit Bridge method.
Input: [LN(h_l), time_embed(t_l), s] concatenated.
"""
import torch
import torch.nn as nn
import math
import copy


class SinusoidalTimeEmbed(nn.Module):
    """Sinusoidal positional encoding for scalar depth-time t_l = l/L."""

    def __init__(self, embed_dim: int):
        super().__init__()
        self.embed_dim = embed_dim

    def forward(self, t: torch.Tensor) -> torch.Tensor:
        """t: (batch,) or (batch, 1) scalar in [0,1]."""
        if t.dim() == 1:
            t = t.unsqueeze(-1)  # (batch, 1)
        half = self.embed_dim // 2
        freqs = torch.exp(
            -math.log(10000.0) * torch.arange(half, device=t.device, dtype=t.dtype) / half
        )
        args = t * freqs.unsqueeze(0)  # (batch, half)
        return torch.cat([torch.sin(args), torch.cos(args)], dim=-1)  # (batch, embed_dim)


class ValueNet(nn.Module):
    """
    Scalar value network V_phi(h_l, t_l, s).
    Inputs:
        h: hidden state (batch, d_hidden)
        t: depth-time scalar (batch,) in [0, 1]
        s: terminal modulation code (batch, s_dim)
    Output:
        V: scalar (batch,)
    """

    def __init__(self, d_hidden: int, s_dim: int, time_embed_dim: int = 32,
                 hidden_dim: int = 256, num_layers: int = 3):
        super().__init__()
        self.ln = nn.LayerNorm(d_hidden)
        self.time_embed = SinusoidalTimeEmbed(time_embed_dim)

        input_dim = d_hidden + time_embed_dim + s_dim
        layers = []
        for i in range(num_layers):
            in_d = input_dim if i == 0 else hidden_dim
            layers.append(nn.Linear(in_d, hidden_dim))
            layers.append(nn.GELU())
        layers.append(nn.Linear(hidden_dim, 1))
        self.net = nn.Sequential(*layers)

    def forward(self, h: torch.Tensor, t: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
        """Returns V(h, t, s) as (batch,) scalar."""
        h_normed = self.ln(h)
        t_emb = self.time_embed(t)
        inp = torch.cat([h_normed, t_emb, s], dim=-1)
        return self.net(inp).squeeze(-1)


def create_ema_model(model: nn.Module) -> nn.Module:
    """Create an EMA copy of a model."""
    ema = copy.deepcopy(model)
    for p in ema.parameters():
        p.requires_grad_(False)
    return ema


@torch.no_grad()
def update_ema(model: nn.Module, ema_model: nn.Module, momentum: float = 0.99):
    """Update EMA model parameters."""
    for p, ep in zip(model.parameters(), ema_model.parameters()):
        ep.data.mul_(momentum).add_(p.data, alpha=1 - momentum)