summaryrefslogtreecommitdiff
path: root/baselines/profile_based.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-05 10:31:36 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-05 10:31:36 -0500
commitea4a8f837e81b5e5fab6086cb3014c711c5e58e9 (patch)
tree11638546dc91c97815e5bdab8fa0b587481d0a3c /baselines/profile_based.py
parent8fe28101366dd32562b8c5534d7fe359b252bdf3 (diff)
Add PEFT baselines, ICL baselines, profile-based, and unified pipeline
New baselines: - baselines/peft_baseline.py: LoRA, Tiny LoRA, VeRA (per-user PEFT adaptation) - baselines/dense_retrieval.py: Dense retrieval ICL (sentence-transformers) - baselines/profile_based.py: LLM-generated user profile conditioned generation New scripts: - scripts/run_all_methods.py: Unified pipeline running all 9 methods with per-method directory output structure (method/per_user.json) - scripts/run_peft_baselines.py: PEFT-only evaluation (legacy) - scripts/run_significance.py: Significance tests (UPH+Base per-user) - scripts/run_uph_base_per_user.py: UPH+Base with full per-user data - scripts/compute_bertscore.py: BERTScore from saved predictions - scripts/significance_test.py: Standalone significance test framework Updated .gitignore to exclude outputs/ directory. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'baselines/profile_based.py')
-rw-r--r--baselines/profile_based.py58
1 files changed, 58 insertions, 0 deletions
diff --git a/baselines/profile_based.py b/baselines/profile_based.py
new file mode 100644
index 0000000..bc48679
--- /dev/null
+++ b/baselines/profile_based.py
@@ -0,0 +1,58 @@
+"""Profile-based baseline.
+
+Uses the LLM to generate a user writing style profile from K support examples,
+then conditions generation on that profile summary.
+"""
+
+
+def build_profile_prompt(support_items, task):
+ """Build prompt to generate a user writing style profile from support examples."""
+ parts = ["Analyze the following writing samples and describe the author's writing style "
+ "in 2-3 sentences. Focus on tone, vocabulary, sentence structure, and any "
+ "distinctive patterns.\n"]
+
+ for i, item in enumerate(support_items, 1):
+ parts.append(f"--- Sample {i} ---")
+ parts.append(item['support_output'][:500]) # truncate long samples
+ parts.append("")
+
+ parts.append("Writing style description:")
+ return "\n".join(parts)
+
+
+def build_profile_conditioned_prompt(query_input, profile_summary, task):
+ """Build generation prompt conditioned on the user profile."""
+ from data.templates import build_query_prompt
+ base_prompt = build_query_prompt(query_input, task)
+
+ return (
+ f"The following describes this user's writing style:\n"
+ f"{profile_summary}\n\n"
+ f"Write in this style.\n\n"
+ f"{base_prompt}"
+ )
+
+
+def generate_profile(wrapper, support_items, task, max_profile_tokens=150):
+ """Generate a user writing style profile using the LLM."""
+ import torch
+
+ prompt = build_profile_prompt(support_items, task)
+ chat_messages = [
+ {"role": "system", "content": "You are a writing style analyst."},
+ {"role": "user", "content": prompt},
+ ]
+ prompt_text = wrapper.tokenizer.apply_chat_template(
+ chat_messages, tokenize=False, add_generation_prompt=True
+ )
+ input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device)
+
+ with torch.no_grad():
+ outputs = wrapper.model.generate(
+ input_ids,
+ max_new_tokens=max_profile_tokens,
+ temperature=None, top_p=None, do_sample=False,
+ pad_token_id=wrapper.tokenizer.pad_token_id,
+ )
+ generated_ids = outputs[0, input_ids.shape[1]:]
+ return wrapper.tokenizer.decode(generated_ids, skip_special_tokens=True)