summaryrefslogtreecommitdiff
path: root/baselines/peft_baseline.py
diff options
context:
space:
mode:
Diffstat (limited to 'baselines/peft_baseline.py')
-rw-r--r--baselines/peft_baseline.py30
1 files changed, 28 insertions, 2 deletions
diff --git a/baselines/peft_baseline.py b/baselines/peft_baseline.py
index 442ba60..246385f 100644
--- a/baselines/peft_baseline.py
+++ b/baselines/peft_baseline.py
@@ -1,4 +1,4 @@
-"""PEFT baselines: LoRA, Tiny LoRA, and VeRA.
+"""PEFT baselines: LoRA, Tiny LoRA, VeRA, Prompt Tuning, Prefix Tuning.
Per-user adaptation on K support examples, then standard generation.
Uses a class-based API to avoid repeated model wrapping/unwrapping.
@@ -11,7 +11,10 @@ Usage:
"""
import torch
-from peft import LoraConfig, VeraConfig, get_peft_model, TaskType
+from peft import (
+ LoraConfig, VeraConfig, PromptTuningConfig, PrefixTuningConfig,
+ get_peft_model, TaskType, PromptTuningInit,
+)
TARGET_MODULES = ["q_proj", "v_proj"]
@@ -43,6 +46,21 @@ def _make_vera_config(rank, target_modules=None):
)
+def _make_prompt_tuning_config(num_virtual_tokens):
+ return PromptTuningConfig(
+ task_type=TaskType.CAUSAL_LM,
+ num_virtual_tokens=num_virtual_tokens,
+ prompt_tuning_init=PromptTuningInit.RANDOM,
+ )
+
+
+def _make_prefix_tuning_config(num_virtual_tokens):
+ return PrefixTuningConfig(
+ task_type=TaskType.CAUSAL_LM,
+ num_virtual_tokens=num_virtual_tokens,
+ )
+
+
def get_lora_config(rank=8):
return _make_lora_config(rank=rank)
@@ -55,6 +73,14 @@ def get_vera_config(rank=256):
return _make_vera_config(rank=rank)
+def get_prompt_tuning_config(num_tokens=10):
+ return _make_prompt_tuning_config(num_virtual_tokens=num_tokens)
+
+
+def get_prefix_tuning_config(num_tokens=10):
+ return _make_prefix_tuning_config(num_virtual_tokens=num_tokens)
+
+
class PEFTBaseline:
"""Manages a PEFT-wrapped model for repeated per-user adaptation."""