summaryrefslogtreecommitdiff
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/__init__.py0
-rw-r--r--data/longlamp.py101
-rw-r--r--data/style_features.py138
-rw-r--r--data/templates.py60
4 files changed, 299 insertions, 0 deletions
diff --git a/data/__init__.py b/data/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/data/__init__.py
diff --git a/data/longlamp.py b/data/longlamp.py
new file mode 100644
index 0000000..05b3939
--- /dev/null
+++ b/data/longlamp.py
@@ -0,0 +1,101 @@
+"""LongLaMP dataset loader for Review Writing and Topic Writing tasks."""
+
+import random
+from datasets import load_dataset
+
+
+def load_longlamp(config_name: str, split: str = "val"):
+ """Load a LongLaMP dataset configuration.
+
+ Args:
+ config_name: One of product_review_user, product_review_temporal,
+ topic_writing_user, topic_writing_temporal
+ split: train, val, or test
+
+ Returns:
+ List of unified dicts.
+ """
+ ds = load_dataset("LongLaMP/LongLaMP", config_name, split=split)
+
+ task = "review" if "review" in config_name else "topic"
+ setting = "user" if "user" in config_name else "temporal"
+
+ examples = []
+ for idx, row in enumerate(ds):
+ profile_items = row["profile"]
+
+ if task == "review":
+ processed_profile = []
+ for p in profile_items:
+ processed_profile.append({
+ "support_input": _build_review_support_input(p),
+ "support_output": p["reviewText"],
+ "raw": p,
+ })
+ else: # topic
+ processed_profile = []
+ for p in profile_items:
+ processed_profile.append({
+ "support_input": _build_topic_support_input(p),
+ "support_output": p["content"],
+ "raw": p,
+ })
+
+ user_id = row.get("reviewerId", row.get("author", f"user_{idx}"))
+
+ examples.append({
+ "task": task,
+ "setting": setting,
+ "query_input": row["input"],
+ "target_output": row["output"],
+ "profile_items": processed_profile,
+ "user_id": user_id,
+ "example_id": f"{config_name}_{split}_{idx}",
+ })
+
+ return examples
+
+
+def _build_review_support_input(profile_item: dict) -> str:
+ """Build the input text for a review support example."""
+ overall = profile_item.get("overall", "5.0")
+ description = profile_item.get("description", "")
+ summary = profile_item.get("summary", "")
+ return (
+ f'Generate the review text written by a reviewer who has a given an overall '
+ f'rating of "{overall}" for a product with description "{description}". '
+ f'The summary of the review text is "{summary}".'
+ )
+
+
+def _build_topic_support_input(profile_item: dict) -> str:
+ """Build the input text for a topic support example."""
+ summary = profile_item.get("summary", "")
+ return f"Generate the content for a reddit post {summary}"
+
+
+def select_k_profile_items(profile_items: list, K: int, seed: int = 0) -> list:
+ """Select K profile items from the available profile.
+
+ If fewer than K items available, return all of them.
+ Uses random selection with a fixed seed for reproducibility.
+ """
+ if len(profile_items) <= K:
+ return profile_items
+ rng = random.Random(seed)
+ return rng.sample(profile_items, K)
+
+
+if __name__ == "__main__":
+ # Quick test
+ examples = load_longlamp("product_review_user", split="validation")
+ print(f"Loaded {len(examples)} review user validation examples")
+ ex = examples[0]
+ print(f"User: {ex['user_id']}")
+ print(f"Query: {ex['query_input'][:200]}...")
+ print(f"Target: {ex['target_output'][:200]}...")
+ print(f"Profile items: {len(ex['profile_items'])}")
+ if ex['profile_items']:
+ p = ex['profile_items'][0]
+ print(f" Support input: {p['support_input'][:200]}...")
+ print(f" Support output: {p['support_output'][:200]}...")
diff --git a/data/style_features.py b/data/style_features.py
new file mode 100644
index 0000000..d674419
--- /dev/null
+++ b/data/style_features.py
@@ -0,0 +1,138 @@
+"""Style feature extraction for SFD (Style Feature Distance) metric."""
+
+import nltk
+from nltk.tokenize import sent_tokenize, word_tokenize
+
+FEATURE_NAMES = [
+ 'length', 'avg_sent_len', 'TTR', 'newline_rate', 'exclaim_rate',
+ 'first_person_rate', 'adj_adv_rate', 'sentiment_score'
+]
+
+
+def extract_style_features(text: str) -> list:
+ """Extract style feature vector from text.
+
+ Returns:
+ [length, avg_sent_len, TTR, newline_rate, exclaim_rate,
+ first_person_rate, adj_adv_rate, sentiment_score]
+ """
+ if not text or not text.strip():
+ return [0.0] * 8
+
+ words = word_tokenize(text)
+ num_words = max(len(words), 1)
+
+ # 1. Length (word count)
+ length = float(num_words)
+
+ # 2. Average sentence length
+ sentences = sent_tokenize(text)
+ num_sents = max(len(sentences), 1)
+ avg_sent_len = num_words / num_sents
+
+ # 3. Type-Token Ratio
+ unique_words = set(w.lower() for w in words if w.isalpha())
+ alpha_words = [w for w in words if w.isalpha()]
+ ttr = len(unique_words) / max(len(alpha_words), 1)
+
+ # 4. Newline rate
+ newline_count = text.count('\n')
+ newline_rate = newline_count / num_sents
+
+ # 5. Exclamation rate
+ exclaim_count = text.count('!')
+ exclaim_rate = exclaim_count / num_sents
+
+ # 6. First-person rate
+ first_person = {'i', 'me', 'my', 'mine', 'myself', 'we', 'us', 'our', 'ours', 'ourselves'}
+ fp_count = sum(1 for w in words if w.lower() in first_person)
+ first_person_rate = fp_count / num_words
+
+ # 7. Adjective/Adverb rate
+ try:
+ tagged = nltk.pos_tag(words)
+ adj_adv_tags = {'JJ', 'JJR', 'JJS', 'RB', 'RBR', 'RBS'}
+ adj_adv_count = sum(1 for _, tag in tagged if tag in adj_adv_tags)
+ adj_adv_rate = adj_adv_count / num_words
+ except Exception:
+ adj_adv_rate = 0.0
+
+ # 8. Simple sentiment score (positive - negative word ratio)
+ positive_words = {
+ 'good', 'great', 'excellent', 'amazing', 'wonderful', 'fantastic',
+ 'love', 'loved', 'best', 'perfect', 'awesome', 'beautiful',
+ 'enjoy', 'enjoyed', 'happy', 'glad', 'nice', 'brilliant',
+ 'outstanding', 'superb', 'delightful', 'pleasant', 'favorite',
+ 'recommend', 'recommended', 'impressive', 'incredible',
+ }
+ negative_words = {
+ 'bad', 'terrible', 'awful', 'horrible', 'worst', 'poor',
+ 'hate', 'hated', 'boring', 'disappointing', 'disappointed',
+ 'ugly', 'annoying', 'waste', 'useless', 'mediocre', 'dull',
+ 'pathetic', 'garbage', 'rubbish', 'disgusting', 'dreadful',
+ }
+ pos_count = sum(1 for w in words if w.lower() in positive_words)
+ neg_count = sum(1 for w in words if w.lower() in negative_words)
+ sentiment_score = (pos_count - neg_count) / num_words
+
+ return [length, avg_sent_len, ttr, newline_rate, exclaim_rate,
+ first_person_rate, adj_adv_rate, sentiment_score]
+
+
+def compute_sfd(generated_text: str, support_texts: list, exclude_length: bool = False) -> float:
+ """Compute Style Feature Distance.
+
+ Args:
+ generated_text: The model's generated output.
+ support_texts: List of the user's support set output texts.
+ exclude_length: If True, exclude length feature (index 0) from SFD.
+
+ Returns:
+ L1 distance between generated style and user style prototype.
+ """
+ gen_features = extract_style_features(generated_text)
+ support_features_list = [extract_style_features(t) for t in support_texts]
+ num_support = len(support_features_list)
+ if num_support == 0:
+ return 0.0
+
+ prototype = [0.0] * len(gen_features)
+ for sf in support_features_list:
+ for i in range(len(prototype)):
+ prototype[i] += sf[i]
+ prototype = [p / num_support for p in prototype]
+
+ start_idx = 1 if exclude_length else 0
+ sfd = 0.0
+ for i in range(start_idx, len(gen_features)):
+ g, p = gen_features[i], prototype[i]
+ scale = max(abs(p), 1.0)
+ sfd += abs(g - p) / scale
+ return sfd
+
+
+def compute_feature_deltas(generated_text: str, support_texts: list) -> dict:
+ """Compute per-feature deltas between generated text and user style prototype.
+
+ Returns dict mapping feature_name -> (gen_value, proto_value, delta).
+ """
+ gen_features = extract_style_features(generated_text)
+ support_features_list = [extract_style_features(t) for t in support_texts]
+ num_support = len(support_features_list)
+ if num_support == 0:
+ return {}
+
+ prototype = [0.0] * len(gen_features)
+ for sf in support_features_list:
+ for i in range(len(prototype)):
+ prototype[i] += sf[i]
+ prototype = [p / num_support for p in prototype]
+
+ deltas = {}
+ for i, name in enumerate(FEATURE_NAMES):
+ deltas[name] = {
+ 'gen': gen_features[i],
+ 'proto': prototype[i],
+ 'delta': gen_features[i] - prototype[i],
+ }
+ return deltas
diff --git a/data/templates.py b/data/templates.py
new file mode 100644
index 0000000..a4bc9a4
--- /dev/null
+++ b/data/templates.py
@@ -0,0 +1,60 @@
+"""Prompt templates for LongLaMP tasks."""
+
+
+SYSTEM_PROMPT = "You are a helpful writing assistant."
+
+
+def build_query_prompt(query_input: str, task: str) -> str:
+ """Build the inference prompt for a query (no personalization text)."""
+ if task == "review":
+ return (
+ f"{query_input}\n\n"
+ f"Write the review text:"
+ )
+ else: # topic
+ return (
+ f"{query_input}\n\n"
+ f"Write the post content:"
+ )
+
+
+def build_support_prompt(support_input: str, task: str) -> str:
+ """Build the prompt for a support item (used in teacher forcing to cache hidden states)."""
+ if task == "review":
+ return (
+ f"{support_input}\n\n"
+ f"Write the review text:"
+ )
+ else: # topic
+ return (
+ f"{support_input}\n\n"
+ f"Write the post content:"
+ )
+
+
+def build_prompt_with_examples(query_input: str, support_items: list, task: str) -> str:
+ """Build prompt with K support examples included as in-context demonstrations.
+
+ Used by Prompt-All-K and BM25-Top1 baselines.
+ """
+ parts = []
+ parts.append("Here are some examples of this user's writing style:\n")
+
+ for i, item in enumerate(support_items, 1):
+ parts.append(f"--- Example {i} ---")
+ parts.append(f"Prompt: {item['support_input']}")
+ parts.append(f"Response: {item['support_output']}")
+ parts.append("")
+
+ parts.append("Now, write in the same style as the examples above.\n")
+ parts.append(build_query_prompt(query_input, task))
+
+ return "\n".join(parts)
+
+
+def build_chat_messages(prompt: str) -> list:
+ """Wrap a prompt into chat message format for Qwen."""
+ return [
+ {"role": "system", "content": SYSTEM_PROMPT},
+ {"role": "user", "content": prompt},
+ ]