optimized scrfd code

This commit is contained in:
mertalev
2024-06-09 23:03:34 -04:00
parent fb4fe5d40b
commit 8d2a849edc
6 changed files with 411 additions and 32 deletions
@@ -1,44 +1,33 @@
from pathlib import Path
from typing import Any
import numpy as np
import onnxruntime as ort
from insightface.model_zoo import RetinaFace
from numpy.typing import NDArray
from app.models.base import InferenceModel
from app.models.session import ort_has_batch_dim, ort_squeeze_outputs
from app.models.transforms import decode_cv2
from app.models.session import ort_has_batch_dim, ort_expand_outputs
from app.models.transforms import decode_pil
from app.schemas import FaceDetectionOutput, ModelSession, ModelTask, ModelType
from .scrfd import SCRFD
from PIL import Image
from PIL.ImageOps import pad
class FaceDetector(InferenceModel):
depends = []
identity = (ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION)
def __init__(
self,
model_name: str,
min_score: float = 0.7,
cache_dir: Path | str | None = None,
**model_kwargs: Any,
) -> None:
self.min_score = model_kwargs.pop("minScore", min_score)
super().__init__(model_name, cache_dir, **model_kwargs)
def _load(self) -> ModelSession:
session = self._make_session(self.model_path)
if isinstance(session, ort.InferenceSession) and ort_has_batch_dim(session):
ort_squeeze_outputs(session)
self.model = RetinaFace(session=session)
self.model.prepare(ctx_id=0, det_thresh=self.min_score, input_size=(640, 640))
if isinstance(session, ort.InferenceSession) and not ort_has_batch_dim(session):
ort_expand_outputs(session)
self.model = SCRFD(session=session)
return session
def _predict(self, inputs: NDArray[np.uint8] | bytes, **kwargs: Any) -> FaceDetectionOutput:
inputs = decode_cv2(inputs)
def _predict(self, inputs: NDArray[np.uint8] | bytes | Image.Image, **kwargs: Any) -> FaceDetectionOutput:
inputs = self._transform(inputs)
bboxes, landmarks = self._detect(inputs)
[bboxes], [landmarks] = self.model.detect(inputs, threshold=kwargs.pop("minScore", 0.7))
return {
"boxes": bboxes[:, :4].round(),
"scores": bboxes[:, 4],
@@ -48,5 +37,7 @@ class FaceDetector(InferenceModel):
def _detect(self, inputs: NDArray[np.uint8] | bytes) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
return self.model.detect(inputs) # type: ignore
def configure(self, **kwargs: Any) -> None:
self.model.det_thresh = kwargs.pop("minScore", self.model.det_thresh)
def _transform(self, inputs: NDArray[np.uint8] | bytes | Image.Image) -> NDArray[np.uint8]:
image = decode_pil(inputs)
padded = pad(image, (640, 640), method=Image.Resampling.BICUBIC)
return np.array(padded, dtype=np.uint8)[None, ...]