summaryrefslogtreecommitdiff
path: root/data/longlamp.py
blob: 05b393979a79ff5eb453f3cfe7dbe3d8f0cc862f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""LongLaMP dataset loader for Review Writing and Topic Writing tasks."""

import random
from datasets import load_dataset


def load_longlamp(config_name: str, split: str = "val"):
    """Load a LongLaMP dataset configuration.

    Args:
        config_name: One of product_review_user, product_review_temporal,
                     topic_writing_user, topic_writing_temporal
        split: train, val, or test

    Returns:
        List of unified dicts.
    """
    ds = load_dataset("LongLaMP/LongLaMP", config_name, split=split)

    task = "review" if "review" in config_name else "topic"
    setting = "user" if "user" in config_name else "temporal"

    examples = []
    for idx, row in enumerate(ds):
        profile_items = row["profile"]

        if task == "review":
            processed_profile = []
            for p in profile_items:
                processed_profile.append({
                    "support_input": _build_review_support_input(p),
                    "support_output": p["reviewText"],
                    "raw": p,
                })
        else:  # topic
            processed_profile = []
            for p in profile_items:
                processed_profile.append({
                    "support_input": _build_topic_support_input(p),
                    "support_output": p["content"],
                    "raw": p,
                })

        user_id = row.get("reviewerId", row.get("author", f"user_{idx}"))

        examples.append({
            "task": task,
            "setting": setting,
            "query_input": row["input"],
            "target_output": row["output"],
            "profile_items": processed_profile,
            "user_id": user_id,
            "example_id": f"{config_name}_{split}_{idx}",
        })

    return examples


def _build_review_support_input(profile_item: dict) -> str:
    """Build the input text for a review support example."""
    overall = profile_item.get("overall", "5.0")
    description = profile_item.get("description", "")
    summary = profile_item.get("summary", "")
    return (
        f'Generate the review text written by a reviewer who has a given an overall '
        f'rating of "{overall}" for a product with description "{description}". '
        f'The summary of the review text is "{summary}".'
    )


def _build_topic_support_input(profile_item: dict) -> str:
    """Build the input text for a topic support example."""
    summary = profile_item.get("summary", "")
    return f"Generate the content for a reddit post {summary}"


def select_k_profile_items(profile_items: list, K: int, seed: int = 0) -> list:
    """Select K profile items from the available profile.

    If fewer than K items available, return all of them.
    Uses random selection with a fixed seed for reproducibility.
    """
    if len(profile_items) <= K:
        return profile_items
    rng = random.Random(seed)
    return rng.sample(profile_items, K)


if __name__ == "__main__":
    # Quick test
    examples = load_longlamp("product_review_user", split="validation")
    print(f"Loaded {len(examples)} review user validation examples")
    ex = examples[0]
    print(f"User: {ex['user_id']}")
    print(f"Query: {ex['query_input'][:200]}...")
    print(f"Target: {ex['target_output'][:200]}...")
    print(f"Profile items: {len(ex['profile_items'])}")
    if ex['profile_items']:
        p = ex['profile_items'][0]
        print(f"  Support input: {p['support_input'][:200]}...")
        print(f"  Support output: {p['support_output'][:200]}...")