1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
|
#!/usr/bin/env python3
"""
Script to download and process OpenAssistant/oasst1 dataset.
Converts it into a flat list of user turns (ChatTurn) for our pipeline.
Output format per line (JSONL):
{
"original_query": str,
"source": "oasst1",
"user_id": str,
"session_id": str,
"turn_id": int
}
"""
import json
import os
from datasets import load_dataset
from tqdm import tqdm
def main():
output_path = "data/raw_datasets/oasst1_queries.jsonl"
os.makedirs(os.path.dirname(output_path), exist_ok=True)
print("Downloading OpenAssistant/oasst1 dataset...")
# OASST1 is a tree structure. We need to traverse it to reconstruct conversations.
# It has 'message_id', 'parent_id', 'user_id', 'text', 'role'
ds = load_dataset("OpenAssistant/oasst1", split="train")
print(f"Loaded {len(ds)} messages. Reconstructing threads...")
# Index by message_id
id2msg = {}
for row in tqdm(ds, desc="Indexing"):
id2msg[row["message_id"]] = row
# Find leaf nodes to trace back threads?
# Or just find all user messages and trace back to root to establish session context?
# For this task: "从 OASST1 里构造统一的 ChatTurn 序列(带 user_id 和 session_id)"
# We want valid user turns.
# OASST1 'user_id' is the author ID.
# 'message_tree_id' identifies the conversation tree (session).
# We can iterate all messages. If role=='prompter' (user), we treat it as a turn.
# We use 'message_tree_id' as session_id.
queries = []
# Iterate all rows
for row in tqdm(ds, desc="Processing"):
if row["role"] == "prompter":
# This is a user turn
user_id = row["user_id"] # Author ID
session_id = row["message_tree_id"]
text = row["text"]
# Simple metadata
queries.append({
"original_query": text,
"source": "oasst1",
"user_id": str(user_id),
"session_id": str(session_id),
"turn_id": 0 # We don't strictly need precise turn_id for Day 1 pipeline right now unless we sort
})
print(f"Extracted {len(queries)} user queries.")
# Save
print(f"Saving to {output_path}...")
with open(output_path, "w", encoding="utf-8") as f:
for q in queries:
f.write(json.dumps(q, ensure_ascii=False) + "\n")
print("Done.")
if __name__ == "__main__":
main()
|