"""PEFT baselines: LoRA, Tiny LoRA, VeRA, Prompt Tuning, Prefix Tuning. Per-user adaptation on K support examples, then standard generation. Uses a class-based API to avoid repeated model wrapping/unwrapping. Usage: baseline = PEFTBaseline(wrapper, get_lora_config(rank=8)) for user in users: text = baseline.adapt_and_generate(support, query, task) baseline.cleanup() # restore frozen model """ import torch from peft import ( LoraConfig, VeraConfig, PromptTuningConfig, PrefixTuningConfig, get_peft_model, TaskType, PromptTuningInit, ) TARGET_MODULES = ["q_proj", "v_proj"] def _make_lora_config(rank, target_modules=None, lora_alpha=None): if target_modules is None: target_modules = TARGET_MODULES if lora_alpha is None: lora_alpha = 2 * rank return LoraConfig( task_type=TaskType.CAUSAL_LM, r=rank, lora_alpha=lora_alpha, lora_dropout=0.0, target_modules=target_modules, bias="none", ) def _make_vera_config(rank, target_modules=None): if target_modules is None: target_modules = TARGET_MODULES return VeraConfig( task_type=TaskType.CAUSAL_LM, r=rank, target_modules=target_modules, vera_dropout=0.0, ) def _make_prompt_tuning_config(num_virtual_tokens): return PromptTuningConfig( task_type=TaskType.CAUSAL_LM, num_virtual_tokens=num_virtual_tokens, prompt_tuning_init=PromptTuningInit.RANDOM, ) def _make_prefix_tuning_config(num_virtual_tokens): return PrefixTuningConfig( task_type=TaskType.CAUSAL_LM, num_virtual_tokens=num_virtual_tokens, ) def get_lora_config(rank=8): return _make_lora_config(rank=rank) def get_tiny_lora_config(rank=1): return _make_lora_config(rank=rank) def get_vera_config(rank=256): return _make_vera_config(rank=rank) def get_prompt_tuning_config(num_tokens=10): return _make_prompt_tuning_config(num_virtual_tokens=num_tokens) def get_prefix_tuning_config(num_tokens=10): return _make_prefix_tuning_config(num_virtual_tokens=num_tokens) class PEFTBaseline: """Manages a PEFT-wrapped model for repeated per-user adaptation.""" def __init__(self, wrapper, peft_config): self.wrapper = wrapper self.device = wrapper.device # Save reference to original model BEFORE wrapping self._original_model = wrapper.model self.peft_model = get_peft_model(wrapper.model, peft_config) self.n_params = sum(p.numel() for p in self.peft_model.parameters() if p.requires_grad) self.n_bytes = self.n_params * 2 # bf16 # Save initial adapter state for reset between users self._init_state = { name: param.data.clone() for name, param in self.peft_model.named_parameters() if param.requires_grad } def _reset_adapter(self): """Reset adapter weights to initial state (zeros for LoRA).""" for name, param in self.peft_model.named_parameters(): if param.requires_grad and name in self._init_state: param.data.copy_(self._init_state[name]) def _build_training_data(self, support_items, task): """Build (input_ids, labels) pairs from support items.""" from data.templates import build_support_prompt data = [] for item in support_items: input_text = build_support_prompt(item['support_input'], task) target_text = " " + item['support_output'] chat_messages = [ {"role": "system", "content": "You are a helpful writing assistant."}, {"role": "user", "content": input_text}, ] prompt_text = self.wrapper.tokenizer.apply_chat_template( chat_messages, tokenize=False, add_generation_prompt=True ) full_text = prompt_text + target_text prompt_ids = self.wrapper.tokenizer.encode(prompt_text, return_tensors="pt") full_ids = self.wrapper.tokenizer.encode(full_text, return_tensors="pt") labels = full_ids.clone() labels[0, :prompt_ids.shape[1]] = -100 data.append((full_ids.to(self.device), labels.to(self.device))) return data def adapt_and_generate( self, support_items, query_input, task, lr=1e-4, steps=30, max_new_tokens=512, min_new_tokens=128, verbose=False, ): """Reset adapter, fine-tune on support set, generate, return text.""" self._reset_adapter() # Build training data train_data = self._build_training_data(support_items, task) if not train_data: return self._generate_fallback(query_input, task, max_new_tokens, min_new_tokens) # Fine-tune trainable = [p for p in self.peft_model.parameters() if p.requires_grad] optimizer = torch.optim.AdamW(trainable, lr=lr) self.peft_model.train() for step in range(steps): optimizer.zero_grad() total_loss = 0.0 for input_ids, labels in train_data: outputs = self.peft_model(input_ids=input_ids, labels=labels) (outputs.loss / len(train_data)).backward() total_loss += outputs.loss.item() torch.nn.utils.clip_grad_norm_(trainable, 1.0) optimizer.step() if verbose and (step % 10 == 0 or step == steps - 1): print(f" Step {step:3d}: loss={total_loss/len(train_data):.4f}") # Generate self.peft_model.eval() generated = self._generate(query_input, task, max_new_tokens, min_new_tokens) del optimizer torch.cuda.empty_cache() return generated def _generate(self, query_input, task, max_new_tokens, min_new_tokens): from data.templates import build_query_prompt prompt = build_query_prompt(query_input, task) chat_messages = [ {"role": "system", "content": "You are a helpful writing assistant."}, {"role": "user", "content": prompt}, ] prompt_text = self.wrapper.tokenizer.apply_chat_template( chat_messages, tokenize=False, add_generation_prompt=True ) input_ids = self.wrapper.tokenizer.encode( prompt_text, return_tensors="pt" ).to(self.device) with torch.no_grad(): outputs = self.peft_model.generate( input_ids, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, temperature=None, top_p=None, do_sample=False, pad_token_id=self.wrapper.tokenizer.pad_token_id, ) generated_ids = outputs[0, input_ids.shape[1]:] return self.wrapper.tokenizer.decode(generated_ids, skip_special_tokens=True) def _generate_fallback(self, query_input, task, max_new_tokens, min_new_tokens): """Fallback: generate without adaptation (empty support set).""" self.peft_model.eval() return self._generate(query_input, task, max_new_tokens, min_new_tokens) def cleanup(self): """Remove adapter and restore wrapper.model to the original base model.""" # Always restore the saved original model reference — safe for all PEFT types self.wrapper.model = self._original_model del self.peft_model torch.cuda.empty_cache()