This commit is contained in:
mertalev
2024-06-07 00:01:31 -04:00
parent 7e587c2703
commit 259386cf13
3 changed files with 41 additions and 28 deletions
@@ -7,6 +7,7 @@ from insightface.model_zoo import RetinaFace
from numpy.typing import NDArray
from app.models.base import InferenceModel
from app.models.session import ort_has_batch_dim, ort_squeeze_outputs
from app.models.transforms import decode_cv2
from app.schemas import FaceDetectionOutput, ModelSession, ModelTask, ModelType
@@ -27,7 +28,8 @@ class FaceDetector(InferenceModel):
def _load(self) -> ModelSession:
session = self._make_session(self.model_path)
self._squeeze_outputs(session)
if isinstance(session, ort.InferenceSession) and ort_has_batch_dim(session):
ort_squeeze_outputs(session)
self.model = RetinaFace(session=session)
self.model.prepare(ctx_id=0, det_thresh=self.min_score, input_size=(640, 640))
@@ -46,15 +48,5 @@ class FaceDetector(InferenceModel):
def _detect(self, inputs: NDArray[np.uint8] | bytes) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
return self.model.detect(inputs) # type: ignore
def _squeeze_outputs(self, session: ort.InferenceSession) -> None:
original_run = session.run
def run(output_names: list[str], input_feed: dict[str, NDArray[np.float32]]) -> list[NDArray[np.float32]]:
out: list[NDArray[np.float32]] = original_run(output_names, input_feed)
out = [o.squeeze(axis=0) for o in out]
return out
session.run = run
def configure(self, **kwargs: Any) -> None:
self.model.det_thresh = kwargs.pop("minScore", self.model.det_thresh)