diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-05 10:31:36 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-05 10:31:36 -0500 |
| commit | ea4a8f837e81b5e5fab6086cb3014c711c5e58e9 (patch) | |
| tree | 11638546dc91c97815e5bdab8fa0b587481d0a3c /baselines | |
| parent | 8fe28101366dd32562b8c5534d7fe359b252bdf3 (diff) | |
Add PEFT baselines, ICL baselines, profile-based, and unified pipeline
New baselines:
- baselines/peft_baseline.py: LoRA, Tiny LoRA, VeRA (per-user PEFT adaptation)
- baselines/dense_retrieval.py: Dense retrieval ICL (sentence-transformers)
- baselines/profile_based.py: LLM-generated user profile conditioned generation
New scripts:
- scripts/run_all_methods.py: Unified pipeline running all 9 methods with
per-method directory output structure (method/per_user.json)
- scripts/run_peft_baselines.py: PEFT-only evaluation (legacy)
- scripts/run_significance.py: Significance tests (UPH+Base per-user)
- scripts/run_uph_base_per_user.py: UPH+Base with full per-user data
- scripts/compute_bertscore.py: BERTScore from saved predictions
- scripts/significance_test.py: Standalone significance test framework
Updated .gitignore to exclude outputs/ directory.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'baselines')
| -rw-r--r-- | baselines/dense_retrieval.py | 42 | ||||
| -rw-r--r-- | baselines/peft_baseline.py | 194 | ||||
| -rw-r--r-- | baselines/profile_based.py | 58 |
3 files changed, 294 insertions, 0 deletions
diff --git a/baselines/dense_retrieval.py b/baselines/dense_retrieval.py new file mode 100644 index 0000000..db403a8 --- /dev/null +++ b/baselines/dense_retrieval.py @@ -0,0 +1,42 @@ +"""Dense Retrieval ICL baseline. + +Uses sentence-transformers for dense retrieval over user support set, +then places top-K retrieved items as in-context examples. +""" + +import torch +from sentence_transformers import SentenceTransformer + + +class DenseRetriever: + """Dense retriever using sentence-transformers embeddings.""" + + def __init__(self, model_name='all-MiniLM-L6-v2', device='cpu'): + self.model = SentenceTransformer(model_name, device=device) + + def retrieve_top_k(self, query: str, support_items: list, k: int = 1): + """Retrieve top-k support items most relevant to query. + + Args: + query: query input text + support_items: list of dicts with 'support_input', 'support_output' + k: number of items to retrieve + + Returns: + List of top-k support items (sorted by relevance) + """ + if len(support_items) <= k: + return support_items + + texts = [item['support_input'] for item in support_items] + embeddings = self.model.encode([query] + texts, convert_to_tensor=True) + + query_emb = embeddings[0] + item_embs = embeddings[1:] + + similarities = torch.nn.functional.cosine_similarity( + query_emb.unsqueeze(0), item_embs, dim=1 + ) + + top_indices = similarities.argsort(descending=True)[:k].tolist() + return [support_items[i] for i in top_indices] diff --git a/baselines/peft_baseline.py b/baselines/peft_baseline.py new file mode 100644 index 0000000..442ba60 --- /dev/null +++ b/baselines/peft_baseline.py @@ -0,0 +1,194 @@ +"""PEFT baselines: LoRA, Tiny LoRA, and VeRA. + +Per-user adaptation on K support examples, then standard generation. +Uses a class-based API to avoid repeated model wrapping/unwrapping. + +Usage: + baseline = PEFTBaseline(wrapper, get_lora_config(rank=8)) + for user in users: + text = baseline.adapt_and_generate(support, query, task) + baseline.cleanup() # restore frozen model +""" + +import torch +from peft import LoraConfig, VeraConfig, get_peft_model, TaskType + + +TARGET_MODULES = ["q_proj", "v_proj"] + + +def _make_lora_config(rank, target_modules=None, lora_alpha=None): + if target_modules is None: + target_modules = TARGET_MODULES + if lora_alpha is None: + lora_alpha = 2 * rank + return LoraConfig( + task_type=TaskType.CAUSAL_LM, + r=rank, + lora_alpha=lora_alpha, + lora_dropout=0.0, + target_modules=target_modules, + bias="none", + ) + + +def _make_vera_config(rank, target_modules=None): + if target_modules is None: + target_modules = TARGET_MODULES + return VeraConfig( + task_type=TaskType.CAUSAL_LM, + r=rank, + target_modules=target_modules, + vera_dropout=0.0, + ) + + +def get_lora_config(rank=8): + return _make_lora_config(rank=rank) + + +def get_tiny_lora_config(rank=1): + return _make_lora_config(rank=rank) + + +def get_vera_config(rank=256): + return _make_vera_config(rank=rank) + + +class PEFTBaseline: + """Manages a PEFT-wrapped model for repeated per-user adaptation.""" + + def __init__(self, wrapper, peft_config): + self.wrapper = wrapper + self.device = wrapper.device + self.peft_model = get_peft_model(wrapper.model, peft_config) + + self.n_params = sum(p.numel() for p in self.peft_model.parameters() if p.requires_grad) + self.n_bytes = self.n_params * 2 # bf16 + + # Save initial adapter state for reset between users + self._init_state = { + name: param.data.clone() + for name, param in self.peft_model.named_parameters() + if param.requires_grad + } + + def _reset_adapter(self): + """Reset adapter weights to initial state (zeros for LoRA).""" + for name, param in self.peft_model.named_parameters(): + if param.requires_grad and name in self._init_state: + param.data.copy_(self._init_state[name]) + + def _build_training_data(self, support_items, task): + """Build (input_ids, labels) pairs from support items.""" + from data.templates import build_support_prompt + + data = [] + for item in support_items: + input_text = build_support_prompt(item['support_input'], task) + target_text = " " + item['support_output'] + + chat_messages = [ + {"role": "system", "content": "You are a helpful writing assistant."}, + {"role": "user", "content": input_text}, + ] + prompt_text = self.wrapper.tokenizer.apply_chat_template( + chat_messages, tokenize=False, add_generation_prompt=True + ) + full_text = prompt_text + target_text + + prompt_ids = self.wrapper.tokenizer.encode(prompt_text, return_tensors="pt") + full_ids = self.wrapper.tokenizer.encode(full_text, return_tensors="pt") + + labels = full_ids.clone() + labels[0, :prompt_ids.shape[1]] = -100 + + data.append((full_ids.to(self.device), labels.to(self.device))) + return data + + def adapt_and_generate( + self, + support_items, + query_input, + task, + lr=1e-4, + steps=30, + max_new_tokens=512, + min_new_tokens=128, + verbose=False, + ): + """Reset adapter, fine-tune on support set, generate, return text.""" + self._reset_adapter() + + # Build training data + train_data = self._build_training_data(support_items, task) + if not train_data: + return self._generate_fallback(query_input, task, max_new_tokens, min_new_tokens) + + # Fine-tune + trainable = [p for p in self.peft_model.parameters() if p.requires_grad] + optimizer = torch.optim.AdamW(trainable, lr=lr) + + self.peft_model.train() + for step in range(steps): + optimizer.zero_grad() + total_loss = 0.0 + + for input_ids, labels in train_data: + outputs = self.peft_model(input_ids=input_ids, labels=labels) + (outputs.loss / len(train_data)).backward() + total_loss += outputs.loss.item() + + torch.nn.utils.clip_grad_norm_(trainable, 1.0) + optimizer.step() + + if verbose and (step % 10 == 0 or step == steps - 1): + print(f" Step {step:3d}: loss={total_loss/len(train_data):.4f}") + + # Generate + self.peft_model.eval() + generated = self._generate(query_input, task, max_new_tokens, min_new_tokens) + + del optimizer + torch.cuda.empty_cache() + return generated + + def _generate(self, query_input, task, max_new_tokens, min_new_tokens): + from data.templates import build_query_prompt + prompt = build_query_prompt(query_input, task) + + chat_messages = [ + {"role": "system", "content": "You are a helpful writing assistant."}, + {"role": "user", "content": prompt}, + ] + prompt_text = self.wrapper.tokenizer.apply_chat_template( + chat_messages, tokenize=False, add_generation_prompt=True + ) + input_ids = self.wrapper.tokenizer.encode( + prompt_text, return_tensors="pt" + ).to(self.device) + + with torch.no_grad(): + outputs = self.peft_model.generate( + input_ids, + max_new_tokens=max_new_tokens, + min_new_tokens=min_new_tokens, + temperature=None, + top_p=None, + do_sample=False, + pad_token_id=self.wrapper.tokenizer.pad_token_id, + ) + generated_ids = outputs[0, input_ids.shape[1]:] + return self.wrapper.tokenizer.decode(generated_ids, skip_special_tokens=True) + + def _generate_fallback(self, query_input, task, max_new_tokens, min_new_tokens): + """Fallback: generate without adaptation (empty support set).""" + self.peft_model.eval() + return self._generate(query_input, task, max_new_tokens, min_new_tokens) + + def cleanup(self): + """Remove adapter and restore wrapper.model to the original base model.""" + base_model = self.peft_model.unload() + self.wrapper.model = base_model + del self.peft_model + torch.cuda.empty_cache() diff --git a/baselines/profile_based.py b/baselines/profile_based.py new file mode 100644 index 0000000..bc48679 --- /dev/null +++ b/baselines/profile_based.py @@ -0,0 +1,58 @@ +"""Profile-based baseline. + +Uses the LLM to generate a user writing style profile from K support examples, +then conditions generation on that profile summary. +""" + + +def build_profile_prompt(support_items, task): + """Build prompt to generate a user writing style profile from support examples.""" + parts = ["Analyze the following writing samples and describe the author's writing style " + "in 2-3 sentences. Focus on tone, vocabulary, sentence structure, and any " + "distinctive patterns.\n"] + + for i, item in enumerate(support_items, 1): + parts.append(f"--- Sample {i} ---") + parts.append(item['support_output'][:500]) # truncate long samples + parts.append("") + + parts.append("Writing style description:") + return "\n".join(parts) + + +def build_profile_conditioned_prompt(query_input, profile_summary, task): + """Build generation prompt conditioned on the user profile.""" + from data.templates import build_query_prompt + base_prompt = build_query_prompt(query_input, task) + + return ( + f"The following describes this user's writing style:\n" + f"{profile_summary}\n\n" + f"Write in this style.\n\n" + f"{base_prompt}" + ) + + +def generate_profile(wrapper, support_items, task, max_profile_tokens=150): + """Generate a user writing style profile using the LLM.""" + import torch + + prompt = build_profile_prompt(support_items, task) + chat_messages = [ + {"role": "system", "content": "You are a writing style analyst."}, + {"role": "user", "content": prompt}, + ] + prompt_text = wrapper.tokenizer.apply_chat_template( + chat_messages, tokenize=False, add_generation_prompt=True + ) + input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device) + + with torch.no_grad(): + outputs = wrapper.model.generate( + input_ids, + max_new_tokens=max_profile_tokens, + temperature=None, top_p=None, do_sample=False, + pad_token_id=wrapper.tokenizer.pad_token_id, + ) + generated_ids = outputs[0, input_ids.shape[1]:] + return wrapper.tokenizer.decode(generated_ids, skip_special_tokens=True) |
