summaryrefslogtreecommitdiff
path: root/scripts/upload_to_hf.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
commite43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (patch)
tree6ce8a00d2f8b9ebd83c894a27ea01ac50cfb2ff5 /scripts/upload_to_hf.py
Initial commit (clean history)HEADmain
Diffstat (limited to 'scripts/upload_to_hf.py')
-rw-r--r--scripts/upload_to_hf.py69
1 files changed, 69 insertions, 0 deletions
diff --git a/scripts/upload_to_hf.py b/scripts/upload_to_hf.py
new file mode 100644
index 0000000..3c41011
--- /dev/null
+++ b/scripts/upload_to_hf.py
@@ -0,0 +1,69 @@
+import os
+import argparse
+from huggingface_hub import HfApi, create_repo, upload_folder
+
+def main():
+ parser = argparse.ArgumentParser(description="Upload a checkpoint to Hugging Face Hub")
+ parser.add_argument("--ckpt_path", type=str, required=True, help="Path to the checkpoint directory")
+ parser.add_argument("--repo_id", type=str, required=True, help="Hugging Face repo ID (e.g., username/model-name)")
+ parser.add_argument("--token", type=str, help="Hugging Face token (optional if logged in via CLI)")
+ parser.add_argument("--private", action="store_true", help="Make the repository private")
+ parser.add_argument("--include-optimizer", action="store_true", help="Include optimizer states (optimizer.pt, scheduler.pt, etc.)")
+
+ args = parser.parse_args()
+
+ # Expand path
+ ckpt_path = os.path.abspath(args.ckpt_path)
+ if not os.path.exists(ckpt_path):
+ print(f"Error: Checkpoint path does not exist: {ckpt_path}")
+ return
+
+ print(f"Preparing to upload {ckpt_path} to {args.repo_id}...")
+
+ api = HfApi(token=args.token)
+
+ # Create repo if it doesn't exist
+ try:
+ print(f"Creating repository {args.repo_id} (if not exists)...")
+ create_repo(
+ repo_id=args.repo_id,
+ token=args.token,
+ private=args.private,
+ exist_ok=True,
+ repo_type="model"
+ )
+ except Exception as e:
+ print(f"Error creating repo: {e}")
+ print("Please check your token and permissions.")
+ return
+
+ # patterns to ignore if we don't want optimizer states
+ ignore_patterns = []
+ if not args.include_optimizer:
+ ignore_patterns = [
+ "optimizer.pt",
+ "scheduler.pt",
+ "rng_state_*.pth",
+ "trainer_state.json",
+ "training_args.bin"
+ ]
+ print("Excluding optimizer states from upload.")
+
+ print("Starting upload...")
+ try:
+ api.upload_folder(
+ folder_path=ckpt_path,
+ repo_id=args.repo_id,
+ repo_type="model",
+ ignore_patterns=ignore_patterns,
+ token=args.token
+ )
+ print(f"Successfully uploaded to https://huggingface.co/{args.repo_id}")
+ except Exception as e:
+ print(f"Upload failed: {e}")
+
+if __name__ == "__main__":
+ main()
+
+
+