feat(ml): improve test coverage (#7041)
* update e2e * tokenizer tests * more tests, remove unnecessary code * fix e2e setting * add tests for loading model * update workflow * fixed test
This commit is contained in:
@@ -21,4 +21,4 @@ def from_model_type(model_type: ModelType, model_name: str, **model_kwargs: Any)
|
||||
case _:
|
||||
raise ValueError(f"Unknown model type {model_type}")
|
||||
|
||||
raise ValueError(f"Unknown ${model_type} model {model_name}")
|
||||
raise ValueError(f"Unknown {model_type} model {model_name}")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
@@ -11,7 +10,6 @@ import onnxruntime as ort
|
||||
from huggingface_hub import snapshot_download
|
||||
from onnx.shape_inference import infer_shapes
|
||||
from onnx.tools.update_model_dims import update_inputs_outputs_dims
|
||||
from typing_extensions import Buffer
|
||||
|
||||
import ann.ann
|
||||
from app.models.constants import STATIC_INPUT_PROVIDERS, SUPPORTED_PROVIDERS
|
||||
@@ -200,7 +198,7 @@ class InferenceModel(ABC):
|
||||
|
||||
@providers.setter
|
||||
def providers(self, providers: list[str]) -> None:
|
||||
log.debug(
|
||||
log.info(
|
||||
(f"Setting '{self.model_name}' execution providers to {providers}, " "in descending order of preference"),
|
||||
)
|
||||
self._providers = providers
|
||||
@@ -217,7 +215,7 @@ class InferenceModel(ABC):
|
||||
|
||||
@provider_options.setter
|
||||
def provider_options(self, provider_options: list[dict[str, Any]]) -> None:
|
||||
log.info(f"Setting execution provider options to {provider_options}")
|
||||
log.debug(f"Setting execution provider options to {provider_options}")
|
||||
self._provider_options = provider_options
|
||||
|
||||
@property
|
||||
@@ -255,7 +253,7 @@ class InferenceModel(ABC):
|
||||
|
||||
@property
|
||||
def sess_options_default(self) -> ort.SessionOptions:
|
||||
sess_options = PicklableSessionOptions()
|
||||
sess_options = ort.SessionOptions()
|
||||
sess_options.enable_cpu_mem_arena = False
|
||||
|
||||
# avoid thread contention between models
|
||||
@@ -287,15 +285,3 @@ class InferenceModel(ABC):
|
||||
@property
|
||||
def preferred_runtime_default(self) -> ModelRuntime:
|
||||
return ModelRuntime.ARMNN if ann.ann.is_available and settings.ann else ModelRuntime.ONNX
|
||||
|
||||
|
||||
# HF deep copies configs, so we need to make session options picklable
|
||||
class PicklableSessionOptions(ort.SessionOptions): # type: ignore[misc]
|
||||
def __getstate__(self) -> bytes:
|
||||
return pickle.dumps([(attr, getattr(self, attr)) for attr in dir(self) if not callable(getattr(self, attr))])
|
||||
|
||||
def __setstate__(self, state: Buffer) -> None:
|
||||
self.__init__() # type: ignore[misc]
|
||||
attrs: list[tuple[str, Any]] = pickle.loads(state)
|
||||
for attr, val in attrs:
|
||||
setattr(self, attr, val)
|
||||
|
||||
@@ -80,20 +80,3 @@ class RevalidationPlugin(BasePlugin): # type: ignore[misc]
|
||||
key = client.build_key(key, namespace)
|
||||
if key in client._handlers:
|
||||
await client.expire(key, client.ttl)
|
||||
|
||||
async def post_multi_get(
|
||||
self,
|
||||
client: SimpleMemoryCache,
|
||||
keys: list[str],
|
||||
ret: list[Any] | None = None,
|
||||
namespace: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if ret is None:
|
||||
return
|
||||
|
||||
for key, val in zip(keys, ret):
|
||||
if namespace is not None:
|
||||
key = client.build_key(key, namespace)
|
||||
if val is not None and key in client._handlers:
|
||||
await client.expire(key, client.ttl)
|
||||
|
||||
@@ -144,9 +144,7 @@ class OpenCLIPEncoder(BaseCLIPEncoder):
|
||||
|
||||
def _load(self) -> None:
|
||||
super()._load()
|
||||
text_cfg: dict[str, Any] = self.model_cfg["text_cfg"]
|
||||
context_length: int = text_cfg.get("context_length", 77)
|
||||
pad_token: int = self.tokenizer_cfg["pad_token"]
|
||||
self._load_tokenizer()
|
||||
|
||||
size: list[int] | int = self.preprocess_cfg["size"]
|
||||
self.size = size[0] if isinstance(size, list) else size
|
||||
@@ -155,11 +153,19 @@ class OpenCLIPEncoder(BaseCLIPEncoder):
|
||||
self.mean = np.array(self.preprocess_cfg["mean"], dtype=np.float32)
|
||||
self.std = np.array(self.preprocess_cfg["std"], dtype=np.float32)
|
||||
|
||||
def _load_tokenizer(self) -> Tokenizer:
|
||||
log.debug(f"Loading tokenizer for CLIP model '{self.model_name}'")
|
||||
|
||||
text_cfg: dict[str, Any] = self.model_cfg["text_cfg"]
|
||||
context_length: int = text_cfg.get("context_length", 77)
|
||||
pad_token: str = self.tokenizer_cfg["pad_token"]
|
||||
|
||||
self.tokenizer: Tokenizer = Tokenizer.from_file(self.tokenizer_file_path.as_posix())
|
||||
|
||||
pad_id: int = self.tokenizer.token_to_id(pad_token)
|
||||
self.tokenizer.enable_padding(length=context_length, pad_token=pad_token, pad_id=pad_id)
|
||||
self.tokenizer.enable_truncation(max_length=context_length)
|
||||
|
||||
log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'")
|
||||
|
||||
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
||||
|
||||
Reference in New Issue
Block a user