summaryrefslogtreecommitdiff
path: root/scripts/download_datasets.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
commite43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (patch)
tree6ce8a00d2f8b9ebd83c894a27ea01ac50cfb2ff5 /scripts/download_datasets.py
Initial commit (clean history)HEADmain
Diffstat (limited to 'scripts/download_datasets.py')
-rw-r--r--scripts/download_datasets.py210
1 files changed, 210 insertions, 0 deletions
diff --git a/scripts/download_datasets.py b/scripts/download_datasets.py
new file mode 100644
index 0000000..f78b15f
--- /dev/null
+++ b/scripts/download_datasets.py
@@ -0,0 +1,210 @@
+import os
+import json
+import random
+from typing import List, Dict, Any
+from datasets import load_dataset
+from tqdm import tqdm
+
+# Configuration
+OUTPUT_DIR = "data/raw_datasets"
+
+# Dataset configurations
+# Format: (huggingface_id, subset, split, text_column_name, approximate_limit)
+SOURCES = [
+ {
+ "id": "lmsys/lmsys-chat-1m",
+ "subset": None,
+ "split": "train",
+ "type": "lmsys",
+ "limit": 200000
+ },
+ {
+ "id": "allenai/WildChat",
+ "subset": None,
+ "split": "train",
+ "type": "wildchat",
+ "limit": 150000
+ },
+ {
+ "id": "anon8231489123/ShareGPT_Vicuna_unfiltered",
+ "subset": None,
+ "split": "train",
+ "data_files": "ShareGPT_V3_unfiltered_cleaned_split.json",
+ "type": "sharegpt",
+ "limit": 50000
+ },
+ {
+ "id": "yahma/alpaca-cleaned",
+ "subset": None,
+ "split": "train",
+ "type": "alpaca",
+ "limit": 52000
+ },
+ {
+ "id": "Open-Orca/SlimOrca",
+ "subset": None,
+ "split": "train",
+ "type": "slimorca",
+ "limit": 100000
+ }
+]
+
+def ensure_english(text: str) -> bool:
+ # A simple heuristic to filter non-English text.
+ # For production, use langdetect or similar libraries.
+ # Here we check if a significant portion of characters are ASCII.
+ try:
+ return text.isascii()
+ except:
+ return False
+
+def process_lmsys(example: Dict[str, Any]) -> str | None:
+ # LMSYS format: conversation is in 'conversation' list of dicts
+ try:
+ conversation = example.get("conversation", [])
+ if not conversation:
+ return None
+ # Get first user message
+ if conversation[0]["role"] == "user":
+ return conversation[0]["content"]
+ except:
+ pass
+ return None
+
+def process_wildchat(example: Dict[str, Any]) -> str | None:
+ # WildChat format: 'conversation' list of dicts or 'prompt' column?
+ # Checking dataset viewer, it usually has 'conversation' with 'content' and 'role'
+ try:
+ conversation = example.get("conversation", [])
+ if not conversation:
+ return None
+ if conversation[0]["role"] == "user":
+ return conversation[0]["content"]
+ except:
+ pass
+ return None
+
+def process_sharegpt(example: Dict[str, Any]) -> str | None:
+ # ShareGPT format: 'conversations' list
+ try:
+ conversations = example.get("conversations", [])
+ if not conversations:
+ return None
+ # Usually human/gpt or user/assistant
+ if conversations[0]["from"] in ["human", "user"]:
+ return conversations[0]["value"]
+ except:
+ pass
+ return None
+
+def process_alpaca(example: Dict[str, Any]) -> str | None:
+ # Alpaca format: 'instruction' and 'input'. We combine them if input exists.
+ try:
+ instruction = example.get("instruction", "").strip()
+ inp = example.get("input", "").strip()
+ if inp:
+ return f"{instruction}\n\nInput: {inp}"
+ return instruction
+ except:
+ pass
+ return None
+
+def process_slimorca(example: Dict[str, Any]) -> str | None:
+ # SlimOrca format: 'conversations' list of dicts (from, value)
+ # Similar to ShareGPT but keys might differ slightly
+ try:
+ conversations = example.get("conversations", [])
+ if not conversations:
+ return None
+ # Usually from: human/user
+ if conversations[0]["from"] in ["human", "user"]:
+ return conversations[0]["value"]
+ except:
+ pass
+ return None
+
+def download_and_process():
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
+
+ all_queries = []
+
+ # Target new sources only (alpaca and slimorca)
+ # You can comment out this filter if you want to re-run everything
+ new_types = ["alpaca", "slimorca"]
+
+ for source in SOURCES:
+ if source["type"] not in new_types:
+ continue
+
+ print(f"Processing {source['id']}...")
+ try:
+ # Load streaming to save disk/memory
+ kwargs = {"streaming": True}
+ if "data_files" in source:
+ kwargs["data_files"] = source["data_files"]
+
+ ds = load_dataset(source["id"], source["subset"], split=source["split"], **kwargs)
+
+ count = 0
+ limit = source["limit"]
+
+ for example in tqdm(ds, desc=f"Reading {source['id']}", total=limit):
+ if count >= limit:
+ break
+
+ query = None
+ if source["type"] == "lmsys":
+ query = process_lmsys(example)
+ elif source["type"] == "wildchat":
+ query = process_wildchat(example)
+ elif source["type"] == "sharegpt":
+ query = process_sharegpt(example)
+ elif source["type"] == "alpaca":
+ query = process_alpaca(example)
+ elif source["type"] == "slimorca":
+ query = process_slimorca(example)
+
+ # Basic cleaning
+ if query and len(query.strip()) > 5 and ensure_english(query):
+ all_queries.append({
+ "source": source["id"],
+ "query": query.strip()
+ })
+ count += 1
+
+ except Exception as e:
+ print(f"Error processing {source['id']}: {e}")
+
+ # Deduplicate based on query content
+ print(f"Total collected new items: {len(all_queries)}")
+
+ # Load existing if available to dedup against
+ output_path = os.path.join(OUTPUT_DIR, "combined_raw_queries.jsonl")
+ existing_data = []
+ if os.path.exists(output_path):
+ print("Loading existing data for deduplication...")
+ with open(output_path, "r", encoding="utf-8") as f:
+ for line in f:
+ if line.strip():
+ existing_data.append(json.loads(line))
+
+ combined = existing_data + all_queries
+ print(f"Total before final deduplication: {len(combined)}")
+
+ unique_queries = {item["query"]: item for item in combined}.values()
+ final_data = list(unique_queries)
+ print(f"Total after final deduplication: {len(final_data)}")
+
+ # Shuffle
+ random.shuffle(final_data)
+
+ # Save
+ print(f"Saving to {output_path}...")
+ with open(output_path, "w", encoding="utf-8") as f:
+ for item in final_data:
+ f.write(json.dumps(item, ensure_ascii=False) + "\n")
+
+ print("Done!")
+
+if __name__ == "__main__":
+ download_and_process()