diff options
Diffstat (limited to 'baselines/profile_based.py')
| -rw-r--r-- | baselines/profile_based.py | 58 |
1 files changed, 58 insertions, 0 deletions
diff --git a/baselines/profile_based.py b/baselines/profile_based.py new file mode 100644 index 0000000..bc48679 --- /dev/null +++ b/baselines/profile_based.py @@ -0,0 +1,58 @@ +"""Profile-based baseline. + +Uses the LLM to generate a user writing style profile from K support examples, +then conditions generation on that profile summary. +""" + + +def build_profile_prompt(support_items, task): + """Build prompt to generate a user writing style profile from support examples.""" + parts = ["Analyze the following writing samples and describe the author's writing style " + "in 2-3 sentences. Focus on tone, vocabulary, sentence structure, and any " + "distinctive patterns.\n"] + + for i, item in enumerate(support_items, 1): + parts.append(f"--- Sample {i} ---") + parts.append(item['support_output'][:500]) # truncate long samples + parts.append("") + + parts.append("Writing style description:") + return "\n".join(parts) + + +def build_profile_conditioned_prompt(query_input, profile_summary, task): + """Build generation prompt conditioned on the user profile.""" + from data.templates import build_query_prompt + base_prompt = build_query_prompt(query_input, task) + + return ( + f"The following describes this user's writing style:\n" + f"{profile_summary}\n\n" + f"Write in this style.\n\n" + f"{base_prompt}" + ) + + +def generate_profile(wrapper, support_items, task, max_profile_tokens=150): + """Generate a user writing style profile using the LLM.""" + import torch + + prompt = build_profile_prompt(support_items, task) + chat_messages = [ + {"role": "system", "content": "You are a writing style analyst."}, + {"role": "user", "content": prompt}, + ] + prompt_text = wrapper.tokenizer.apply_chat_template( + chat_messages, tokenize=False, add_generation_prompt=True + ) + input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device) + + with torch.no_grad(): + outputs = wrapper.model.generate( + input_ids, + max_new_tokens=max_profile_tokens, + temperature=None, top_p=None, do_sample=False, + pad_token_id=wrapper.tokenizer.pad_token_id, + ) + generated_ids = outputs[0, input_ids.shape[1]:] + return wrapper.tokenizer.decode(generated_ids, skip_special_tokens=True) |
