summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-03 15:12:34 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-03 15:12:34 -0500
commit8fe28101366dd32562b8c5534d7fe359b252bdf3 (patch)
treec92a92184fb2f46f265ab84c1f754c3d5d6597bc /models
Initial commit: UPH project codebase and experiment results
Includes model code, evaluation scripts, configs, analysis outputs, and experiment results for the User Prior Head personalization method. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'models')
-rw-r--r--models/__init__.py0
-rw-r--r--models/cvh.py86
-rw-r--r--models/qwen_wrapper.py262
-rw-r--r--models/svd_cvh.py99
4 files changed, 447 insertions, 0 deletions
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/models/__init__.py
diff --git a/models/cvh.py b/models/cvh.py
new file mode 100644
index 0000000..e669722
--- /dev/null
+++ b/models/cvh.py
@@ -0,0 +1,86 @@
+"""Contextual Vector Head (CVH): the core personalization module.
+
+h'_t = h_t + alpha * B(theta_u ⊙ (A @ h_t))
+
+Where:
+- A ∈ R^{d x H}: fixed random projection (down), scaled by 1/sqrt(H)
+- B ∈ R^{H x d}: fixed random projection (up), scaled by 1/sqrt(d)
+- theta_u ∈ R^d: per-user vector (the only thing that changes per user)
+- alpha: scaling factor
+"""
+
+import torch
+import torch.nn as nn
+
+
+class CVHHead(nn.Module):
+ """Contextual Vector Head for style personalization."""
+
+ def __init__(self, hidden_size: int, d: int = 64, alpha: float = 0.1, basis_seed: int = 42):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.d = d
+ self.alpha = alpha
+
+ gen = torch.Generator()
+ gen.manual_seed(basis_seed)
+
+ # A: down-projection (d, H) - fan_in = H
+ scale_a = 1.0 / (hidden_size ** 0.5)
+ self.register_buffer('A', torch.randn(d, hidden_size, generator=gen) * scale_a)
+
+ # B: up-projection (H, d) - fan_in = d
+ scale_b = 1.0 / (d ** 0.5)
+ self.register_buffer('B', torch.randn(hidden_size, d, generator=gen) * scale_b)
+
+ def forward(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
+ """Apply contextual vector head.
+
+ Args:
+ h: Hidden states (batch, H) or (seq_len, H) - float32
+ theta: User vector (d,) - float32
+
+ Returns:
+ h_prime: Modified hidden states, same shape as h
+ """
+ # A @ h^T -> (d, batch) then transpose -> (batch, d)
+ projected = (self.A.float() @ h.T).T # (batch, d)
+
+ # Element-wise gating with user vector
+ gated = theta.unsqueeze(0) * projected # (batch, d)
+
+ # Project back up, scale by original hidden state magnitude
+ residual = (self.B.float() @ gated.T).T # (batch, H)
+
+ h_prime = h + self.alpha * residual
+ return h_prime
+
+ def forward_fn(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
+ return self.forward(h, theta)
+
+
+class UnconditionalHead(nn.Module):
+ """Unconditional vector head baseline: h'_t = h_t + alpha * U @ theta_u
+
+ No dependence on current hidden state - just adds a fixed user bias.
+ """
+
+ def __init__(self, hidden_size: int, d: int = 64, alpha: float = 0.1, basis_seed: int = 42):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.d = d
+ self.alpha = alpha
+
+ gen = torch.Generator()
+ gen.manual_seed(basis_seed + 1000)
+ scale = 1.0 / (d ** 0.5) # fan_in = d
+
+ self.register_buffer('U', torch.randn(hidden_size, d, generator=gen) * scale)
+
+ def forward(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
+ bias = self.U.float() @ theta # (H,)
+ h_prime = h + self.alpha * bias.unsqueeze(0)
+ return h_prime
+
+ def forward_fn(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
+ return self.forward(h, theta)
diff --git a/models/qwen_wrapper.py b/models/qwen_wrapper.py
new file mode 100644
index 0000000..6a0020f
--- /dev/null
+++ b/models/qwen_wrapper.py
@@ -0,0 +1,262 @@
+"""Wrapper around Qwen2.5-1.5B-Instruct for frozen inference and hidden state extraction."""
+
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+
+class QwenWrapper:
+ """Wraps a frozen Qwen model for hidden state extraction and generation."""
+
+ def __init__(self, model_name: str = "Qwen/Qwen2.5-1.5B-Instruct", device: str = "cuda:1"):
+ self.device = device
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
+ if self.tokenizer.pad_token is None:
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_name,
+ torch_dtype=torch.bfloat16,
+ trust_remote_code=True,
+ ).to(device)
+ self.model.eval()
+
+ # Extract lm_head weight for CVH
+ self.lm_head_weight = self.model.lm_head.weight.data # (vocab_size, H)
+ self.hidden_size = self.model.config.hidden_size
+
+ @torch.no_grad()
+ def get_hidden_states_teacher_forced(self, input_text: str, target_text: str):
+ """Run teacher-forced forward pass and extract final hidden states at target positions.
+
+ Args:
+ input_text: The prompt/input text
+ target_text: The target continuation text
+
+ Returns:
+ h_states: (num_target_tokens, H) tensor of final hidden states
+ label_ids: (num_target_tokens,) tensor of target token ids
+ """
+ # Tokenize input and target separately to know the boundary
+ chat_messages = [
+ {"role": "system", "content": "You are a helpful writing assistant."},
+ {"role": "user", "content": input_text},
+ ]
+ prompt_text = self.tokenizer.apply_chat_template(
+ chat_messages, tokenize=False, add_generation_prompt=True
+ )
+ full_text = prompt_text + target_text
+
+ prompt_ids = self.tokenizer.encode(prompt_text, return_tensors="pt").to(self.device)
+ full_ids = self.tokenizer.encode(full_text, return_tensors="pt").to(self.device)
+
+ prompt_len = prompt_ids.shape[1]
+ total_len = full_ids.shape[1]
+
+ if total_len <= prompt_len:
+ # Target text was empty or tokenized into nothing
+ return None, None
+
+ # Forward pass through the full sequence
+ outputs = self.model(
+ input_ids=full_ids,
+ output_hidden_states=True,
+ return_dict=True,
+ )
+
+ # Get the last hidden layer's states
+ last_hidden = outputs.hidden_states[-1] # (1, seq_len, H)
+
+ # Hidden states at positions [prompt_len-1, ..., total_len-2] predict tokens [prompt_len, ..., total_len-1]
+ # So for target token at position t, the hidden state is at position t-1
+ start_pos = prompt_len - 1
+ end_pos = total_len - 1
+
+ h_states = last_hidden[0, start_pos:end_pos, :].float() # (num_target, H)
+ label_ids = full_ids[0, prompt_len:total_len] # (num_target,)
+
+ return h_states, label_ids
+
+ @torch.no_grad()
+ def generate_base(self, input_text: str, max_new_tokens: int = 512,
+ temperature: float = 0.7, top_p: float = 0.9) -> str:
+ """Generate text without any personalization."""
+ chat_messages = [
+ {"role": "system", "content": "You are a helpful writing assistant."},
+ {"role": "user", "content": input_text},
+ ]
+ prompt_text = self.tokenizer.apply_chat_template(
+ chat_messages, tokenize=False, add_generation_prompt=True
+ )
+ input_ids = self.tokenizer.encode(prompt_text, return_tensors="pt").to(self.device)
+
+ outputs = self.model.generate(
+ input_ids,
+ max_new_tokens=max_new_tokens,
+ temperature=temperature if temperature > 0 else None,
+ top_p=top_p if temperature > 0 else None,
+ do_sample=temperature > 0,
+ pad_token_id=self.tokenizer.pad_token_id,
+ )
+
+ generated_ids = outputs[0, input_ids.shape[1]:]
+ return self.tokenizer.decode(generated_ids, skip_special_tokens=True)
+
+ def generate_with_head(self, input_text: str, theta: torch.Tensor,
+ head_fn, max_new_tokens: int = 512,
+ temperature: float = 0.7, top_p: float = 0.9,
+ min_new_tokens: int = 64) -> str:
+ """Generate text with a personalized head applied at each decoding step.
+
+ Args:
+ input_text: The query prompt
+ theta: User vector (d,)
+ head_fn: Function that takes (h, theta) -> h_prime
+ max_new_tokens: Max tokens to generate
+ temperature: Sampling temperature
+ top_p: Nucleus sampling threshold
+ min_new_tokens: Suppress EOS until this many tokens generated
+ """
+ chat_messages = [
+ {"role": "system", "content": "You are a helpful writing assistant."},
+ {"role": "user", "content": input_text},
+ ]
+ prompt_text = self.tokenizer.apply_chat_template(
+ chat_messages, tokenize=False, add_generation_prompt=True
+ )
+ input_ids = self.tokenizer.encode(prompt_text, return_tensors="pt").to(self.device)
+
+ generated_ids = []
+ past_key_values = None
+
+ for step in range(max_new_tokens):
+ if step == 0:
+ cur_input = input_ids
+ else:
+ cur_input = torch.tensor([[generated_ids[-1]]], device=self.device)
+
+ with torch.no_grad():
+ outputs = self.model(
+ input_ids=cur_input,
+ past_key_values=past_key_values,
+ output_hidden_states=True,
+ use_cache=True,
+ return_dict=True,
+ )
+
+ past_key_values = outputs.past_key_values
+
+ # Get last hidden state of the last token
+ last_hidden = outputs.hidden_states[-1][:, -1, :] # (1, H)
+
+ # Apply personalized head
+ h_prime = head_fn(last_hidden.float(), theta) # (1, H)
+
+ # Compute logits through lm_head
+ logits = torch.nn.functional.linear(
+ h_prime.to(self.lm_head_weight.dtype),
+ self.lm_head_weight,
+ self.model.lm_head.bias if hasattr(self.model.lm_head, 'bias') and self.model.lm_head.bias is not None else None,
+ ) # (1, vocab_size)
+
+ logits = logits.float()
+
+ # Suppress EOS before min_new_tokens
+ if step < min_new_tokens and self.tokenizer.eos_token_id is not None:
+ logits[0, self.tokenizer.eos_token_id] = float('-inf')
+
+ # Apply temperature and top-p sampling
+ if temperature > 0:
+ logits = logits / temperature
+ # Top-p filtering
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
+ cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
+ sorted_mask = cum_probs - torch.softmax(sorted_logits, dim=-1) >= top_p
+ sorted_logits[sorted_mask] = float('-inf')
+ # Scatter back
+ logits = sorted_logits.scatter(1, sorted_indices, sorted_logits)
+ probs = torch.softmax(logits, dim=-1)
+ next_token = torch.multinomial(probs, num_samples=1).item()
+ else:
+ next_token = logits.argmax(dim=-1).item()
+
+ if next_token == self.tokenizer.eos_token_id:
+ break
+
+ generated_ids.append(next_token)
+
+ return self.tokenizer.decode(generated_ids, skip_special_tokens=True)
+
+ def generate_with_head_blended(self, input_text: str, theta: torch.Tensor,
+ head_fn, blend_gamma: float = 0.5,
+ max_new_tokens: int = 512,
+ min_new_tokens: int = 128,
+ temperature: float = 0.0) -> str:
+ """Generate with blended base + CVH logits.
+
+ logits = (1 - gamma) * base_logits + gamma * cvh_logits
+ """
+ chat_messages = [
+ {"role": "system", "content": "You are a helpful writing assistant."},
+ {"role": "user", "content": input_text},
+ ]
+ prompt_text = self.tokenizer.apply_chat_template(
+ chat_messages, tokenize=False, add_generation_prompt=True
+ )
+ input_ids = self.tokenizer.encode(prompt_text, return_tensors="pt").to(self.device)
+
+ generated_ids = []
+ past_key_values = None
+
+ for step in range(max_new_tokens):
+ if step == 0:
+ cur_input = input_ids
+ else:
+ cur_input = torch.tensor([[generated_ids[-1]]], device=self.device)
+
+ with torch.no_grad():
+ outputs = self.model(
+ input_ids=cur_input,
+ past_key_values=past_key_values,
+ output_hidden_states=True,
+ use_cache=True,
+ return_dict=True,
+ )
+
+ past_key_values = outputs.past_key_values
+ last_hidden = outputs.hidden_states[-1][:, -1, :]
+
+ # Base logits
+ base_logits = torch.nn.functional.linear(
+ last_hidden.to(self.lm_head_weight.dtype),
+ self.lm_head_weight,
+ self.model.lm_head.bias if hasattr(self.model.lm_head, 'bias') and self.model.lm_head.bias is not None else None,
+ ).float()
+
+ # CVH logits
+ h_prime = head_fn(last_hidden.float(), theta)
+ cvh_logits = torch.nn.functional.linear(
+ h_prime.to(self.lm_head_weight.dtype),
+ self.lm_head_weight,
+ self.model.lm_head.bias if hasattr(self.model.lm_head, 'bias') and self.model.lm_head.bias is not None else None,
+ ).float()
+
+ # Blend
+ logits = (1 - blend_gamma) * base_logits + blend_gamma * cvh_logits
+
+ # Suppress EOS before min_new_tokens
+ if step < min_new_tokens and self.tokenizer.eos_token_id is not None:
+ logits[0, self.tokenizer.eos_token_id] = float('-inf')
+
+ if temperature > 0:
+ logits = logits / temperature
+ probs = torch.softmax(logits, dim=-1)
+ next_token = torch.multinomial(probs, num_samples=1).item()
+ else:
+ next_token = logits.argmax(dim=-1).item()
+
+ if next_token == self.tokenizer.eos_token_id:
+ break
+
+ generated_ids.append(next_token)
+
+ return self.tokenizer.decode(generated_ids, skip_special_tokens=True)
diff --git a/models/svd_cvh.py b/models/svd_cvh.py
new file mode 100644
index 0000000..82e63f0
--- /dev/null
+++ b/models/svd_cvh.py
@@ -0,0 +1,99 @@
+"""SVD-based CVH: use principal components of lm_head as basis instead of random.
+
+The key insight: instead of random A and B, use the SVD of the lm_head weight matrix.
+The top-d right singular vectors of W_lm define the most important directions in hidden
+space for token prediction. Modulating these directions with theta_u should be more
+effective than random directions.
+
+This doesn't violate "no training" since the basis comes from the frozen model's
+existing weights, not from any user data.
+"""
+
+import torch
+import torch.nn as nn
+
+
+class SVDCVHHead(nn.Module):
+ """CVH with SVD-derived basis from lm_head."""
+
+ def __init__(self, lm_head_weight: torch.Tensor, d: int = 64, alpha: float = 0.1):
+ """
+ Args:
+ lm_head_weight: (vocab_size, H) weight matrix from frozen lm_head
+ d: Number of principal components to use
+ alpha: Scaling factor
+ """
+ super().__init__()
+ self.d = d
+ self.alpha = alpha
+
+ # Compute SVD of lm_head weight: W = U S V^T
+ # V^T rows are the right singular vectors (principal directions in H-space)
+ with torch.no_grad():
+ W = lm_head_weight.float()
+ # Use truncated SVD for efficiency
+ U, S, Vh = torch.linalg.svd(W, full_matrices=False)
+ # Vh: (min(vocab, H), H) - take top d rows
+ # S: (min(vocab, H),) - singular values
+
+ # A: down-projection using top-d right singular vectors
+ # Shape: (d, H) - each row is a right singular vector
+ A = Vh[:d, :] # (d, H)
+
+ # B: up-projection - scale by inverse singular values for conditioning
+ # Shape: (H, d)
+ B = Vh[:d, :].T # (H, d)
+
+ # Store singular values for optional weighting
+ self.register_buffer('S', S[:d].clone())
+
+ self.register_buffer('A', A)
+ self.register_buffer('B', B)
+
+ def forward(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
+ """Apply SVD-based contextual vector head.
+
+ Args:
+ h: Hidden states (batch, H) - float32
+ theta: User vector (d,) - float32
+
+ Returns:
+ h_prime: Modified hidden states
+ """
+ # Project down to d-dim PCA space: (batch, d)
+ projected = (self.A @ h.T).T # (batch, d)
+
+ # Element-wise gating with user vector
+ gated = theta.unsqueeze(0) * projected # (batch, d)
+
+ # Project back up: (batch, H)
+ residual = (self.B @ gated.T).T # (batch, H)
+
+ return h + self.alpha * residual
+
+ def forward_fn(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
+ return self.forward(h, theta)
+
+
+class SVDUncondHead(nn.Module):
+ """Unconditional head with SVD-derived basis."""
+
+ def __init__(self, lm_head_weight: torch.Tensor, d: int = 64, alpha: float = 0.1):
+ super().__init__()
+ self.d = d
+ self.alpha = alpha
+
+ with torch.no_grad():
+ W = lm_head_weight.float()
+ U, S, Vh = torch.linalg.svd(W, full_matrices=False)
+ # Use Vh^T as the up-projection
+ B = Vh[:d, :].T # (H, d)
+
+ self.register_buffer('U_proj', B)
+
+ def forward(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
+ bias = self.U_proj @ theta # (H,)
+ return h + self.alpha * bias.unsqueeze(0)
+
+ def forward_fn(self, h: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
+ return self.forward(h, theta)