summaryrefslogtreecommitdiff
path: root/baselines/peft_baseline.py
blob: 442ba60a58fc8231d80294269cb39f33f9a239fa (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
"""PEFT baselines: LoRA, Tiny LoRA, and VeRA.

Per-user adaptation on K support examples, then standard generation.
Uses a class-based API to avoid repeated model wrapping/unwrapping.

Usage:
    baseline = PEFTBaseline(wrapper, get_lora_config(rank=8))
    for user in users:
        text = baseline.adapt_and_generate(support, query, task)
    baseline.cleanup()  # restore frozen model
"""

import torch
from peft import LoraConfig, VeraConfig, get_peft_model, TaskType


TARGET_MODULES = ["q_proj", "v_proj"]


def _make_lora_config(rank, target_modules=None, lora_alpha=None):
    if target_modules is None:
        target_modules = TARGET_MODULES
    if lora_alpha is None:
        lora_alpha = 2 * rank
    return LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=rank,
        lora_alpha=lora_alpha,
        lora_dropout=0.0,
        target_modules=target_modules,
        bias="none",
    )


def _make_vera_config(rank, target_modules=None):
    if target_modules is None:
        target_modules = TARGET_MODULES
    return VeraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=rank,
        target_modules=target_modules,
        vera_dropout=0.0,
    )


def get_lora_config(rank=8):
    return _make_lora_config(rank=rank)


def get_tiny_lora_config(rank=1):
    return _make_lora_config(rank=rank)


def get_vera_config(rank=256):
    return _make_vera_config(rank=rank)


class PEFTBaseline:
    """Manages a PEFT-wrapped model for repeated per-user adaptation."""

    def __init__(self, wrapper, peft_config):
        self.wrapper = wrapper
        self.device = wrapper.device
        self.peft_model = get_peft_model(wrapper.model, peft_config)

        self.n_params = sum(p.numel() for p in self.peft_model.parameters() if p.requires_grad)
        self.n_bytes = self.n_params * 2  # bf16

        # Save initial adapter state for reset between users
        self._init_state = {
            name: param.data.clone()
            for name, param in self.peft_model.named_parameters()
            if param.requires_grad
        }

    def _reset_adapter(self):
        """Reset adapter weights to initial state (zeros for LoRA)."""
        for name, param in self.peft_model.named_parameters():
            if param.requires_grad and name in self._init_state:
                param.data.copy_(self._init_state[name])

    def _build_training_data(self, support_items, task):
        """Build (input_ids, labels) pairs from support items."""
        from data.templates import build_support_prompt

        data = []
        for item in support_items:
            input_text = build_support_prompt(item['support_input'], task)
            target_text = " " + item['support_output']

            chat_messages = [
                {"role": "system", "content": "You are a helpful writing assistant."},
                {"role": "user", "content": input_text},
            ]
            prompt_text = self.wrapper.tokenizer.apply_chat_template(
                chat_messages, tokenize=False, add_generation_prompt=True
            )
            full_text = prompt_text + target_text

            prompt_ids = self.wrapper.tokenizer.encode(prompt_text, return_tensors="pt")
            full_ids = self.wrapper.tokenizer.encode(full_text, return_tensors="pt")

            labels = full_ids.clone()
            labels[0, :prompt_ids.shape[1]] = -100

            data.append((full_ids.to(self.device), labels.to(self.device)))
        return data

    def adapt_and_generate(
        self,
        support_items,
        query_input,
        task,
        lr=1e-4,
        steps=30,
        max_new_tokens=512,
        min_new_tokens=128,
        verbose=False,
    ):
        """Reset adapter, fine-tune on support set, generate, return text."""
        self._reset_adapter()

        # Build training data
        train_data = self._build_training_data(support_items, task)
        if not train_data:
            return self._generate_fallback(query_input, task, max_new_tokens, min_new_tokens)

        # Fine-tune
        trainable = [p for p in self.peft_model.parameters() if p.requires_grad]
        optimizer = torch.optim.AdamW(trainable, lr=lr)

        self.peft_model.train()
        for step in range(steps):
            optimizer.zero_grad()
            total_loss = 0.0

            for input_ids, labels in train_data:
                outputs = self.peft_model(input_ids=input_ids, labels=labels)
                (outputs.loss / len(train_data)).backward()
                total_loss += outputs.loss.item()

            torch.nn.utils.clip_grad_norm_(trainable, 1.0)
            optimizer.step()

            if verbose and (step % 10 == 0 or step == steps - 1):
                print(f"  Step {step:3d}: loss={total_loss/len(train_data):.4f}")

        # Generate
        self.peft_model.eval()
        generated = self._generate(query_input, task, max_new_tokens, min_new_tokens)

        del optimizer
        torch.cuda.empty_cache()
        return generated

    def _generate(self, query_input, task, max_new_tokens, min_new_tokens):
        from data.templates import build_query_prompt
        prompt = build_query_prompt(query_input, task)

        chat_messages = [
            {"role": "system", "content": "You are a helpful writing assistant."},
            {"role": "user", "content": prompt},
        ]
        prompt_text = self.wrapper.tokenizer.apply_chat_template(
            chat_messages, tokenize=False, add_generation_prompt=True
        )
        input_ids = self.wrapper.tokenizer.encode(
            prompt_text, return_tensors="pt"
        ).to(self.device)

        with torch.no_grad():
            outputs = self.peft_model.generate(
                input_ids,
                max_new_tokens=max_new_tokens,
                min_new_tokens=min_new_tokens,
                temperature=None,
                top_p=None,
                do_sample=False,
                pad_token_id=self.wrapper.tokenizer.pad_token_id,
            )
        generated_ids = outputs[0, input_ids.shape[1]:]
        return self.wrapper.tokenizer.decode(generated_ids, skip_special_tokens=True)

    def _generate_fallback(self, query_input, task, max_new_tokens, min_new_tokens):
        """Fallback: generate without adaptation (empty support set)."""
        self.peft_model.eval()
        return self._generate(query_input, task, max_new_tokens, min_new_tokens)

    def cleanup(self):
        """Remove adapter and restore wrapper.model to the original base model."""
        base_model = self.peft_model.unload()
        self.wrapper.model = base_model
        del self.peft_model
        torch.cuda.empty_cache()