summaryrefslogtreecommitdiff
path: root/baselines
diff options
context:
space:
mode:
Diffstat (limited to 'baselines')
-rw-r--r--baselines/dense_retrieval.py42
-rw-r--r--baselines/peft_baseline.py194
-rw-r--r--baselines/profile_based.py58
3 files changed, 294 insertions, 0 deletions
diff --git a/baselines/dense_retrieval.py b/baselines/dense_retrieval.py
new file mode 100644
index 0000000..db403a8
--- /dev/null
+++ b/baselines/dense_retrieval.py
@@ -0,0 +1,42 @@
+"""Dense Retrieval ICL baseline.
+
+Uses sentence-transformers for dense retrieval over user support set,
+then places top-K retrieved items as in-context examples.
+"""
+
+import torch
+from sentence_transformers import SentenceTransformer
+
+
+class DenseRetriever:
+ """Dense retriever using sentence-transformers embeddings."""
+
+ def __init__(self, model_name='all-MiniLM-L6-v2', device='cpu'):
+ self.model = SentenceTransformer(model_name, device=device)
+
+ def retrieve_top_k(self, query: str, support_items: list, k: int = 1):
+ """Retrieve top-k support items most relevant to query.
+
+ Args:
+ query: query input text
+ support_items: list of dicts with 'support_input', 'support_output'
+ k: number of items to retrieve
+
+ Returns:
+ List of top-k support items (sorted by relevance)
+ """
+ if len(support_items) <= k:
+ return support_items
+
+ texts = [item['support_input'] for item in support_items]
+ embeddings = self.model.encode([query] + texts, convert_to_tensor=True)
+
+ query_emb = embeddings[0]
+ item_embs = embeddings[1:]
+
+ similarities = torch.nn.functional.cosine_similarity(
+ query_emb.unsqueeze(0), item_embs, dim=1
+ )
+
+ top_indices = similarities.argsort(descending=True)[:k].tolist()
+ return [support_items[i] for i in top_indices]
diff --git a/baselines/peft_baseline.py b/baselines/peft_baseline.py
new file mode 100644
index 0000000..442ba60
--- /dev/null
+++ b/baselines/peft_baseline.py
@@ -0,0 +1,194 @@
+"""PEFT baselines: LoRA, Tiny LoRA, and VeRA.
+
+Per-user adaptation on K support examples, then standard generation.
+Uses a class-based API to avoid repeated model wrapping/unwrapping.
+
+Usage:
+ baseline = PEFTBaseline(wrapper, get_lora_config(rank=8))
+ for user in users:
+ text = baseline.adapt_and_generate(support, query, task)
+ baseline.cleanup() # restore frozen model
+"""
+
+import torch
+from peft import LoraConfig, VeraConfig, get_peft_model, TaskType
+
+
+TARGET_MODULES = ["q_proj", "v_proj"]
+
+
+def _make_lora_config(rank, target_modules=None, lora_alpha=None):
+ if target_modules is None:
+ target_modules = TARGET_MODULES
+ if lora_alpha is None:
+ lora_alpha = 2 * rank
+ return LoraConfig(
+ task_type=TaskType.CAUSAL_LM,
+ r=rank,
+ lora_alpha=lora_alpha,
+ lora_dropout=0.0,
+ target_modules=target_modules,
+ bias="none",
+ )
+
+
+def _make_vera_config(rank, target_modules=None):
+ if target_modules is None:
+ target_modules = TARGET_MODULES
+ return VeraConfig(
+ task_type=TaskType.CAUSAL_LM,
+ r=rank,
+ target_modules=target_modules,
+ vera_dropout=0.0,
+ )
+
+
+def get_lora_config(rank=8):
+ return _make_lora_config(rank=rank)
+
+
+def get_tiny_lora_config(rank=1):
+ return _make_lora_config(rank=rank)
+
+
+def get_vera_config(rank=256):
+ return _make_vera_config(rank=rank)
+
+
+class PEFTBaseline:
+ """Manages a PEFT-wrapped model for repeated per-user adaptation."""
+
+ def __init__(self, wrapper, peft_config):
+ self.wrapper = wrapper
+ self.device = wrapper.device
+ self.peft_model = get_peft_model(wrapper.model, peft_config)
+
+ self.n_params = sum(p.numel() for p in self.peft_model.parameters() if p.requires_grad)
+ self.n_bytes = self.n_params * 2 # bf16
+
+ # Save initial adapter state for reset between users
+ self._init_state = {
+ name: param.data.clone()
+ for name, param in self.peft_model.named_parameters()
+ if param.requires_grad
+ }
+
+ def _reset_adapter(self):
+ """Reset adapter weights to initial state (zeros for LoRA)."""
+ for name, param in self.peft_model.named_parameters():
+ if param.requires_grad and name in self._init_state:
+ param.data.copy_(self._init_state[name])
+
+ def _build_training_data(self, support_items, task):
+ """Build (input_ids, labels) pairs from support items."""
+ from data.templates import build_support_prompt
+
+ data = []
+ for item in support_items:
+ input_text = build_support_prompt(item['support_input'], task)
+ target_text = " " + item['support_output']
+
+ chat_messages = [
+ {"role": "system", "content": "You are a helpful writing assistant."},
+ {"role": "user", "content": input_text},
+ ]
+ prompt_text = self.wrapper.tokenizer.apply_chat_template(
+ chat_messages, tokenize=False, add_generation_prompt=True
+ )
+ full_text = prompt_text + target_text
+
+ prompt_ids = self.wrapper.tokenizer.encode(prompt_text, return_tensors="pt")
+ full_ids = self.wrapper.tokenizer.encode(full_text, return_tensors="pt")
+
+ labels = full_ids.clone()
+ labels[0, :prompt_ids.shape[1]] = -100
+
+ data.append((full_ids.to(self.device), labels.to(self.device)))
+ return data
+
+ def adapt_and_generate(
+ self,
+ support_items,
+ query_input,
+ task,
+ lr=1e-4,
+ steps=30,
+ max_new_tokens=512,
+ min_new_tokens=128,
+ verbose=False,
+ ):
+ """Reset adapter, fine-tune on support set, generate, return text."""
+ self._reset_adapter()
+
+ # Build training data
+ train_data = self._build_training_data(support_items, task)
+ if not train_data:
+ return self._generate_fallback(query_input, task, max_new_tokens, min_new_tokens)
+
+ # Fine-tune
+ trainable = [p for p in self.peft_model.parameters() if p.requires_grad]
+ optimizer = torch.optim.AdamW(trainable, lr=lr)
+
+ self.peft_model.train()
+ for step in range(steps):
+ optimizer.zero_grad()
+ total_loss = 0.0
+
+ for input_ids, labels in train_data:
+ outputs = self.peft_model(input_ids=input_ids, labels=labels)
+ (outputs.loss / len(train_data)).backward()
+ total_loss += outputs.loss.item()
+
+ torch.nn.utils.clip_grad_norm_(trainable, 1.0)
+ optimizer.step()
+
+ if verbose and (step % 10 == 0 or step == steps - 1):
+ print(f" Step {step:3d}: loss={total_loss/len(train_data):.4f}")
+
+ # Generate
+ self.peft_model.eval()
+ generated = self._generate(query_input, task, max_new_tokens, min_new_tokens)
+
+ del optimizer
+ torch.cuda.empty_cache()
+ return generated
+
+ def _generate(self, query_input, task, max_new_tokens, min_new_tokens):
+ from data.templates import build_query_prompt
+ prompt = build_query_prompt(query_input, task)
+
+ chat_messages = [
+ {"role": "system", "content": "You are a helpful writing assistant."},
+ {"role": "user", "content": prompt},
+ ]
+ prompt_text = self.wrapper.tokenizer.apply_chat_template(
+ chat_messages, tokenize=False, add_generation_prompt=True
+ )
+ input_ids = self.wrapper.tokenizer.encode(
+ prompt_text, return_tensors="pt"
+ ).to(self.device)
+
+ with torch.no_grad():
+ outputs = self.peft_model.generate(
+ input_ids,
+ max_new_tokens=max_new_tokens,
+ min_new_tokens=min_new_tokens,
+ temperature=None,
+ top_p=None,
+ do_sample=False,
+ pad_token_id=self.wrapper.tokenizer.pad_token_id,
+ )
+ generated_ids = outputs[0, input_ids.shape[1]:]
+ return self.wrapper.tokenizer.decode(generated_ids, skip_special_tokens=True)
+
+ def _generate_fallback(self, query_input, task, max_new_tokens, min_new_tokens):
+ """Fallback: generate without adaptation (empty support set)."""
+ self.peft_model.eval()
+ return self._generate(query_input, task, max_new_tokens, min_new_tokens)
+
+ def cleanup(self):
+ """Remove adapter and restore wrapper.model to the original base model."""
+ base_model = self.peft_model.unload()
+ self.wrapper.model = base_model
+ del self.peft_model
+ torch.cuda.empty_cache()
diff --git a/baselines/profile_based.py b/baselines/profile_based.py
new file mode 100644
index 0000000..bc48679
--- /dev/null
+++ b/baselines/profile_based.py
@@ -0,0 +1,58 @@
+"""Profile-based baseline.
+
+Uses the LLM to generate a user writing style profile from K support examples,
+then conditions generation on that profile summary.
+"""
+
+
+def build_profile_prompt(support_items, task):
+ """Build prompt to generate a user writing style profile from support examples."""
+ parts = ["Analyze the following writing samples and describe the author's writing style "
+ "in 2-3 sentences. Focus on tone, vocabulary, sentence structure, and any "
+ "distinctive patterns.\n"]
+
+ for i, item in enumerate(support_items, 1):
+ parts.append(f"--- Sample {i} ---")
+ parts.append(item['support_output'][:500]) # truncate long samples
+ parts.append("")
+
+ parts.append("Writing style description:")
+ return "\n".join(parts)
+
+
+def build_profile_conditioned_prompt(query_input, profile_summary, task):
+ """Build generation prompt conditioned on the user profile."""
+ from data.templates import build_query_prompt
+ base_prompt = build_query_prompt(query_input, task)
+
+ return (
+ f"The following describes this user's writing style:\n"
+ f"{profile_summary}\n\n"
+ f"Write in this style.\n\n"
+ f"{base_prompt}"
+ )
+
+
+def generate_profile(wrapper, support_items, task, max_profile_tokens=150):
+ """Generate a user writing style profile using the LLM."""
+ import torch
+
+ prompt = build_profile_prompt(support_items, task)
+ chat_messages = [
+ {"role": "system", "content": "You are a writing style analyst."},
+ {"role": "user", "content": prompt},
+ ]
+ prompt_text = wrapper.tokenizer.apply_chat_template(
+ chat_messages, tokenize=False, add_generation_prompt=True
+ )
+ input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device)
+
+ with torch.no_grad():
+ outputs = wrapper.model.generate(
+ input_ids,
+ max_new_tokens=max_profile_tokens,
+ temperature=None, top_p=None, do_sample=False,
+ pad_token_id=wrapper.tokenizer.pad_token_id,
+ )
+ generated_ids = outputs[0, input_ids.shape[1]:]
+ return wrapper.tokenizer.decode(generated_ids, skip_special_tokens=True)