diff options
Diffstat (limited to 'scripts/upload_to_hf.py')
| -rw-r--r-- | scripts/upload_to_hf.py | 69 |
1 files changed, 69 insertions, 0 deletions
diff --git a/scripts/upload_to_hf.py b/scripts/upload_to_hf.py new file mode 100644 index 0000000..3c41011 --- /dev/null +++ b/scripts/upload_to_hf.py @@ -0,0 +1,69 @@ +import os +import argparse +from huggingface_hub import HfApi, create_repo, upload_folder + +def main(): + parser = argparse.ArgumentParser(description="Upload a checkpoint to Hugging Face Hub") + parser.add_argument("--ckpt_path", type=str, required=True, help="Path to the checkpoint directory") + parser.add_argument("--repo_id", type=str, required=True, help="Hugging Face repo ID (e.g., username/model-name)") + parser.add_argument("--token", type=str, help="Hugging Face token (optional if logged in via CLI)") + parser.add_argument("--private", action="store_true", help="Make the repository private") + parser.add_argument("--include-optimizer", action="store_true", help="Include optimizer states (optimizer.pt, scheduler.pt, etc.)") + + args = parser.parse_args() + + # Expand path + ckpt_path = os.path.abspath(args.ckpt_path) + if not os.path.exists(ckpt_path): + print(f"Error: Checkpoint path does not exist: {ckpt_path}") + return + + print(f"Preparing to upload {ckpt_path} to {args.repo_id}...") + + api = HfApi(token=args.token) + + # Create repo if it doesn't exist + try: + print(f"Creating repository {args.repo_id} (if not exists)...") + create_repo( + repo_id=args.repo_id, + token=args.token, + private=args.private, + exist_ok=True, + repo_type="model" + ) + except Exception as e: + print(f"Error creating repo: {e}") + print("Please check your token and permissions.") + return + + # patterns to ignore if we don't want optimizer states + ignore_patterns = [] + if not args.include_optimizer: + ignore_patterns = [ + "optimizer.pt", + "scheduler.pt", + "rng_state_*.pth", + "trainer_state.json", + "training_args.bin" + ] + print("Excluding optimizer states from upload.") + + print("Starting upload...") + try: + api.upload_folder( + folder_path=ckpt_path, + repo_id=args.repo_id, + repo_type="model", + ignore_patterns=ignore_patterns, + token=args.token + ) + print(f"Successfully uploaded to https://huggingface.co/{args.repo_id}") + except Exception as e: + print(f"Upload failed: {e}") + +if __name__ == "__main__": + main() + + + |
