summaryrefslogtreecommitdiff
path: root/scripts/download_oasst1.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/download_oasst1.py')
-rw-r--r--scripts/download_oasst1.py78
1 files changed, 78 insertions, 0 deletions
diff --git a/scripts/download_oasst1.py b/scripts/download_oasst1.py
new file mode 100644
index 0000000..3105f28
--- /dev/null
+++ b/scripts/download_oasst1.py
@@ -0,0 +1,78 @@
+#!/usr/bin/env python3
+"""
+Script to download and process OpenAssistant/oasst1 dataset.
+Converts it into a flat list of user turns (ChatTurn) for our pipeline.
+
+Output format per line (JSONL):
+{
+ "original_query": str,
+ "source": "oasst1",
+ "user_id": str,
+ "session_id": str,
+ "turn_id": int
+}
+"""
+
+import json
+import os
+from datasets import load_dataset
+from tqdm import tqdm
+
+def main():
+ output_path = "data/raw_datasets/oasst1_queries.jsonl"
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+
+ print("Downloading OpenAssistant/oasst1 dataset...")
+ # OASST1 is a tree structure. We need to traverse it to reconstruct conversations.
+ # It has 'message_id', 'parent_id', 'user_id', 'text', 'role'
+ ds = load_dataset("OpenAssistant/oasst1", split="train")
+
+ print(f"Loaded {len(ds)} messages. Reconstructing threads...")
+
+ # Index by message_id
+ id2msg = {}
+ for row in tqdm(ds, desc="Indexing"):
+ id2msg[row["message_id"]] = row
+
+ # Find leaf nodes to trace back threads?
+ # Or just find all user messages and trace back to root to establish session context?
+ # For this task: "从 OASST1 里构造统一的 ChatTurn 序列(带 user_id 和 session_id)"
+ # We want valid user turns.
+ # OASST1 'user_id' is the author ID.
+ # 'message_tree_id' identifies the conversation tree (session).
+
+ # We can iterate all messages. If role=='prompter' (user), we treat it as a turn.
+ # We use 'message_tree_id' as session_id.
+
+ queries = []
+
+ # Iterate all rows
+ for row in tqdm(ds, desc="Processing"):
+ if row["role"] == "prompter":
+ # This is a user turn
+ user_id = row["user_id"] # Author ID
+ session_id = row["message_tree_id"]
+ text = row["text"]
+
+ # Simple metadata
+ queries.append({
+ "original_query": text,
+ "source": "oasst1",
+ "user_id": str(user_id),
+ "session_id": str(session_id),
+ "turn_id": 0 # We don't strictly need precise turn_id for Day 1 pipeline right now unless we sort
+ })
+
+ print(f"Extracted {len(queries)} user queries.")
+
+ # Save
+ print(f"Saving to {output_path}...")
+ with open(output_path, "w", encoding="utf-8") as f:
+ for q in queries:
+ f.write(json.dumps(q, ensure_ascii=False) + "\n")
+
+ print("Done.")
+
+if __name__ == "__main__":
+ main()
+