"""SVD-based CVH: use principal components of lm_head as basis instead of random. The key insight: instead of random A and B, use the SVD of the lm_head weight matrix. The top-d right singular vectors of W_lm define the most important directions in hidden space for token prediction. Modulating these directions with theta_u should be more effective than random directions. This doesn't violate "no training" since the basis comes from the frozen model's existing weights, not from any user data. """ import torch import torch.nn as nn class SVDCVHHead(nn.Module): """CVH with SVD-derived basis from lm_head.""" def __init__(self, lm_head_weight: torch.Tensor, d: int = 64, alpha: float = 0.1): """ Args: lm_head_weight: (vocab_size, H) weight matrix from frozen lm_head d: Number of principal components to use alpha: Scaling factor """ super().__init__() self.d = d self.alpha = alpha # Compute SVD of lm_head weight: W = U S V^T # V^T rows are the right singular vectors (principal directions in H-space) with torch.no_grad(): W = lm_head_weight.float() # Use truncated SVD for efficiency U, S, Vh = torch.linalg.svd(W, full_matrices=False) # Vh: (min(vocab, H), H) - take top d rows # S: (min(vocab, H),) - singular values # A: down-projection using top-d right singular vectors # Shape: (d, H) - each row is a right singular vector A = Vh[:d, :] # (d, H) # B: up-projection - scale by inverse singular values for conditioning # Shape: (H, d) B = Vh[:d, :].T # (H, d) # Store singular values for optional weighting self.register_buffer('S', S[:d].clone()) self.register_buffer('A', A) self.register_buffer('B', B) def forward(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor: """Apply SVD-based contextual vector head. Args: h: Hidden states (batch, H) - float32 theta: User vector (d,) - float32 Returns: h_prime: Modified hidden states """ # Project down to d-dim PCA space: (batch, d) projected = (self.A @ h.T).T # (batch, d) # Element-wise gating with user vector gated = theta.unsqueeze(0) * projected # (batch, d) # Project back up: (batch, H) residual = (self.B @ gated.T).T # (batch, H) return h + self.alpha * residual def forward_fn(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor: return self.forward(h, theta) class SVDUncondHead(nn.Module): """Unconditional head with SVD-derived basis.""" def __init__(self, lm_head_weight: torch.Tensor, d: int = 64, alpha: float = 0.1): super().__init__() self.d = d self.alpha = alpha with torch.no_grad(): W = lm_head_weight.float() U, S, Vh = torch.linalg.svd(W, full_matrices=False) # Use Vh^T as the up-projection B = Vh[:d, :].T # (H, d) self.register_buffer('U_proj', B) def forward(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor: bias = self.U_proj @ theta # (H,) return h + self.alpha * bias.unsqueeze(0) def forward_fn(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor: return self.forward(h, theta)