1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
|
"""PEFT baselines: LoRA, Tiny LoRA, and VeRA.
Per-user adaptation on K support examples, then standard generation.
Uses a class-based API to avoid repeated model wrapping/unwrapping.
Usage:
baseline = PEFTBaseline(wrapper, get_lora_config(rank=8))
for user in users:
text = baseline.adapt_and_generate(support, query, task)
baseline.cleanup() # restore frozen model
"""
import torch
from peft import LoraConfig, VeraConfig, get_peft_model, TaskType
TARGET_MODULES = ["q_proj", "v_proj"]
def _make_lora_config(rank, target_modules=None, lora_alpha=None):
if target_modules is None:
target_modules = TARGET_MODULES
if lora_alpha is None:
lora_alpha = 2 * rank
return LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=rank,
lora_alpha=lora_alpha,
lora_dropout=0.0,
target_modules=target_modules,
bias="none",
)
def _make_vera_config(rank, target_modules=None):
if target_modules is None:
target_modules = TARGET_MODULES
return VeraConfig(
task_type=TaskType.CAUSAL_LM,
r=rank,
target_modules=target_modules,
vera_dropout=0.0,
)
def get_lora_config(rank=8):
return _make_lora_config(rank=rank)
def get_tiny_lora_config(rank=1):
return _make_lora_config(rank=rank)
def get_vera_config(rank=256):
return _make_vera_config(rank=rank)
class PEFTBaseline:
"""Manages a PEFT-wrapped model for repeated per-user adaptation."""
def __init__(self, wrapper, peft_config):
self.wrapper = wrapper
self.device = wrapper.device
self.peft_model = get_peft_model(wrapper.model, peft_config)
self.n_params = sum(p.numel() for p in self.peft_model.parameters() if p.requires_grad)
self.n_bytes = self.n_params * 2 # bf16
# Save initial adapter state for reset between users
self._init_state = {
name: param.data.clone()
for name, param in self.peft_model.named_parameters()
if param.requires_grad
}
def _reset_adapter(self):
"""Reset adapter weights to initial state (zeros for LoRA)."""
for name, param in self.peft_model.named_parameters():
if param.requires_grad and name in self._init_state:
param.data.copy_(self._init_state[name])
def _build_training_data(self, support_items, task):
"""Build (input_ids, labels) pairs from support items."""
from data.templates import build_support_prompt
data = []
for item in support_items:
input_text = build_support_prompt(item['support_input'], task)
target_text = " " + item['support_output']
chat_messages = [
{"role": "system", "content": "You are a helpful writing assistant."},
{"role": "user", "content": input_text},
]
prompt_text = self.wrapper.tokenizer.apply_chat_template(
chat_messages, tokenize=False, add_generation_prompt=True
)
full_text = prompt_text + target_text
prompt_ids = self.wrapper.tokenizer.encode(prompt_text, return_tensors="pt")
full_ids = self.wrapper.tokenizer.encode(full_text, return_tensors="pt")
labels = full_ids.clone()
labels[0, :prompt_ids.shape[1]] = -100
data.append((full_ids.to(self.device), labels.to(self.device)))
return data
def adapt_and_generate(
self,
support_items,
query_input,
task,
lr=1e-4,
steps=30,
max_new_tokens=512,
min_new_tokens=128,
verbose=False,
):
"""Reset adapter, fine-tune on support set, generate, return text."""
self._reset_adapter()
# Build training data
train_data = self._build_training_data(support_items, task)
if not train_data:
return self._generate_fallback(query_input, task, max_new_tokens, min_new_tokens)
# Fine-tune
trainable = [p for p in self.peft_model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(trainable, lr=lr)
self.peft_model.train()
for step in range(steps):
optimizer.zero_grad()
total_loss = 0.0
for input_ids, labels in train_data:
outputs = self.peft_model(input_ids=input_ids, labels=labels)
(outputs.loss / len(train_data)).backward()
total_loss += outputs.loss.item()
torch.nn.utils.clip_grad_norm_(trainable, 1.0)
optimizer.step()
if verbose and (step % 10 == 0 or step == steps - 1):
print(f" Step {step:3d}: loss={total_loss/len(train_data):.4f}")
# Generate
self.peft_model.eval()
generated = self._generate(query_input, task, max_new_tokens, min_new_tokens)
del optimizer
torch.cuda.empty_cache()
return generated
def _generate(self, query_input, task, max_new_tokens, min_new_tokens):
from data.templates import build_query_prompt
prompt = build_query_prompt(query_input, task)
chat_messages = [
{"role": "system", "content": "You are a helpful writing assistant."},
{"role": "user", "content": prompt},
]
prompt_text = self.wrapper.tokenizer.apply_chat_template(
chat_messages, tokenize=False, add_generation_prompt=True
)
input_ids = self.wrapper.tokenizer.encode(
prompt_text, return_tensors="pt"
).to(self.device)
with torch.no_grad():
outputs = self.peft_model.generate(
input_ids,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
temperature=None,
top_p=None,
do_sample=False,
pad_token_id=self.wrapper.tokenizer.pad_token_id,
)
generated_ids = outputs[0, input_ids.shape[1]:]
return self.wrapper.tokenizer.decode(generated_ids, skip_special_tokens=True)
def _generate_fallback(self, query_input, task, max_new_tokens, min_new_tokens):
"""Fallback: generate without adaptation (empty support set)."""
self.peft_model.eval()
return self._generate(query_input, task, max_new_tokens, min_new_tokens)
def cleanup(self):
"""Remove adapter and restore wrapper.model to the original base model."""
base_model = self.peft_model.unload()
self.wrapper.model = base_model
del self.peft_model
torch.cuda.empty_cache()
|