summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/data/dolma.py93
1 files changed, 58 insertions, 35 deletions
diff --git a/src/data/dolma.py b/src/data/dolma.py
index 4e2baaf..ed8c13b 100644
--- a/src/data/dolma.py
+++ b/src/data/dolma.py
@@ -7,6 +7,7 @@ See CLAUDE.md ยง3.1.1 for sequence packing specification.
from __future__ import annotations
import os
+import time
from typing import Iterator, Optional
import torch
@@ -14,6 +15,9 @@ from datasets import load_dataset
from torch.utils.data import IterableDataset
from transformers import AutoTokenizer
+MAX_RETRIES = 10
+RETRY_WAIT = 30 # seconds
+
class DolmaPackedDataset(IterableDataset):
"""Streaming Dolma dataset with sequence packing.
@@ -49,8 +53,8 @@ class DolmaPackedDataset(IterableDataset):
self.eos_id = olmo_tokenizer.eos_token_id
assert self.eos_id is not None, "OLMo tokenizer must have an EOS token"
- def __iter__(self) -> Iterator[dict]:
- """Yield packed sequences from Dolma stream."""
+ def _load_stream(self):
+ """Load Dolma streaming dataset with fallback."""
try:
dataset = load_dataset(
self.dataset_name,
@@ -60,7 +64,6 @@ class DolmaPackedDataset(IterableDataset):
trust_remote_code=True,
)
except Exception:
- # Fallback if specific version not available
dataset = load_dataset(
self.dataset_name,
split="train",
@@ -68,43 +71,63 @@ class DolmaPackedDataset(IterableDataset):
trust_remote_code=True,
)
- # Shard for multi-GPU
if self.world_size > 1:
dataset = dataset.shard(num_shards=self.world_size, index=self.rank)
+ return dataset
+
+ def __iter__(self) -> Iterator[dict]:
+ """Yield packed sequences from Dolma stream with retry on HTTP errors."""
buffer: list[int] = []
sample_count = 0
-
- for doc in dataset:
- if self.max_samples is not None and sample_count >= self.max_samples:
- break
-
- text = doc.get("text", "")
- if not text.strip():
- continue
-
- tokens = self.olmo_tokenizer(text, add_special_tokens=False)["input_ids"]
- buffer.extend(tokens)
- buffer.append(self.eos_id)
-
- # Yield packed sequences as buffer fills
- while len(buffer) >= self.seq_len + 1:
- chunk = buffer[:self.seq_len + 1]
- buffer = buffer[self.seq_len + 1:]
-
- olmo_ids = torch.tensor(chunk[:self.seq_len], dtype=torch.long)
- olmo_labels = torch.tensor(chunk[1:self.seq_len + 1], dtype=torch.long)
- raw_text = self.olmo_tokenizer.decode(chunk[:self.seq_len], skip_special_tokens=False)
-
- yield {
- "olmo_ids": olmo_ids,
- "olmo_labels": olmo_labels,
- "raw_text": raw_text,
- }
- sample_count += 1
-
- if self.max_samples is not None and sample_count >= self.max_samples:
- break
+ retries = 0
+
+ while retries <= MAX_RETRIES:
+ try:
+ dataset = self._load_stream()
+
+ for doc in dataset:
+ if self.max_samples is not None and sample_count >= self.max_samples:
+ return
+
+ text = doc.get("text", "")
+ if not text.strip():
+ continue
+
+ tokens = self.olmo_tokenizer(text, add_special_tokens=False)["input_ids"]
+ buffer.extend(tokens)
+ buffer.append(self.eos_id)
+
+ # Yield packed sequences as buffer fills
+ while len(buffer) >= self.seq_len + 1:
+ chunk = buffer[:self.seq_len + 1]
+ buffer = buffer[self.seq_len + 1:]
+
+ olmo_ids = torch.tensor(chunk[:self.seq_len], dtype=torch.long)
+ olmo_labels = torch.tensor(chunk[1:self.seq_len + 1], dtype=torch.long)
+ raw_text = self.olmo_tokenizer.decode(chunk[:self.seq_len], skip_special_tokens=False)
+
+ yield {
+ "olmo_ids": olmo_ids,
+ "olmo_labels": olmo_labels,
+ "raw_text": raw_text,
+ }
+ sample_count += 1
+
+ if self.max_samples is not None and sample_count >= self.max_samples:
+ return
+
+ # Stream exhausted normally
+ return
+
+ except Exception as e:
+ retries += 1
+ if retries > MAX_RETRIES:
+ raise RuntimeError(f"Dolma stream failed after {MAX_RETRIES} retries: {e}") from e
+ print(f"[DolmaDataset] Stream error (retry {retries}/{MAX_RETRIES}): {e}")
+ print(f"[DolmaDataset] Waiting {RETRY_WAIT}s before reconnecting...")
+ time.sleep(RETRY_WAIT)
+ buffer = [] # reset buffer, data order doesn't matter for training
def build_train_dataloader(