summaryrefslogtreecommitdiff
path: root/data/cached_challenge_fineweb.py
diff options
context:
space:
mode:
authorWill DePue <williamd@openai.com>2026-03-18 09:32:01 -0700
committerWill DePue <williamd@openai.com>2026-03-18 09:32:01 -0700
commita15093adad328a650d421e53c078cbd2c45beb0e (patch)
treee054c4bde12b89e6d3b39d611d9caadabc7f7234 /data/cached_challenge_fineweb.py
Launch snapshot
Diffstat (limited to 'data/cached_challenge_fineweb.py')
-rw-r--r--data/cached_challenge_fineweb.py157
1 files changed, 157 insertions, 0 deletions
diff --git a/data/cached_challenge_fineweb.py b/data/cached_challenge_fineweb.py
new file mode 100644
index 0000000..fa8029b
--- /dev/null
+++ b/data/cached_challenge_fineweb.py
@@ -0,0 +1,157 @@
+import argparse
+import json
+import os
+import shutil
+from pathlib import Path
+
+from huggingface_hub import hf_hub_download
+
+
+REPO_ID = os.environ.get("MATCHED_FINEWEB_REPO_ID", "willdepueoai/parameter-golf")
+REMOTE_ROOT_PREFIX = os.environ.get("MATCHED_FINEWEB_REMOTE_ROOT_PREFIX", "datasets")
+ROOT = Path(__file__).resolve().parent
+DATASETS_DIR = ROOT / "datasets"
+TOKENIZERS_DIR = ROOT / "tokenizers"
+
+def dataset_dir_for_variant(name: str) -> str:
+ if name == "byte260":
+ return "fineweb10B_byte260"
+ if name.startswith("sp") and name[2:].isdigit():
+ return f"fineweb10B_{name}"
+ raise ValueError(f"unsupported variant {name!r}; expected byte260 or sp<VOCAB_SIZE>")
+
+
+def local_path_for_remote(relative_path: str) -> Path:
+ remote_path = Path(relative_path)
+ if REMOTE_ROOT_PREFIX and remote_path.parts[:1] == (REMOTE_ROOT_PREFIX,):
+ remote_path = remote_path.relative_to(REMOTE_ROOT_PREFIX)
+ if remote_path.parts[:1] == ("datasets",):
+ return DATASETS_DIR.joinpath(*remote_path.parts[1:])
+ if remote_path.parts[:1] == ("tokenizers",):
+ return TOKENIZERS_DIR.joinpath(*remote_path.parts[1:])
+ return ROOT / remote_path
+
+
+def get(relative_path: str) -> None:
+ destination = local_path_for_remote(relative_path)
+ if destination.exists():
+ return
+ if destination.is_symlink():
+ destination.unlink()
+
+ remote_path = Path(relative_path)
+ cached_path = Path(
+ hf_hub_download(
+ repo_id=REPO_ID,
+ filename=remote_path.name,
+ subfolder=remote_path.parent.as_posix() if remote_path.parent != Path(".") else None,
+ repo_type="dataset",
+ )
+ )
+ # HF cache entries may be snapshot symlinks. Resolve to the underlying blob so we
+ # always materialize a real file in data/, not a broken relative symlink.
+ cached_source = cached_path.resolve(strict=True)
+ destination.parent.mkdir(parents=True, exist_ok=True)
+ try:
+ os.link(cached_source, destination)
+ except OSError:
+ shutil.copy2(cached_source, destination)
+
+
+def manifest_path() -> Path:
+ return local_path_for_remote(f"{REMOTE_ROOT_PREFIX}/manifest.json")
+
+
+def load_manifest(*, skip_manifest_download: bool) -> dict:
+ path = manifest_path()
+ if not path.is_file():
+ if skip_manifest_download:
+ raise FileNotFoundError(
+ f"manifest.json is required for manifest-driven shard counts but is not present locally at {path}"
+ )
+ get(f"{REMOTE_ROOT_PREFIX}/manifest.json")
+ return json.loads(path.read_text(encoding="utf-8"))
+
+
+def artifact_paths_for_tokenizer(tokenizer_entry: dict) -> list[str]:
+ artifacts = []
+ for key in ("model_path", "vocab_path", "path"):
+ value = tokenizer_entry.get(key)
+ if value:
+ artifacts.append(str(value))
+ if not artifacts:
+ raise ValueError(f"tokenizer entry is missing downloadable artifacts: {tokenizer_entry}")
+ return artifacts
+
+
+def build_parser() -> argparse.ArgumentParser:
+ parser = argparse.ArgumentParser(description="Download challenge FineWeb shards from Hugging Face")
+ parser.add_argument(
+ "train_shards_positional",
+ nargs="?",
+ type=int,
+ default=None,
+ help=argparse.SUPPRESS,
+ )
+ parser.add_argument(
+ "--train-shards",
+ type=int,
+ default=80,
+ help="Number of training shards to download for the selected variant. Defaults to 80.",
+ )
+ parser.add_argument(
+ "--variant",
+ default="sp1024",
+ help="Tokenizer family to download, for example sp1024, sp4096, or byte260.",
+ )
+ parser.add_argument(
+ "--skip-manifest",
+ action="store_true",
+ help="Skip downloading manifest.json.",
+ )
+ parser.add_argument(
+ "--with-docs",
+ action="store_true",
+ help="Also download docs_selected.jsonl and its sidecar for tokenizer retraining or dataset re-export.",
+ )
+ return parser
+
+
+def main() -> None:
+ args = build_parser().parse_args()
+ dataset_dir = dataset_dir_for_variant(args.variant)
+ train_shards = args.train_shards_positional if args.train_shards_positional is not None else args.train_shards
+ if train_shards < 0:
+ raise ValueError("train_shards must be non-negative")
+
+ manifest = load_manifest(skip_manifest_download=args.skip_manifest)
+ dataset_entry = next((x for x in manifest.get("datasets", []) if x.get("name") == dataset_dir), None)
+ if dataset_entry is None:
+ raise ValueError(f"dataset {dataset_dir} not found in {REMOTE_ROOT_PREFIX}/manifest.json")
+ max_train_shards = int((dataset_entry.get("stats") or {}).get("files_train"))
+ val_shards = int((dataset_entry.get("stats") or {}).get("files_val"))
+ if train_shards > max_train_shards:
+ raise ValueError(
+ f"{args.variant} only has {max_train_shards} training shards on {REPO_ID}, requested {train_shards}"
+ )
+ tokenizer_name = dataset_entry.get("tokenizer_name")
+ tokenizer_entry = next((x for x in manifest.get("tokenizers", []) if x.get("name") == tokenizer_name), None)
+ if tokenizer_entry is None:
+ raise ValueError(f"tokenizer {tokenizer_name} not found in {REMOTE_ROOT_PREFIX}/manifest.json")
+
+ if args.with_docs:
+ get(f"{REMOTE_ROOT_PREFIX}/docs_selected.jsonl")
+ get(f"{REMOTE_ROOT_PREFIX}/docs_selected.source_manifest.json")
+
+ dataset_prefix = f"{REMOTE_ROOT_PREFIX}/datasets/{dataset_dir}"
+ for i in range(val_shards):
+ get(f"{dataset_prefix}/fineweb_val_{i:06d}.bin")
+ for i in range(train_shards):
+ get(f"{dataset_prefix}/fineweb_train_{i:06d}.bin")
+
+ for artifact_path in artifact_paths_for_tokenizer(tokenizer_entry):
+ get(f"{REMOTE_ROOT_PREFIX}/{artifact_path}")
+
+
+if __name__ == "__main__":
+ main()