summaryrefslogtreecommitdiff
path: root/scripts/prepare_corpus.py
blob: 93fc0ce973f4be8da6ed3ed8a216d32c1f454678 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""Convert linear-rag chunks.json to JSONL corpus for build_memory_bank.py.

The linear-rag dataset stores chunks as a list of strings with format "idx:text...".
This script strips the index prefix and outputs one {"text": "..."} per line.

Usage:
    python scripts/prepare_corpus.py --dataset hotpotqa
    python scripts/prepare_corpus.py --dataset hotpotqa --dataset musique --dataset 2wikimultihop
"""

import argparse
import json
import logging
from pathlib import Path

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

DATASETS = ["hotpotqa", "musique", "2wikimultihop", "medical"]


def convert_chunks(dataset: str, data_root: Path, output_dir: Path) -> Path:
    """Convert a single dataset's chunks.json to corpus JSONL.

    Args:
        dataset: dataset name (e.g., "hotpotqa")
        data_root: path to linear-rag clone
        output_dir: directory to write output JSONL

    Returns:
        Path to the output JSONL file.
    """
    chunks_path = data_root / dataset / "chunks.json"
    if not chunks_path.exists():
        raise FileNotFoundError(f"Not found: {chunks_path}")

    with open(chunks_path) as f:
        chunks = json.load(f)

    output_dir.mkdir(parents=True, exist_ok=True)
    output_path = output_dir / f"{dataset}_corpus.jsonl"

    count = 0
    with open(output_path, "w") as out:
        for chunk in chunks:
            # Strip the "idx:" prefix
            text = chunk.split(":", 1)[1] if ":" in chunk else chunk
            text = text.strip()
            if text:
                out.write(json.dumps({"text": text}) + "\n")
                count += 1

    logger.info("%s: %d chunks -> %s", dataset, count, output_path)
    return output_path


def convert_questions(dataset: str, data_root: Path, output_dir: Path) -> Path:
    """Convert questions.json to a standardized JSONL format.

    Args:
        dataset: dataset name
        data_root: path to linear-rag clone
        output_dir: directory to write output JSONL

    Returns:
        Path to the output JSONL file.
    """
    questions_path = data_root / dataset / "questions.json"
    if not questions_path.exists():
        raise FileNotFoundError(f"Not found: {questions_path}")

    with open(questions_path) as f:
        questions = json.load(f)

    output_dir.mkdir(parents=True, exist_ok=True)
    output_path = output_dir / f"{dataset}_questions.jsonl"

    count = 0
    with open(output_path, "w") as out:
        for q in questions:
            record = {
                "id": q.get("id", ""),
                "question": q["question"],
                "answer": q["answer"],
                "question_type": q.get("question_type", ""),
            }
            out.write(json.dumps(record) + "\n")
            count += 1

    logger.info("%s: %d questions -> %s", dataset, count, output_path)
    return output_path


def main() -> None:
    parser = argparse.ArgumentParser(description="Prepare linear-rag data for HAG")
    parser.add_argument(
        "--dataset",
        type=str,
        action="append",
        default=None,
        help=f"Dataset(s) to process. Choices: {DATASETS}. Can specify multiple times.",
    )
    parser.add_argument(
        "--data-root",
        type=str,
        default="data/linear-rag",
        help="Path to linear-rag clone",
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default="data/processed",
        help="Output directory for processed files",
    )
    args = parser.parse_args()

    datasets = args.dataset if args.dataset else DATASETS
    data_root = Path(args.data_root)
    output_dir = Path(args.output_dir)

    for ds in datasets:
        convert_chunks(ds, data_root, output_dir)
        convert_questions(ds, data_root, output_dir)


if __name__ == "__main__":
    main()