1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
|
"""Contextual Vector Head (CVH): the core personalization module.
h'_t = h_t + alpha * B(theta_u ⊙ (A @ h_t))
Where:
- A ∈ R^{d x H}: fixed random projection (down), scaled by 1/sqrt(H)
- B ∈ R^{H x d}: fixed random projection (up), scaled by 1/sqrt(d)
- theta_u ∈ R^d: per-user vector (the only thing that changes per user)
- alpha: scaling factor
"""
import torch
import torch.nn as nn
class CVHHead(nn.Module):
"""Contextual Vector Head for style personalization."""
def __init__(self, hidden_size: int, d: int = 64, alpha: float = 0.1, basis_seed: int = 42):
super().__init__()
self.hidden_size = hidden_size
self.d = d
self.alpha = alpha
gen = torch.Generator()
gen.manual_seed(basis_seed)
# A: down-projection (d, H) - fan_in = H
scale_a = 1.0 / (hidden_size ** 0.5)
self.register_buffer('A', torch.randn(d, hidden_size, generator=gen) * scale_a)
# B: up-projection (H, d) - fan_in = d
scale_b = 1.0 / (d ** 0.5)
self.register_buffer('B', torch.randn(hidden_size, d, generator=gen) * scale_b)
def forward(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
"""Apply contextual vector head.
Args:
h: Hidden states (batch, H) or (seq_len, H) - float32
theta: User vector (d,) - float32
Returns:
h_prime: Modified hidden states, same shape as h
"""
# A @ h^T -> (d, batch) then transpose -> (batch, d)
projected = (self.A.float() @ h.T).T # (batch, d)
# Element-wise gating with user vector
gated = theta.unsqueeze(0) * projected # (batch, d)
# Project back up, scale by original hidden state magnitude
residual = (self.B.float() @ gated.T).T # (batch, H)
h_prime = h + self.alpha * residual
return h_prime
def forward_fn(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
return self.forward(h, theta)
class UnconditionalHead(nn.Module):
"""Unconditional vector head baseline: h'_t = h_t + alpha * U @ theta_u
No dependence on current hidden state - just adds a fixed user bias.
"""
def __init__(self, hidden_size: int, d: int = 64, alpha: float = 0.1, basis_seed: int = 42):
super().__init__()
self.hidden_size = hidden_size
self.d = d
self.alpha = alpha
gen = torch.Generator()
gen.manual_seed(basis_seed + 1000)
scale = 1.0 / (d ** 0.5) # fan_in = d
self.register_buffer('U', torch.randn(hidden_size, d, generator=gen) * scale)
def forward(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
bias = self.U.float() @ theta # (H,)
h_prime = h + self.alpha * bias.unsqueeze(0)
return h_prime
def forward_fn(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
return self.forward(h, theta)
|