summaryrefslogtreecommitdiff
path: root/models/svd_cvh.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-03 15:12:34 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-03 15:12:34 -0500
commit8fe28101366dd32562b8c5534d7fe359b252bdf3 (patch)
treec92a92184fb2f46f265ab84c1f754c3d5d6597bc /models/svd_cvh.py
Initial commit: UPH project codebase and experiment results
Includes model code, evaluation scripts, configs, analysis outputs, and experiment results for the User Prior Head personalization method. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'models/svd_cvh.py')
-rw-r--r--models/svd_cvh.py99
1 files changed, 99 insertions, 0 deletions
diff --git a/models/svd_cvh.py b/models/svd_cvh.py
new file mode 100644
index 0000000..82e63f0
--- /dev/null
+++ b/models/svd_cvh.py
@@ -0,0 +1,99 @@
+"""SVD-based CVH: use principal components of lm_head as basis instead of random.
+
+The key insight: instead of random A and B, use the SVD of the lm_head weight matrix.
+The top-d right singular vectors of W_lm define the most important directions in hidden
+space for token prediction. Modulating these directions with theta_u should be more
+effective than random directions.
+
+This doesn't violate "no training" since the basis comes from the frozen model's
+existing weights, not from any user data.
+"""
+
+import torch
+import torch.nn as nn
+
+
+class SVDCVHHead(nn.Module):
+ """CVH with SVD-derived basis from lm_head."""
+
+ def __init__(self, lm_head_weight: torch.Tensor, d: int = 64, alpha: float = 0.1):
+ """
+ Args:
+ lm_head_weight: (vocab_size, H) weight matrix from frozen lm_head
+ d: Number of principal components to use
+ alpha: Scaling factor
+ """
+ super().__init__()
+ self.d = d
+ self.alpha = alpha
+
+ # Compute SVD of lm_head weight: W = U S V^T
+ # V^T rows are the right singular vectors (principal directions in H-space)
+ with torch.no_grad():
+ W = lm_head_weight.float()
+ # Use truncated SVD for efficiency
+ U, S, Vh = torch.linalg.svd(W, full_matrices=False)
+ # Vh: (min(vocab, H), H) - take top d rows
+ # S: (min(vocab, H),) - singular values
+
+ # A: down-projection using top-d right singular vectors
+ # Shape: (d, H) - each row is a right singular vector
+ A = Vh[:d, :] # (d, H)
+
+ # B: up-projection - scale by inverse singular values for conditioning
+ # Shape: (H, d)
+ B = Vh[:d, :].T # (H, d)
+
+ # Store singular values for optional weighting
+ self.register_buffer('S', S[:d].clone())
+
+ self.register_buffer('A', A)
+ self.register_buffer('B', B)
+
+ def forward(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
+ """Apply SVD-based contextual vector head.
+
+ Args:
+ h: Hidden states (batch, H) - float32
+ theta: User vector (d,) - float32
+
+ Returns:
+ h_prime: Modified hidden states
+ """
+ # Project down to d-dim PCA space: (batch, d)
+ projected = (self.A @ h.T).T # (batch, d)
+
+ # Element-wise gating with user vector
+ gated = theta.unsqueeze(0) * projected # (batch, d)
+
+ # Project back up: (batch, H)
+ residual = (self.B @ gated.T).T # (batch, H)
+
+ return h + self.alpha * residual
+
+ def forward_fn(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
+ return self.forward(h, theta)
+
+
+class SVDUncondHead(nn.Module):
+ """Unconditional head with SVD-derived basis."""
+
+ def __init__(self, lm_head_weight: torch.Tensor, d: int = 64, alpha: float = 0.1):
+ super().__init__()
+ self.d = d
+ self.alpha = alpha
+
+ with torch.no_grad():
+ W = lm_head_weight.float()
+ U, S, Vh = torch.linalg.svd(W, full_matrices=False)
+ # Use Vh^T as the up-projection
+ B = Vh[:d, :].T # (H, d)
+
+ self.register_buffer('U_proj', B)
+
+ def forward(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
+ bias = self.U_proj @ theta # (H,)
+ return h + self.alpha * bias.unsqueeze(0)
+
+ def forward_fn(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
+ return self.forward(h, theta)