diff options
Diffstat (limited to 'baselines/peft_baseline.py')
| -rw-r--r-- | baselines/peft_baseline.py | 194 |
1 files changed, 194 insertions, 0 deletions
diff --git a/baselines/peft_baseline.py b/baselines/peft_baseline.py new file mode 100644 index 0000000..442ba60 --- /dev/null +++ b/baselines/peft_baseline.py @@ -0,0 +1,194 @@ +"""PEFT baselines: LoRA, Tiny LoRA, and VeRA. + +Per-user adaptation on K support examples, then standard generation. +Uses a class-based API to avoid repeated model wrapping/unwrapping. + +Usage: + baseline = PEFTBaseline(wrapper, get_lora_config(rank=8)) + for user in users: + text = baseline.adapt_and_generate(support, query, task) + baseline.cleanup() # restore frozen model +""" + +import torch +from peft import LoraConfig, VeraConfig, get_peft_model, TaskType + + +TARGET_MODULES = ["q_proj", "v_proj"] + + +def _make_lora_config(rank, target_modules=None, lora_alpha=None): + if target_modules is None: + target_modules = TARGET_MODULES + if lora_alpha is None: + lora_alpha = 2 * rank + return LoraConfig( + task_type=TaskType.CAUSAL_LM, + r=rank, + lora_alpha=lora_alpha, + lora_dropout=0.0, + target_modules=target_modules, + bias="none", + ) + + +def _make_vera_config(rank, target_modules=None): + if target_modules is None: + target_modules = TARGET_MODULES + return VeraConfig( + task_type=TaskType.CAUSAL_LM, + r=rank, + target_modules=target_modules, + vera_dropout=0.0, + ) + + +def get_lora_config(rank=8): + return _make_lora_config(rank=rank) + + +def get_tiny_lora_config(rank=1): + return _make_lora_config(rank=rank) + + +def get_vera_config(rank=256): + return _make_vera_config(rank=rank) + + +class PEFTBaseline: + """Manages a PEFT-wrapped model for repeated per-user adaptation.""" + + def __init__(self, wrapper, peft_config): + self.wrapper = wrapper + self.device = wrapper.device + self.peft_model = get_peft_model(wrapper.model, peft_config) + + self.n_params = sum(p.numel() for p in self.peft_model.parameters() if p.requires_grad) + self.n_bytes = self.n_params * 2 # bf16 + + # Save initial adapter state for reset between users + self._init_state = { + name: param.data.clone() + for name, param in self.peft_model.named_parameters() + if param.requires_grad + } + + def _reset_adapter(self): + """Reset adapter weights to initial state (zeros for LoRA).""" + for name, param in self.peft_model.named_parameters(): + if param.requires_grad and name in self._init_state: + param.data.copy_(self._init_state[name]) + + def _build_training_data(self, support_items, task): + """Build (input_ids, labels) pairs from support items.""" + from data.templates import build_support_prompt + + data = [] + for item in support_items: + input_text = build_support_prompt(item['support_input'], task) + target_text = " " + item['support_output'] + + chat_messages = [ + {"role": "system", "content": "You are a helpful writing assistant."}, + {"role": "user", "content": input_text}, + ] + prompt_text = self.wrapper.tokenizer.apply_chat_template( + chat_messages, tokenize=False, add_generation_prompt=True + ) + full_text = prompt_text + target_text + + prompt_ids = self.wrapper.tokenizer.encode(prompt_text, return_tensors="pt") + full_ids = self.wrapper.tokenizer.encode(full_text, return_tensors="pt") + + labels = full_ids.clone() + labels[0, :prompt_ids.shape[1]] = -100 + + data.append((full_ids.to(self.device), labels.to(self.device))) + return data + + def adapt_and_generate( + self, + support_items, + query_input, + task, + lr=1e-4, + steps=30, + max_new_tokens=512, + min_new_tokens=128, + verbose=False, + ): + """Reset adapter, fine-tune on support set, generate, return text.""" + self._reset_adapter() + + # Build training data + train_data = self._build_training_data(support_items, task) + if not train_data: + return self._generate_fallback(query_input, task, max_new_tokens, min_new_tokens) + + # Fine-tune + trainable = [p for p in self.peft_model.parameters() if p.requires_grad] + optimizer = torch.optim.AdamW(trainable, lr=lr) + + self.peft_model.train() + for step in range(steps): + optimizer.zero_grad() + total_loss = 0.0 + + for input_ids, labels in train_data: + outputs = self.peft_model(input_ids=input_ids, labels=labels) + (outputs.loss / len(train_data)).backward() + total_loss += outputs.loss.item() + + torch.nn.utils.clip_grad_norm_(trainable, 1.0) + optimizer.step() + + if verbose and (step % 10 == 0 or step == steps - 1): + print(f" Step {step:3d}: loss={total_loss/len(train_data):.4f}") + + # Generate + self.peft_model.eval() + generated = self._generate(query_input, task, max_new_tokens, min_new_tokens) + + del optimizer + torch.cuda.empty_cache() + return generated + + def _generate(self, query_input, task, max_new_tokens, min_new_tokens): + from data.templates import build_query_prompt + prompt = build_query_prompt(query_input, task) + + chat_messages = [ + {"role": "system", "content": "You are a helpful writing assistant."}, + {"role": "user", "content": prompt}, + ] + prompt_text = self.wrapper.tokenizer.apply_chat_template( + chat_messages, tokenize=False, add_generation_prompt=True + ) + input_ids = self.wrapper.tokenizer.encode( + prompt_text, return_tensors="pt" + ).to(self.device) + + with torch.no_grad(): + outputs = self.peft_model.generate( + input_ids, + max_new_tokens=max_new_tokens, + min_new_tokens=min_new_tokens, + temperature=None, + top_p=None, + do_sample=False, + pad_token_id=self.wrapper.tokenizer.pad_token_id, + ) + generated_ids = outputs[0, input_ids.shape[1]:] + return self.wrapper.tokenizer.decode(generated_ids, skip_special_tokens=True) + + def _generate_fallback(self, query_input, task, max_new_tokens, min_new_tokens): + """Fallback: generate without adaptation (empty support set).""" + self.peft_model.eval() + return self._generate(query_input, task, max_new_tokens, min_new_tokens) + + def cleanup(self): + """Remove adapter and restore wrapper.model to the original base model.""" + base_model = self.peft_model.unload() + self.wrapper.model = base_model + del self.peft_model + torch.cuda.empty_cache() |
