diff options
Diffstat (limited to 'models/cvh.py')
| -rw-r--r-- | models/cvh.py | 86 |
1 files changed, 86 insertions, 0 deletions
diff --git a/models/cvh.py b/models/cvh.py new file mode 100644 index 0000000..e669722 --- /dev/null +++ b/models/cvh.py @@ -0,0 +1,86 @@ +"""Contextual Vector Head (CVH): the core personalization module. + +h'_t = h_t + alpha * B(theta_u ⊙ (A @ h_t)) + +Where: +- A ∈ R^{d x H}: fixed random projection (down), scaled by 1/sqrt(H) +- B ∈ R^{H x d}: fixed random projection (up), scaled by 1/sqrt(d) +- theta_u ∈ R^d: per-user vector (the only thing that changes per user) +- alpha: scaling factor +""" + +import torch +import torch.nn as nn + + +class CVHHead(nn.Module): + """Contextual Vector Head for style personalization.""" + + def __init__(self, hidden_size: int, d: int = 64, alpha: float = 0.1, basis_seed: int = 42): + super().__init__() + self.hidden_size = hidden_size + self.d = d + self.alpha = alpha + + gen = torch.Generator() + gen.manual_seed(basis_seed) + + # A: down-projection (d, H) - fan_in = H + scale_a = 1.0 / (hidden_size ** 0.5) + self.register_buffer('A', torch.randn(d, hidden_size, generator=gen) * scale_a) + + # B: up-projection (H, d) - fan_in = d + scale_b = 1.0 / (d ** 0.5) + self.register_buffer('B', torch.randn(hidden_size, d, generator=gen) * scale_b) + + def forward(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor: + """Apply contextual vector head. + + Args: + h: Hidden states (batch, H) or (seq_len, H) - float32 + theta: User vector (d,) - float32 + + Returns: + h_prime: Modified hidden states, same shape as h + """ + # A @ h^T -> (d, batch) then transpose -> (batch, d) + projected = (self.A.float() @ h.T).T # (batch, d) + + # Element-wise gating with user vector + gated = theta.unsqueeze(0) * projected # (batch, d) + + # Project back up, scale by original hidden state magnitude + residual = (self.B.float() @ gated.T).T # (batch, H) + + h_prime = h + self.alpha * residual + return h_prime + + def forward_fn(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor: + return self.forward(h, theta) + + +class UnconditionalHead(nn.Module): + """Unconditional vector head baseline: h'_t = h_t + alpha * U @ theta_u + + No dependence on current hidden state - just adds a fixed user bias. + """ + + def __init__(self, hidden_size: int, d: int = 64, alpha: float = 0.1, basis_seed: int = 42): + super().__init__() + self.hidden_size = hidden_size + self.d = d + self.alpha = alpha + + gen = torch.Generator() + gen.manual_seed(basis_seed + 1000) + scale = 1.0 / (d ** 0.5) # fan_in = d + + self.register_buffer('U', torch.randn(hidden_size, d, generator=gen) * scale) + + def forward(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor: + bias = self.U.float() @ theta # (H,) + h_prime = h + self.alpha * bias.unsqueeze(0) + return h_prime + + def forward_fn(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor: + return self.forward(h, theta) |
