summaryrefslogtreecommitdiff
path: root/src/data
diff options
context:
space:
mode:
Diffstat (limited to 'src/data')
-rw-r--r--src/data/__init__.py0
-rw-r--r--src/data/dolma.py226
2 files changed, 226 insertions, 0 deletions
diff --git a/src/data/__init__.py b/src/data/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/data/__init__.py
diff --git a/src/data/dolma.py b/src/data/dolma.py
new file mode 100644
index 0000000..4e2baaf
--- /dev/null
+++ b/src/data/dolma.py
@@ -0,0 +1,226 @@
+"""Streaming dataloader for Dolma v1.7 with sequence packing.
+
+Produces packed sequences of fixed length for both OLMo and Qwen tokenizers.
+See CLAUDE.md §3.1.1 for sequence packing specification.
+"""
+
+from __future__ import annotations
+
+import os
+from typing import Iterator, Optional
+
+import torch
+from datasets import load_dataset
+from torch.utils.data import IterableDataset
+from transformers import AutoTokenizer
+
+
+class DolmaPackedDataset(IterableDataset):
+ """Streaming Dolma dataset with sequence packing.
+
+ Concatenates documents with EOS separators, then chunks into fixed-length
+ sequences. No padding — every token contributes to NLL.
+
+ Each sample yields:
+ olmo_ids: [seq_len] — OLMo input token IDs
+ olmo_labels: [seq_len] — shifted labels (next-token prediction)
+ raw_text: str — decoded text for Qwen encoder
+ """
+
+ def __init__(
+ self,
+ olmo_tokenizer: AutoTokenizer,
+ seq_len: int = 1024,
+ dataset_name: str = "allenai/dolma",
+ dataset_version: str = "v1_7",
+ rank: int = 0,
+ world_size: int = 1,
+ max_samples: Optional[int] = None,
+ ):
+ super().__init__()
+ self.olmo_tokenizer = olmo_tokenizer
+ self.seq_len = seq_len
+ self.dataset_name = dataset_name
+ self.dataset_version = dataset_version
+ self.rank = rank
+ self.world_size = world_size
+ self.max_samples = max_samples
+
+ self.eos_id = olmo_tokenizer.eos_token_id
+ assert self.eos_id is not None, "OLMo tokenizer must have an EOS token"
+
+ def __iter__(self) -> Iterator[dict]:
+ """Yield packed sequences from Dolma stream."""
+ try:
+ dataset = load_dataset(
+ self.dataset_name,
+ name=self.dataset_version,
+ split="train",
+ streaming=True,
+ trust_remote_code=True,
+ )
+ except Exception:
+ # Fallback if specific version not available
+ dataset = load_dataset(
+ self.dataset_name,
+ split="train",
+ streaming=True,
+ trust_remote_code=True,
+ )
+
+ # Shard for multi-GPU
+ if self.world_size > 1:
+ dataset = dataset.shard(num_shards=self.world_size, index=self.rank)
+
+ buffer: list[int] = []
+ sample_count = 0
+
+ for doc in dataset:
+ if self.max_samples is not None and sample_count >= self.max_samples:
+ break
+
+ text = doc.get("text", "")
+ if not text.strip():
+ continue
+
+ tokens = self.olmo_tokenizer(text, add_special_tokens=False)["input_ids"]
+ buffer.extend(tokens)
+ buffer.append(self.eos_id)
+
+ # Yield packed sequences as buffer fills
+ while len(buffer) >= self.seq_len + 1:
+ chunk = buffer[:self.seq_len + 1]
+ buffer = buffer[self.seq_len + 1:]
+
+ olmo_ids = torch.tensor(chunk[:self.seq_len], dtype=torch.long)
+ olmo_labels = torch.tensor(chunk[1:self.seq_len + 1], dtype=torch.long)
+ raw_text = self.olmo_tokenizer.decode(chunk[:self.seq_len], skip_special_tokens=False)
+
+ yield {
+ "olmo_ids": olmo_ids,
+ "olmo_labels": olmo_labels,
+ "raw_text": raw_text,
+ }
+ sample_count += 1
+
+ if self.max_samples is not None and sample_count >= self.max_samples:
+ break
+
+
+def build_train_dataloader(
+ olmo_tokenizer: AutoTokenizer,
+ seq_len: int = 1024,
+ batch_size: int = 4,
+ dataset_name: str = "allenai/dolma",
+ dataset_version: str = "v1_7",
+ rank: int = 0,
+ world_size: int = 1,
+ num_workers: int = 0,
+) -> torch.utils.data.DataLoader:
+ """Build training dataloader with sequence packing."""
+ dataset = DolmaPackedDataset(
+ olmo_tokenizer=olmo_tokenizer,
+ seq_len=seq_len,
+ dataset_name=dataset_name,
+ dataset_version=dataset_version,
+ rank=rank,
+ world_size=world_size,
+ )
+ return torch.utils.data.DataLoader(
+ dataset,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ collate_fn=_collate_packed,
+ )
+
+
+def build_eval_dataloader(
+ olmo_tokenizer: AutoTokenizer,
+ seq_len: int = 1024,
+ batch_size: int = 4,
+ dataset_name: str = "allenai/dolma",
+ dataset_version: str = "v1_7",
+ eval_skip: int = 1_000_000,
+ eval_size: int = 1_000,
+ cache_path: Optional[str] = None,
+) -> list[dict]:
+ """Build eval batches (cached in memory).
+
+ Skips eval_skip examples in the stream, then takes eval_size packed sequences.
+ Caches to disk to avoid repeated skip on restart.
+ """
+ # Try loading from cache
+ if cache_path and os.path.exists(cache_path):
+ print(f"Loading eval cache from {cache_path}")
+ return torch.load(cache_path)
+
+ print(f"Building eval set (skip={eval_skip}, size={eval_size})...")
+
+ try:
+ dataset = load_dataset(
+ dataset_name,
+ name=dataset_version,
+ split="train",
+ streaming=True,
+ trust_remote_code=True,
+ )
+ except Exception:
+ dataset = load_dataset(
+ dataset_name,
+ split="train",
+ streaming=True,
+ trust_remote_code=True,
+ )
+
+ # Skip to held-out region
+ dataset = dataset.skip(eval_skip)
+
+ eos_id = olmo_tokenizer.eos_token_id
+ buffer: list[int] = []
+ eval_samples: list[dict] = []
+
+ for doc in dataset:
+ if len(eval_samples) >= eval_size:
+ break
+
+ text = doc.get("text", "")
+ if not text.strip():
+ continue
+
+ tokens = olmo_tokenizer(text, add_special_tokens=False)["input_ids"]
+ buffer.extend(tokens)
+ buffer.append(eos_id)
+
+ while len(buffer) >= seq_len + 1 and len(eval_samples) < eval_size:
+ chunk = buffer[:seq_len + 1]
+ buffer = buffer[seq_len + 1:]
+ eval_samples.append({
+ "olmo_ids": torch.tensor(chunk[:seq_len], dtype=torch.long),
+ "olmo_labels": torch.tensor(chunk[1:seq_len + 1], dtype=torch.long),
+ "raw_text": olmo_tokenizer.decode(chunk[:seq_len], skip_special_tokens=False),
+ })
+
+ print(f"Built {len(eval_samples)} eval sequences")
+
+ # Batch the samples
+ eval_batches = []
+ for i in range(0, len(eval_samples), batch_size):
+ batch_items = eval_samples[i:i + batch_size]
+ eval_batches.append(_collate_packed(batch_items))
+
+ # Cache to disk
+ if cache_path:
+ os.makedirs(os.path.dirname(cache_path) or ".", exist_ok=True)
+ torch.save(eval_batches, cache_path)
+ print(f"Eval cache saved to {cache_path}")
+
+ return eval_batches
+
+
+def _collate_packed(batch: list[dict]) -> dict:
+ """Collate packed samples into a batch dict."""
+ return {
+ "olmo_ids": torch.stack([s["olmo_ids"] for s in batch]),
+ "olmo_labels": torch.stack([s["olmo_labels"] for s in batch]),
+ "raw_text": [s["raw_text"] for s in batch],
+ }