upload to hf

This commit is contained in:
mertalev
2025-03-13 17:31:41 -04:00
parent c57c562166
commit 9958ac9ec9
8 changed files with 128 additions and 81 deletions
@@ -1,9 +1,10 @@
from pathlib import Path
import typer
from exporters.constants import SOURCE_TO_METADATA, ModelSource
from exporters.constants import DELETE_PATTERNS, SOURCE_TO_METADATA, ModelSource
from exporters.onnx import export as onnx_export
from exporters.rknn import export as rknn_export
from tenacity import retry, stop_after_attempt
from typing_extensions import Annotated
app = typer.Typer(pretty_exceptions_show_locals=False)
@@ -45,7 +46,8 @@ def main(
hf_organization: str = "immich-app",
hf_auth_token: Annotated[str | None, typer.Option(envvar="HF_AUTH_TOKEN")] = None,
):
hf_model_name = model_name.replace("xlm-roberta-large", "XLM-Roberta-Large")
hf_model_name = model_name.split("/")[-1]
hf_model_name = hf_model_name.replace("xlm-roberta-large", "XLM-Roberta-Large")
hf_model_name = hf_model_name.replace("xlm-roberta-base", "XLM-Roberta-Base")
output_dir = output_dir / hf_model_name
match model_source:
@@ -75,21 +77,20 @@ def main(
from huggingface_hub import create_repo, upload_folder
repo_id = f"{hf_organization}/{hf_model_name}"
create_repo(repo_id, exist_ok=True, token=hf_auth_token)
# glob to delete old UUID blobs when reuploading models
uuid_char = "[a-fA-F0-9]"
uuid_glob = (
uuid_char * 8 + "-" + uuid_char * 4 + "-" + uuid_char * 4 + "-" + uuid_char * 4 + "-" + uuid_char * 12
)
upload_folder(
repo_id=repo_id,
folder_path=output_dir,
# remote repo files to be deleted before uploading
# deletion is in the same commit as the upload, so it's atomic
delete_patterns=[f"**/{uuid_glob}"],
token=hf_auth_token,
)
@retry(stop=stop_after_attempt(3), wait=2)
def upload_model():
create_repo(repo_id, exist_ok=True, token=hf_auth_token)
upload_folder(
repo_id=repo_id,
folder_path=output_dir,
# remote repo files to be deleted before uploading
# deletion is in the same commit as the upload, so it's atomic
delete_patterns=DELETE_PATTERNS,
token=hf_auth_token,
)
upload_model()
if __name__ == "__main__":