summaryrefslogtreecommitdiff
path: root/models/cvh.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/cvh.py')
-rw-r--r--models/cvh.py86
1 files changed, 86 insertions, 0 deletions
diff --git a/models/cvh.py b/models/cvh.py
new file mode 100644
index 0000000..e669722
--- /dev/null
+++ b/models/cvh.py
@@ -0,0 +1,86 @@
+"""Contextual Vector Head (CVH): the core personalization module.
+
+h'_t = h_t + alpha * B(theta_u ⊙ (A @ h_t))
+
+Where:
+- A ∈ R^{d x H}: fixed random projection (down), scaled by 1/sqrt(H)
+- B ∈ R^{H x d}: fixed random projection (up), scaled by 1/sqrt(d)
+- theta_u ∈ R^d: per-user vector (the only thing that changes per user)
+- alpha: scaling factor
+"""
+
+import torch
+import torch.nn as nn
+
+
+class CVHHead(nn.Module):
+ """Contextual Vector Head for style personalization."""
+
+ def __init__(self, hidden_size: int, d: int = 64, alpha: float = 0.1, basis_seed: int = 42):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.d = d
+ self.alpha = alpha
+
+ gen = torch.Generator()
+ gen.manual_seed(basis_seed)
+
+ # A: down-projection (d, H) - fan_in = H
+ scale_a = 1.0 / (hidden_size ** 0.5)
+ self.register_buffer('A', torch.randn(d, hidden_size, generator=gen) * scale_a)
+
+ # B: up-projection (H, d) - fan_in = d
+ scale_b = 1.0 / (d ** 0.5)
+ self.register_buffer('B', torch.randn(hidden_size, d, generator=gen) * scale_b)
+
+ def forward(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
+ """Apply contextual vector head.
+
+ Args:
+ h: Hidden states (batch, H) or (seq_len, H) - float32
+ theta: User vector (d,) - float32
+
+ Returns:
+ h_prime: Modified hidden states, same shape as h
+ """
+ # A @ h^T -> (d, batch) then transpose -> (batch, d)
+ projected = (self.A.float() @ h.T).T # (batch, d)
+
+ # Element-wise gating with user vector
+ gated = theta.unsqueeze(0) * projected # (batch, d)
+
+ # Project back up, scale by original hidden state magnitude
+ residual = (self.B.float() @ gated.T).T # (batch, H)
+
+ h_prime = h + self.alpha * residual
+ return h_prime
+
+ def forward_fn(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
+ return self.forward(h, theta)
+
+
+class UnconditionalHead(nn.Module):
+ """Unconditional vector head baseline: h'_t = h_t + alpha * U @ theta_u
+
+ No dependence on current hidden state - just adds a fixed user bias.
+ """
+
+ def __init__(self, hidden_size: int, d: int = 64, alpha: float = 0.1, basis_seed: int = 42):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.d = d
+ self.alpha = alpha
+
+ gen = torch.Generator()
+ gen.manual_seed(basis_seed + 1000)
+ scale = 1.0 / (d ** 0.5) # fan_in = d
+
+ self.register_buffer('U', torch.randn(hidden_size, d, generator=gen) * scale)
+
+ def forward(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
+ bias = self.U.float() @ theta # (H,)
+ h_prime = h + self.alpha * bias.unsqueeze(0)
+ return h_prime
+
+ def forward_fn(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
+ return self.forward(h, theta)