"""Contextual Vector Head (CVH): the core personalization module. h'_t = h_t + alpha * B(theta_u ⊙ (A @ h_t)) Where: - A ∈ R^{d x H}: fixed random projection (down), scaled by 1/sqrt(H) - B ∈ R^{H x d}: fixed random projection (up), scaled by 1/sqrt(d) - theta_u ∈ R^d: per-user vector (the only thing that changes per user) - alpha: scaling factor """ import torch import torch.nn as nn class CVHHead(nn.Module): """Contextual Vector Head for style personalization.""" def __init__(self, hidden_size: int, d: int = 64, alpha: float = 0.1, basis_seed: int = 42): super().__init__() self.hidden_size = hidden_size self.d = d self.alpha = alpha gen = torch.Generator() gen.manual_seed(basis_seed) # A: down-projection (d, H) - fan_in = H scale_a = 1.0 / (hidden_size ** 0.5) self.register_buffer('A', torch.randn(d, hidden_size, generator=gen) * scale_a) # B: up-projection (H, d) - fan_in = d scale_b = 1.0 / (d ** 0.5) self.register_buffer('B', torch.randn(hidden_size, d, generator=gen) * scale_b) def forward(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor: """Apply contextual vector head. Args: h: Hidden states (batch, H) or (seq_len, H) - float32 theta: User vector (d,) - float32 Returns: h_prime: Modified hidden states, same shape as h """ # A @ h^T -> (d, batch) then transpose -> (batch, d) projected = (self.A.float() @ h.T).T # (batch, d) # Element-wise gating with user vector gated = theta.unsqueeze(0) * projected # (batch, d) # Project back up, scale by original hidden state magnitude residual = (self.B.float() @ gated.T).T # (batch, H) h_prime = h + self.alpha * residual return h_prime def forward_fn(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor: return self.forward(h, theta) class UnconditionalHead(nn.Module): """Unconditional vector head baseline: h'_t = h_t + alpha * U @ theta_u No dependence on current hidden state - just adds a fixed user bias. """ def __init__(self, hidden_size: int, d: int = 64, alpha: float = 0.1, basis_seed: int = 42): super().__init__() self.hidden_size = hidden_size self.d = d self.alpha = alpha gen = torch.Generator() gen.manual_seed(basis_seed + 1000) scale = 1.0 / (d ** 0.5) # fan_in = d self.register_buffer('U', torch.randn(hidden_size, d, generator=gen) * scale) def forward(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor: bias = self.U.float() @ theta # (H,) h_prime = h + self.alpha * bias.unsqueeze(0) return h_prime def forward_fn(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor: return self.forward(h, theta)