This commit is contained in:
mertalev
2025-03-14 15:04:46 -04:00
parent 0f1a551842
commit bd9374e4a9
7 changed files with 238 additions and 38 deletions
@@ -1,12 +1,13 @@
from pathlib import Path
import typer
from exporters.constants import DELETE_PATTERNS, SOURCE_TO_METADATA, ModelSource
from exporters.onnx import export as onnx_export
from exporters.rknn import export as rknn_export
from tenacity import retry, stop_after_attempt, wait_fixed
from typing_extensions import Annotated
from .exporters.constants import DELETE_PATTERNS, SOURCE_TO_METADATA, ModelSource
from .exporters.onnx import export as onnx_export
from .exporters.rknn import export as rknn_export
app = typer.Typer(pretty_exceptions_show_locals=False)
@@ -45,7 +46,7 @@ def main(
no_cache: bool = False,
hf_organization: str = "immich-app",
hf_auth_token: Annotated[str | None, typer.Option(envvar="HF_AUTH_TOKEN")] = None,
):
) -> None:
hf_model_name = model_name.split("/")[-1]
hf_model_name = hf_model_name.replace("xlm-roberta-large", "XLM-Roberta-Large")
hf_model_name = hf_model_name.replace("xlm-roberta-base", "XLM-Roberta-Base")
@@ -79,7 +80,7 @@ def main(
repo_id = f"{hf_organization}/{hf_model_name}"
@retry(stop=stop_after_attempt(5), wait=wait_fixed(5))
def upload_model():
def upload_model() -> None:
create_repo(repo_id, exist_ok=True, token=hf_auth_token)
upload_folder(
repo_id=repo_id,
@@ -1,5 +1,6 @@
import warnings
from pathlib import Path
from typing import Any
from .openclip import OpenCLIPModelConfig
from .openclip import to_onnx as openclip_to_onnx
@@ -45,7 +46,7 @@ def to_onnx(
return visual_path, textual_path
def _export_text_encoder(model: "MultilingualCLIP", output_path: Path | str, opset_version: int) -> None:
def _export_text_encoder(model: Any, output_path: Path | str, opset_version: int) -> None:
import torch
from multilingual_clip.pt_multilingual_clip import MultilingualCLIP
@@ -16,18 +16,20 @@ class OpenCLIPModelConfig:
def model_config(self) -> dict[str, Any]:
import open_clip
config = open_clip.get_model_config(self.name)
config: dict[str, Any] | None = open_clip.get_model_config(self.name)
if config is None:
raise ValueError(f"Unknown model {self.name}")
return config
@property
def image_size(self) -> int:
return self.model_config["vision_cfg"]["image_size"]
image_size: int = self.model_config["vision_cfg"]["image_size"]
return image_size
@property
def sequence_length(self) -> int:
return self.model_config["text_cfg"].get("context_length", 77)
context_length: int = self.model_config["text_cfg"].get("context_length", 77)
return context_length
def to_onnx(
@@ -72,7 +74,7 @@ def to_onnx(
for param in model.parameters():
param.requires_grad_(False)
if visual_path is not None:
if visual_path is not None and output_dir_visual is not None:
if no_cache or not visual_path.exists():
save_config(
open_clip.get_model_preprocess_cfg(model),
@@ -83,7 +85,7 @@ def to_onnx(
else:
print(f"Model {visual_path} already exists, skipping")
if textual_path is not None:
if textual_path is not None and output_dir_textual is not None:
if no_cache or not textual_path.exists():
tokenizer_name = text_vision_cfg["text_cfg"].get("hf_tokenizer_name", "openai/clip-vit-base-patch32")
AutoTokenizer.from_pretrained(tokenizer_name).save_pretrained(output_dir_textual)
@@ -94,7 +96,7 @@ def to_onnx(
def _export_image_encoder(
model: "open_clip.CLIP", model_cfg: OpenCLIPModelConfig, output_path: Path | str, opset_version: int
model: Any, model_cfg: OpenCLIPModelConfig, output_path: Path | str, opset_version: int
) -> None:
import torch
@@ -123,7 +125,7 @@ def _export_image_encoder(
def _export_text_encoder(
model: "open_clip.CLIP", model_cfg: OpenCLIPModelConfig, output_path: Path | str, opset_version: int
model: Any, model_cfg: OpenCLIPModelConfig, output_path: Path | str, opset_version: int
) -> None:
import torch
@@ -10,7 +10,7 @@ def _export_platform(
input_size_list: list[list[int]] | None = None,
fuse_matmul_softmax_matmul_to_sdpa: bool = True,
no_cache: bool = False,
):
) -> None:
from rknn.api import RKNN
input_path = model_dir / "model.onnx"
@@ -50,7 +50,7 @@ def _export_platforms(
inputs: list[str] | None = None,
input_size_list: list[list[int]] | None = None,
no_cache: bool = False,
):
) -> None:
fuse_matmul_softmax_matmul_to_sdpa = True
for soc in RKNN_SOCS:
try:
@@ -77,7 +77,7 @@ def _export_platforms(
)
def export(model_dir: Path, no_cache: bool = False):
def export(model_dir: Path, no_cache: bool = False) -> None:
textual = model_dir / "textual"
visual = model_dir / "visual"
detection = model_dir / "detection"
@@ -73,7 +73,7 @@ insightface = [
]
def export_models(models: list[str], source: ModelSource):
def export_models(models: list[str], source: ModelSource) -> None:
for model in models:
try:
print(f"Exporting model {model}")