""" Value network V_phi(h_l, t_l, s) -> scalar. Used by the Credit Bridge method. Input: [LN(h_l), time_embed(t_l), s] concatenated. """ import torch import torch.nn as nn import math import copy class SinusoidalTimeEmbed(nn.Module): """Sinusoidal positional encoding for scalar depth-time t_l = l/L.""" def __init__(self, embed_dim: int): super().__init__() self.embed_dim = embed_dim def forward(self, t: torch.Tensor) -> torch.Tensor: """t: (batch,) or (batch, 1) scalar in [0,1].""" if t.dim() == 1: t = t.unsqueeze(-1) # (batch, 1) half = self.embed_dim // 2 freqs = torch.exp( -math.log(10000.0) * torch.arange(half, device=t.device, dtype=t.dtype) / half ) args = t * freqs.unsqueeze(0) # (batch, half) return torch.cat([torch.sin(args), torch.cos(args)], dim=-1) # (batch, embed_dim) class ValueNet(nn.Module): """ Scalar value network V_phi(h_l, t_l, s). Inputs: h: hidden state (batch, d_hidden) t: depth-time scalar (batch,) in [0, 1] s: terminal modulation code (batch, s_dim) Output: V: scalar (batch,) """ def __init__(self, d_hidden: int, s_dim: int, time_embed_dim: int = 32, hidden_dim: int = 256, num_layers: int = 3): super().__init__() self.ln = nn.LayerNorm(d_hidden) self.time_embed = SinusoidalTimeEmbed(time_embed_dim) input_dim = d_hidden + time_embed_dim + s_dim layers = [] for i in range(num_layers): in_d = input_dim if i == 0 else hidden_dim layers.append(nn.Linear(in_d, hidden_dim)) layers.append(nn.GELU()) layers.append(nn.Linear(hidden_dim, 1)) self.net = nn.Sequential(*layers) def forward(self, h: torch.Tensor, t: torch.Tensor, s: torch.Tensor) -> torch.Tensor: """Returns V(h, t, s) as (batch,) scalar.""" h_normed = self.ln(h) t_emb = self.time_embed(t) inp = torch.cat([h_normed, t_emb, s], dim=-1) return self.net(inp).squeeze(-1) def create_ema_model(model: nn.Module) -> nn.Module: """Create an EMA copy of a model.""" ema = copy.deepcopy(model) for p in ema.parameters(): p.requires_grad_(False) return ema @torch.no_grad() def update_ema(model: nn.Module, ema_model: nn.Module, momentum: float = 0.99): """Update EMA model parameters.""" for p, ep in zip(model.parameters(), ema_model.parameters()): ep.data.mul_(momentum).add_(p.data, alpha=1 - momentum)