optimized scrfd code
This commit is contained in:
@@ -1,44 +1,33 @@
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
from insightface.model_zoo import RetinaFace
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from app.models.base import InferenceModel
|
||||
from app.models.session import ort_has_batch_dim, ort_squeeze_outputs
|
||||
from app.models.transforms import decode_cv2
|
||||
from app.models.session import ort_has_batch_dim, ort_expand_outputs
|
||||
from app.models.transforms import decode_pil
|
||||
from app.schemas import FaceDetectionOutput, ModelSession, ModelTask, ModelType
|
||||
|
||||
from .scrfd import SCRFD
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import pad
|
||||
|
||||
class FaceDetector(InferenceModel):
|
||||
depends = []
|
||||
identity = (ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
min_score: float = 0.7,
|
||||
cache_dir: Path | str | None = None,
|
||||
**model_kwargs: Any,
|
||||
) -> None:
|
||||
self.min_score = model_kwargs.pop("minScore", min_score)
|
||||
super().__init__(model_name, cache_dir, **model_kwargs)
|
||||
|
||||
def _load(self) -> ModelSession:
|
||||
session = self._make_session(self.model_path)
|
||||
if isinstance(session, ort.InferenceSession) and ort_has_batch_dim(session):
|
||||
ort_squeeze_outputs(session)
|
||||
self.model = RetinaFace(session=session)
|
||||
self.model.prepare(ctx_id=0, det_thresh=self.min_score, input_size=(640, 640))
|
||||
if isinstance(session, ort.InferenceSession) and not ort_has_batch_dim(session):
|
||||
ort_expand_outputs(session)
|
||||
self.model = SCRFD(session=session)
|
||||
|
||||
return session
|
||||
|
||||
def _predict(self, inputs: NDArray[np.uint8] | bytes, **kwargs: Any) -> FaceDetectionOutput:
|
||||
inputs = decode_cv2(inputs)
|
||||
def _predict(self, inputs: NDArray[np.uint8] | bytes | Image.Image, **kwargs: Any) -> FaceDetectionOutput:
|
||||
inputs = self._transform(inputs)
|
||||
|
||||
bboxes, landmarks = self._detect(inputs)
|
||||
[bboxes], [landmarks] = self.model.detect(inputs, threshold=kwargs.pop("minScore", 0.7))
|
||||
return {
|
||||
"boxes": bboxes[:, :4].round(),
|
||||
"scores": bboxes[:, 4],
|
||||
@@ -48,5 +37,7 @@ class FaceDetector(InferenceModel):
|
||||
def _detect(self, inputs: NDArray[np.uint8] | bytes) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
|
||||
return self.model.detect(inputs) # type: ignore
|
||||
|
||||
def configure(self, **kwargs: Any) -> None:
|
||||
self.model.det_thresh = kwargs.pop("minScore", self.model.det_thresh)
|
||||
def _transform(self, inputs: NDArray[np.uint8] | bytes | Image.Image) -> NDArray[np.uint8]:
|
||||
image = decode_pil(inputs)
|
||||
padded = pad(image, (640, 640), method=Image.Resampling.BICUBIC)
|
||||
return np.array(padded, dtype=np.uint8)[None, ...]
|
||||
|
||||
Reference in New Issue
Block a user