summaryrefslogtreecommitdiff
path: root/models/cvh.py
blob: e669722784fc329d0ead291ab62630fa8489a9b1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""Contextual Vector Head (CVH): the core personalization module.

h'_t = h_t + alpha * B(theta_u ⊙ (A @ h_t))

Where:
- A ∈ R^{d x H}: fixed random projection (down), scaled by 1/sqrt(H)
- B ∈ R^{H x d}: fixed random projection (up), scaled by 1/sqrt(d)
- theta_u ∈ R^d: per-user vector (the only thing that changes per user)
- alpha: scaling factor
"""

import torch
import torch.nn as nn


class CVHHead(nn.Module):
    """Contextual Vector Head for style personalization."""

    def __init__(self, hidden_size: int, d: int = 64, alpha: float = 0.1, basis_seed: int = 42):
        super().__init__()
        self.hidden_size = hidden_size
        self.d = d
        self.alpha = alpha

        gen = torch.Generator()
        gen.manual_seed(basis_seed)

        # A: down-projection (d, H) - fan_in = H
        scale_a = 1.0 / (hidden_size ** 0.5)
        self.register_buffer('A', torch.randn(d, hidden_size, generator=gen) * scale_a)

        # B: up-projection (H, d) - fan_in = d
        scale_b = 1.0 / (d ** 0.5)
        self.register_buffer('B', torch.randn(hidden_size, d, generator=gen) * scale_b)

    def forward(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        """Apply contextual vector head.

        Args:
            h: Hidden states (batch, H) or (seq_len, H) - float32
            theta: User vector (d,) - float32

        Returns:
            h_prime: Modified hidden states, same shape as h
        """
        # A @ h^T -> (d, batch) then transpose -> (batch, d)
        projected = (self.A.float() @ h.T).T  # (batch, d)

        # Element-wise gating with user vector
        gated = theta.unsqueeze(0) * projected  # (batch, d)

        # Project back up, scale by original hidden state magnitude
        residual = (self.B.float() @ gated.T).T  # (batch, H)

        h_prime = h + self.alpha * residual
        return h_prime

    def forward_fn(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        return self.forward(h, theta)


class UnconditionalHead(nn.Module):
    """Unconditional vector head baseline: h'_t = h_t + alpha * U @ theta_u

    No dependence on current hidden state - just adds a fixed user bias.
    """

    def __init__(self, hidden_size: int, d: int = 64, alpha: float = 0.1, basis_seed: int = 42):
        super().__init__()
        self.hidden_size = hidden_size
        self.d = d
        self.alpha = alpha

        gen = torch.Generator()
        gen.manual_seed(basis_seed + 1000)
        scale = 1.0 / (d ** 0.5)  # fan_in = d

        self.register_buffer('U', torch.randn(hidden_size, d, generator=gen) * scale)

    def forward(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        bias = self.U.float() @ theta  # (H,)
        h_prime = h + self.alpha * bias.unsqueeze(0)
        return h_prime

    def forward_fn(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        return self.forward(h, theta)