diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/__init__.py | 0 | ||||
| -rw-r--r-- | data/longlamp.py | 101 | ||||
| -rw-r--r-- | data/style_features.py | 138 | ||||
| -rw-r--r-- | data/templates.py | 60 |
4 files changed, 299 insertions, 0 deletions
diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/data/__init__.py diff --git a/data/longlamp.py b/data/longlamp.py new file mode 100644 index 0000000..05b3939 --- /dev/null +++ b/data/longlamp.py @@ -0,0 +1,101 @@ +"""LongLaMP dataset loader for Review Writing and Topic Writing tasks.""" + +import random +from datasets import load_dataset + + +def load_longlamp(config_name: str, split: str = "val"): + """Load a LongLaMP dataset configuration. + + Args: + config_name: One of product_review_user, product_review_temporal, + topic_writing_user, topic_writing_temporal + split: train, val, or test + + Returns: + List of unified dicts. + """ + ds = load_dataset("LongLaMP/LongLaMP", config_name, split=split) + + task = "review" if "review" in config_name else "topic" + setting = "user" if "user" in config_name else "temporal" + + examples = [] + for idx, row in enumerate(ds): + profile_items = row["profile"] + + if task == "review": + processed_profile = [] + for p in profile_items: + processed_profile.append({ + "support_input": _build_review_support_input(p), + "support_output": p["reviewText"], + "raw": p, + }) + else: # topic + processed_profile = [] + for p in profile_items: + processed_profile.append({ + "support_input": _build_topic_support_input(p), + "support_output": p["content"], + "raw": p, + }) + + user_id = row.get("reviewerId", row.get("author", f"user_{idx}")) + + examples.append({ + "task": task, + "setting": setting, + "query_input": row["input"], + "target_output": row["output"], + "profile_items": processed_profile, + "user_id": user_id, + "example_id": f"{config_name}_{split}_{idx}", + }) + + return examples + + +def _build_review_support_input(profile_item: dict) -> str: + """Build the input text for a review support example.""" + overall = profile_item.get("overall", "5.0") + description = profile_item.get("description", "") + summary = profile_item.get("summary", "") + return ( + f'Generate the review text written by a reviewer who has a given an overall ' + f'rating of "{overall}" for a product with description "{description}". ' + f'The summary of the review text is "{summary}".' + ) + + +def _build_topic_support_input(profile_item: dict) -> str: + """Build the input text for a topic support example.""" + summary = profile_item.get("summary", "") + return f"Generate the content for a reddit post {summary}" + + +def select_k_profile_items(profile_items: list, K: int, seed: int = 0) -> list: + """Select K profile items from the available profile. + + If fewer than K items available, return all of them. + Uses random selection with a fixed seed for reproducibility. + """ + if len(profile_items) <= K: + return profile_items + rng = random.Random(seed) + return rng.sample(profile_items, K) + + +if __name__ == "__main__": + # Quick test + examples = load_longlamp("product_review_user", split="validation") + print(f"Loaded {len(examples)} review user validation examples") + ex = examples[0] + print(f"User: {ex['user_id']}") + print(f"Query: {ex['query_input'][:200]}...") + print(f"Target: {ex['target_output'][:200]}...") + print(f"Profile items: {len(ex['profile_items'])}") + if ex['profile_items']: + p = ex['profile_items'][0] + print(f" Support input: {p['support_input'][:200]}...") + print(f" Support output: {p['support_output'][:200]}...") diff --git a/data/style_features.py b/data/style_features.py new file mode 100644 index 0000000..d674419 --- /dev/null +++ b/data/style_features.py @@ -0,0 +1,138 @@ +"""Style feature extraction for SFD (Style Feature Distance) metric.""" + +import nltk +from nltk.tokenize import sent_tokenize, word_tokenize + +FEATURE_NAMES = [ + 'length', 'avg_sent_len', 'TTR', 'newline_rate', 'exclaim_rate', + 'first_person_rate', 'adj_adv_rate', 'sentiment_score' +] + + +def extract_style_features(text: str) -> list: + """Extract style feature vector from text. + + Returns: + [length, avg_sent_len, TTR, newline_rate, exclaim_rate, + first_person_rate, adj_adv_rate, sentiment_score] + """ + if not text or not text.strip(): + return [0.0] * 8 + + words = word_tokenize(text) + num_words = max(len(words), 1) + + # 1. Length (word count) + length = float(num_words) + + # 2. Average sentence length + sentences = sent_tokenize(text) + num_sents = max(len(sentences), 1) + avg_sent_len = num_words / num_sents + + # 3. Type-Token Ratio + unique_words = set(w.lower() for w in words if w.isalpha()) + alpha_words = [w for w in words if w.isalpha()] + ttr = len(unique_words) / max(len(alpha_words), 1) + + # 4. Newline rate + newline_count = text.count('\n') + newline_rate = newline_count / num_sents + + # 5. Exclamation rate + exclaim_count = text.count('!') + exclaim_rate = exclaim_count / num_sents + + # 6. First-person rate + first_person = {'i', 'me', 'my', 'mine', 'myself', 'we', 'us', 'our', 'ours', 'ourselves'} + fp_count = sum(1 for w in words if w.lower() in first_person) + first_person_rate = fp_count / num_words + + # 7. Adjective/Adverb rate + try: + tagged = nltk.pos_tag(words) + adj_adv_tags = {'JJ', 'JJR', 'JJS', 'RB', 'RBR', 'RBS'} + adj_adv_count = sum(1 for _, tag in tagged if tag in adj_adv_tags) + adj_adv_rate = adj_adv_count / num_words + except Exception: + adj_adv_rate = 0.0 + + # 8. Simple sentiment score (positive - negative word ratio) + positive_words = { + 'good', 'great', 'excellent', 'amazing', 'wonderful', 'fantastic', + 'love', 'loved', 'best', 'perfect', 'awesome', 'beautiful', + 'enjoy', 'enjoyed', 'happy', 'glad', 'nice', 'brilliant', + 'outstanding', 'superb', 'delightful', 'pleasant', 'favorite', + 'recommend', 'recommended', 'impressive', 'incredible', + } + negative_words = { + 'bad', 'terrible', 'awful', 'horrible', 'worst', 'poor', + 'hate', 'hated', 'boring', 'disappointing', 'disappointed', + 'ugly', 'annoying', 'waste', 'useless', 'mediocre', 'dull', + 'pathetic', 'garbage', 'rubbish', 'disgusting', 'dreadful', + } + pos_count = sum(1 for w in words if w.lower() in positive_words) + neg_count = sum(1 for w in words if w.lower() in negative_words) + sentiment_score = (pos_count - neg_count) / num_words + + return [length, avg_sent_len, ttr, newline_rate, exclaim_rate, + first_person_rate, adj_adv_rate, sentiment_score] + + +def compute_sfd(generated_text: str, support_texts: list, exclude_length: bool = False) -> float: + """Compute Style Feature Distance. + + Args: + generated_text: The model's generated output. + support_texts: List of the user's support set output texts. + exclude_length: If True, exclude length feature (index 0) from SFD. + + Returns: + L1 distance between generated style and user style prototype. + """ + gen_features = extract_style_features(generated_text) + support_features_list = [extract_style_features(t) for t in support_texts] + num_support = len(support_features_list) + if num_support == 0: + return 0.0 + + prototype = [0.0] * len(gen_features) + for sf in support_features_list: + for i in range(len(prototype)): + prototype[i] += sf[i] + prototype = [p / num_support for p in prototype] + + start_idx = 1 if exclude_length else 0 + sfd = 0.0 + for i in range(start_idx, len(gen_features)): + g, p = gen_features[i], prototype[i] + scale = max(abs(p), 1.0) + sfd += abs(g - p) / scale + return sfd + + +def compute_feature_deltas(generated_text: str, support_texts: list) -> dict: + """Compute per-feature deltas between generated text and user style prototype. + + Returns dict mapping feature_name -> (gen_value, proto_value, delta). + """ + gen_features = extract_style_features(generated_text) + support_features_list = [extract_style_features(t) for t in support_texts] + num_support = len(support_features_list) + if num_support == 0: + return {} + + prototype = [0.0] * len(gen_features) + for sf in support_features_list: + for i in range(len(prototype)): + prototype[i] += sf[i] + prototype = [p / num_support for p in prototype] + + deltas = {} + for i, name in enumerate(FEATURE_NAMES): + deltas[name] = { + 'gen': gen_features[i], + 'proto': prototype[i], + 'delta': gen_features[i] - prototype[i], + } + return deltas diff --git a/data/templates.py b/data/templates.py new file mode 100644 index 0000000..a4bc9a4 --- /dev/null +++ b/data/templates.py @@ -0,0 +1,60 @@ +"""Prompt templates for LongLaMP tasks.""" + + +SYSTEM_PROMPT = "You are a helpful writing assistant." + + +def build_query_prompt(query_input: str, task: str) -> str: + """Build the inference prompt for a query (no personalization text).""" + if task == "review": + return ( + f"{query_input}\n\n" + f"Write the review text:" + ) + else: # topic + return ( + f"{query_input}\n\n" + f"Write the post content:" + ) + + +def build_support_prompt(support_input: str, task: str) -> str: + """Build the prompt for a support item (used in teacher forcing to cache hidden states).""" + if task == "review": + return ( + f"{support_input}\n\n" + f"Write the review text:" + ) + else: # topic + return ( + f"{support_input}\n\n" + f"Write the post content:" + ) + + +def build_prompt_with_examples(query_input: str, support_items: list, task: str) -> str: + """Build prompt with K support examples included as in-context demonstrations. + + Used by Prompt-All-K and BM25-Top1 baselines. + """ + parts = [] + parts.append("Here are some examples of this user's writing style:\n") + + for i, item in enumerate(support_items, 1): + parts.append(f"--- Example {i} ---") + parts.append(f"Prompt: {item['support_input']}") + parts.append(f"Response: {item['support_output']}") + parts.append("") + + parts.append("Now, write in the same style as the examples above.\n") + parts.append(build_query_prompt(query_input, task)) + + return "\n".join(parts) + + +def build_chat_messages(prompt: str) -> list: + """Wrap a prompt into chat message format for Qwen.""" + return [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": prompt}, + ] |
