linting
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
from exporters.constants import DELETE_PATTERNS, SOURCE_TO_METADATA, ModelSource
|
||||
from exporters.onnx import export as onnx_export
|
||||
from exporters.rknn import export as rknn_export
|
||||
from tenacity import retry, stop_after_attempt, wait_fixed
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from .exporters.constants import DELETE_PATTERNS, SOURCE_TO_METADATA, ModelSource
|
||||
from .exporters.onnx import export as onnx_export
|
||||
from .exporters.rknn import export as rknn_export
|
||||
|
||||
app = typer.Typer(pretty_exceptions_show_locals=False)
|
||||
|
||||
|
||||
@@ -45,7 +46,7 @@ def main(
|
||||
no_cache: bool = False,
|
||||
hf_organization: str = "immich-app",
|
||||
hf_auth_token: Annotated[str | None, typer.Option(envvar="HF_AUTH_TOKEN")] = None,
|
||||
):
|
||||
) -> None:
|
||||
hf_model_name = model_name.split("/")[-1]
|
||||
hf_model_name = hf_model_name.replace("xlm-roberta-large", "XLM-Roberta-Large")
|
||||
hf_model_name = hf_model_name.replace("xlm-roberta-base", "XLM-Roberta-Base")
|
||||
@@ -79,7 +80,7 @@ def main(
|
||||
repo_id = f"{hf_organization}/{hf_model_name}"
|
||||
|
||||
@retry(stop=stop_after_attempt(5), wait=wait_fixed(5))
|
||||
def upload_model():
|
||||
def upload_model() -> None:
|
||||
create_repo(repo_id, exist_ok=True, token=hf_auth_token)
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .openclip import OpenCLIPModelConfig
|
||||
from .openclip import to_onnx as openclip_to_onnx
|
||||
@@ -45,7 +46,7 @@ def to_onnx(
|
||||
return visual_path, textual_path
|
||||
|
||||
|
||||
def _export_text_encoder(model: "MultilingualCLIP", output_path: Path | str, opset_version: int) -> None:
|
||||
def _export_text_encoder(model: Any, output_path: Path | str, opset_version: int) -> None:
|
||||
import torch
|
||||
from multilingual_clip.pt_multilingual_clip import MultilingualCLIP
|
||||
|
||||
|
||||
@@ -16,18 +16,20 @@ class OpenCLIPModelConfig:
|
||||
def model_config(self) -> dict[str, Any]:
|
||||
import open_clip
|
||||
|
||||
config = open_clip.get_model_config(self.name)
|
||||
config: dict[str, Any] | None = open_clip.get_model_config(self.name)
|
||||
if config is None:
|
||||
raise ValueError(f"Unknown model {self.name}")
|
||||
return config
|
||||
|
||||
@property
|
||||
def image_size(self) -> int:
|
||||
return self.model_config["vision_cfg"]["image_size"]
|
||||
image_size: int = self.model_config["vision_cfg"]["image_size"]
|
||||
return image_size
|
||||
|
||||
@property
|
||||
def sequence_length(self) -> int:
|
||||
return self.model_config["text_cfg"].get("context_length", 77)
|
||||
context_length: int = self.model_config["text_cfg"].get("context_length", 77)
|
||||
return context_length
|
||||
|
||||
|
||||
def to_onnx(
|
||||
@@ -72,7 +74,7 @@ def to_onnx(
|
||||
for param in model.parameters():
|
||||
param.requires_grad_(False)
|
||||
|
||||
if visual_path is not None:
|
||||
if visual_path is not None and output_dir_visual is not None:
|
||||
if no_cache or not visual_path.exists():
|
||||
save_config(
|
||||
open_clip.get_model_preprocess_cfg(model),
|
||||
@@ -83,7 +85,7 @@ def to_onnx(
|
||||
else:
|
||||
print(f"Model {visual_path} already exists, skipping")
|
||||
|
||||
if textual_path is not None:
|
||||
if textual_path is not None and output_dir_textual is not None:
|
||||
if no_cache or not textual_path.exists():
|
||||
tokenizer_name = text_vision_cfg["text_cfg"].get("hf_tokenizer_name", "openai/clip-vit-base-patch32")
|
||||
AutoTokenizer.from_pretrained(tokenizer_name).save_pretrained(output_dir_textual)
|
||||
@@ -94,7 +96,7 @@ def to_onnx(
|
||||
|
||||
|
||||
def _export_image_encoder(
|
||||
model: "open_clip.CLIP", model_cfg: OpenCLIPModelConfig, output_path: Path | str, opset_version: int
|
||||
model: Any, model_cfg: OpenCLIPModelConfig, output_path: Path | str, opset_version: int
|
||||
) -> None:
|
||||
import torch
|
||||
|
||||
@@ -123,7 +125,7 @@ def _export_image_encoder(
|
||||
|
||||
|
||||
def _export_text_encoder(
|
||||
model: "open_clip.CLIP", model_cfg: OpenCLIPModelConfig, output_path: Path | str, opset_version: int
|
||||
model: Any, model_cfg: OpenCLIPModelConfig, output_path: Path | str, opset_version: int
|
||||
) -> None:
|
||||
import torch
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ def _export_platform(
|
||||
input_size_list: list[list[int]] | None = None,
|
||||
fuse_matmul_softmax_matmul_to_sdpa: bool = True,
|
||||
no_cache: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
from rknn.api import RKNN
|
||||
|
||||
input_path = model_dir / "model.onnx"
|
||||
@@ -50,7 +50,7 @@ def _export_platforms(
|
||||
inputs: list[str] | None = None,
|
||||
input_size_list: list[list[int]] | None = None,
|
||||
no_cache: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
fuse_matmul_softmax_matmul_to_sdpa = True
|
||||
for soc in RKNN_SOCS:
|
||||
try:
|
||||
@@ -77,7 +77,7 @@ def _export_platforms(
|
||||
)
|
||||
|
||||
|
||||
def export(model_dir: Path, no_cache: bool = False):
|
||||
def export(model_dir: Path, no_cache: bool = False) -> None:
|
||||
textual = model_dir / "textual"
|
||||
visual = model_dir / "visual"
|
||||
detection = model_dir / "detection"
|
||||
|
||||
@@ -73,7 +73,7 @@ insightface = [
|
||||
]
|
||||
|
||||
|
||||
def export_models(models: list[str], source: ModelSource):
|
||||
def export_models(models: list[str], source: ModelSource) -> None:
|
||||
for model in models:
|
||||
try:
|
||||
print(f"Exporting model {model}")
|
||||
|
||||
Reference in New Issue
Block a user