summaryrefslogtreecommitdiff
path: root/scripts/upload_to_hf.py
blob: 3c4101104bd508663420822a92fb89de110cfb25 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import argparse
from huggingface_hub import HfApi, create_repo, upload_folder

def main():
    parser = argparse.ArgumentParser(description="Upload a checkpoint to Hugging Face Hub")
    parser.add_argument("--ckpt_path", type=str, required=True, help="Path to the checkpoint directory")
    parser.add_argument("--repo_id", type=str, required=True, help="Hugging Face repo ID (e.g., username/model-name)")
    parser.add_argument("--token", type=str, help="Hugging Face token (optional if logged in via CLI)")
    parser.add_argument("--private", action="store_true", help="Make the repository private")
    parser.add_argument("--include-optimizer", action="store_true", help="Include optimizer states (optimizer.pt, scheduler.pt, etc.)")

    args = parser.parse_args()

    # Expand path
    ckpt_path = os.path.abspath(args.ckpt_path)
    if not os.path.exists(ckpt_path):
        print(f"Error: Checkpoint path does not exist: {ckpt_path}")
        return

    print(f"Preparing to upload {ckpt_path} to {args.repo_id}...")

    api = HfApi(token=args.token)

    # Create repo if it doesn't exist
    try:
        print(f"Creating repository {args.repo_id} (if not exists)...")
        create_repo(
            repo_id=args.repo_id,
            token=args.token,
            private=args.private,
            exist_ok=True,
            repo_type="model"
        )
    except Exception as e:
        print(f"Error creating repo: {e}")
        print("Please check your token and permissions.")
        return

    # patterns to ignore if we don't want optimizer states
    ignore_patterns = []
    if not args.include_optimizer:
        ignore_patterns = [
            "optimizer.pt",
            "scheduler.pt",
            "rng_state_*.pth",
            "trainer_state.json",
            "training_args.bin"
        ]
        print("Excluding optimizer states from upload.")
    
    print("Starting upload...")
    try:
        api.upload_folder(
            folder_path=ckpt_path,
            repo_id=args.repo_id,
            repo_type="model",
            ignore_patterns=ignore_patterns,
            token=args.token
        )
        print(f"Successfully uploaded to https://huggingface.co/{args.repo_id}")
    except Exception as e:
        print(f"Upload failed: {e}")

if __name__ == "__main__":
    main()