1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
|
"""Wrapper around Qwen2.5-1.5B-Instruct for frozen inference and hidden state extraction."""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
class QwenWrapper:
"""Wraps a frozen Qwen model for hidden state extraction and generation."""
def __init__(self, model_name: str = "Qwen/Qwen2.5-1.5B-Instruct", device: str = "cuda:1"):
self.device = device
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
).to(device)
self.model.eval()
# Extract lm_head weight for CVH
self.lm_head_weight = self.model.lm_head.weight.data # (vocab_size, H)
self.hidden_size = self.model.config.hidden_size
@torch.no_grad()
def get_hidden_states_teacher_forced(self, input_text: str, target_text: str):
"""Run teacher-forced forward pass and extract final hidden states at target positions.
Args:
input_text: The prompt/input text
target_text: The target continuation text
Returns:
h_states: (num_target_tokens, H) tensor of final hidden states
label_ids: (num_target_tokens,) tensor of target token ids
"""
# Tokenize input and target separately to know the boundary
chat_messages = [
{"role": "system", "content": "You are a helpful writing assistant."},
{"role": "user", "content": input_text},
]
prompt_text = self.tokenizer.apply_chat_template(
chat_messages, tokenize=False, add_generation_prompt=True
)
full_text = prompt_text + target_text
prompt_ids = self.tokenizer.encode(prompt_text, return_tensors="pt").to(self.device)
full_ids = self.tokenizer.encode(full_text, return_tensors="pt").to(self.device)
prompt_len = prompt_ids.shape[1]
total_len = full_ids.shape[1]
if total_len <= prompt_len:
# Target text was empty or tokenized into nothing
return None, None
# Forward pass through the full sequence
outputs = self.model(
input_ids=full_ids,
output_hidden_states=True,
return_dict=True,
)
# Get the last hidden layer's states
last_hidden = outputs.hidden_states[-1] # (1, seq_len, H)
# Hidden states at positions [prompt_len-1, ..., total_len-2] predict tokens [prompt_len, ..., total_len-1]
# So for target token at position t, the hidden state is at position t-1
start_pos = prompt_len - 1
end_pos = total_len - 1
h_states = last_hidden[0, start_pos:end_pos, :].float() # (num_target, H)
label_ids = full_ids[0, prompt_len:total_len] # (num_target,)
return h_states, label_ids
@torch.no_grad()
def generate_base(self, input_text: str, max_new_tokens: int = 512,
temperature: float = 0.7, top_p: float = 0.9) -> str:
"""Generate text without any personalization."""
chat_messages = [
{"role": "system", "content": "You are a helpful writing assistant."},
{"role": "user", "content": input_text},
]
prompt_text = self.tokenizer.apply_chat_template(
chat_messages, tokenize=False, add_generation_prompt=True
)
input_ids = self.tokenizer.encode(prompt_text, return_tensors="pt").to(self.device)
outputs = self.model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature if temperature > 0 else None,
top_p=top_p if temperature > 0 else None,
do_sample=temperature > 0,
pad_token_id=self.tokenizer.pad_token_id,
)
generated_ids = outputs[0, input_ids.shape[1]:]
return self.tokenizer.decode(generated_ids, skip_special_tokens=True)
def generate_with_head(self, input_text: str, theta: torch.Tensor,
head_fn, max_new_tokens: int = 512,
temperature: float = 0.7, top_p: float = 0.9,
min_new_tokens: int = 64) -> str:
"""Generate text with a personalized head applied at each decoding step.
Args:
input_text: The query prompt
theta: User vector (d,)
head_fn: Function that takes (h, theta) -> h_prime
max_new_tokens: Max tokens to generate
temperature: Sampling temperature
top_p: Nucleus sampling threshold
min_new_tokens: Suppress EOS until this many tokens generated
"""
chat_messages = [
{"role": "system", "content": "You are a helpful writing assistant."},
{"role": "user", "content": input_text},
]
prompt_text = self.tokenizer.apply_chat_template(
chat_messages, tokenize=False, add_generation_prompt=True
)
input_ids = self.tokenizer.encode(prompt_text, return_tensors="pt").to(self.device)
generated_ids = []
past_key_values = None
for step in range(max_new_tokens):
if step == 0:
cur_input = input_ids
else:
cur_input = torch.tensor([[generated_ids[-1]]], device=self.device)
with torch.no_grad():
outputs = self.model(
input_ids=cur_input,
past_key_values=past_key_values,
output_hidden_states=True,
use_cache=True,
return_dict=True,
)
past_key_values = outputs.past_key_values
# Get last hidden state of the last token
last_hidden = outputs.hidden_states[-1][:, -1, :] # (1, H)
# Apply personalized head
h_prime = head_fn(last_hidden.float(), theta) # (1, H)
# Compute logits through lm_head
logits = torch.nn.functional.linear(
h_prime.to(self.lm_head_weight.dtype),
self.lm_head_weight,
self.model.lm_head.bias if hasattr(self.model.lm_head, 'bias') and self.model.lm_head.bias is not None else None,
) # (1, vocab_size)
logits = logits.float()
# Suppress EOS before min_new_tokens
if step < min_new_tokens and self.tokenizer.eos_token_id is not None:
logits[0, self.tokenizer.eos_token_id] = float('-inf')
# Apply temperature and top-p sampling
if temperature > 0:
logits = logits / temperature
# Top-p filtering
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_mask = cum_probs - torch.softmax(sorted_logits, dim=-1) >= top_p
sorted_logits[sorted_mask] = float('-inf')
# Scatter back
logits = sorted_logits.scatter(1, sorted_indices, sorted_logits)
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).item()
else:
next_token = logits.argmax(dim=-1).item()
if next_token == self.tokenizer.eos_token_id:
break
generated_ids.append(next_token)
return self.tokenizer.decode(generated_ids, skip_special_tokens=True)
def generate_with_head_blended(self, input_text: str, theta: torch.Tensor,
head_fn, blend_gamma: float = 0.5,
max_new_tokens: int = 512,
min_new_tokens: int = 128,
temperature: float = 0.0) -> str:
"""Generate with blended base + CVH logits.
logits = (1 - gamma) * base_logits + gamma * cvh_logits
"""
chat_messages = [
{"role": "system", "content": "You are a helpful writing assistant."},
{"role": "user", "content": input_text},
]
prompt_text = self.tokenizer.apply_chat_template(
chat_messages, tokenize=False, add_generation_prompt=True
)
input_ids = self.tokenizer.encode(prompt_text, return_tensors="pt").to(self.device)
generated_ids = []
past_key_values = None
for step in range(max_new_tokens):
if step == 0:
cur_input = input_ids
else:
cur_input = torch.tensor([[generated_ids[-1]]], device=self.device)
with torch.no_grad():
outputs = self.model(
input_ids=cur_input,
past_key_values=past_key_values,
output_hidden_states=True,
use_cache=True,
return_dict=True,
)
past_key_values = outputs.past_key_values
last_hidden = outputs.hidden_states[-1][:, -1, :]
# Base logits
base_logits = torch.nn.functional.linear(
last_hidden.to(self.lm_head_weight.dtype),
self.lm_head_weight,
self.model.lm_head.bias if hasattr(self.model.lm_head, 'bias') and self.model.lm_head.bias is not None else None,
).float()
# CVH logits
h_prime = head_fn(last_hidden.float(), theta)
cvh_logits = torch.nn.functional.linear(
h_prime.to(self.lm_head_weight.dtype),
self.lm_head_weight,
self.model.lm_head.bias if hasattr(self.model.lm_head, 'bias') and self.model.lm_head.bias is not None else None,
).float()
# Blend
logits = (1 - blend_gamma) * base_logits + blend_gamma * cvh_logits
# Suppress EOS before min_new_tokens
if step < min_new_tokens and self.tokenizer.eos_token_id is not None:
logits[0, self.tokenizer.eos_token_id] = float('-inf')
if temperature > 0:
logits = logits / temperature
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).item()
else:
next_token = logits.argmax(dim=-1).item()
if next_token == self.tokenizer.eos_token_id:
break
generated_ids.append(next_token)
return self.tokenizer.decode(generated_ids, skip_special_tokens=True)
|