summaryrefslogtreecommitdiff
path: root/models/qwen_wrapper.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/qwen_wrapper.py')
-rw-r--r--models/qwen_wrapper.py262
1 files changed, 262 insertions, 0 deletions
diff --git a/models/qwen_wrapper.py b/models/qwen_wrapper.py
new file mode 100644
index 0000000..6a0020f
--- /dev/null
+++ b/models/qwen_wrapper.py
@@ -0,0 +1,262 @@
+"""Wrapper around Qwen2.5-1.5B-Instruct for frozen inference and hidden state extraction."""
+
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+
+class QwenWrapper:
+ """Wraps a frozen Qwen model for hidden state extraction and generation."""
+
+ def __init__(self, model_name: str = "Qwen/Qwen2.5-1.5B-Instruct", device: str = "cuda:1"):
+ self.device = device
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
+ if self.tokenizer.pad_token is None:
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_name,
+ torch_dtype=torch.bfloat16,
+ trust_remote_code=True,
+ ).to(device)
+ self.model.eval()
+
+ # Extract lm_head weight for CVH
+ self.lm_head_weight = self.model.lm_head.weight.data # (vocab_size, H)
+ self.hidden_size = self.model.config.hidden_size
+
+ @torch.no_grad()
+ def get_hidden_states_teacher_forced(self, input_text: str, target_text: str):
+ """Run teacher-forced forward pass and extract final hidden states at target positions.
+
+ Args:
+ input_text: The prompt/input text
+ target_text: The target continuation text
+
+ Returns:
+ h_states: (num_target_tokens, H) tensor of final hidden states
+ label_ids: (num_target_tokens,) tensor of target token ids
+ """
+ # Tokenize input and target separately to know the boundary
+ chat_messages = [
+ {"role": "system", "content": "You are a helpful writing assistant."},
+ {"role": "user", "content": input_text},
+ ]
+ prompt_text = self.tokenizer.apply_chat_template(
+ chat_messages, tokenize=False, add_generation_prompt=True
+ )
+ full_text = prompt_text + target_text
+
+ prompt_ids = self.tokenizer.encode(prompt_text, return_tensors="pt").to(self.device)
+ full_ids = self.tokenizer.encode(full_text, return_tensors="pt").to(self.device)
+
+ prompt_len = prompt_ids.shape[1]
+ total_len = full_ids.shape[1]
+
+ if total_len <= prompt_len:
+ # Target text was empty or tokenized into nothing
+ return None, None
+
+ # Forward pass through the full sequence
+ outputs = self.model(
+ input_ids=full_ids,
+ output_hidden_states=True,
+ return_dict=True,
+ )
+
+ # Get the last hidden layer's states
+ last_hidden = outputs.hidden_states[-1] # (1, seq_len, H)
+
+ # Hidden states at positions [prompt_len-1, ..., total_len-2] predict tokens [prompt_len, ..., total_len-1]
+ # So for target token at position t, the hidden state is at position t-1
+ start_pos = prompt_len - 1
+ end_pos = total_len - 1
+
+ h_states = last_hidden[0, start_pos:end_pos, :].float() # (num_target, H)
+ label_ids = full_ids[0, prompt_len:total_len] # (num_target,)
+
+ return h_states, label_ids
+
+ @torch.no_grad()
+ def generate_base(self, input_text: str, max_new_tokens: int = 512,
+ temperature: float = 0.7, top_p: float = 0.9) -> str:
+ """Generate text without any personalization."""
+ chat_messages = [
+ {"role": "system", "content": "You are a helpful writing assistant."},
+ {"role": "user", "content": input_text},
+ ]
+ prompt_text = self.tokenizer.apply_chat_template(
+ chat_messages, tokenize=False, add_generation_prompt=True
+ )
+ input_ids = self.tokenizer.encode(prompt_text, return_tensors="pt").to(self.device)
+
+ outputs = self.model.generate(
+ input_ids,
+ max_new_tokens=max_new_tokens,
+ temperature=temperature if temperature > 0 else None,
+ top_p=top_p if temperature > 0 else None,
+ do_sample=temperature > 0,
+ pad_token_id=self.tokenizer.pad_token_id,
+ )
+
+ generated_ids = outputs[0, input_ids.shape[1]:]
+ return self.tokenizer.decode(generated_ids, skip_special_tokens=True)
+
+ def generate_with_head(self, input_text: str, theta: torch.Tensor,
+ head_fn, max_new_tokens: int = 512,
+ temperature: float = 0.7, top_p: float = 0.9,
+ min_new_tokens: int = 64) -> str:
+ """Generate text with a personalized head applied at each decoding step.
+
+ Args:
+ input_text: The query prompt
+ theta: User vector (d,)
+ head_fn: Function that takes (h, theta) -> h_prime
+ max_new_tokens: Max tokens to generate
+ temperature: Sampling temperature
+ top_p: Nucleus sampling threshold
+ min_new_tokens: Suppress EOS until this many tokens generated
+ """
+ chat_messages = [
+ {"role": "system", "content": "You are a helpful writing assistant."},
+ {"role": "user", "content": input_text},
+ ]
+ prompt_text = self.tokenizer.apply_chat_template(
+ chat_messages, tokenize=False, add_generation_prompt=True
+ )
+ input_ids = self.tokenizer.encode(prompt_text, return_tensors="pt").to(self.device)
+
+ generated_ids = []
+ past_key_values = None
+
+ for step in range(max_new_tokens):
+ if step == 0:
+ cur_input = input_ids
+ else:
+ cur_input = torch.tensor([[generated_ids[-1]]], device=self.device)
+
+ with torch.no_grad():
+ outputs = self.model(
+ input_ids=cur_input,
+ past_key_values=past_key_values,
+ output_hidden_states=True,
+ use_cache=True,
+ return_dict=True,
+ )
+
+ past_key_values = outputs.past_key_values
+
+ # Get last hidden state of the last token
+ last_hidden = outputs.hidden_states[-1][:, -1, :] # (1, H)
+
+ # Apply personalized head
+ h_prime = head_fn(last_hidden.float(), theta) # (1, H)
+
+ # Compute logits through lm_head
+ logits = torch.nn.functional.linear(
+ h_prime.to(self.lm_head_weight.dtype),
+ self.lm_head_weight,
+ self.model.lm_head.bias if hasattr(self.model.lm_head, 'bias') and self.model.lm_head.bias is not None else None,
+ ) # (1, vocab_size)
+
+ logits = logits.float()
+
+ # Suppress EOS before min_new_tokens
+ if step < min_new_tokens and self.tokenizer.eos_token_id is not None:
+ logits[0, self.tokenizer.eos_token_id] = float('-inf')
+
+ # Apply temperature and top-p sampling
+ if temperature > 0:
+ logits = logits / temperature
+ # Top-p filtering
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
+ cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
+ sorted_mask = cum_probs - torch.softmax(sorted_logits, dim=-1) >= top_p
+ sorted_logits[sorted_mask] = float('-inf')
+ # Scatter back
+ logits = sorted_logits.scatter(1, sorted_indices, sorted_logits)
+ probs = torch.softmax(logits, dim=-1)
+ next_token = torch.multinomial(probs, num_samples=1).item()
+ else:
+ next_token = logits.argmax(dim=-1).item()
+
+ if next_token == self.tokenizer.eos_token_id:
+ break
+
+ generated_ids.append(next_token)
+
+ return self.tokenizer.decode(generated_ids, skip_special_tokens=True)
+
+ def generate_with_head_blended(self, input_text: str, theta: torch.Tensor,
+ head_fn, blend_gamma: float = 0.5,
+ max_new_tokens: int = 512,
+ min_new_tokens: int = 128,
+ temperature: float = 0.0) -> str:
+ """Generate with blended base + CVH logits.
+
+ logits = (1 - gamma) * base_logits + gamma * cvh_logits
+ """
+ chat_messages = [
+ {"role": "system", "content": "You are a helpful writing assistant."},
+ {"role": "user", "content": input_text},
+ ]
+ prompt_text = self.tokenizer.apply_chat_template(
+ chat_messages, tokenize=False, add_generation_prompt=True
+ )
+ input_ids = self.tokenizer.encode(prompt_text, return_tensors="pt").to(self.device)
+
+ generated_ids = []
+ past_key_values = None
+
+ for step in range(max_new_tokens):
+ if step == 0:
+ cur_input = input_ids
+ else:
+ cur_input = torch.tensor([[generated_ids[-1]]], device=self.device)
+
+ with torch.no_grad():
+ outputs = self.model(
+ input_ids=cur_input,
+ past_key_values=past_key_values,
+ output_hidden_states=True,
+ use_cache=True,
+ return_dict=True,
+ )
+
+ past_key_values = outputs.past_key_values
+ last_hidden = outputs.hidden_states[-1][:, -1, :]
+
+ # Base logits
+ base_logits = torch.nn.functional.linear(
+ last_hidden.to(self.lm_head_weight.dtype),
+ self.lm_head_weight,
+ self.model.lm_head.bias if hasattr(self.model.lm_head, 'bias') and self.model.lm_head.bias is not None else None,
+ ).float()
+
+ # CVH logits
+ h_prime = head_fn(last_hidden.float(), theta)
+ cvh_logits = torch.nn.functional.linear(
+ h_prime.to(self.lm_head_weight.dtype),
+ self.lm_head_weight,
+ self.model.lm_head.bias if hasattr(self.model.lm_head, 'bias') and self.model.lm_head.bias is not None else None,
+ ).float()
+
+ # Blend
+ logits = (1 - blend_gamma) * base_logits + blend_gamma * cvh_logits
+
+ # Suppress EOS before min_new_tokens
+ if step < min_new_tokens and self.tokenizer.eos_token_id is not None:
+ logits[0, self.tokenizer.eos_token_id] = float('-inf')
+
+ if temperature > 0:
+ logits = logits / temperature
+ probs = torch.softmax(logits, dim=-1)
+ next_token = torch.multinomial(probs, num_samples=1).item()
+ else:
+ next_token = logits.argmax(dim=-1).item()
+
+ if next_token == self.tokenizer.eos_token_id:
+ break
+
+ generated_ids.append(next_token)
+
+ return self.tokenizer.decode(generated_ids, skip_special_tokens=True)