import os import argparse from huggingface_hub import HfApi, create_repo, upload_folder def main(): parser = argparse.ArgumentParser(description="Upload a checkpoint to Hugging Face Hub") parser.add_argument("--ckpt_path", type=str, required=True, help="Path to the checkpoint directory") parser.add_argument("--repo_id", type=str, required=True, help="Hugging Face repo ID (e.g., username/model-name)") parser.add_argument("--token", type=str, help="Hugging Face token (optional if logged in via CLI)") parser.add_argument("--private", action="store_true", help="Make the repository private") parser.add_argument("--include-optimizer", action="store_true", help="Include optimizer states (optimizer.pt, scheduler.pt, etc.)") args = parser.parse_args() # Expand path ckpt_path = os.path.abspath(args.ckpt_path) if not os.path.exists(ckpt_path): print(f"Error: Checkpoint path does not exist: {ckpt_path}") return print(f"Preparing to upload {ckpt_path} to {args.repo_id}...") api = HfApi(token=args.token) # Create repo if it doesn't exist try: print(f"Creating repository {args.repo_id} (if not exists)...") create_repo( repo_id=args.repo_id, token=args.token, private=args.private, exist_ok=True, repo_type="model" ) except Exception as e: print(f"Error creating repo: {e}") print("Please check your token and permissions.") return # patterns to ignore if we don't want optimizer states ignore_patterns = [] if not args.include_optimizer: ignore_patterns = [ "optimizer.pt", "scheduler.pt", "rng_state_*.pth", "trainer_state.json", "training_args.bin" ] print("Excluding optimizer states from upload.") print("Starting upload...") try: api.upload_folder( folder_path=ckpt_path, repo_id=args.repo_id, repo_type="model", ignore_patterns=ignore_patterns, token=args.token ) print(f"Successfully uploaded to https://huggingface.co/{args.repo_id}") except Exception as e: print(f"Upload failed: {e}") if __name__ == "__main__": main()