diff options
Diffstat (limited to 'scripts/download_datasets.py')
| -rw-r--r-- | scripts/download_datasets.py | 210 |
1 files changed, 210 insertions, 0 deletions
diff --git a/scripts/download_datasets.py b/scripts/download_datasets.py new file mode 100644 index 0000000..f78b15f --- /dev/null +++ b/scripts/download_datasets.py @@ -0,0 +1,210 @@ +import os +import json +import random +from typing import List, Dict, Any +from datasets import load_dataset +from tqdm import tqdm + +# Configuration +OUTPUT_DIR = "data/raw_datasets" + +# Dataset configurations +# Format: (huggingface_id, subset, split, text_column_name, approximate_limit) +SOURCES = [ + { + "id": "lmsys/lmsys-chat-1m", + "subset": None, + "split": "train", + "type": "lmsys", + "limit": 200000 + }, + { + "id": "allenai/WildChat", + "subset": None, + "split": "train", + "type": "wildchat", + "limit": 150000 + }, + { + "id": "anon8231489123/ShareGPT_Vicuna_unfiltered", + "subset": None, + "split": "train", + "data_files": "ShareGPT_V3_unfiltered_cleaned_split.json", + "type": "sharegpt", + "limit": 50000 + }, + { + "id": "yahma/alpaca-cleaned", + "subset": None, + "split": "train", + "type": "alpaca", + "limit": 52000 + }, + { + "id": "Open-Orca/SlimOrca", + "subset": None, + "split": "train", + "type": "slimorca", + "limit": 100000 + } +] + +def ensure_english(text: str) -> bool: + # A simple heuristic to filter non-English text. + # For production, use langdetect or similar libraries. + # Here we check if a significant portion of characters are ASCII. + try: + return text.isascii() + except: + return False + +def process_lmsys(example: Dict[str, Any]) -> str | None: + # LMSYS format: conversation is in 'conversation' list of dicts + try: + conversation = example.get("conversation", []) + if not conversation: + return None + # Get first user message + if conversation[0]["role"] == "user": + return conversation[0]["content"] + except: + pass + return None + +def process_wildchat(example: Dict[str, Any]) -> str | None: + # WildChat format: 'conversation' list of dicts or 'prompt' column? + # Checking dataset viewer, it usually has 'conversation' with 'content' and 'role' + try: + conversation = example.get("conversation", []) + if not conversation: + return None + if conversation[0]["role"] == "user": + return conversation[0]["content"] + except: + pass + return None + +def process_sharegpt(example: Dict[str, Any]) -> str | None: + # ShareGPT format: 'conversations' list + try: + conversations = example.get("conversations", []) + if not conversations: + return None + # Usually human/gpt or user/assistant + if conversations[0]["from"] in ["human", "user"]: + return conversations[0]["value"] + except: + pass + return None + +def process_alpaca(example: Dict[str, Any]) -> str | None: + # Alpaca format: 'instruction' and 'input'. We combine them if input exists. + try: + instruction = example.get("instruction", "").strip() + inp = example.get("input", "").strip() + if inp: + return f"{instruction}\n\nInput: {inp}" + return instruction + except: + pass + return None + +def process_slimorca(example: Dict[str, Any]) -> str | None: + # SlimOrca format: 'conversations' list of dicts (from, value) + # Similar to ShareGPT but keys might differ slightly + try: + conversations = example.get("conversations", []) + if not conversations: + return None + # Usually from: human/user + if conversations[0]["from"] in ["human", "user"]: + return conversations[0]["value"] + except: + pass + return None + +def download_and_process(): + os.makedirs(OUTPUT_DIR, exist_ok=True) + + all_queries = [] + + # Target new sources only (alpaca and slimorca) + # You can comment out this filter if you want to re-run everything + new_types = ["alpaca", "slimorca"] + + for source in SOURCES: + if source["type"] not in new_types: + continue + + print(f"Processing {source['id']}...") + try: + # Load streaming to save disk/memory + kwargs = {"streaming": True} + if "data_files" in source: + kwargs["data_files"] = source["data_files"] + + ds = load_dataset(source["id"], source["subset"], split=source["split"], **kwargs) + + count = 0 + limit = source["limit"] + + for example in tqdm(ds, desc=f"Reading {source['id']}", total=limit): + if count >= limit: + break + + query = None + if source["type"] == "lmsys": + query = process_lmsys(example) + elif source["type"] == "wildchat": + query = process_wildchat(example) + elif source["type"] == "sharegpt": + query = process_sharegpt(example) + elif source["type"] == "alpaca": + query = process_alpaca(example) + elif source["type"] == "slimorca": + query = process_slimorca(example) + + # Basic cleaning + if query and len(query.strip()) > 5 and ensure_english(query): + all_queries.append({ + "source": source["id"], + "query": query.strip() + }) + count += 1 + + except Exception as e: + print(f"Error processing {source['id']}: {e}") + + # Deduplicate based on query content + print(f"Total collected new items: {len(all_queries)}") + + # Load existing if available to dedup against + output_path = os.path.join(OUTPUT_DIR, "combined_raw_queries.jsonl") + existing_data = [] + if os.path.exists(output_path): + print("Loading existing data for deduplication...") + with open(output_path, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + existing_data.append(json.loads(line)) + + combined = existing_data + all_queries + print(f"Total before final deduplication: {len(combined)}") + + unique_queries = {item["query"]: item for item in combined}.values() + final_data = list(unique_queries) + print(f"Total after final deduplication: {len(final_data)}") + + # Shuffle + random.shuffle(final_data) + + # Save + print(f"Saving to {output_path}...") + with open(output_path, "w", encoding="utf-8") as f: + for item in final_data: + f.write(json.dumps(item, ensure_ascii=False) + "\n") + + print("Done!") + +if __name__ == "__main__": + download_and_process() |
