"""Wrapper around Qwen2.5-1.5B-Instruct for frozen inference and hidden state extraction.""" import torch from transformers import AutoModelForCausalLM, AutoTokenizer class QwenWrapper: """Wraps a frozen Qwen model for hidden state extraction and generation.""" def __init__(self, model_name: str = "Qwen/Qwen2.5-1.5B-Instruct", device: str = "cuda:1"): self.device = device self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, trust_remote_code=True, ).to(device) self.model.eval() # Extract lm_head weight for CVH self.lm_head_weight = self.model.lm_head.weight.data # (vocab_size, H) self.hidden_size = self.model.config.hidden_size @torch.no_grad() def get_hidden_states_teacher_forced(self, input_text: str, target_text: str): """Run teacher-forced forward pass and extract final hidden states at target positions. Args: input_text: The prompt/input text target_text: The target continuation text Returns: h_states: (num_target_tokens, H) tensor of final hidden states label_ids: (num_target_tokens,) tensor of target token ids """ # Tokenize input and target separately to know the boundary chat_messages = [ {"role": "system", "content": "You are a helpful writing assistant."}, {"role": "user", "content": input_text}, ] prompt_text = self.tokenizer.apply_chat_template( chat_messages, tokenize=False, add_generation_prompt=True ) full_text = prompt_text + target_text prompt_ids = self.tokenizer.encode(prompt_text, return_tensors="pt").to(self.device) full_ids = self.tokenizer.encode(full_text, return_tensors="pt").to(self.device) prompt_len = prompt_ids.shape[1] total_len = full_ids.shape[1] if total_len <= prompt_len: # Target text was empty or tokenized into nothing return None, None # Forward pass through the full sequence outputs = self.model( input_ids=full_ids, output_hidden_states=True, return_dict=True, ) # Get the last hidden layer's states last_hidden = outputs.hidden_states[-1] # (1, seq_len, H) # Hidden states at positions [prompt_len-1, ..., total_len-2] predict tokens [prompt_len, ..., total_len-1] # So for target token at position t, the hidden state is at position t-1 start_pos = prompt_len - 1 end_pos = total_len - 1 h_states = last_hidden[0, start_pos:end_pos, :].float() # (num_target, H) label_ids = full_ids[0, prompt_len:total_len] # (num_target,) return h_states, label_ids @torch.no_grad() def generate_base(self, input_text: str, max_new_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.9) -> str: """Generate text without any personalization.""" chat_messages = [ {"role": "system", "content": "You are a helpful writing assistant."}, {"role": "user", "content": input_text}, ] prompt_text = self.tokenizer.apply_chat_template( chat_messages, tokenize=False, add_generation_prompt=True ) input_ids = self.tokenizer.encode(prompt_text, return_tensors="pt").to(self.device) outputs = self.model.generate( input_ids, max_new_tokens=max_new_tokens, temperature=temperature if temperature > 0 else None, top_p=top_p if temperature > 0 else None, do_sample=temperature > 0, pad_token_id=self.tokenizer.pad_token_id, ) generated_ids = outputs[0, input_ids.shape[1]:] return self.tokenizer.decode(generated_ids, skip_special_tokens=True) def generate_with_head(self, input_text: str, theta: torch.Tensor, head_fn, max_new_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.9, min_new_tokens: int = 64) -> str: """Generate text with a personalized head applied at each decoding step. Args: input_text: The query prompt theta: User vector (d,) head_fn: Function that takes (h, theta) -> h_prime max_new_tokens: Max tokens to generate temperature: Sampling temperature top_p: Nucleus sampling threshold min_new_tokens: Suppress EOS until this many tokens generated """ chat_messages = [ {"role": "system", "content": "You are a helpful writing assistant."}, {"role": "user", "content": input_text}, ] prompt_text = self.tokenizer.apply_chat_template( chat_messages, tokenize=False, add_generation_prompt=True ) input_ids = self.tokenizer.encode(prompt_text, return_tensors="pt").to(self.device) generated_ids = [] past_key_values = None for step in range(max_new_tokens): if step == 0: cur_input = input_ids else: cur_input = torch.tensor([[generated_ids[-1]]], device=self.device) with torch.no_grad(): outputs = self.model( input_ids=cur_input, past_key_values=past_key_values, output_hidden_states=True, use_cache=True, return_dict=True, ) past_key_values = outputs.past_key_values # Get last hidden state of the last token last_hidden = outputs.hidden_states[-1][:, -1, :] # (1, H) # Apply personalized head h_prime = head_fn(last_hidden.float(), theta) # (1, H) # Compute logits through lm_head logits = torch.nn.functional.linear( h_prime.to(self.lm_head_weight.dtype), self.lm_head_weight, self.model.lm_head.bias if hasattr(self.model.lm_head, 'bias') and self.model.lm_head.bias is not None else None, ) # (1, vocab_size) logits = logits.float() # Suppress EOS before min_new_tokens if step < min_new_tokens and self.tokenizer.eos_token_id is not None: logits[0, self.tokenizer.eos_token_id] = float('-inf') # Apply temperature and top-p sampling if temperature > 0: logits = logits / temperature # Top-p filtering sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) sorted_mask = cum_probs - torch.softmax(sorted_logits, dim=-1) >= top_p sorted_logits[sorted_mask] = float('-inf') # Scatter back logits = sorted_logits.scatter(1, sorted_indices, sorted_logits) probs = torch.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1).item() else: next_token = logits.argmax(dim=-1).item() if next_token == self.tokenizer.eos_token_id: break generated_ids.append(next_token) return self.tokenizer.decode(generated_ids, skip_special_tokens=True) def generate_with_head_blended(self, input_text: str, theta: torch.Tensor, head_fn, blend_gamma: float = 0.5, max_new_tokens: int = 512, min_new_tokens: int = 128, temperature: float = 0.0) -> str: """Generate with blended base + CVH logits. logits = (1 - gamma) * base_logits + gamma * cvh_logits """ chat_messages = [ {"role": "system", "content": "You are a helpful writing assistant."}, {"role": "user", "content": input_text}, ] prompt_text = self.tokenizer.apply_chat_template( chat_messages, tokenize=False, add_generation_prompt=True ) input_ids = self.tokenizer.encode(prompt_text, return_tensors="pt").to(self.device) generated_ids = [] past_key_values = None for step in range(max_new_tokens): if step == 0: cur_input = input_ids else: cur_input = torch.tensor([[generated_ids[-1]]], device=self.device) with torch.no_grad(): outputs = self.model( input_ids=cur_input, past_key_values=past_key_values, output_hidden_states=True, use_cache=True, return_dict=True, ) past_key_values = outputs.past_key_values last_hidden = outputs.hidden_states[-1][:, -1, :] # Base logits base_logits = torch.nn.functional.linear( last_hidden.to(self.lm_head_weight.dtype), self.lm_head_weight, self.model.lm_head.bias if hasattr(self.model.lm_head, 'bias') and self.model.lm_head.bias is not None else None, ).float() # CVH logits h_prime = head_fn(last_hidden.float(), theta) cvh_logits = torch.nn.functional.linear( h_prime.to(self.lm_head_weight.dtype), self.lm_head_weight, self.model.lm_head.bias if hasattr(self.model.lm_head, 'bias') and self.model.lm_head.bias is not None else None, ).float() # Blend logits = (1 - blend_gamma) * base_logits + blend_gamma * cvh_logits # Suppress EOS before min_new_tokens if step < min_new_tokens and self.tokenizer.eos_token_id is not None: logits[0, self.tokenizer.eos_token_id] = float('-inf') if temperature > 0: logits = logits / temperature probs = torch.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1).item() else: next_token = logits.argmax(dim=-1).item() if next_token == self.tokenizer.eos_token_id: break generated_ids.append(next_token) return self.tokenizer.decode(generated_ids, skip_special_tokens=True)