fix(ml): tokenization for webli models (#11881)
This commit is contained in:
@@ -10,6 +10,7 @@ from tokenizers import Encoding, Tokenizer
|
||||
|
||||
from app.config import log
|
||||
from app.models.base import InferenceModel
|
||||
from app.models.transforms import clean_text
|
||||
from app.schemas import ModelSession, ModelTask, ModelType
|
||||
|
||||
|
||||
@@ -25,6 +26,8 @@ class BaseCLIPTextualEncoder(InferenceModel):
|
||||
session = super()._load()
|
||||
log.debug(f"Loading tokenizer for CLIP model '{self.model_name}'")
|
||||
self.tokenizer = self._load_tokenizer()
|
||||
tokenizer_kwargs: dict[str, Any] | None = self.text_cfg.get("tokenizer_kwargs")
|
||||
self.canonicalize = tokenizer_kwargs is not None and tokenizer_kwargs.get("clean") == "canonicalize"
|
||||
log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'")
|
||||
|
||||
return session
|
||||
@@ -56,6 +59,11 @@ class BaseCLIPTextualEncoder(InferenceModel):
|
||||
log.debug(f"Loaded model config for CLIP model '{self.model_name}'")
|
||||
return model_cfg
|
||||
|
||||
@property
|
||||
def text_cfg(self) -> dict[str, Any]:
|
||||
text_cfg: dict[str, Any] = self.model_cfg["text_cfg"]
|
||||
return text_cfg
|
||||
|
||||
@cached_property
|
||||
def tokenizer_file(self) -> dict[str, Any]:
|
||||
log.debug(f"Loading tokenizer file for CLIP model '{self.model_name}'")
|
||||
@@ -73,8 +81,7 @@ class BaseCLIPTextualEncoder(InferenceModel):
|
||||
|
||||
class OpenClipTextualEncoder(BaseCLIPTextualEncoder):
|
||||
def _load_tokenizer(self) -> Tokenizer:
|
||||
text_cfg: dict[str, Any] = self.model_cfg["text_cfg"]
|
||||
context_length: int = text_cfg.get("context_length", 77)
|
||||
context_length: int = self.text_cfg.get("context_length", 77)
|
||||
pad_token: str = self.tokenizer_cfg["pad_token"]
|
||||
|
||||
tokenizer: Tokenizer = Tokenizer.from_file(self.tokenizer_file_path.as_posix())
|
||||
@@ -86,12 +93,14 @@ class OpenClipTextualEncoder(BaseCLIPTextualEncoder):
|
||||
return tokenizer
|
||||
|
||||
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
||||
text = clean_text(text, canonicalize=self.canonicalize)
|
||||
tokens: Encoding = self.tokenizer.encode(text)
|
||||
return {"text": np.array([tokens.ids], dtype=np.int32)}
|
||||
|
||||
|
||||
class MClipTextualEncoder(OpenClipTextualEncoder):
|
||||
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
||||
text = clean_text(text, canonicalize=self.canonicalize)
|
||||
tokens: Encoding = self.tokenizer.encode(text)
|
||||
return {
|
||||
"input_ids": np.array([tokens.ids], dtype=np.int32),
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import string
|
||||
from io import BytesIO
|
||||
from typing import IO
|
||||
|
||||
@@ -7,6 +8,7 @@ from numpy.typing import NDArray
|
||||
from PIL import Image
|
||||
|
||||
_PIL_RESAMPLING_METHODS = {resampling.name.lower(): resampling for resampling in Image.Resampling}
|
||||
_PUNCTUATION_TRANS = str.maketrans("", "", string.punctuation)
|
||||
|
||||
|
||||
def resize_pil(img: Image.Image, size: int) -> Image.Image:
|
||||
@@ -60,3 +62,10 @@ def decode_cv2(image_bytes: NDArray[np.uint8] | bytes | Image.Image) -> NDArray[
|
||||
if isinstance(image_bytes, Image.Image):
|
||||
return pil_to_cv2(image_bytes)
|
||||
return image_bytes
|
||||
|
||||
|
||||
def clean_text(text: str, canonicalize: bool = False) -> str:
|
||||
text = " ".join(text.split())
|
||||
if canonicalize:
|
||||
text = text.translate(_PUNCTUATION_TRANS).lower()
|
||||
return text
|
||||
|
||||
Reference in New Issue
Block a user