diff options
Diffstat (limited to 'src/data/dolma.py')
| -rw-r--r-- | src/data/dolma.py | 93 |
1 files changed, 58 insertions, 35 deletions
diff --git a/src/data/dolma.py b/src/data/dolma.py index 4e2baaf..ed8c13b 100644 --- a/src/data/dolma.py +++ b/src/data/dolma.py @@ -7,6 +7,7 @@ See CLAUDE.md ยง3.1.1 for sequence packing specification. from __future__ import annotations import os +import time from typing import Iterator, Optional import torch @@ -14,6 +15,9 @@ from datasets import load_dataset from torch.utils.data import IterableDataset from transformers import AutoTokenizer +MAX_RETRIES = 10 +RETRY_WAIT = 30 # seconds + class DolmaPackedDataset(IterableDataset): """Streaming Dolma dataset with sequence packing. @@ -49,8 +53,8 @@ class DolmaPackedDataset(IterableDataset): self.eos_id = olmo_tokenizer.eos_token_id assert self.eos_id is not None, "OLMo tokenizer must have an EOS token" - def __iter__(self) -> Iterator[dict]: - """Yield packed sequences from Dolma stream.""" + def _load_stream(self): + """Load Dolma streaming dataset with fallback.""" try: dataset = load_dataset( self.dataset_name, @@ -60,7 +64,6 @@ class DolmaPackedDataset(IterableDataset): trust_remote_code=True, ) except Exception: - # Fallback if specific version not available dataset = load_dataset( self.dataset_name, split="train", @@ -68,43 +71,63 @@ class DolmaPackedDataset(IterableDataset): trust_remote_code=True, ) - # Shard for multi-GPU if self.world_size > 1: dataset = dataset.shard(num_shards=self.world_size, index=self.rank) + return dataset + + def __iter__(self) -> Iterator[dict]: + """Yield packed sequences from Dolma stream with retry on HTTP errors.""" buffer: list[int] = [] sample_count = 0 - - for doc in dataset: - if self.max_samples is not None and sample_count >= self.max_samples: - break - - text = doc.get("text", "") - if not text.strip(): - continue - - tokens = self.olmo_tokenizer(text, add_special_tokens=False)["input_ids"] - buffer.extend(tokens) - buffer.append(self.eos_id) - - # Yield packed sequences as buffer fills - while len(buffer) >= self.seq_len + 1: - chunk = buffer[:self.seq_len + 1] - buffer = buffer[self.seq_len + 1:] - - olmo_ids = torch.tensor(chunk[:self.seq_len], dtype=torch.long) - olmo_labels = torch.tensor(chunk[1:self.seq_len + 1], dtype=torch.long) - raw_text = self.olmo_tokenizer.decode(chunk[:self.seq_len], skip_special_tokens=False) - - yield { - "olmo_ids": olmo_ids, - "olmo_labels": olmo_labels, - "raw_text": raw_text, - } - sample_count += 1 - - if self.max_samples is not None and sample_count >= self.max_samples: - break + retries = 0 + + while retries <= MAX_RETRIES: + try: + dataset = self._load_stream() + + for doc in dataset: + if self.max_samples is not None and sample_count >= self.max_samples: + return + + text = doc.get("text", "") + if not text.strip(): + continue + + tokens = self.olmo_tokenizer(text, add_special_tokens=False)["input_ids"] + buffer.extend(tokens) + buffer.append(self.eos_id) + + # Yield packed sequences as buffer fills + while len(buffer) >= self.seq_len + 1: + chunk = buffer[:self.seq_len + 1] + buffer = buffer[self.seq_len + 1:] + + olmo_ids = torch.tensor(chunk[:self.seq_len], dtype=torch.long) + olmo_labels = torch.tensor(chunk[1:self.seq_len + 1], dtype=torch.long) + raw_text = self.olmo_tokenizer.decode(chunk[:self.seq_len], skip_special_tokens=False) + + yield { + "olmo_ids": olmo_ids, + "olmo_labels": olmo_labels, + "raw_text": raw_text, + } + sample_count += 1 + + if self.max_samples is not None and sample_count >= self.max_samples: + return + + # Stream exhausted normally + return + + except Exception as e: + retries += 1 + if retries > MAX_RETRIES: + raise RuntimeError(f"Dolma stream failed after {MAX_RETRIES} retries: {e}") from e + print(f"[DolmaDataset] Stream error (retry {retries}/{MAX_RETRIES}): {e}") + print(f"[DolmaDataset] Waiting {RETRY_WAIT}s before reconnecting...") + time.sleep(RETRY_WAIT) + buffer = [] # reset buffer, data order doesn't matter for training def build_train_dataloader( |
