summaryrefslogtreecommitdiff
path: root/scripts/download_oasst1.py
blob: 3105f282992e8302282406bb2de61f6e26d5914e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#!/usr/bin/env python3
"""
Script to download and process OpenAssistant/oasst1 dataset.
Converts it into a flat list of user turns (ChatTurn) for our pipeline.

Output format per line (JSONL):
{
    "original_query": str,
    "source": "oasst1",
    "user_id": str,
    "session_id": str,
    "turn_id": int
}
"""

import json
import os
from datasets import load_dataset
from tqdm import tqdm

def main():
    output_path = "data/raw_datasets/oasst1_queries.jsonl"
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    print("Downloading OpenAssistant/oasst1 dataset...")
    # OASST1 is a tree structure. We need to traverse it to reconstruct conversations.
    # It has 'message_id', 'parent_id', 'user_id', 'text', 'role'
    ds = load_dataset("OpenAssistant/oasst1", split="train")
    
    print(f"Loaded {len(ds)} messages. Reconstructing threads...")
    
    # Index by message_id
    id2msg = {}
    for row in tqdm(ds, desc="Indexing"):
        id2msg[row["message_id"]] = row
        
    # Find leaf nodes to trace back threads?
    # Or just find all user messages and trace back to root to establish session context?
    # For this task: "从 OASST1 里构造统一的 ChatTurn 序列(带 user_id 和 session_id)"
    # We want valid user turns. 
    # OASST1 'user_id' is the author ID. 
    # 'message_tree_id' identifies the conversation tree (session).
    
    # We can iterate all messages. If role=='prompter' (user), we treat it as a turn.
    # We use 'message_tree_id' as session_id.
    
    queries = []
    
    # Iterate all rows
    for row in tqdm(ds, desc="Processing"):
        if row["role"] == "prompter":
            # This is a user turn
            user_id = row["user_id"] # Author ID
            session_id = row["message_tree_id"]
            text = row["text"]
            
            # Simple metadata
            queries.append({
                "original_query": text,
                "source": "oasst1",
                "user_id": str(user_id),
                "session_id": str(session_id),
                "turn_id": 0 # We don't strictly need precise turn_id for Day 1 pipeline right now unless we sort
            })
            
    print(f"Extracted {len(queries)} user queries.")
    
    # Save
    print(f"Saving to {output_path}...")
    with open(output_path, "w", encoding="utf-8") as f:
        for q in queries:
            f.write(json.dumps(q, ensure_ascii=False) + "\n")
            
    print("Done.")

if __name__ == "__main__":
    main()