diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2025-12-17 04:29:37 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2025-12-17 04:29:37 -0600 |
| commit | e43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (patch) | |
| tree | 6ce8a00d2f8b9ebd83c894a27ea01ac50cfb2ff5 /scripts/download_oasst1.py | |
Diffstat (limited to 'scripts/download_oasst1.py')
| -rw-r--r-- | scripts/download_oasst1.py | 78 |
1 files changed, 78 insertions, 0 deletions
diff --git a/scripts/download_oasst1.py b/scripts/download_oasst1.py new file mode 100644 index 0000000..3105f28 --- /dev/null +++ b/scripts/download_oasst1.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +""" +Script to download and process OpenAssistant/oasst1 dataset. +Converts it into a flat list of user turns (ChatTurn) for our pipeline. + +Output format per line (JSONL): +{ + "original_query": str, + "source": "oasst1", + "user_id": str, + "session_id": str, + "turn_id": int +} +""" + +import json +import os +from datasets import load_dataset +from tqdm import tqdm + +def main(): + output_path = "data/raw_datasets/oasst1_queries.jsonl" + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + print("Downloading OpenAssistant/oasst1 dataset...") + # OASST1 is a tree structure. We need to traverse it to reconstruct conversations. + # It has 'message_id', 'parent_id', 'user_id', 'text', 'role' + ds = load_dataset("OpenAssistant/oasst1", split="train") + + print(f"Loaded {len(ds)} messages. Reconstructing threads...") + + # Index by message_id + id2msg = {} + for row in tqdm(ds, desc="Indexing"): + id2msg[row["message_id"]] = row + + # Find leaf nodes to trace back threads? + # Or just find all user messages and trace back to root to establish session context? + # For this task: "从 OASST1 里构造统一的 ChatTurn 序列(带 user_id 和 session_id)" + # We want valid user turns. + # OASST1 'user_id' is the author ID. + # 'message_tree_id' identifies the conversation tree (session). + + # We can iterate all messages. If role=='prompter' (user), we treat it as a turn. + # We use 'message_tree_id' as session_id. + + queries = [] + + # Iterate all rows + for row in tqdm(ds, desc="Processing"): + if row["role"] == "prompter": + # This is a user turn + user_id = row["user_id"] # Author ID + session_id = row["message_tree_id"] + text = row["text"] + + # Simple metadata + queries.append({ + "original_query": text, + "source": "oasst1", + "user_id": str(user_id), + "session_id": str(session_id), + "turn_id": 0 # We don't strictly need precise turn_id for Day 1 pipeline right now unless we sort + }) + + print(f"Extracted {len(queries)} user queries.") + + # Save + print(f"Saving to {output_path}...") + with open(output_path, "w", encoding="utf-8") as f: + for q in queries: + f.write(json.dumps(q, ensure_ascii=False) + "\n") + + print("Done.") + +if __name__ == "__main__": + main() + |
