This commit is contained in:
mertalev
2025-03-14 15:04:46 -04:00
parent 0f1a551842
commit bd9374e4a9
7 changed files with 238 additions and 38 deletions
@@ -1,5 +1,6 @@
import warnings
from pathlib import Path
from typing import Any
from .openclip import OpenCLIPModelConfig
from .openclip import to_onnx as openclip_to_onnx
@@ -45,7 +46,7 @@ def to_onnx(
return visual_path, textual_path
def _export_text_encoder(model: "MultilingualCLIP", output_path: Path | str, opset_version: int) -> None:
def _export_text_encoder(model: Any, output_path: Path | str, opset_version: int) -> None:
import torch
from multilingual_clip.pt_multilingual_clip import MultilingualCLIP
@@ -16,18 +16,20 @@ class OpenCLIPModelConfig:
def model_config(self) -> dict[str, Any]:
import open_clip
config = open_clip.get_model_config(self.name)
config: dict[str, Any] | None = open_clip.get_model_config(self.name)
if config is None:
raise ValueError(f"Unknown model {self.name}")
return config
@property
def image_size(self) -> int:
return self.model_config["vision_cfg"]["image_size"]
image_size: int = self.model_config["vision_cfg"]["image_size"]
return image_size
@property
def sequence_length(self) -> int:
return self.model_config["text_cfg"].get("context_length", 77)
context_length: int = self.model_config["text_cfg"].get("context_length", 77)
return context_length
def to_onnx(
@@ -72,7 +74,7 @@ def to_onnx(
for param in model.parameters():
param.requires_grad_(False)
if visual_path is not None:
if visual_path is not None and output_dir_visual is not None:
if no_cache or not visual_path.exists():
save_config(
open_clip.get_model_preprocess_cfg(model),
@@ -83,7 +85,7 @@ def to_onnx(
else:
print(f"Model {visual_path} already exists, skipping")
if textual_path is not None:
if textual_path is not None and output_dir_textual is not None:
if no_cache or not textual_path.exists():
tokenizer_name = text_vision_cfg["text_cfg"].get("hf_tokenizer_name", "openai/clip-vit-base-patch32")
AutoTokenizer.from_pretrained(tokenizer_name).save_pretrained(output_dir_textual)
@@ -94,7 +96,7 @@ def to_onnx(
def _export_image_encoder(
model: "open_clip.CLIP", model_cfg: OpenCLIPModelConfig, output_path: Path | str, opset_version: int
model: Any, model_cfg: OpenCLIPModelConfig, output_path: Path | str, opset_version: int
) -> None:
import torch
@@ -123,7 +125,7 @@ def _export_image_encoder(
def _export_text_encoder(
model: "open_clip.CLIP", model_cfg: OpenCLIPModelConfig, output_path: Path | str, opset_version: int
model: Any, model_cfg: OpenCLIPModelConfig, output_path: Path | str, opset_version: int
) -> None:
import torch
@@ -10,7 +10,7 @@ def _export_platform(
input_size_list: list[list[int]] | None = None,
fuse_matmul_softmax_matmul_to_sdpa: bool = True,
no_cache: bool = False,
):
) -> None:
from rknn.api import RKNN
input_path = model_dir / "model.onnx"
@@ -50,7 +50,7 @@ def _export_platforms(
inputs: list[str] | None = None,
input_size_list: list[list[int]] | None = None,
no_cache: bool = False,
):
) -> None:
fuse_matmul_softmax_matmul_to_sdpa = True
for soc in RKNN_SOCS:
try:
@@ -77,7 +77,7 @@ def _export_platforms(
)
def export(model_dir: Path, no_cache: bool = False):
def export(model_dir: Path, no_cache: bool = False) -> None:
textual = model_dir / "textual"
visual = model_dir / "visual"
detection = model_dir / "detection"