summaryrefslogtreecommitdiff
path: root/models/svd_cvh.py
blob: 82e63f0de07a4abd38bb4d716cbbf433ad7f20d0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""SVD-based CVH: use principal components of lm_head as basis instead of random.

The key insight: instead of random A and B, use the SVD of the lm_head weight matrix.
The top-d right singular vectors of W_lm define the most important directions in hidden
space for token prediction. Modulating these directions with theta_u should be more
effective than random directions.

This doesn't violate "no training" since the basis comes from the frozen model's
existing weights, not from any user data.
"""

import torch
import torch.nn as nn


class SVDCVHHead(nn.Module):
    """CVH with SVD-derived basis from lm_head."""

    def __init__(self, lm_head_weight: torch.Tensor, d: int = 64, alpha: float = 0.1):
        """
        Args:
            lm_head_weight: (vocab_size, H) weight matrix from frozen lm_head
            d: Number of principal components to use
            alpha: Scaling factor
        """
        super().__init__()
        self.d = d
        self.alpha = alpha

        # Compute SVD of lm_head weight: W = U S V^T
        # V^T rows are the right singular vectors (principal directions in H-space)
        with torch.no_grad():
            W = lm_head_weight.float()
            # Use truncated SVD for efficiency
            U, S, Vh = torch.linalg.svd(W, full_matrices=False)
            # Vh: (min(vocab, H), H) - take top d rows
            # S: (min(vocab, H),) - singular values

            # A: down-projection using top-d right singular vectors
            # Shape: (d, H) - each row is a right singular vector
            A = Vh[:d, :]  # (d, H)

            # B: up-projection - scale by inverse singular values for conditioning
            # Shape: (H, d)
            B = Vh[:d, :].T  # (H, d)

            # Store singular values for optional weighting
            self.register_buffer('S', S[:d].clone())

        self.register_buffer('A', A)
        self.register_buffer('B', B)

    def forward(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        """Apply SVD-based contextual vector head.

        Args:
            h: Hidden states (batch, H) - float32
            theta: User vector (d,) - float32

        Returns:
            h_prime: Modified hidden states
        """
        # Project down to d-dim PCA space: (batch, d)
        projected = (self.A @ h.T).T  # (batch, d)

        # Element-wise gating with user vector
        gated = theta.unsqueeze(0) * projected  # (batch, d)

        # Project back up: (batch, H)
        residual = (self.B @ gated.T).T  # (batch, H)

        return h + self.alpha * residual

    def forward_fn(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        return self.forward(h, theta)


class SVDUncondHead(nn.Module):
    """Unconditional head with SVD-derived basis."""

    def __init__(self, lm_head_weight: torch.Tensor, d: int = 64, alpha: float = 0.1):
        super().__init__()
        self.d = d
        self.alpha = alpha

        with torch.no_grad():
            W = lm_head_weight.float()
            U, S, Vh = torch.linalg.svd(W, full_matrices=False)
            # Use Vh^T as the up-projection
            B = Vh[:d, :].T  # (H, d)

        self.register_buffer('U_proj', B)

    def forward(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        bias = self.U_proj @ theta  # (H,)
        return h + self.alpha * bias.unsqueeze(0)

    def forward_fn(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        return self.forward(h, theta)