From 8fe28101366dd32562b8c5534d7fe359b252bdf3 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Fri, 3 Apr 2026 15:12:34 -0500 Subject: Initial commit: UPH project codebase and experiment results Includes model code, evaluation scripts, configs, analysis outputs, and experiment results for the User Prior Head personalization method. Co-Authored-By: Claude Opus 4.6 (1M context) --- models/__init__.py | 0 models/cvh.py | 86 ++++++++++++++++ models/qwen_wrapper.py | 262 +++++++++++++++++++++++++++++++++++++++++++++++++ models/svd_cvh.py | 99 +++++++++++++++++++ 4 files changed, 447 insertions(+) create mode 100644 models/__init__.py create mode 100644 models/cvh.py create mode 100644 models/qwen_wrapper.py create mode 100644 models/svd_cvh.py (limited to 'models') diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/cvh.py b/models/cvh.py new file mode 100644 index 0000000..e669722 --- /dev/null +++ b/models/cvh.py @@ -0,0 +1,86 @@ +"""Contextual Vector Head (CVH): the core personalization module. + +h'_t = h_t + alpha * B(theta_u ⊙ (A @ h_t)) + +Where: +- A ∈ R^{d x H}: fixed random projection (down), scaled by 1/sqrt(H) +- B ∈ R^{H x d}: fixed random projection (up), scaled by 1/sqrt(d) +- theta_u ∈ R^d: per-user vector (the only thing that changes per user) +- alpha: scaling factor +""" + +import torch +import torch.nn as nn + + +class CVHHead(nn.Module): + """Contextual Vector Head for style personalization.""" + + def __init__(self, hidden_size: int, d: int = 64, alpha: float = 0.1, basis_seed: int = 42): + super().__init__() + self.hidden_size = hidden_size + self.d = d + self.alpha = alpha + + gen = torch.Generator() + gen.manual_seed(basis_seed) + + # A: down-projection (d, H) - fan_in = H + scale_a = 1.0 / (hidden_size ** 0.5) + self.register_buffer('A', torch.randn(d, hidden_size, generator=gen) * scale_a) + + # B: up-projection (H, d) - fan_in = d + scale_b = 1.0 / (d ** 0.5) + self.register_buffer('B', torch.randn(hidden_size, d, generator=gen) * scale_b) + + def forward(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor: + """Apply contextual vector head. + + Args: + h: Hidden states (batch, H) or (seq_len, H) - float32 + theta: User vector (d,) - float32 + + Returns: + h_prime: Modified hidden states, same shape as h + """ + # A @ h^T -> (d, batch) then transpose -> (batch, d) + projected = (self.A.float() @ h.T).T # (batch, d) + + # Element-wise gating with user vector + gated = theta.unsqueeze(0) * projected # (batch, d) + + # Project back up, scale by original hidden state magnitude + residual = (self.B.float() @ gated.T).T # (batch, H) + + h_prime = h + self.alpha * residual + return h_prime + + def forward_fn(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor: + return self.forward(h, theta) + + +class UnconditionalHead(nn.Module): + """Unconditional vector head baseline: h'_t = h_t + alpha * U @ theta_u + + No dependence on current hidden state - just adds a fixed user bias. + """ + + def __init__(self, hidden_size: int, d: int = 64, alpha: float = 0.1, basis_seed: int = 42): + super().__init__() + self.hidden_size = hidden_size + self.d = d + self.alpha = alpha + + gen = torch.Generator() + gen.manual_seed(basis_seed + 1000) + scale = 1.0 / (d ** 0.5) # fan_in = d + + self.register_buffer('U', torch.randn(hidden_size, d, generator=gen) * scale) + + def forward(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor: + bias = self.U.float() @ theta # (H,) + h_prime = h + self.alpha * bias.unsqueeze(0) + return h_prime + + def forward_fn(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor: + return self.forward(h, theta) diff --git a/models/qwen_wrapper.py b/models/qwen_wrapper.py new file mode 100644 index 0000000..6a0020f --- /dev/null +++ b/models/qwen_wrapper.py @@ -0,0 +1,262 @@ +"""Wrapper around Qwen2.5-1.5B-Instruct for frozen inference and hidden state extraction.""" + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + + +class QwenWrapper: + """Wraps a frozen Qwen model for hidden state extraction and generation.""" + + def __init__(self, model_name: str = "Qwen/Qwen2.5-1.5B-Instruct", device: str = "cuda:1"): + self.device = device + self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + trust_remote_code=True, + ).to(device) + self.model.eval() + + # Extract lm_head weight for CVH + self.lm_head_weight = self.model.lm_head.weight.data # (vocab_size, H) + self.hidden_size = self.model.config.hidden_size + + @torch.no_grad() + def get_hidden_states_teacher_forced(self, input_text: str, target_text: str): + """Run teacher-forced forward pass and extract final hidden states at target positions. + + Args: + input_text: The prompt/input text + target_text: The target continuation text + + Returns: + h_states: (num_target_tokens, H) tensor of final hidden states + label_ids: (num_target_tokens,) tensor of target token ids + """ + # Tokenize input and target separately to know the boundary + chat_messages = [ + {"role": "system", "content": "You are a helpful writing assistant."}, + {"role": "user", "content": input_text}, + ] + prompt_text = self.tokenizer.apply_chat_template( + chat_messages, tokenize=False, add_generation_prompt=True + ) + full_text = prompt_text + target_text + + prompt_ids = self.tokenizer.encode(prompt_text, return_tensors="pt").to(self.device) + full_ids = self.tokenizer.encode(full_text, return_tensors="pt").to(self.device) + + prompt_len = prompt_ids.shape[1] + total_len = full_ids.shape[1] + + if total_len <= prompt_len: + # Target text was empty or tokenized into nothing + return None, None + + # Forward pass through the full sequence + outputs = self.model( + input_ids=full_ids, + output_hidden_states=True, + return_dict=True, + ) + + # Get the last hidden layer's states + last_hidden = outputs.hidden_states[-1] # (1, seq_len, H) + + # Hidden states at positions [prompt_len-1, ..., total_len-2] predict tokens [prompt_len, ..., total_len-1] + # So for target token at position t, the hidden state is at position t-1 + start_pos = prompt_len - 1 + end_pos = total_len - 1 + + h_states = last_hidden[0, start_pos:end_pos, :].float() # (num_target, H) + label_ids = full_ids[0, prompt_len:total_len] # (num_target,) + + return h_states, label_ids + + @torch.no_grad() + def generate_base(self, input_text: str, max_new_tokens: int = 512, + temperature: float = 0.7, top_p: float = 0.9) -> str: + """Generate text without any personalization.""" + chat_messages = [ + {"role": "system", "content": "You are a helpful writing assistant."}, + {"role": "user", "content": input_text}, + ] + prompt_text = self.tokenizer.apply_chat_template( + chat_messages, tokenize=False, add_generation_prompt=True + ) + input_ids = self.tokenizer.encode(prompt_text, return_tensors="pt").to(self.device) + + outputs = self.model.generate( + input_ids, + max_new_tokens=max_new_tokens, + temperature=temperature if temperature > 0 else None, + top_p=top_p if temperature > 0 else None, + do_sample=temperature > 0, + pad_token_id=self.tokenizer.pad_token_id, + ) + + generated_ids = outputs[0, input_ids.shape[1]:] + return self.tokenizer.decode(generated_ids, skip_special_tokens=True) + + def generate_with_head(self, input_text: str, theta: torch.Tensor, + head_fn, max_new_tokens: int = 512, + temperature: float = 0.7, top_p: float = 0.9, + min_new_tokens: int = 64) -> str: + """Generate text with a personalized head applied at each decoding step. + + Args: + input_text: The query prompt + theta: User vector (d,) + head_fn: Function that takes (h, theta) -> h_prime + max_new_tokens: Max tokens to generate + temperature: Sampling temperature + top_p: Nucleus sampling threshold + min_new_tokens: Suppress EOS until this many tokens generated + """ + chat_messages = [ + {"role": "system", "content": "You are a helpful writing assistant."}, + {"role": "user", "content": input_text}, + ] + prompt_text = self.tokenizer.apply_chat_template( + chat_messages, tokenize=False, add_generation_prompt=True + ) + input_ids = self.tokenizer.encode(prompt_text, return_tensors="pt").to(self.device) + + generated_ids = [] + past_key_values = None + + for step in range(max_new_tokens): + if step == 0: + cur_input = input_ids + else: + cur_input = torch.tensor([[generated_ids[-1]]], device=self.device) + + with torch.no_grad(): + outputs = self.model( + input_ids=cur_input, + past_key_values=past_key_values, + output_hidden_states=True, + use_cache=True, + return_dict=True, + ) + + past_key_values = outputs.past_key_values + + # Get last hidden state of the last token + last_hidden = outputs.hidden_states[-1][:, -1, :] # (1, H) + + # Apply personalized head + h_prime = head_fn(last_hidden.float(), theta) # (1, H) + + # Compute logits through lm_head + logits = torch.nn.functional.linear( + h_prime.to(self.lm_head_weight.dtype), + self.lm_head_weight, + self.model.lm_head.bias if hasattr(self.model.lm_head, 'bias') and self.model.lm_head.bias is not None else None, + ) # (1, vocab_size) + + logits = logits.float() + + # Suppress EOS before min_new_tokens + if step < min_new_tokens and self.tokenizer.eos_token_id is not None: + logits[0, self.tokenizer.eos_token_id] = float('-inf') + + # Apply temperature and top-p sampling + if temperature > 0: + logits = logits / temperature + # Top-p filtering + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) + sorted_mask = cum_probs - torch.softmax(sorted_logits, dim=-1) >= top_p + sorted_logits[sorted_mask] = float('-inf') + # Scatter back + logits = sorted_logits.scatter(1, sorted_indices, sorted_logits) + probs = torch.softmax(logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1).item() + else: + next_token = logits.argmax(dim=-1).item() + + if next_token == self.tokenizer.eos_token_id: + break + + generated_ids.append(next_token) + + return self.tokenizer.decode(generated_ids, skip_special_tokens=True) + + def generate_with_head_blended(self, input_text: str, theta: torch.Tensor, + head_fn, blend_gamma: float = 0.5, + max_new_tokens: int = 512, + min_new_tokens: int = 128, + temperature: float = 0.0) -> str: + """Generate with blended base + CVH logits. + + logits = (1 - gamma) * base_logits + gamma * cvh_logits + """ + chat_messages = [ + {"role": "system", "content": "You are a helpful writing assistant."}, + {"role": "user", "content": input_text}, + ] + prompt_text = self.tokenizer.apply_chat_template( + chat_messages, tokenize=False, add_generation_prompt=True + ) + input_ids = self.tokenizer.encode(prompt_text, return_tensors="pt").to(self.device) + + generated_ids = [] + past_key_values = None + + for step in range(max_new_tokens): + if step == 0: + cur_input = input_ids + else: + cur_input = torch.tensor([[generated_ids[-1]]], device=self.device) + + with torch.no_grad(): + outputs = self.model( + input_ids=cur_input, + past_key_values=past_key_values, + output_hidden_states=True, + use_cache=True, + return_dict=True, + ) + + past_key_values = outputs.past_key_values + last_hidden = outputs.hidden_states[-1][:, -1, :] + + # Base logits + base_logits = torch.nn.functional.linear( + last_hidden.to(self.lm_head_weight.dtype), + self.lm_head_weight, + self.model.lm_head.bias if hasattr(self.model.lm_head, 'bias') and self.model.lm_head.bias is not None else None, + ).float() + + # CVH logits + h_prime = head_fn(last_hidden.float(), theta) + cvh_logits = torch.nn.functional.linear( + h_prime.to(self.lm_head_weight.dtype), + self.lm_head_weight, + self.model.lm_head.bias if hasattr(self.model.lm_head, 'bias') and self.model.lm_head.bias is not None else None, + ).float() + + # Blend + logits = (1 - blend_gamma) * base_logits + blend_gamma * cvh_logits + + # Suppress EOS before min_new_tokens + if step < min_new_tokens and self.tokenizer.eos_token_id is not None: + logits[0, self.tokenizer.eos_token_id] = float('-inf') + + if temperature > 0: + logits = logits / temperature + probs = torch.softmax(logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1).item() + else: + next_token = logits.argmax(dim=-1).item() + + if next_token == self.tokenizer.eos_token_id: + break + + generated_ids.append(next_token) + + return self.tokenizer.decode(generated_ids, skip_special_tokens=True) diff --git a/models/svd_cvh.py b/models/svd_cvh.py new file mode 100644 index 0000000..82e63f0 --- /dev/null +++ b/models/svd_cvh.py @@ -0,0 +1,99 @@ +"""SVD-based CVH: use principal components of lm_head as basis instead of random. + +The key insight: instead of random A and B, use the SVD of the lm_head weight matrix. +The top-d right singular vectors of W_lm define the most important directions in hidden +space for token prediction. Modulating these directions with theta_u should be more +effective than random directions. + +This doesn't violate "no training" since the basis comes from the frozen model's +existing weights, not from any user data. +""" + +import torch +import torch.nn as nn + + +class SVDCVHHead(nn.Module): + """CVH with SVD-derived basis from lm_head.""" + + def __init__(self, lm_head_weight: torch.Tensor, d: int = 64, alpha: float = 0.1): + """ + Args: + lm_head_weight: (vocab_size, H) weight matrix from frozen lm_head + d: Number of principal components to use + alpha: Scaling factor + """ + super().__init__() + self.d = d + self.alpha = alpha + + # Compute SVD of lm_head weight: W = U S V^T + # V^T rows are the right singular vectors (principal directions in H-space) + with torch.no_grad(): + W = lm_head_weight.float() + # Use truncated SVD for efficiency + U, S, Vh = torch.linalg.svd(W, full_matrices=False) + # Vh: (min(vocab, H), H) - take top d rows + # S: (min(vocab, H),) - singular values + + # A: down-projection using top-d right singular vectors + # Shape: (d, H) - each row is a right singular vector + A = Vh[:d, :] # (d, H) + + # B: up-projection - scale by inverse singular values for conditioning + # Shape: (H, d) + B = Vh[:d, :].T # (H, d) + + # Store singular values for optional weighting + self.register_buffer('S', S[:d].clone()) + + self.register_buffer('A', A) + self.register_buffer('B', B) + + def forward(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor: + """Apply SVD-based contextual vector head. + + Args: + h: Hidden states (batch, H) - float32 + theta: User vector (d,) - float32 + + Returns: + h_prime: Modified hidden states + """ + # Project down to d-dim PCA space: (batch, d) + projected = (self.A @ h.T).T # (batch, d) + + # Element-wise gating with user vector + gated = theta.unsqueeze(0) * projected # (batch, d) + + # Project back up: (batch, H) + residual = (self.B @ gated.T).T # (batch, H) + + return h + self.alpha * residual + + def forward_fn(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor: + return self.forward(h, theta) + + +class SVDUncondHead(nn.Module): + """Unconditional head with SVD-derived basis.""" + + def __init__(self, lm_head_weight: torch.Tensor, d: int = 64, alpha: float = 0.1): + super().__init__() + self.d = d + self.alpha = alpha + + with torch.no_grad(): + W = lm_head_weight.float() + U, S, Vh = torch.linalg.svd(W, full_matrices=False) + # Use Vh^T as the up-projection + B = Vh[:d, :].T # (H, d) + + self.register_buffer('U_proj', B) + + def forward(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor: + bias = self.U_proj @ theta # (H,) + return h + self.alpha * bias.unsqueeze(0) + + def forward_fn(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor: + return self.forward(h, theta) -- cgit v1.2.3