feat(ml): improve test coverage (#7041)

* update e2e

* tokenizer tests

* more tests, remove unnecessary code

* fix e2e setting

* add tests for loading model

* update workflow

* fixed test
This commit is contained in:
Mert
2024-02-11 17:58:56 -05:00
committed by GitHub
parent 6e853e2a9d
commit 0c4df216d7
8 changed files with 501 additions and 1636 deletions
+1 -1
View File
@@ -21,4 +21,4 @@ def from_model_type(model_type: ModelType, model_name: str, **model_kwargs: Any)
case _:
raise ValueError(f"Unknown model type {model_type}")
raise ValueError(f"Unknown ${model_type} model {model_name}")
raise ValueError(f"Unknown {model_type} model {model_name}")
+3 -17
View File
@@ -1,6 +1,5 @@
from __future__ import annotations
import pickle
from abc import ABC, abstractmethod
from pathlib import Path
from shutil import rmtree
@@ -11,7 +10,6 @@ import onnxruntime as ort
from huggingface_hub import snapshot_download
from onnx.shape_inference import infer_shapes
from onnx.tools.update_model_dims import update_inputs_outputs_dims
from typing_extensions import Buffer
import ann.ann
from app.models.constants import STATIC_INPUT_PROVIDERS, SUPPORTED_PROVIDERS
@@ -200,7 +198,7 @@ class InferenceModel(ABC):
@providers.setter
def providers(self, providers: list[str]) -> None:
log.debug(
log.info(
(f"Setting '{self.model_name}' execution providers to {providers}, " "in descending order of preference"),
)
self._providers = providers
@@ -217,7 +215,7 @@ class InferenceModel(ABC):
@provider_options.setter
def provider_options(self, provider_options: list[dict[str, Any]]) -> None:
log.info(f"Setting execution provider options to {provider_options}")
log.debug(f"Setting execution provider options to {provider_options}")
self._provider_options = provider_options
@property
@@ -255,7 +253,7 @@ class InferenceModel(ABC):
@property
def sess_options_default(self) -> ort.SessionOptions:
sess_options = PicklableSessionOptions()
sess_options = ort.SessionOptions()
sess_options.enable_cpu_mem_arena = False
# avoid thread contention between models
@@ -287,15 +285,3 @@ class InferenceModel(ABC):
@property
def preferred_runtime_default(self) -> ModelRuntime:
return ModelRuntime.ARMNN if ann.ann.is_available and settings.ann else ModelRuntime.ONNX
# HF deep copies configs, so we need to make session options picklable
class PicklableSessionOptions(ort.SessionOptions): # type: ignore[misc]
def __getstate__(self) -> bytes:
return pickle.dumps([(attr, getattr(self, attr)) for attr in dir(self) if not callable(getattr(self, attr))])
def __setstate__(self, state: Buffer) -> None:
self.__init__() # type: ignore[misc]
attrs: list[tuple[str, Any]] = pickle.loads(state)
for attr, val in attrs:
setattr(self, attr, val)
-17
View File
@@ -80,20 +80,3 @@ class RevalidationPlugin(BasePlugin): # type: ignore[misc]
key = client.build_key(key, namespace)
if key in client._handlers:
await client.expire(key, client.ttl)
async def post_multi_get(
self,
client: SimpleMemoryCache,
keys: list[str],
ret: list[Any] | None = None,
namespace: str | None = None,
**kwargs: Any,
) -> None:
if ret is None:
return
for key, val in zip(keys, ret):
if namespace is not None:
key = client.build_key(key, namespace)
if val is not None and key in client._handlers:
await client.expire(key, client.ttl)
+9 -3
View File
@@ -144,9 +144,7 @@ class OpenCLIPEncoder(BaseCLIPEncoder):
def _load(self) -> None:
super()._load()
text_cfg: dict[str, Any] = self.model_cfg["text_cfg"]
context_length: int = text_cfg.get("context_length", 77)
pad_token: int = self.tokenizer_cfg["pad_token"]
self._load_tokenizer()
size: list[int] | int = self.preprocess_cfg["size"]
self.size = size[0] if isinstance(size, list) else size
@@ -155,11 +153,19 @@ class OpenCLIPEncoder(BaseCLIPEncoder):
self.mean = np.array(self.preprocess_cfg["mean"], dtype=np.float32)
self.std = np.array(self.preprocess_cfg["std"], dtype=np.float32)
def _load_tokenizer(self) -> Tokenizer:
log.debug(f"Loading tokenizer for CLIP model '{self.model_name}'")
text_cfg: dict[str, Any] = self.model_cfg["text_cfg"]
context_length: int = text_cfg.get("context_length", 77)
pad_token: str = self.tokenizer_cfg["pad_token"]
self.tokenizer: Tokenizer = Tokenizer.from_file(self.tokenizer_file_path.as_posix())
pad_id: int = self.tokenizer.token_to_id(pad_token)
self.tokenizer.enable_padding(length=context_length, pad_token=pad_token, pad_id=pad_id)
self.tokenizer.enable_truncation(max_length=context_length)
log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'")
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]: