1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
|
#!/usr/bin/env python3
"""Replace challenge dataset artifacts in a Hugging Face dataset repo with a local export."""
from __future__ import annotations
import argparse
from pathlib import Path
from huggingface_hub import HfApi
DEFAULT_REPO_ID = "willdepueoai/parameter-golf"
DEFAULT_PATH_IN_REPO = "datasets"
DATA_ARTIFACT_NAMES = {
"datasets",
"tokenizers",
"manifest.json",
"docs_selected.jsonl",
"docs_selected.source_manifest.json",
"tokenizer_config.export.json",
"snapshot_meta.json",
}
def repo_path(prefix: str, name: str) -> str:
return f"{prefix}/{name}" if prefix else name
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Replace old dataset artifacts in a HF dataset repo with a local export")
parser.add_argument("--repo-id", default=DEFAULT_REPO_ID)
parser.add_argument("--local-export-root", required=True)
parser.add_argument("--path-in-repo", default=DEFAULT_PATH_IN_REPO, help="Subdirectory inside the dataset repo")
parser.add_argument("--repo-type", default="dataset")
parser.add_argument("--revision", default=None)
parser.add_argument("--commit-message", default="Replace dataset export")
parser.add_argument("--dry-run", action="store_true")
return parser
def main() -> None:
args = build_parser().parse_args()
api = HfApi()
local_export_root = Path(args.local_export_root).expanduser().resolve()
if not local_export_root.is_dir():
raise FileNotFoundError(local_export_root)
prefix = args.path_in_repo.strip("/")
top_level_local = {path.name for path in local_export_root.iterdir()}
delete_names = sorted(DATA_ARTIFACT_NAMES | top_level_local)
root_entries = {
entry.path: entry
for entry in api.list_repo_tree(
repo_id=args.repo_id,
recursive=False,
repo_type=args.repo_type,
revision=args.revision,
)
}
if prefix:
if prefix in root_entries:
print(f"delete {prefix}")
if not args.dry_run:
api.delete_folder(
prefix,
repo_id=args.repo_id,
repo_type=args.repo_type,
revision=args.revision,
commit_message=f"Delete {prefix}",
)
remote_entries = root_entries if not prefix else {}
for name in delete_names:
if prefix:
break
remote_path = repo_path(prefix, name)
entry = remote_entries.get(remote_path)
if entry is None:
continue
print(f"delete {remote_path}")
if args.dry_run:
continue
if entry.__class__.__name__ == "RepoFolder":
api.delete_folder(
remote_path,
repo_id=args.repo_id,
repo_type=args.repo_type,
revision=args.revision,
commit_message=f"Delete {remote_path}",
)
else:
api.delete_file(
remote_path,
repo_id=args.repo_id,
repo_type=args.repo_type,
revision=args.revision,
commit_message=f"Delete {remote_path}",
)
print(f"upload {local_export_root} -> {prefix or '/'}")
if args.dry_run:
return
api.upload_folder(
repo_id=args.repo_id,
repo_type=args.repo_type,
revision=args.revision,
folder_path=local_export_root,
path_in_repo=prefix or None,
commit_message=args.commit_message,
)
if __name__ == "__main__":
main()
|